refactor: bbh/tmlu test case
Browse files- tests/test_bbh_parser.py +29 -65
- tests/test_tmlu_parser.py +26 -62
tests/test_bbh_parser.py
CHANGED
|
@@ -160,74 +160,38 @@ def test_different_tasks_parsing(bbh_parser, task_name):
|
|
| 160 |
assert all(isinstance(entry.answer, str) for entry in parsed_data)
|
| 161 |
|
| 162 |
|
| 163 |
-
def
|
| 164 |
-
"""Test
|
| 165 |
-
metrics = bbh_parser.get_evaluation_metrics()
|
| 166 |
-
|
| 167 |
-
# Check basic structure
|
| 168 |
-
assert isinstance(metrics, list)
|
| 169 |
-
assert len(metrics) > 0
|
| 170 |
-
|
| 171 |
-
# Check each metric has required fields
|
| 172 |
-
required_fields = ["name", "type", "description", "implementation", "primary"]
|
| 173 |
-
for metric in metrics:
|
| 174 |
-
for field in required_fields:
|
| 175 |
-
assert field in metric, f"Missing field {field} in metric {metric['name']}"
|
| 176 |
-
|
| 177 |
-
# Check field types
|
| 178 |
-
assert isinstance(metric["name"], str)
|
| 179 |
-
assert isinstance(metric["type"], str)
|
| 180 |
-
assert isinstance(metric["description"], str)
|
| 181 |
-
assert isinstance(metric["implementation"], str)
|
| 182 |
-
assert isinstance(metric["primary"], bool)
|
| 183 |
-
|
| 184 |
-
# Check specific metrics exist
|
| 185 |
-
metric_names = {m["name"] for m in metrics}
|
| 186 |
-
expected_metrics = {
|
| 187 |
-
"accuracy",
|
| 188 |
-
"human_eval_delta",
|
| 189 |
-
"per_task_accuracy",
|
| 190 |
-
"exact_match",
|
| 191 |
-
}
|
| 192 |
-
assert expected_metrics.issubset(metric_names)
|
| 193 |
-
|
| 194 |
-
# Check primary metrics
|
| 195 |
-
primary_metrics = {m["name"] for m in metrics if m["primary"]}
|
| 196 |
-
assert "accuracy" in primary_metrics
|
| 197 |
-
assert "human_eval_delta" in primary_metrics
|
| 198 |
-
|
| 199 |
-
|
| 200 |
-
def test_dataset_description_citation_format(bbh_parser):
|
| 201 |
-
"""Test that the citation in dataset description is properly formatted."""
|
| 202 |
description = bbh_parser.get_dataset_description()
|
| 203 |
-
citation = description["citation"]
|
| 204 |
-
|
| 205 |
-
# Check citation structure
|
| 206 |
-
assert citation.startswith("@article{")
|
| 207 |
-
assert "title=" in citation
|
| 208 |
-
assert "author=" in citation
|
| 209 |
-
assert "journal=" in citation
|
| 210 |
-
assert "year=" in citation
|
| 211 |
|
| 212 |
-
|
| 213 |
-
assert "
|
| 214 |
-
assert
|
| 215 |
-
assert
|
| 216 |
-
assert "
|
|
|
|
|
|
|
|
|
|
|
|
|
| 217 |
|
| 218 |
|
| 219 |
-
def
|
| 220 |
-
"""Test
|
| 221 |
metrics = bbh_parser.get_evaluation_metrics()
|
| 222 |
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
| 226 |
-
|
| 227 |
-
|
| 228 |
-
|
| 229 |
-
|
| 230 |
-
|
| 231 |
-
|
| 232 |
-
|
| 233 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
assert all(isinstance(entry.answer, str) for entry in parsed_data)
|
| 161 |
|
| 162 |
|
| 163 |
+
def test_get_dataset_description(bbh_parser):
|
| 164 |
+
"""Test dataset description generation."""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 165 |
description = bbh_parser.get_dataset_description()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 166 |
|
| 167 |
+
assert description.name == "Big Bench Hard (BBH)"
|
| 168 |
+
assert "challenging BIG-Bench tasks" in description.purpose
|
| 169 |
+
assert description.language == "English"
|
| 170 |
+
assert description.format == "Multiple choice questions with single correct answers"
|
| 171 |
+
assert "Tasks require complex multi-step reasoning" in description.characteristics
|
| 172 |
+
assert "suzgun2022challenging" in description.citation
|
| 173 |
+
assert description.additional_info is not None
|
| 174 |
+
assert "model_performance" in description.additional_info
|
| 175 |
+
assert "size" in description.additional_info
|
| 176 |
|
| 177 |
|
| 178 |
+
def test_get_evaluation_metrics(bbh_parser):
|
| 179 |
+
"""Test evaluation metrics generation."""
|
| 180 |
metrics = bbh_parser.get_evaluation_metrics()
|
| 181 |
|
| 182 |
+
assert len(metrics) == 4 # Check total number of metrics
|
| 183 |
+
|
| 184 |
+
# Check primary metrics
|
| 185 |
+
primary_metrics = [m for m in metrics if m.primary]
|
| 186 |
+
assert len(primary_metrics) == 2
|
| 187 |
+
assert any(m.name == "accuracy" for m in primary_metrics)
|
| 188 |
+
assert any(m.name == "human_eval_delta" for m in primary_metrics)
|
| 189 |
+
|
| 190 |
+
# Check specific metric properties
|
| 191 |
+
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
|
| 192 |
+
assert accuracy_metric.type == "classification"
|
| 193 |
+
assert "evaluate.load('accuracy')" in accuracy_metric.implementation
|
| 194 |
+
|
| 195 |
+
# Check non-primary metrics
|
| 196 |
+
assert any(m.name == "per_task_accuracy" and not m.primary for m in metrics)
|
| 197 |
+
assert any(m.name == "exact_match" and not m.primary for m in metrics)
|
tests/test_tmlu_parser.py
CHANGED
|
@@ -170,76 +170,40 @@ def test_metadata_handling(tmlu_parser, sample_tmlu_entries):
|
|
| 170 |
assert entry.metadata["source"] == "AST chinese - 108"
|
| 171 |
|
| 172 |
|
| 173 |
-
def
|
| 174 |
-
"""Test dataset description
|
| 175 |
description = tmlu_parser.get_dataset_description()
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
-
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
"size",
|
| 185 |
-
"domain",
|
| 186 |
-
"characteristics",
|
| 187 |
-
"reference",
|
| 188 |
-
]
|
| 189 |
-
|
| 190 |
-
for field in required_fields:
|
| 191 |
-
assert field in description, f"Missing required field: {field}"
|
| 192 |
|
| 193 |
-
assert description["language"] == "Traditional Chinese"
|
| 194 |
-
assert "TMLU" in description["name"]
|
| 195 |
-
assert "miulab/tmlu" in description["reference"]
|
| 196 |
-
assert "AST" in description["characteristics"]
|
| 197 |
-
assert "GSAT" in description["characteristics"]
|
| 198 |
|
| 199 |
-
|
| 200 |
-
|
| 201 |
-
"""Test evaluation metrics structure and content."""
|
| 202 |
metrics = tmlu_parser.get_evaluation_metrics()
|
| 203 |
|
| 204 |
-
# Check
|
| 205 |
-
assert len(metrics) > 0
|
| 206 |
|
| 207 |
-
# Check
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
|
| 211 |
-
|
| 212 |
-
"implementation",
|
| 213 |
-
"primary",
|
| 214 |
-
]
|
| 215 |
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
assert isinstance(metric["implementation"], str)
|
| 225 |
-
assert isinstance(metric["primary"], bool)
|
| 226 |
-
|
| 227 |
-
# Check for TMLU-specific metrics
|
| 228 |
-
metric_names = {m["name"] for m in metrics}
|
| 229 |
-
expected_metrics = {
|
| 230 |
-
"accuracy",
|
| 231 |
-
"per_subject_accuracy",
|
| 232 |
"per_difficulty_accuracy",
|
|
|
|
| 233 |
"explanation_quality",
|
| 234 |
}
|
| 235 |
-
|
| 236 |
-
for expected in expected_metrics:
|
| 237 |
-
assert expected in metric_names, f"Missing expected metric: {expected}"
|
| 238 |
-
|
| 239 |
-
# Verify primary metrics
|
| 240 |
-
primary_metrics = [m for m in metrics if m["primary"]]
|
| 241 |
-
assert (
|
| 242 |
-
len(primary_metrics) >= 2
|
| 243 |
-
) # Should have at least accuracy and per_subject_accuracy
|
| 244 |
-
assert any(m["name"] == "accuracy" for m in primary_metrics)
|
| 245 |
-
assert any(m["name"] == "per_subject_accuracy" for m in primary_metrics)
|
|
|
|
| 170 |
assert entry.metadata["source"] == "AST chinese - 108"
|
| 171 |
|
| 172 |
|
| 173 |
+
def test_get_dataset_description(tmlu_parser):
|
| 174 |
+
"""Test dataset description generation."""
|
| 175 |
description = tmlu_parser.get_dataset_description()
|
| 176 |
|
| 177 |
+
assert description.name == "Taiwan Multiple-choice Language Understanding (TMLU)"
|
| 178 |
+
assert description.language == "Traditional Chinese"
|
| 179 |
+
assert "Taiwan-specific educational" in description.purpose
|
| 180 |
+
assert "Various Taiwan standardized tests" in description.source
|
| 181 |
+
assert description.format == "Multiple choice questions (A/B/C/D)"
|
| 182 |
+
assert "Advanced Subjects Test (AST)" in description.characteristics
|
| 183 |
+
assert "DBLP:journals/corr/abs-2403-20180" in description.citation
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
|
| 186 |
+
def test_get_evaluation_metrics(tmlu_parser):
|
| 187 |
+
"""Test evaluation metrics generation."""
|
|
|
|
| 188 |
metrics = tmlu_parser.get_evaluation_metrics()
|
| 189 |
|
| 190 |
+
assert len(metrics) == 5 # Check total number of metrics
|
|
|
|
| 191 |
|
| 192 |
+
# Check primary metrics
|
| 193 |
+
primary_metrics = [m for m in metrics if m.primary]
|
| 194 |
+
assert len(primary_metrics) == 2
|
| 195 |
+
assert any(m.name == "accuracy" for m in primary_metrics)
|
| 196 |
+
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
+
# Check specific metric properties
|
| 199 |
+
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
|
| 200 |
+
assert accuracy_metric.type == "classification"
|
| 201 |
+
assert "datasets.load_metric('accuracy')" in accuracy_metric.implementation
|
| 202 |
+
|
| 203 |
+
# Check non-primary metrics
|
| 204 |
+
non_primary_metrics = {m.name for m in metrics if not m.primary}
|
| 205 |
+
assert non_primary_metrics == {
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 206 |
"per_difficulty_accuracy",
|
| 207 |
+
"confusion_matrix",
|
| 208 |
"explanation_quality",
|
| 209 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|