import torch import numpy as np import textstat from utils.preprocess import compute_text_features from model_architecture import EnsembleBertBiLSTMRegressor def load_model(path): checkpoint = torch.load(path, map_location="cpu") # Recreate the same model architecture model = EnsembleBertBiLSTMRegressor( model_name_mcq="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", model_name_clinical="emilyalsentzer/Bio_ClinicalBERT", hidden_dim=768, extra_dim=67 # e.g. 10 if you have 10 engineered + categorical features ) # Load saved weights model.load_state_dict(checkpoint["model_state"]) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return model, device def predict_from_input(data, model, device, tok_mcq, tok_clin, encoder, scaler): """ Predict difficulty and discrimination index for a single MCQ item. Combines text, engineered numeric features, and one-hot categorical features. """ # Combine question text text = " ".join([ data["StemText"], data["LeadIn"], data["OptionA"], data["OptionB"], data["OptionC"], data["OptionD"] ]) # Compute text-derived numeric features (8 total) features = compute_text_features( data["StemText"], data["LeadIn"], [data["OptionA"], data["OptionB"], data["OptionC"], data["OptionD"]] ) # Encode categorical features safely known = encoder.categories_ fields = ["DepartmentName", "CourseName", "BloomLevel"] cat_data = [[ data[f] if data[f] in known[i] else "Other" for i, f in enumerate(fields) ]] cat_enc = encoder.transform(cat_data) num_feats_scaled = scaler.transform(features) # scaler expects only 8 numeric features feats = np.hstack([num_feats_scaled, cat_enc]) # combine scaled numeric + one-hot encoded extra_feats = torch.tensor(feats, dtype=torch.float32).to(device) # Tokenize text using both models enc_mcq = tok_mcq(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") enc_clin = tok_clin(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") ids_mcq = enc_mcq["input_ids"].to(device) mask_mcq = enc_mcq["attention_mask"].to(device) ids_clin = enc_clin["input_ids"].to(device) mask_clin = enc_clin["attention_mask"].to(device) # Forward pass through the model with torch.no_grad(): preds = model(ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats) # Extract predicted values diff, disc = preds.squeeze().tolist() return { "difficulty": round(float(diff), 3), "discrimination": round(float(disc), 3) }