Spaces:
Sleeping
Sleeping
| # ---------- Model ---------- | |
| import torch | |
| import torch.nn as nn | |
| from transformers import AutoModel | |
| class EnsembleBertBiLSTMRegressor(nn.Module): | |
| """ | |
| Two BERT encoders (MCQ + Clinical) -> concat token reps -> BiLSTM -> pooled -> 2-dim regression | |
| Outputs: preds[:,0] = difficulty (0..1 typical), preds[:,1] = discrimination (~-1..+1 typical) | |
| """ | |
| def __init__(self, model_name_mcq, model_name_clinical, hidden_dim=768, extra_dim=0, dropout=0.2): | |
| super().__init__() | |
| self.bert_mcq = AutoModel.from_pretrained(model_name_mcq) | |
| self.bert_clin = AutoModel.from_pretrained(model_name_clinical) | |
| emb_dim = self.bert_mcq.config.hidden_size + self.bert_clin.config.hidden_size | |
| self.lstm = nn.LSTM( | |
| input_size=emb_dim, | |
| hidden_size=hidden_dim, | |
| num_layers=1, | |
| bidirectional=True, | |
| batch_first=True, | |
| ) | |
| self.dropout = nn.Dropout(dropout) | |
| self.feat_proj = nn.Linear(extra_dim, 64) if extra_dim and extra_dim > 0 else None | |
| in_dim = hidden_dim * 2 + (64 if self.feat_proj is not None else 0) | |
| self.fc_out = nn.Linear(in_dim, 2) # [difficulty, discrimination] | |
| # Initialize final layer with smaller weights for stability | |
| nn.init.xavier_uniform_(self.fc_out.weight, gain=0.1) | |
| nn.init.zeros_(self.fc_out.bias) | |
| def forward(self, ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats=None): | |
| # Ensure masks are 0/1 integers (or bool) | |
| mask_mcq = mask_mcq.long().clamp_(0, 1) | |
| mask_clin = mask_clin.long().clamp_(0, 1) | |
| # BERT forward passes | |
| out_mcq = self.bert_mcq(input_ids=ids_mcq, attention_mask=mask_mcq) | |
| out_clin = self.bert_clin(input_ids=ids_clin, attention_mask=mask_clin) | |
| # Concat along hidden dimension; enforce contiguity before LSTM | |
| seq_out = torch.cat([out_mcq.last_hidden_state, out_clin.last_hidden_state], dim=-1).contiguous() | |
| # BiLSTM processing | |
| lstm_out, _ = self.lstm(seq_out) # [B, T, 2*hidden_dim] | |
| pooled = lstm_out.mean(dim=1) # [B, 2*hidden_dim] | |
| # Optional: Add extra features | |
| if (extra_feats is not None) and (self.feat_proj is not None): | |
| # Ensure extra features are clean | |
| extra_feats = torch.nan_to_num(extra_feats, nan=0.0, posinf=0.0, neginf=0.0) | |
| extra_feats = extra_feats.clamp(-10.0, 10.0) | |
| pooled = torch.cat([pooled, self.feat_proj(extra_feats)], dim=1) | |
| pooled = self.dropout(pooled) | |
| preds = self.fc_out(pooled) # [B, 2] | |
| # Optional: Clip predictions to reasonable ranges for stability | |
| # difficulty: typically 0-1, discrimination: typically -3 to +3 | |
| # Uncomment if you want to enforce these constraints: | |
| # preds[:, 0] = preds[:, 0].clamp(0.0, 1.0) # difficulty | |
| # preds[:, 1] = preds[:, 1].clamp(-3.0, 3.0) # discrimination | |
| return preds | |
| def freeze_bert_layers(self, num_layers_to_freeze=None): | |
| """ | |
| Optionally freeze early BERT layers for faster training and reduced memory. | |
| If num_layers_to_freeze is None, freezes only embeddings. | |
| """ | |
| if num_layers_to_freeze is None: | |
| # Freeze only embeddings | |
| for param in self.bert_mcq.embeddings.parameters(): | |
| param.requires_grad = False | |
| for param in self.bert_clin.embeddings.parameters(): | |
| param.requires_grad = False | |
| else: | |
| # Freeze embeddings + first N encoder layers | |
| for param in self.bert_mcq.embeddings.parameters(): | |
| param.requires_grad = False | |
| for param in self.bert_clin.embeddings.parameters(): | |
| param.requires_grad = False | |
| for i in range(num_layers_to_freeze): | |
| for param in self.bert_mcq.encoder.layer[i].parameters(): | |
| param.requires_grad = False | |
| for param in self.bert_clin.encoder.layer[i].parameters(): | |
| param.requires_grad = False | |
| print(f"Frozen BERT layers: {'embeddings only' if num_layers_to_freeze is None else f'embeddings + {num_layers_to_freeze} layers'}") | |
| def unfreeze_all(self): | |
| """Unfreeze all parameters.""" | |
| for param in self.parameters(): | |
| param.requires_grad = True | |
| print("All layers unfrozen") | |