sqb-predict-api / model_architecture.py
Ahmad Hathim bin Ahmad Azman
Add FastAPI backend for SQB prediction
beae064
# ---------- Model ----------
import torch
import torch.nn as nn
from transformers import AutoModel
class EnsembleBertBiLSTMRegressor(nn.Module):
"""
Two BERT encoders (MCQ + Clinical) -> concat token reps -> BiLSTM -> pooled -> 2-dim regression
Outputs: preds[:,0] = difficulty (0..1 typical), preds[:,1] = discrimination (~-1..+1 typical)
"""
def __init__(self, model_name_mcq, model_name_clinical, hidden_dim=768, extra_dim=0, dropout=0.2):
super().__init__()
self.bert_mcq = AutoModel.from_pretrained(model_name_mcq)
self.bert_clin = AutoModel.from_pretrained(model_name_clinical)
emb_dim = self.bert_mcq.config.hidden_size + self.bert_clin.config.hidden_size
self.lstm = nn.LSTM(
input_size=emb_dim,
hidden_size=hidden_dim,
num_layers=1,
bidirectional=True,
batch_first=True,
)
self.dropout = nn.Dropout(dropout)
self.feat_proj = nn.Linear(extra_dim, 64) if extra_dim and extra_dim > 0 else None
in_dim = hidden_dim * 2 + (64 if self.feat_proj is not None else 0)
self.fc_out = nn.Linear(in_dim, 2) # [difficulty, discrimination]
# Initialize final layer with smaller weights for stability
nn.init.xavier_uniform_(self.fc_out.weight, gain=0.1)
nn.init.zeros_(self.fc_out.bias)
def forward(self, ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats=None):
# Ensure masks are 0/1 integers (or bool)
mask_mcq = mask_mcq.long().clamp_(0, 1)
mask_clin = mask_clin.long().clamp_(0, 1)
# BERT forward passes
out_mcq = self.bert_mcq(input_ids=ids_mcq, attention_mask=mask_mcq)
out_clin = self.bert_clin(input_ids=ids_clin, attention_mask=mask_clin)
# Concat along hidden dimension; enforce contiguity before LSTM
seq_out = torch.cat([out_mcq.last_hidden_state, out_clin.last_hidden_state], dim=-1).contiguous()
# BiLSTM processing
lstm_out, _ = self.lstm(seq_out) # [B, T, 2*hidden_dim]
pooled = lstm_out.mean(dim=1) # [B, 2*hidden_dim]
# Optional: Add extra features
if (extra_feats is not None) and (self.feat_proj is not None):
# Ensure extra features are clean
extra_feats = torch.nan_to_num(extra_feats, nan=0.0, posinf=0.0, neginf=0.0)
extra_feats = extra_feats.clamp(-10.0, 10.0)
pooled = torch.cat([pooled, self.feat_proj(extra_feats)], dim=1)
pooled = self.dropout(pooled)
preds = self.fc_out(pooled) # [B, 2]
# Optional: Clip predictions to reasonable ranges for stability
# difficulty: typically 0-1, discrimination: typically -3 to +3
# Uncomment if you want to enforce these constraints:
# preds[:, 0] = preds[:, 0].clamp(0.0, 1.0) # difficulty
# preds[:, 1] = preds[:, 1].clamp(-3.0, 3.0) # discrimination
return preds
def freeze_bert_layers(self, num_layers_to_freeze=None):
"""
Optionally freeze early BERT layers for faster training and reduced memory.
If num_layers_to_freeze is None, freezes only embeddings.
"""
if num_layers_to_freeze is None:
# Freeze only embeddings
for param in self.bert_mcq.embeddings.parameters():
param.requires_grad = False
for param in self.bert_clin.embeddings.parameters():
param.requires_grad = False
else:
# Freeze embeddings + first N encoder layers
for param in self.bert_mcq.embeddings.parameters():
param.requires_grad = False
for param in self.bert_clin.embeddings.parameters():
param.requires_grad = False
for i in range(num_layers_to_freeze):
for param in self.bert_mcq.encoder.layer[i].parameters():
param.requires_grad = False
for param in self.bert_clin.encoder.layer[i].parameters():
param.requires_grad = False
print(f"Frozen BERT layers: {'embeddings only' if num_layers_to_freeze is None else f'embeddings + {num_layers_to_freeze} layers'}")
def unfreeze_all(self):
"""Unfreeze all parameters."""
for param in self.parameters():
param.requires_grad = True
print("All layers unfrozen")