File size: 2,104 Bytes
beae064
 
 
 
 
 
 
b8fb185
beae064
 
 
b8fb185
beae064
 
b8fb185
beae064
 
 
 
 
b8fb185
 
 
 
 
 
 
beae064
b8fb185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
beae064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b8fb185
 
 
 
 
 
beae064
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel
import torch
import joblib
from transformers import AutoTokenizer
from model_inference import load_model, predict_from_input
import os

app = FastAPI(title="Question Difficulty/Discrimination Predictor")

# CORS for frontend usage (Next.js, Streamlit, etc.)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

# Global variables – will be loaded at startup
model = None
device = None
encoder = None
scaler = None
tok_mcq = None
tok_clin = None

@app.on_event("startup")
def load_all_resources():
    """
    ✅ Load model + tokenizers + encoders only once at startup.
    Avoids slow import times & prevents “Space in Error”.
    """
    global model, device, encoder, scaler, tok_mcq, tok_clin

    print("🚀 Loading model and dependencies...")

    # Load model from local or Hugging Face
    model, device = load_model("assets/best_checkpoint_regression.pt")

    # Load pretrained scaler + encoder
    encoder = joblib.load("assets/onehot_encoder.pkl")
    scaler = joblib.load("assets/scaler.pkl")

    # Load tokenizers lazily
    tok_mcq = AutoTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
    tok_clin = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")

    print("✅ All resources successfully loaded.")

# Input schema
class QuestionInput(BaseModel):
    StemText: str
    LeadIn: str
    OptionA: str
    OptionB: str
    OptionC: str
    OptionD: str
    DepartmentName: str
    CourseName: str
    BloomLevel: str

@app.get("/health")
def health():
    return {"status": "ok"}

@app.post("/predict")
def predict(input_data: QuestionInput):
    """
    ✅ Main prediction endpoint.
    """
    if model is None:
        return {"error": "Model not loaded. Try again in a few seconds."}

    pred = predict_from_input(
        input_data.dict(), model, device,
        tok_mcq, tok_clin, encoder, scaler
    )
    return pred