refactor: tmli/tw_legal parser
Browse files- llmdataparser/tmlu_parser.py +0 -21
- llmdataparser/tw_legal_parser.py +35 -1
- tests/test_tmlu_parser.py +1 -14
- tests/test_tw_legal_parser.py +24 -0
llmdataparser/tmlu_parser.py
CHANGED
|
@@ -169,27 +169,6 @@ class TMLUDatasetParser(HuggingFaceDatasetParser[TMLUParseEntry]):
|
|
| 169 |
implementation="custom_subject_accuracy",
|
| 170 |
primary=True,
|
| 171 |
),
|
| 172 |
-
EvaluationMetric.create(
|
| 173 |
-
name="per_difficulty_accuracy",
|
| 174 |
-
type="classification",
|
| 175 |
-
description="Accuracy broken down by test difficulty levels",
|
| 176 |
-
implementation="custom_difficulty_accuracy",
|
| 177 |
-
primary=False,
|
| 178 |
-
),
|
| 179 |
-
EvaluationMetric.create(
|
| 180 |
-
name="confusion_matrix",
|
| 181 |
-
type="classification",
|
| 182 |
-
description="Distribution of predicted vs actual answers",
|
| 183 |
-
implementation="datasets.load_metric('confusion_matrix')",
|
| 184 |
-
primary=False,
|
| 185 |
-
),
|
| 186 |
-
EvaluationMetric.create(
|
| 187 |
-
name="explanation_quality",
|
| 188 |
-
type="text",
|
| 189 |
-
description="Quality assessment of model explanations when available",
|
| 190 |
-
implementation="custom_explanation_metric",
|
| 191 |
-
primary=False,
|
| 192 |
-
),
|
| 193 |
]
|
| 194 |
|
| 195 |
|
|
|
|
| 169 |
implementation="custom_subject_accuracy",
|
| 170 |
primary=True,
|
| 171 |
),
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 172 |
]
|
| 173 |
|
| 174 |
|
llmdataparser/tw_legal_parser.py
CHANGED
|
@@ -1,7 +1,12 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Any, Final
|
| 3 |
|
| 4 |
-
from llmdataparser.base_parser import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
from llmdataparser.prompts import TW_LEGAL_SYSTEM_PROMPT
|
| 6 |
|
| 7 |
TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
|
|
@@ -70,6 +75,35 @@ class TWLegalDatasetParser(HuggingFaceDatasetParser[TWLegalParseEntry]):
|
|
| 70 |
task_name=task,
|
| 71 |
)
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
|
| 74 |
if __name__ == "__main__":
|
| 75 |
# Example usage
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Any, Final
|
| 3 |
|
| 4 |
+
from llmdataparser.base_parser import (
|
| 5 |
+
DatasetDescription,
|
| 6 |
+
EvaluationMetric,
|
| 7 |
+
HuggingFaceDatasetParser,
|
| 8 |
+
HuggingFaceParseEntry,
|
| 9 |
+
)
|
| 10 |
from llmdataparser.prompts import TW_LEGAL_SYSTEM_PROMPT
|
| 11 |
|
| 12 |
TW_LEGAL_VALID_ANSWERS: Final[set[str]] = {"A", "B", "C", "D"}
|
|
|
|
| 75 |
task_name=task,
|
| 76 |
)
|
| 77 |
|
| 78 |
+
def get_dataset_description(self) -> DatasetDescription:
|
| 79 |
+
"""Returns description of the Taiwan Legal Benchmark dataset."""
|
| 80 |
+
return DatasetDescription.create(
|
| 81 |
+
name="Taiwan Legal Benchmark",
|
| 82 |
+
language="Traditional Chinese",
|
| 83 |
+
purpose="Evaluate models on Taiwan-specific legal knowledge and understanding",
|
| 84 |
+
source="Taiwan Bar Examination questions",
|
| 85 |
+
format="Multiple choice questions (A/B/C/D)",
|
| 86 |
+
characteristics=(
|
| 87 |
+
"Contains questions from Taiwan's bar examination, testing understanding "
|
| 88 |
+
"of Taiwan's legal system, terminology, and concepts"
|
| 89 |
+
),
|
| 90 |
+
citation="""
|
| 91 |
+
url={https://huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1}
|
| 92 |
+
}""",
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def get_evaluation_metrics(self) -> list[EvaluationMetric]:
|
| 96 |
+
"""Returns recommended evaluation metrics for Taiwan Legal Benchmark."""
|
| 97 |
+
return [
|
| 98 |
+
EvaluationMetric.create(
|
| 99 |
+
name="accuracy",
|
| 100 |
+
type="classification",
|
| 101 |
+
description="Overall percentage of correctly answered legal questions",
|
| 102 |
+
implementation="datasets.load_metric('accuracy')",
|
| 103 |
+
primary=True,
|
| 104 |
+
),
|
| 105 |
+
]
|
| 106 |
+
|
| 107 |
|
| 108 |
if __name__ == "__main__":
|
| 109 |
# Example usage
|
tests/test_tmlu_parser.py
CHANGED
|
@@ -187,23 +187,10 @@ def test_get_evaluation_metrics(tmlu_parser):
|
|
| 187 |
"""Test evaluation metrics generation."""
|
| 188 |
metrics = tmlu_parser.get_evaluation_metrics()
|
| 189 |
|
| 190 |
-
assert len(metrics) ==
|
| 191 |
|
| 192 |
# Check primary metrics
|
| 193 |
primary_metrics = [m for m in metrics if m.primary]
|
| 194 |
assert len(primary_metrics) == 2
|
| 195 |
assert any(m.name == "accuracy" for m in primary_metrics)
|
| 196 |
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|
| 197 |
-
|
| 198 |
-
# Check specific metric properties
|
| 199 |
-
accuracy_metric = next(m for m in metrics if m.name == "accuracy")
|
| 200 |
-
assert accuracy_metric.type == "classification"
|
| 201 |
-
assert "datasets.load_metric('accuracy')" in accuracy_metric.implementation
|
| 202 |
-
|
| 203 |
-
# Check non-primary metrics
|
| 204 |
-
non_primary_metrics = {m.name for m in metrics if not m.primary}
|
| 205 |
-
assert non_primary_metrics == {
|
| 206 |
-
"per_difficulty_accuracy",
|
| 207 |
-
"confusion_matrix",
|
| 208 |
-
"explanation_quality",
|
| 209 |
-
}
|
|
|
|
| 187 |
"""Test evaluation metrics generation."""
|
| 188 |
metrics = tmlu_parser.get_evaluation_metrics()
|
| 189 |
|
| 190 |
+
assert len(metrics) == 2 # Check total number of metrics
|
| 191 |
|
| 192 |
# Check primary metrics
|
| 193 |
primary_metrics = [m for m in metrics if m.primary]
|
| 194 |
assert len(primary_metrics) == 2
|
| 195 |
assert any(m.name == "accuracy" for m in primary_metrics)
|
| 196 |
assert any(m.name == "per_subject_accuracy" for m in primary_metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
tests/test_tw_legal_parser.py
CHANGED
|
@@ -138,3 +138,27 @@ def test_system_prompt_override(tw_legal_parser):
|
|
| 138 |
|
| 139 |
entry = parser.process_entry(test_entry)
|
| 140 |
assert custom_prompt in entry.prompt
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 138 |
|
| 139 |
entry = parser.process_entry(test_entry)
|
| 140 |
assert custom_prompt in entry.prompt
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
def test_get_dataset_description(tw_legal_parser):
|
| 144 |
+
"""Test getting dataset description for Taiwan Legal parser."""
|
| 145 |
+
description = tw_legal_parser.get_dataset_description()
|
| 146 |
+
|
| 147 |
+
assert description.name == "Taiwan Legal Benchmark"
|
| 148 |
+
assert description.language == "Traditional Chinese"
|
| 149 |
+
assert "Taiwan's legal system" in description.characteristics
|
| 150 |
+
assert (
|
| 151 |
+
"huggingface.co/datasets/lianghsun/tw-legal-benchmark-v1"
|
| 152 |
+
in description.citation
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def test_get_evaluation_metrics(tw_legal_parser):
|
| 157 |
+
"""Test getting evaluation metrics for Taiwan Legal parser."""
|
| 158 |
+
metrics = tw_legal_parser.get_evaluation_metrics()
|
| 159 |
+
|
| 160 |
+
assert len(metrics) == 1
|
| 161 |
+
metric = metrics[0]
|
| 162 |
+
assert metric.name == "accuracy"
|
| 163 |
+
assert metric.type == "classification"
|
| 164 |
+
assert metric.primary is True
|