import torch import numpy as np import textstat from utils.preprocess import compute_text_features from model_architecture import EnsembleBertBiLSTMRegressor from huggingface_hub import hf_hub_download import os HF_REPO = "hathimazman/sqb-predict" MODEL_CACHE = "/tmp/models" # ✅ Writable on Hugging Face Spaces os.makedirs(MODEL_CACHE, exist_ok=True) def download_from_hf(filename: str): local_path = os.path.join(MODEL_CACHE, filename) if not os.path.exists(local_path): print(f"⬇ Downloading {filename} from {HF_REPO}") hf_hub_download(repo_id=HF_REPO, filename=filename, local_dir=MODEL_CACHE) return local_path def load_model(): model_path = download_from_hf("best_checkpoint_regression.pt") checkpoint = torch.load(model_path, map_location="cpu") model = EnsembleBertBiLSTMRegressor( model_name_mcq="microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", model_name_clinical="emilyalsentzer/Bio_ClinicalBERT", hidden_dim=768, extra_dim=67, ) model.load_state_dict(checkpoint["model_state"]) model.eval() device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) return model, device def predict_from_input(data, model, device, tok_mcq, tok_clin, encoder, scaler): """ Predict difficulty and discrimination index for a single MCQ item. Combines text, engineered numeric features, and one-hot categorical features. """ # Combine question text text = " ".join([ data["StemText"], data["LeadIn"], data["OptionA"], data["OptionB"], data["OptionC"], data["OptionD"] ]) # Compute text-derived numeric features (8 total) features = compute_text_features( data["StemText"], data["LeadIn"], [data["OptionA"], data["OptionB"], data["OptionC"], data["OptionD"]] ) # Encode categorical features safely known = encoder.categories_ fields = ["DepartmentName", "CourseName", "BloomLevel"] cat_data = [[ data[f] if data[f] in known[i] else "Other" for i, f in enumerate(fields) ]] cat_enc = encoder.transform(cat_data) num_feats_scaled = scaler.transform(features) # scaler expects only 8 numeric features feats = np.hstack([num_feats_scaled, cat_enc]) # combine scaled numeric + one-hot encoded extra_feats = torch.tensor(feats, dtype=torch.float32).to(device) # Tokenize text using both models enc_mcq = tok_mcq(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") enc_clin = tok_clin(text, truncation=True, padding="max_length", max_length=512, return_tensors="pt") ids_mcq = enc_mcq["input_ids"].to(device) mask_mcq = enc_mcq["attention_mask"].to(device) ids_clin = enc_clin["input_ids"].to(device) mask_clin = enc_clin["attention_mask"].to(device) # Forward pass through the model with torch.no_grad(): preds = model(ids_mcq, mask_mcq, ids_clin, mask_clin, extra_feats) # Extract predicted values diff, disc = preds.squeeze().tolist() return { "difficulty": round(float(diff), 3), "discrimination": round(float(disc), 3) }