Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| import random | |
| import re | |
| SYSTEM_PROMPT = """ | |
| You are a medical expert. Answer the medical question with careful analysis and explain why the selected option is correct in 2 sentences without repeating. | |
| Respond in the following format: | |
| <answer> | |
| [correct answer] | |
| </answer> | |
| <reasoning> | |
| [explain why the selected option is correct] | |
| </reasoning> | |
| """ | |
| model_name = "abaryan/BioXP-0.5B-MedMCQA" | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| dataset = load_dataset("openlifescienceai/medmcqa") | |
| # Move model to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| def get_random_question(): | |
| """Get a random question from the dataset""" | |
| index = random.randint(0, len(dataset['validation']) - 1) | |
| question_data = dataset['validation'][index] | |
| return ( | |
| question_data['question'], | |
| question_data['opa'], | |
| question_data['opb'], | |
| question_data['opc'], | |
| question_data['opd'], | |
| question_data.get('cop', None), # Correct option (0-3) | |
| question_data.get('exp', None) # Explanation | |
| ) | |
| def predict(question: str, option_a: str = "", option_b: str = "", option_c: str = "", option_d: str = "", | |
| correct_option: int = None, explanation: str = None, | |
| temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256): | |
| # Determine if this is an MCQ by checking if any option is provided | |
| is_mcq = any(opt.strip() for opt in [option_a, option_b, option_c, option_d]) | |
| if is_mcq: | |
| options = [] | |
| if option_a.strip(): options.append(f"A. {option_a}") | |
| if option_b.strip(): options.append(f"B. {option_b}") | |
| if option_c.strip(): options.append(f"C. {option_c}") | |
| if option_d.strip(): options.append(f"D. {option_d}") | |
| formatted_question = f"Question: {question}\n\nOptions:\n" + "\n".join(options) | |
| system_prompt = SYSTEM_PROMPT | |
| else: | |
| # Format regular question | |
| formatted_question = f"Question: {question}" | |
| system_prompt = SYSTEM_PROMPT | |
| prompt = [ | |
| {'role': 'system', 'content': system_prompt}, | |
| {'role': 'user', 'content': formatted_question} | |
| ] | |
| text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ) | |
| generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:] | |
| model_response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Clean up the response by removing tags and formatting | |
| cleaned_response = model_response | |
| cleaned_response = re.sub(r'<answer>\s*([A-D])\s*</answer>', r'Answer: \1', cleaned_response, flags=re.IGNORECASE) | |
| cleaned_response = re.sub(r'<reasoning>\s*(.*?)\s*</reasoning>', r'Reasoning:\n\1', cleaned_response, flags=re.IGNORECASE | re.DOTALL) | |
| # Format output with evaluation if available (only for MCQs) | |
| output = cleaned_response | |
| # if is_mcq and correct_option is not None: | |
| # correct_letter = chr(65 + correct_option) | |
| # answer_match = re.search(r"Answer:\s*([A-D])", cleaned_response, re.IGNORECASE) | |
| # model_answer = answer_match.group(1).upper() if answer_match else "Not found" | |
| # is_correct = model_answer == correct_letter | |
| # output += f"\n\n---\nEvaluation:\n" | |
| # output += f"Correct Answer: {correct_letter}\n" | |
| # output += f"Model's Answer: {model_answer}\n" | |
| # output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" | |
| # if explanation: | |
| # output += f"\nExpert Explanation:\n{explanation}" | |
| return output | |
| with gr.Blocks( | |
| title="BioXP Medical MCQ Assistant", | |
| theme=gr.themes.Soft( | |
| primary_hue="blue", | |
| secondary_hue="blue", | |
| neutral_hue="slate", | |
| radius_size="md", | |
| font=["Inter", "ui-sans-serif", "system-ui", "sans-serif"], | |
| ) | |
| ) as demo: | |
| gr.Markdown(""" | |
| # BioXP Medical MCQ Assistant | |
| A specialized AI assistant for medical multiple-choice questions. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| question = gr.Textbox( | |
| label="Medical Question", | |
| placeholder="Enter your medical question here...", | |
| lines=3, | |
| interactive=True, | |
| elem_classes=["mobile-input"] | |
| ) | |
| with gr.Accordion("Options", open=True): | |
| option_a = gr.Textbox( | |
| label="Option A", | |
| placeholder="Enter option A...", | |
| interactive=True, | |
| elem_classes=["mobile-input"] | |
| ) | |
| option_b = gr.Textbox( | |
| label="Option B", | |
| placeholder="Enter option B...", | |
| interactive=True, | |
| elem_classes=["mobile-input"] | |
| ) | |
| option_c = gr.Textbox( | |
| label="Option C", | |
| placeholder="Enter option C...", | |
| interactive=True, | |
| elem_classes=["mobile-input"] | |
| ) | |
| option_d = gr.Textbox( | |
| label="Option D", | |
| placeholder="Enter option D...", | |
| interactive=True, | |
| elem_classes=["mobile-input"] | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.6, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher = more creative, Lower = more focused" | |
| ) | |
| with gr.Column(scale=1): | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top P", | |
| info="Controls response diversity" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=512, | |
| value=256, | |
| step=32, | |
| label="Max Response Length", | |
| info="Maximum length of the response" | |
| ) | |
| # Hidden fields | |
| correct_option = gr.Number(visible=False) | |
| expert_explanation = gr.Textbox(visible=False) | |
| with gr.Row(): | |
| predict_btn = gr.Button("Get Answer", variant="primary", size="lg", elem_classes=["mobile-button"]) | |
| random_btn = gr.Button("Random Question", variant="secondary", size="lg", elem_classes=["mobile-button"]) | |
| with gr.Column(scale=1): | |
| output = gr.Textbox( | |
| label="Model's Response", | |
| lines=12, | |
| elem_classes=["response-box", "mobile-output"] | |
| ) | |
| # Set up button actions | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[ | |
| question, option_a, option_b, option_c, option_d, | |
| correct_option, expert_explanation, | |
| temperature, top_p, max_tokens | |
| ], | |
| outputs=output | |
| ) | |
| random_btn.click( | |
| fn=get_random_question, | |
| inputs=[], | |
| outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] | |
| ) | |
| gr.HTML(""" | |
| <style> | |
| .container { | |
| max-width: 100%; | |
| padding: 0.5rem; | |
| } | |
| /* Input styling */ | |
| .mobile-input textarea { | |
| font-size: 1rem; | |
| padding: 0.75rem; | |
| border-radius: 0.5rem; | |
| min-height: 2.5rem; | |
| } | |
| /* Button styling */ | |
| .mobile-button { | |
| width: 100%; | |
| margin: 0.5rem 0; | |
| padding: 0.75rem; | |
| font-size: 1rem; | |
| font-weight: 500; | |
| } | |
| .response-box { | |
| font-family: 'Inter', sans-serif; | |
| line-height: 1.6; | |
| } | |
| .response-box textarea { | |
| font-size: 1rem; | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| } | |
| /* Mobile-specific adjustments */ | |
| @media (max-width: 768px) { | |
| .gr-form { | |
| padding: 0.75rem; | |
| } | |
| .gr-box { | |
| margin: 0.5rem 0; | |
| } | |
| .gr-button { | |
| min-height: 2.5rem; | |
| } | |
| .gr-accordion { | |
| margin: 0.5rem 0; | |
| } | |
| .gr-input { | |
| margin-bottom: 0.5rem; | |
| } | |
| } | |
| /* Dark mode support */ | |
| @media (prefers-color-scheme: dark) { | |
| .gr-box { | |
| background-color: #1a1a1a; | |
| } | |
| .mobile-input textarea, | |
| .response-box textarea { | |
| background-color: #2a2a2a; | |
| color: #ffffff; | |
| } | |
| } | |
| </style> | |
| """) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch(share=False) |