Spaces:
Sleeping
Sleeping
File size: 6,443 Bytes
dc3747b 1f15859 30ca71a fa0e902 398a7eb a8c3b23 dc3747b 1b0f88d dee81c5 dc3747b fa0e902 dc3747b fa0e902 dee81c5 fa0e902 398a7eb fa0e902 ffe13aa dee81c5 dc3747b dee81c5 dc3747b dee81c5 ffe13aa dee81c5 1f15859 dee81c5 398a7eb dee81c5 398a7eb dee81c5 398a7eb a8c3b23 fa0e902 f0d0a93 fa0e902 dee81c5 fa0e902 ffe13aa dee81c5 ffe13aa dee81c5 ffe13aa dee81c5 ffe13aa 398a7eb fa0e902 20e34ca 398a7eb fa0e902 ffe13aa fa0e902 398a7eb fa0e902 a8c3b23 dc3747b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 |
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
from datasets import load_dataset
import random
import re
# Load model and tokenizer
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
SYSTEM_PROMPT = """
You're a medical expert. Answer the question with careful analysis and explain why the selected option is correct in 150 words without reapeating.
Respond in the following format:
<answer>
[correct answer]
</answer>
<reasoning>
[explain why the selected option is correct]
</reasoning>
"""
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
# Load dataset
dataset = load_dataset("openlifescienceai/medmcqa")
# Move model to GPU if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = model.to(device)
model.eval()
def get_random_question():
"""Get a random question from the dataset"""
index = random.randint(0, len(dataset['validation']) - 1)
question_data = dataset['validation'][index]
return (
question_data['question'],
question_data['opa'],
question_data['opb'],
question_data['opc'],
question_data['opd'],
question_data.get('cop', None), # Correct option (0-3)
question_data.get('exp', None) # Explanation
)
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
correct_option: int = None, explanation: str = None,
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
# Format the question with options
formatted_question = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}"
# Create chat-style prompt
prompt = [
{'role': 'system', 'content': SYSTEM_PROMPT},
{'role': 'user', 'content': formatted_question}
]
# Use apply_chat_template for better formatting
text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
# Tokenize and generate
model_inputs = tokenizer([text], return_tensors="pt").to(device)
with torch.inference_mode():
generated_ids = model.generate(
**model_inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
# repetition_penalty=1.1,
)
# Get only the generated response (excluding the prompt)
generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
# Format output with evaluation if available
output = model_response
if correct_option is not None:
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
# Extract answer from model response for evaluation
answer_match = re.search(r"<answer>\s*([A-D])\s*</answer>", model_response, re.IGNORECASE)
model_answer = answer_match.group(1).upper() if answer_match else "Not found"
is_correct = model_answer == correct_letter
output += f"\n\n---\nEvaluation:\n"
output += f"Correct Answer: {correct_letter}\n"
output += f"Model's Answer: {model_answer}\n"
output += f"Result: {'✅ Correct' if is_correct else '❌ Incorrect'}\n"
if explanation:
output += f"\nExpert Explanation:\n{explanation}"
return output
# Create Gradio interface with Blocks for more control
with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
gr.Markdown("# Medical-QA (MedMCQA) Predictor")
gr.Markdown("Get a random medical question or enter your own question and options.")
with gr.Row():
with gr.Column():
# Input fields
question = gr.Textbox(label="Question", lines=3, interactive=True)
# Options in an expandable accordion
with gr.Accordion("Options", open=False):
option_a = gr.Textbox(label="Option A", interactive=True)
option_b = gr.Textbox(label="Option B", interactive=True)
option_c = gr.Textbox(label="Option C", interactive=True)
option_d = gr.Textbox(label="Option D", interactive=True)
# Generation parameters
with gr.Accordion("Generation Parameters", open=False):
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.6,
step=0.1,
label="Temperature",
info="Higher values make output more random, lower values more focused"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.9,
step=0.1,
label="Top P",
info="Higher values allow more diverse tokens, lower values more focused"
)
max_tokens = gr.Slider(
minimum=50,
maximum=512,
value=256,
step=32,
label="Max Tokens",
info="Maximum length of the generated response (recommended: 256)"
)
# Hidden fields for correct answer and explanation
correct_option = gr.Number(visible=False)
expert_explanation = gr.Textbox(visible=False)
# Buttons
with gr.Row():
predict_btn = gr.Button("Predict", variant="primary")
random_btn = gr.Button("Get Random Question", variant="secondary")
# Output
output = gr.Textbox(label="Model's Answer", lines=10)
# Set up button actions
predict_btn.click(
fn=predict,
inputs=[
question, option_a, option_b, option_c, option_d,
correct_option, expert_explanation,
temperature, top_p, max_tokens
],
outputs=output
)
random_btn.click(
fn=get_random_question,
inputs=[],
outputs=[question, option_a, option_b, option_c, option_d, correct_option, expert_explanation]
)
# Launch the app
if __name__ == "__main__":
demo.launch() |