Spaces:
Sleeping
Sleeping
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| import torch | |
| from transformers import AutoModelForMultipleChoice, AutoTokenizer | |
| import os | |
| from datasets import load_dataset | |
| import random | |
| from typing import Optional, List | |
| import gradio as gr | |
| app = FastAPI() | |
| # Add CORS middleware for Gradio | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # Define input models | |
| class QuestionRequest(BaseModel): | |
| question: str | |
| options: list[str] # List of 4 options | |
| class DatasetQuestion(BaseModel): | |
| question: str | |
| opa: str | |
| opb: str | |
| opc: str | |
| opd: str | |
| cop: Optional[int] = None # Correct option (0-3) | |
| exp: Optional[str] = None # Explanation if available | |
| # Global variables | |
| model = None | |
| tokenizer = None | |
| dataset = None | |
| def load_model(): | |
| global model, tokenizer, dataset | |
| try: | |
| # Load your fine-tuned model and tokenizer | |
| model_name = os.getenv("BioXP-0.5b", "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B") | |
| model = AutoModelForMultipleChoice.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Load MedMCQA dataset | |
| dataset = load_dataset("openlifescienceai/medmcqa") | |
| # Move model to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| except Exception as e: | |
| raise Exception(f"Error loading model: {str(e)}") | |
| def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str): | |
| """Gradio interface prediction function""" | |
| try: | |
| options = [option_a, option_b, option_c, option_d] | |
| inputs = [] | |
| for option in options: | |
| text = f"{question} {option}" | |
| inputs.append(text) | |
| encodings = tokenizer( | |
| inputs, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| device = next(model.parameters()).device | |
| encodings = {k: v.to(device) for k, v in encodings.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encodings) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1)[0].tolist() | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| # Format the output for Gradio | |
| result = f"Predicted Answer: {options[predicted_class]}\n\n" | |
| result += "Confidence Scores:\n" | |
| for i, (opt, prob) in enumerate(zip(options, probabilities)): | |
| result += f"{opt}: {prob:.2%}\n" | |
| return result | |
| except Exception as e: | |
| return f"Error: {str(e)}" | |
| def get_random_question(): | |
| """Get a random question for Gradio interface""" | |
| if dataset is None: | |
| return "Error: Dataset not loaded", "", "", "", "" | |
| index = random.randint(0, len(dataset['train']) - 1) | |
| question_data = dataset['train'][index] | |
| return ( | |
| question_data['question'], | |
| question_data['opa'], | |
| question_data['opb'], | |
| question_data['opc'], | |
| question_data['opd'] | |
| ) | |
| # Create Gradio interface | |
| with gr.Blocks(title="Medical MCQ Predictor") as demo: | |
| gr.Markdown("# Medical MCQ Predictor") | |
| gr.Markdown("Enter a medical question and its options, or get a random question from MedMCQA dataset.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| question = gr.Textbox(label="Question", lines=3) | |
| option_a = gr.Textbox(label="Option A") | |
| option_b = gr.Textbox(label="Option B") | |
| option_c = gr.Textbox(label="Option C") | |
| option_d = gr.Textbox(label="Option D") | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict") | |
| random_btn = gr.Button("Get Random Question") | |
| output = gr.Textbox(label="Prediction", lines=5) | |
| predict_btn.click( | |
| fn=predict_gradio, | |
| inputs=[question, option_a, option_b, option_c, option_d], | |
| outputs=output | |
| ) | |
| random_btn.click( | |
| fn=get_random_question, | |
| inputs=[], | |
| outputs=[question, option_a, option_b, option_c, option_d] | |
| ) | |
| # Mount Gradio app to FastAPI | |
| app = gr.mount_gradio_app(app, demo, path="/") | |
| async def startup_event(): | |
| load_model() | |
| async def get_dataset_question(index: Optional[int] = None, random_question: bool = False): | |
| """Get a question from the MedMCQA dataset""" | |
| try: | |
| if dataset is None: | |
| raise HTTPException(status_code=500, detail="Dataset not loaded") | |
| if random_question: | |
| index = random.randint(0, len(dataset['train']) - 1) | |
| elif index is None: | |
| raise HTTPException(status_code=400, detail="Either index or random_question must be provided") | |
| question_data = dataset['train'][index] | |
| question = DatasetQuestion( | |
| question=question_data['question'], | |
| opa=question_data['opa'], | |
| opb=question_data['opb'], | |
| opc=question_data['opc'], | |
| opd=question_data['opd'], | |
| cop=question_data['cop'] if 'cop' in question_data else None, | |
| exp=question_data['exp'] if 'exp' in question_data else None | |
| ) | |
| return question | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def predict(request: QuestionRequest): | |
| if len(request.options) != 4: | |
| raise HTTPException(status_code=400, detail="Exactly 4 options are required") | |
| try: | |
| inputs = [] | |
| for option in request.options: | |
| text = f"{request.question} {option}" | |
| inputs.append(text) | |
| encodings = tokenizer( | |
| inputs, | |
| padding=True, | |
| truncation=True, | |
| max_length=512, | |
| return_tensors="pt" | |
| ) | |
| device = next(model.parameters()).device | |
| encodings = {k: v.to(device) for k, v in encodings.items()} | |
| with torch.no_grad(): | |
| outputs = model(**encodings) | |
| logits = outputs.logits | |
| probabilities = torch.softmax(logits, dim=1)[0].tolist() | |
| predicted_class = torch.argmax(logits, dim=1).item() | |
| response = { | |
| "predicted_option": request.options[predicted_class], | |
| "option_index": predicted_class, | |
| "confidence": probabilities[predicted_class], | |
| "probabilities": { | |
| f"option_{i}": prob for i, prob in enumerate(probabilities) | |
| } | |
| } | |
| return response | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| async def health_check(): | |
| return { | |
| "status": "healthy", | |
| "model_loaded": model is not None, | |
| "dataset_loaded": dataset is not None | |
| } |