File size: 3,247 Bytes
beae064
 
 
 
 
aa7ac47
f3ce8a7
aa7ac47
 
ee19cd7
aa7ac47
ee19cd7
aa7ac47
ee19cd7
 
 
 
 
 
beae064
d25635d
ee19cd7
8438377
ee19cd7
8438377
beae064
 
 
 
ee19cd7
beae064
d25635d
 
beae064
 
 
 
 
 
8438377
beae064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
import torch
import numpy as np
import textstat
from utils.preprocess import compute_text_features
from model_architecture import EnsembleBertBiLSTMRegressor 
from huggingface_hub import hf_hub_download
import os

HF_REPO = "hathimazman/sqb-predict"
MODEL_CACHE = "/tmp/models"         # ✅ Writable on Hugging Face Spaces

os.makedirs(MODEL_CACHE, exist_ok=True)

def download_from_hf(filename: str):
    local_path = os.path.join(MODEL_CACHE, filename)
    if not os.path.exists(local_path):
        print(f"⬇ Downloading {filename} from {HF_REPO}")
        hf_hub_download(repo_id=HF_REPO, filename=filename, local_dir=MODEL_CACHE)
    return local_path

def load_model():
    model_path = download_from_hf("best_checkpoint_regression.pt")

    checkpoint = torch.load(model_path, map_location="cpu")

    model = EnsembleBertBiLSTMRegressor(
        model_name_mcq="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract",
        model_name_clinical="emilyalsentzer/Bio_ClinicalBERT",
        hidden_dim=768,
        extra_dim=67,
    )
    model.load_state_dict(checkpoint["model_state"])
    model.eval()

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    return model, device



def predict_from_input(data, model, device, tok_mcq, tok_clin, encoder, scaler):
    """
    Predict difficulty and discrimination index for a single MCQ item.
    Combines text, engineered numeric features, and one-hot categorical features.
    """

    # Combine question text
    text = " ".join([
        data["StemText"],
        data["LeadIn"],
        data["OptionA"],
        data["OptionB"],
        data["OptionC"],
        data["OptionD"]
    ])

    # Compute text-derived numeric features (8 total)
    features = compute_text_features(
        data["StemText"],
        data["LeadIn"],
        [data["OptionA"], data["OptionB"], data["OptionC"], data["OptionD"]]
    )

    # Encode categorical features safely
    known = encoder.categories_
    fields = ["DepartmentName", "CourseName", "BloomLevel"]

    cat_data = [[
        data[f] if data[f] in known[i] else "Other"
        for i, f in enumerate(fields)
    ]]
    cat_enc = encoder.transform(cat_data)

    num_feats_scaled = scaler.transform(features)  # scaler expects only 8 numeric features
    feats = np.hstack([num_feats_scaled, cat_enc])  # combine scaled numeric + one-hot encoded
    extra_feats = torch.tensor(feats, dtype=torch.float32).to(device)

    # Tokenize text using both models
    enc_mcq = tok_mcq(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")
    enc_clin = tok_clin(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt")

    ids_mcq = enc_mcq["input_ids"].to(device)
    mask_mcq = enc_mcq["attention_mask"].to(device)
    ids_clin = enc_clin["input_ids"].to(device)
    mask_clin = enc_clin["attention_mask"].to(device)

    # Forward pass through the model
    with torch.no_grad():
        preds = model(ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats)

    # Extract predicted values
    diff, disc = preds.squeeze().tolist()

    return {
        "difficulty": round(float(diff), 3),
        "discrimination": round(float(disc), 3)
    }