File size: 4,479 Bytes
beae064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
# ---------- Model ----------
import torch
import torch.nn as nn
from transformers import AutoModel


class EnsembleBertBiLSTMRegressor(nn.Module):
    """
    Two BERT encoders (MCQ + Clinical) -> concat token reps -> BiLSTM -> pooled -> 2-dim regression
    Outputs: preds[:,0] = difficulty (0..1 typical), preds[:,1] = discrimination (~-1..+1 typical)
    """
    def __init__(self, model_name_mcq, model_name_clinical, hidden_dim=768, extra_dim=0, dropout=0.2):
        super().__init__()
        self.bert_mcq  = AutoModel.from_pretrained(model_name_mcq)
        self.bert_clin = AutoModel.from_pretrained(model_name_clinical)

        emb_dim = self.bert_mcq.config.hidden_size + self.bert_clin.config.hidden_size

        self.lstm = nn.LSTM(
            input_size=emb_dim,
            hidden_size=hidden_dim,
            num_layers=1,
            bidirectional=True,
            batch_first=True,
        )
        self.dropout   = nn.Dropout(dropout)
        self.feat_proj = nn.Linear(extra_dim, 64) if extra_dim and extra_dim > 0 else None

        in_dim = hidden_dim * 2 + (64 if self.feat_proj is not None else 0)
        self.fc_out = nn.Linear(in_dim, 2)  # [difficulty, discrimination]
        
        # Initialize final layer with smaller weights for stability
        nn.init.xavier_uniform_(self.fc_out.weight, gain=0.1)
        nn.init.zeros_(self.fc_out.bias)

    def forward(self, ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats=None):
        # Ensure masks are 0/1 integers (or bool)
        mask_mcq  = mask_mcq.long().clamp_(0, 1)
        mask_clin = mask_clin.long().clamp_(0, 1)

        # BERT forward passes
        out_mcq  = self.bert_mcq(input_ids=ids_mcq,  attention_mask=mask_mcq)
        out_clin = self.bert_clin(input_ids=ids_clin, attention_mask=mask_clin)

        # Concat along hidden dimension; enforce contiguity before LSTM
        seq_out = torch.cat([out_mcq.last_hidden_state, out_clin.last_hidden_state], dim=-1).contiguous()
        
        # BiLSTM processing
        lstm_out, _ = self.lstm(seq_out)        # [B, T, 2*hidden_dim]
        pooled = lstm_out.mean(dim=1)           # [B, 2*hidden_dim]

        # Optional: Add extra features
        if (extra_feats is not None) and (self.feat_proj is not None):
            # Ensure extra features are clean
            extra_feats = torch.nan_to_num(extra_feats, nan=0.0, posinf=0.0, neginf=0.0)
            extra_feats = extra_feats.clamp(-10.0, 10.0)
            pooled = torch.cat([pooled, self.feat_proj(extra_feats)], dim=1)

        pooled = self.dropout(pooled)
        preds = self.fc_out(pooled)              # [B, 2]
        
        # Optional: Clip predictions to reasonable ranges for stability
        # difficulty: typically 0-1, discrimination: typically -3 to +3
        # Uncomment if you want to enforce these constraints:
        # preds[:, 0] = preds[:, 0].clamp(0.0, 1.0)  # difficulty
        # preds[:, 1] = preds[:, 1].clamp(-3.0, 3.0)  # discrimination
        
        return preds

    def freeze_bert_layers(self, num_layers_to_freeze=None):
        """
        Optionally freeze early BERT layers for faster training and reduced memory.
        If num_layers_to_freeze is None, freezes only embeddings.
        """
        if num_layers_to_freeze is None:
            # Freeze only embeddings
            for param in self.bert_mcq.embeddings.parameters():
                param.requires_grad = False
            for param in self.bert_clin.embeddings.parameters():
                param.requires_grad = False
        else:
            # Freeze embeddings + first N encoder layers
            for param in self.bert_mcq.embeddings.parameters():
                param.requires_grad = False
            for param in self.bert_clin.embeddings.parameters():
                param.requires_grad = False
            
            for i in range(num_layers_to_freeze):
                for param in self.bert_mcq.encoder.layer[i].parameters():
                    param.requires_grad = False
                for param in self.bert_clin.encoder.layer[i].parameters():
                    param.requires_grad = False
        
        print(f"Frozen BERT layers: {'embeddings only' if num_layers_to_freeze is None else f'embeddings + {num_layers_to_freeze} layers'}")

    def unfreeze_all(self):
        """Unfreeze all parameters."""
        for param in self.parameters():
            param.requires_grad = True
        print("All layers unfrozen")