Spaces:
Sleeping
Sleeping
Commit
·
05dae02
1
Parent(s):
da6fc76
group by model
Browse files
app.py
CHANGED
|
@@ -61,9 +61,12 @@ def response(num_responses, model, correct, prompt_ids):
|
|
| 61 |
responses = responses[responses['correct'].isin(correct)]
|
| 62 |
if prompt_ids:
|
| 63 |
responses = responses[responses['prompt'].isin(prompt_ids)]
|
| 64 |
-
if num_responses > len(responses):
|
| 65 |
-
|
| 66 |
-
|
|
|
|
|
|
|
|
|
|
| 67 |
|
| 68 |
|
| 69 |
def barplot_for_prompt_id(prompt_ids, models):
|
|
|
|
| 61 |
responses = responses[responses['correct'].isin(correct)]
|
| 62 |
if prompt_ids:
|
| 63 |
responses = responses[responses['prompt'].isin(prompt_ids)]
|
| 64 |
+
# if num_responses > len(responses):
|
| 65 |
+
# num_responses = len(responses)
|
| 66 |
+
|
| 67 |
+
# sample num_responses for each model
|
| 68 |
+
responses = responses.groupby('model').apply(lambda x: x.sample(num_responses) if num_responses < len(x) else x).reset_index()
|
| 69 |
+
return responses[['model', 'prompt', 'model_response', 'correct']]
|
| 70 |
|
| 71 |
|
| 72 |
def barplot_for_prompt_id(prompt_ids, models):
|