Spaces:
Sleeping
Sleeping
File size: 2,104 Bytes
beae064 b8fb185 beae064 b8fb185 beae064 b8fb185 beae064 b8fb185 beae064 b8fb185 beae064 b8fb185 beae064 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 |
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import joblib
from transformers import AutoTokenizer
from model_inference import load_model, predict_from_input
import os
app = FastAPI(title="Question Difficulty/Discrimination Predictor")
# CORS for frontend usage (Next.js, Streamlit, etc.)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables – will be loaded at startup
model = None
device = None
encoder = None
scaler = None
tok_mcq = None
tok_clin = None
@app.on_event("startup")
def load_all_resources():
"""
✅ Load model + tokenizers + encoders only once at startup.
Avoids slow import times & prevents “Space in Error”.
"""
global model, device, encoder, scaler, tok_mcq, tok_clin
print("🚀 Loading model and dependencies...")
# Load model from local or Hugging Face
model, device = load_model("assets/best_checkpoint_regression.pt")
# Load pretrained scaler + encoder
encoder = joblib.load("assets/onehot_encoder.pkl")
scaler = joblib.load("assets/scaler.pkl")
# Load tokenizers lazily
tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
print("✅ All resources successfully loaded.")
# Input schema
class QuestionInput(BaseModel):
StemText: str
LeadIn: str
OptionA: str
OptionB: str
OptionC: str
OptionD: str
DepartmentName: str
CourseName: str
BloomLevel: str
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(input_data: QuestionInput):
"""
✅ Main prediction endpoint.
"""
if model is None:
return {"error": "Model not loaded. Try again in a few seconds."}
pred = predict_from_input(
input_data.dict(), model, device,
tok_mcq, tok_clin, encoder, scaler
)
return pred
|