sqb-predict-api / app.py
Ahmad Hathim bin Ahmad Azman
Fixed model loading
f36bbe6
raw
history blame
1.61 kB
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import joblib
import os
import nltk
from transformers import AutoTokenizer
from model_inference import load_model, predict_from_input, download_from_hf
# βœ… Redirect HuggingFace & NLTK cache
os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache"
os.environ["HF_HOME"] = "/tmp/hf_cache"
nltk_data_dir = "/tmp/nltk_data"
os.makedirs(nltk_data_dir, exist_ok=True)
nltk.data.path.append(nltk_data_dir)
os.environ["NLTK_DATA"] = nltk_data_dir
app = FastAPI(title="Question Difficulty/Discrimination Predictor")
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
# βœ… Load model on startup
model, device = load_model()
encoder_path = download_from_hf("onehot_encoder.pkl")
scaler_path = download_from_hf("scaler.pkl")
encoder = joblib.load(encoder_path)
scaler = joblib.load(scaler_path)
tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
class QuestionInput(BaseModel):
StemText: str
LeadIn: str
OptionA: str
OptionB: str
OptionC: str
OptionD: str
DepartmentName: str
CourseName: str
BloomLevel: str
@app.get("/health")
def health():
return {"status": "ok"}
@app.post("/predict")
def predict(input_data: QuestionInput):
pred = predict_from_input(
input_data.dict(), model, device,
tok_mcq, tok_clin, encoder, scaler
)
return pred