Spaces:
Sleeping
Sleeping
Abaryan
commited on
Update app.py
Browse files
app.py
CHANGED
|
@@ -8,6 +8,18 @@ import re
|
|
| 8 |
# Load model and tokenizer
|
| 9 |
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
|
| 10 |
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 12 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 13 |
|
|
@@ -21,8 +33,8 @@ model.eval()
|
|
| 21 |
|
| 22 |
def get_random_question():
|
| 23 |
"""Get a random question from the dataset"""
|
| 24 |
-
index = random.randint(0, len(dataset['
|
| 25 |
-
question_data = dataset['
|
| 26 |
return (
|
| 27 |
question_data['question'],
|
| 28 |
question_data['opa'],
|
|
@@ -33,49 +45,46 @@ def get_random_question():
|
|
| 33 |
question_data.get('exp', None) # Explanation
|
| 34 |
)
|
| 35 |
|
| 36 |
-
def extract_answer(prediction: str) -> tuple:
|
| 37 |
-
"""Extract answer and reasoning from model output"""
|
| 38 |
-
# Try to find the answer part
|
| 39 |
-
answer_match = re.search(r"Answer:\s*([A-D])", prediction, re.IGNORECASE)
|
| 40 |
-
answer = answer_match.group(1).upper() if answer_match else "Not found"
|
| 41 |
-
|
| 42 |
-
# Try to find reasoning part
|
| 43 |
-
reasoning = ""
|
| 44 |
-
if "Reasoning:" in prediction:
|
| 45 |
-
reasoning = prediction.split("Reasoning:")[-1].strip()
|
| 46 |
-
elif "Explanation:" in prediction:
|
| 47 |
-
reasoning = prediction.split("Explanation:")[-1].strip()
|
| 48 |
-
|
| 49 |
-
return answer, reasoning
|
| 50 |
-
|
| 51 |
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
|
| 52 |
correct_option: int = None, explanation: str = None,
|
| 53 |
-
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int =
|
| 54 |
-
# Format the
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
# Tokenize and generate
|
| 58 |
-
|
| 59 |
-
inputs = {k: v.to(device) for k, v in inputs.items()}
|
| 60 |
|
| 61 |
-
with torch.
|
| 62 |
-
|
| 63 |
-
**
|
| 64 |
max_new_tokens=max_tokens,
|
| 65 |
temperature=temperature,
|
| 66 |
top_p=top_p,
|
| 67 |
-
|
| 68 |
-
# pad_token_id=tokenizer.eos_token_id
|
| 69 |
)
|
| 70 |
|
| 71 |
-
# Get
|
| 72 |
-
|
| 73 |
-
|
| 74 |
|
| 75 |
# Format output with evaluation if available
|
| 76 |
-
output =
|
|
|
|
| 77 |
if correct_option is not None:
|
| 78 |
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
is_correct = model_answer == correct_letter
|
| 80 |
output += f"\n\n---\nEvaluation:\n"
|
| 81 |
output += f"Correct Answer: {correct_letter}\n"
|
|
@@ -95,10 +104,13 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
|
|
| 95 |
with gr.Column():
|
| 96 |
# Input fields
|
| 97 |
question = gr.Textbox(label="Question", lines=3, interactive=True)
|
| 98 |
-
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
|
|
|
|
|
|
|
|
|
| 102 |
|
| 103 |
# Generation parameters
|
| 104 |
with gr.Accordion("Generation Parameters", open=False):
|
|
@@ -119,12 +131,12 @@ with gr.Blocks(title="Medical-QA (MedMCQA) Predictor") as demo:
|
|
| 119 |
info="Higher values allow more diverse tokens, lower values more focused"
|
| 120 |
)
|
| 121 |
max_tokens = gr.Slider(
|
| 122 |
-
minimum=
|
| 123 |
maximum=512,
|
| 124 |
-
value=
|
| 125 |
step=32,
|
| 126 |
label="Max Tokens",
|
| 127 |
-
info="Maximum length of the generated response"
|
| 128 |
)
|
| 129 |
|
| 130 |
# Hidden fields for correct answer and explanation
|
|
|
|
| 8 |
# Load model and tokenizer
|
| 9 |
# model_name = "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B"
|
| 10 |
model_name = "rgb2gbr/BioXP-0.5B-MedMCQA"
|
| 11 |
+
|
| 12 |
+
SYSTEM_PROMPT = """
|
| 13 |
+
You're a medical expert. Answer the question with careful analysis and explain why the selected option is correct in 150 words without reapeating.
|
| 14 |
+
Respond in the following format:
|
| 15 |
+
<answer>
|
| 16 |
+
[correct answer]
|
| 17 |
+
</answer>
|
| 18 |
+
<reasoning>
|
| 19 |
+
[explain why the selected option is correct]
|
| 20 |
+
</reasoning>
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
model = AutoModelForCausalLM.from_pretrained(model_name)
|
| 24 |
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
| 25 |
|
|
|
|
| 33 |
|
| 34 |
def get_random_question():
|
| 35 |
"""Get a random question from the dataset"""
|
| 36 |
+
index = random.randint(0, len(dataset['validation']) - 1)
|
| 37 |
+
question_data = dataset['validation'][index]
|
| 38 |
return (
|
| 39 |
question_data['question'],
|
| 40 |
question_data['opa'],
|
|
|
|
| 45 |
question_data.get('exp', None) # Explanation
|
| 46 |
)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
def predict(question: str, option_a: str, option_b: str, option_c: str, option_d: str,
|
| 49 |
correct_option: int = None, explanation: str = None,
|
| 50 |
+
temperature: float = 0.6, top_p: float = 0.9, max_tokens: int = 256):
|
| 51 |
+
# Format the question with options
|
| 52 |
+
formatted_question = f"Question: {question}\n\nOptions:\nA. {option_a}\nB. {option_b}\nC. {option_c}\nD. {option_d}"
|
| 53 |
+
|
| 54 |
+
# Create chat-style prompt
|
| 55 |
+
prompt = [
|
| 56 |
+
{'role': 'system', 'content': SYSTEM_PROMPT},
|
| 57 |
+
{'role': 'user', 'content': formatted_question}
|
| 58 |
+
]
|
| 59 |
+
|
| 60 |
+
# Use apply_chat_template for better formatting
|
| 61 |
+
text = tokenizer.apply_chat_template(prompt, tokenize=False, add_generation_prompt=True)
|
| 62 |
|
| 63 |
# Tokenize and generate
|
| 64 |
+
model_inputs = tokenizer([text], return_tensors="pt").to(device)
|
|
|
|
| 65 |
|
| 66 |
+
with torch.inference_mode():
|
| 67 |
+
generated_ids = model.generate(
|
| 68 |
+
**model_inputs,
|
| 69 |
max_new_tokens=max_tokens,
|
| 70 |
temperature=temperature,
|
| 71 |
top_p=top_p,
|
| 72 |
+
# repetition_penalty=1.1,
|
|
|
|
| 73 |
)
|
| 74 |
|
| 75 |
+
# Get only the generated response (excluding the prompt)
|
| 76 |
+
generated_ids = generated_ids[0, model_inputs.input_ids.shape[1]:]
|
| 77 |
+
model_response = tokenizer.decode(generated_ids, skip_special_tokens=True)
|
| 78 |
|
| 79 |
# Format output with evaluation if available
|
| 80 |
+
output = model_response
|
| 81 |
+
|
| 82 |
if correct_option is not None:
|
| 83 |
correct_letter = chr(65 + correct_option) # Convert 0-3 to A-D
|
| 84 |
+
# Extract answer from model response for evaluation
|
| 85 |
+
answer_match = re.search(r"<answer>\s*([A-D])\s*</answer>", model_response, re.IGNORECASE)
|
| 86 |
+
model_answer = answer_match.group(1).upper() if answer_match else "Not found"
|
| 87 |
+
|
| 88 |
is_correct = model_answer == correct_letter
|
| 89 |
output += f"\n\n---\nEvaluation:\n"
|
| 90 |
output += f"Correct Answer: {correct_letter}\n"
|
|
|
|
| 104 |
with gr.Column():
|
| 105 |
# Input fields
|
| 106 |
question = gr.Textbox(label="Question", lines=3, interactive=True)
|
| 107 |
+
|
| 108 |
+
# Options in an expandable accordion
|
| 109 |
+
with gr.Accordion("Options", open=False):
|
| 110 |
+
option_a = gr.Textbox(label="Option A", interactive=True)
|
| 111 |
+
option_b = gr.Textbox(label="Option B", interactive=True)
|
| 112 |
+
option_c = gr.Textbox(label="Option C", interactive=True)
|
| 113 |
+
option_d = gr.Textbox(label="Option D", interactive=True)
|
| 114 |
|
| 115 |
# Generation parameters
|
| 116 |
with gr.Accordion("Generation Parameters", open=False):
|
|
|
|
| 131 |
info="Higher values allow more diverse tokens, lower values more focused"
|
| 132 |
)
|
| 133 |
max_tokens = gr.Slider(
|
| 134 |
+
minimum=50,
|
| 135 |
maximum=512,
|
| 136 |
+
value=256,
|
| 137 |
step=32,
|
| 138 |
label="Max Tokens",
|
| 139 |
+
info="Maximum length of the generated response (recommended: 256)"
|
| 140 |
)
|
| 141 |
|
| 142 |
# Hidden fields for correct answer and explanation
|