Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| import joblib | |
| from transformers import AutoTokenizer | |
| from model_inference import load_model, predict_from_input | |
| import os | |
| app = FastAPI(title="Question Difficulty/Discrimination Predictor") | |
| # CORS for frontend usage (Next.js, Streamlit, etc.) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Global variables – will be loaded at startup | |
| model = None | |
| device = None | |
| encoder = None | |
| scaler = None | |
| tok_mcq = None | |
| tok_clin = None | |
| def load_all_resources(): | |
| """ | |
| ✅ Load model + tokenizers + encoders only once at startup. | |
| Avoids slow import times & prevents “Space in Error”. | |
| """ | |
| global model, device, encoder, scaler, tok_mcq, tok_clin | |
| print("🚀 Loading model and dependencies...") | |
| # Load model from local or Hugging Face | |
| model, device = load_model("assets/best_checkpoint_regression.pt") | |
| # Load pretrained scaler + encoder | |
| encoder = joblib.load("assets/onehot_encoder.pkl") | |
| scaler = joblib.load("assets/scaler.pkl") | |
| # Load tokenizers lazily | |
| tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") | |
| tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") | |
| print("✅ All resources successfully loaded.") | |
| # Input schema | |
| class QuestionInput(BaseModel): | |
| StemText: str | |
| LeadIn: str | |
| OptionA: str | |
| OptionB: str | |
| OptionC: str | |
| OptionD: str | |
| DepartmentName: str | |
| CourseName: str | |
| BloomLevel: str | |
| def health(): | |
| return {"status": "ok"} | |
| def predict(input_data: QuestionInput): | |
| """ | |
| ✅ Main prediction endpoint. | |
| """ | |
| if model is None: | |
| return {"error": "Model not loaded. Try again in a few seconds."} | |
| pred = predict_from_input( | |
| input_data.dict(), model, device, | |
| tok_mcq, tok_clin, encoder, scaler | |
| ) | |
| return pred | |