sqb-predict-api / app.py
Ahmad Hathim bin Ahmad Azman
change setting
b8fb185
raw
history blame
2.1 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import joblib
from transformers import AutoTokenizer
from model_inference import load_model, predict_from_input
import os
app = FastAPI(title="Question Difficulty/Discrimination Predictor")
# CORS for frontend usage (Next.js, Streamlit, etc.)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Global variables – will be loaded at startup
model = None
device = None
encoder = None
scaler = None
tok_mcq = None
tok_clin = None
@app.on_event("startup")
def load_all_resources():
"""
✅ Load model + tokenizers + encoders only once at startup.
Avoids slow import times & prevents “Space in Error”.
"""
global model, device, encoder, scaler, tok_mcq, tok_clin
print("🚀 Loading model and dependencies...")
# Load model from local or Hugging Face
model, device = load_model("assets/best_checkpoint_regression.pt")
# Load pretrained scaler + encoder
encoder = joblib.load("assets/onehot_encoder.pkl")
scaler = joblib.load("assets/scaler.pkl")
# Load tokenizers lazily
tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
print("✅ All resources successfully loaded.")
# Input schema
class QuestionInput(BaseModel):
StemText: str
LeadIn: str
OptionA: str
OptionB: str
OptionC: str
OptionD: str
DepartmentName: str
CourseName: str
BloomLevel: str
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(input_data: QuestionInput):
"""
✅ Main prediction endpoint.
"""
if model is None:
return {"error": "Model not loaded. Try again in a few seconds."}
pred = predict_from_input(
input_data.dict(), model, device,
tok_mcq, tok_clin, encoder, scaler
)
return pred