Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from datasets import load_dataset | |
| import random | |
| import re | |
| # Load model and tokenizer | |
| # model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B" | |
| model_name = "rgb2gbr/BioXP-0.5B-MedMCQA" | |
| SYSTEM_PROMPT = """ | |
| You're a medical expert. Answer the question with careful analysis and explain why the selected option is correct in 150 words without reapeating. | |
| Respond in the following format: | |
| <answer> | |
| [correct answer] | |
| </answer> | |
| <reasoning> | |
| [explain why the selected option is correct] | |
| </reasoning> | |
| """ | |
| model = AutoModelForCausalLM.from_pretrained(model_name) | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Load dataset | |
| dataset = load_dataset("openlifescienceai/medmcqa") | |
| # Move model to GPU if available | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = model.to(device) | |
| model.eval() | |
| def get_random_question(): | |
| """Get a random question from the dataset""" | |
| index = random.randint(0, len(dataset['validation']) - 1) | |
| question_data = dataset['validation'][index] | |
| return ( | |
| question_data['question'], | |
| question_data['opa'], | |
| question_data['opb'], | |
| question_data['opc'], | |
| question_data['opd'], | |
| question_data.get('cop', None), # Correct option (0-3) | |
| question_data.get('exp', None) # Explanation | |
| ) | |
| def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str, | |
| correct_option: int = None, explanation: str = None, | |
| temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256): | |
| # Format the question with options | |
| formatted_question = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}" | |
| # Create chat-style prompt | |
| prompt = [ | |
| {'role': 'system', 'content': SYSTEM_PROMPT}, | |
| {'role': 'user', 'content': formatted_question} | |
| ] | |
| # Use apply_chat_template for better formatting | |
| text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True) | |
| # Tokenize and generate | |
| model_inputs = tokenizer([text], return_tensors="pt").to(device) | |
| with torch.inference_mode(): | |
| generated_ids = model.generate( | |
| **model_inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| # repetition_penalty=1.1, | |
| ) | |
| # Get only the generated response (excluding the prompt) | |
| generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:] | |
| model_response = tokenizer.decode(generated_ids, skip_special_tokens=True) | |
| # Format output with evaluation if available | |
| output = model_response | |
| if correct_option is not None: | |
| correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D | |
| # Extract answer from model response for evaluation | |
| answer_match = re.search(r"<answer>\s*([A-D])\s*</answer>", model_response, re.IGNORECASE) | |
| model_answer = answer_match.group(1).upper() if answer_match else "Not found" | |
| is_correct = model_answer == correct_letter | |
| output += f"\n\n---\nEvaluation:\n" | |
| output += f"Correct Answer: {correct_letter}\n" | |
| output += f"Model's Answer: {model_answer}\n" | |
| output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n" | |
| if explanation: | |
| output += f"\nExpert Explanation:\n{explanation}" | |
| return output | |
| # Create Gradio interface with Blocks for more control | |
| with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo: | |
| gr.Markdown("# Medical-QA (MedMCQA) Predictor") | |
| gr.Markdown("Get a random medical question or enter your own question and options.") | |
| with gr.Row(): | |
| with gr.Column(): | |
| # Input fields | |
| question = gr.Textbox(label="Question", lines=3, interactive=True) | |
| # Options in an expandable accordion | |
| with gr.Accordion("Options", open=False): | |
| option_a = gr.Textbox(label="Option A", interactive=True) | |
| option_b = gr.Textbox(label="Option B", interactive=True) | |
| option_c = gr.Textbox(label="Option C", interactive=True) | |
| option_d = gr.Textbox(label="Option D", interactive=True) | |
| # Generation parameters | |
| with gr.Accordion("Generation Parameters", open=False): | |
| temperature = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.6, | |
| step=0.1, | |
| label="Temperature", | |
| info="Higher values make output more random, lower values more focused" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.9, | |
| step=0.1, | |
| label="Top P", | |
| info="Higher values allow more diverse tokens, lower values more focused" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=50, | |
| maximum=512, | |
| value=256, | |
| step=32, | |
| label="Max Tokens", | |
| info="Maximum length of the generated response (recommended: 256)" | |
| ) | |
| # Hidden fields for correct answer and explanation | |
| correct_option = gr.Number(visible=False) | |
| expert_explanation = gr.Textbox(visible=False) | |
| # Buttons | |
| with gr.Row(): | |
| predict_btn = gr.Button("Predict", variant="primary") | |
| random_btn = gr.Button("Get Random Question", variant="secondary") | |
| # Output | |
| output = gr.Textbox(label="Model's Answer", lines=10) | |
| # Set up button actions | |
| predict_btn.click( | |
| fn=predict, | |
| inputs=[ | |
| question, option_a, option_b, option_c, option_d, | |
| correct_option, expert_explanation, | |
| temperature, top_p, max_tokens | |
| ], | |
| outputs=output | |
| ) | |
| random_btn.click( | |
| fn=get_random_question, | |
| inputs=[], | |
| outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation] | |
| ) | |
| # Launch the app | |
| if __name__ == "__main__": | |
| demo.launch() |