# ---------- Model ---------- import torch import torch.nn as nn from transformers import AutoModel class EnsembleBertBiLSTMRegressor(nn.Module): """ Two BERT encoders (MCQ + Clinical) -> concat token reps -> BiLSTM -> pooled -> 2-dim regression Outputs: preds[:,0] = difficulty (0..1 typical), preds[:,1] = discrimination (~-1..+1 typical) """ def __init__(self, model_name_mcq, model_name_clinical, hidden_dim=768, extra_dim=0, dropout=0.2): super().__init__() self.bert_mcq = AutoModel.from_pretrained(model_name_mcq) self.bert_clin = AutoModel.from_pretrained(model_name_clinical) emb_dim = self.bert_mcq.config.hidden_size + self.bert_clin.config.hidden_size self.lstm = nn.LSTM( input_size=emb_dim, hidden_size=hidden_dim, num_layers=1, bidirectional=True, batch_first=True, ) self.dropout = nn.Dropout(dropout) self.feat_proj = nn.Linear(extra_dim, 64) if extra_dim and extra_dim > 0 else None in_dim = hidden_dim * 2 + (64 if self.feat_proj is not None else 0) self.fc_out = nn.Linear(in_dim, 2) # [difficulty, discrimination] # Initialize final layer with smaller weights for stability nn.init.xavier_uniform_(self.fc_out.weight, gain=0.1) nn.init.zeros_(self.fc_out.bias) def forward(self, ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats=None): # Ensure masks are 0/1 integers (or bool) mask_mcq = mask_mcq.long().clamp_(0, 1) mask_clin = mask_clin.long().clamp_(0, 1) # BERT forward passes out_mcq = self.bert_mcq(input_ids=ids_mcq, attention_mask=mask_mcq) out_clin = self.bert_clin(input_ids=ids_clin, attention_mask=mask_clin) # Concat along hidden dimension; enforce contiguity before LSTM seq_out = torch.cat([out_mcq.last_hidden_state, out_clin.last_hidden_state], dim=-1).contiguous() # BiLSTM processing lstm_out, _ = self.lstm(seq_out) # [B, T, 2*hidden_dim] pooled = lstm_out.mean(dim=1) # [B, 2*hidden_dim] # Optional: Add extra features if (extra_feats is not None) and (self.feat_proj is not None): # Ensure extra features are clean extra_feats = torch.nan_to_num(extra_feats, nan=0.0, posinf=0.0, neginf=0.0) extra_feats = extra_feats.clamp(-10.0, 10.0) pooled = torch.cat([pooled, self.feat_proj(extra_feats)], dim=1) pooled = self.dropout(pooled) preds = self.fc_out(pooled) # [B, 2] # Optional: Clip predictions to reasonable ranges for stability # difficulty: typically 0-1, discrimination: typically -3 to +3 # Uncomment if you want to enforce these constraints: # preds[:, 0] = preds[:, 0].clamp(0.0, 1.0) # difficulty # preds[:, 1] = preds[:, 1].clamp(-3.0, 3.0) # discrimination return preds def freeze_bert_layers(self, num_layers_to_freeze=None): """ Optionally freeze early BERT layers for faster training and reduced memory. If num_layers_to_freeze is None, freezes only embeddings. """ if num_layers_to_freeze is None: # Freeze only embeddings for param in self.bert_mcq.embeddings.parameters(): param.requires_grad = False for param in self.bert_clin.embeddings.parameters(): param.requires_grad = False else: # Freeze embeddings + first N encoder layers for param in self.bert_mcq.embeddings.parameters(): param.requires_grad = False for param in self.bert_clin.embeddings.parameters(): param.requires_grad = False for i in range(num_layers_to_freeze): for param in self.bert_mcq.encoder.layer[i].parameters(): param.requires_grad = False for param in self.bert_clin.encoder.layer[i].parameters(): param.requires_grad = False print(f"Frozen BERT layers: {'embeddings only' if num_layers_to_freeze is None else f'embeddings + {num_layers_to_freeze} layers'}") def unfreeze_all(self): """Unfreeze all parameters.""" for param in self.parameters(): param.requires_grad = True print("All layers unfrozen")