from fastapi import FastAPI from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel import torch import joblib import os import nltk from transformers import AutoTokenizer from model_inference import load_model, predict_from_input, download_from_hf # ✅ Redirect HuggingFace & NLTK cache os.environ["TRANSFORMERS_CACHE"] = "/tmp/hf_cache" os.environ["HF_HOME"] = "/tmp/hf_cache" nltk_data_dir = "/tmp/nltk_data" os.makedirs(nltk_data_dir, exist_ok=True) nltk.data.path.append(nltk_data_dir) os.environ["NLTK_DATA"] = nltk_data_dir app = FastAPI(title="Question Difficulty/Discrimination Predictor") app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"], ) # ✅ Load model on startup model, device = load_model() encoder_path = download_from_hf("onehot_encoder.pkl") scaler_path = download_from_hf("scaler.pkl") encoder = joblib.load(encoder_path) scaler = joblib.load(scaler_path) tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract") tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT") class QuestionInput(BaseModel): StemText: str LeadIn: str OptionA: str OptionB: str OptionC: str OptionD: str DepartmentName: str CourseName: str BloomLevel: str @app.get("/health") def health(): return {"status": "ok"} @app.post("/predict") def predict(input_data: QuestionInput): pred = predict_from_input( input_data.dict(), model, device, tok_mcq, tok_clin, encoder, scaler ) return pred