feat: update base_parser
Browse files- .pre-commit-config.yaml +8 -1
- llmdataparser/__init__.py +33 -0
- llmdataparser/base_parser.py +100 -0
- pyproject.toml +5 -3
.pre-commit-config.yaml
CHANGED
|
@@ -12,6 +12,7 @@ repos:
|
|
| 12 |
hooks:
|
| 13 |
- id: flake8
|
| 14 |
additional_dependencies: ["typing-extensions>=4.8.0"]
|
|
|
|
| 15 |
- repo: https://github.com/PyCQA/isort
|
| 16 |
rev: 5.12.0
|
| 17 |
hooks:
|
|
@@ -21,7 +22,13 @@ repos:
|
|
| 21 |
rev: v1.5.1
|
| 22 |
hooks:
|
| 23 |
- id: mypy
|
| 24 |
-
args:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 25 |
additional_dependencies:
|
| 26 |
- "typing-extensions>=4.8.0"
|
| 27 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
|
|
|
| 12 |
hooks:
|
| 13 |
- id: flake8
|
| 14 |
additional_dependencies: ["typing-extensions>=4.8.0"]
|
| 15 |
+
args: ["--ignore=E203, E501, W503, E501"]
|
| 16 |
- repo: https://github.com/PyCQA/isort
|
| 17 |
rev: 5.12.0
|
| 18 |
hooks:
|
|
|
|
| 22 |
rev: v1.5.1
|
| 23 |
hooks:
|
| 24 |
- id: mypy
|
| 25 |
+
args:
|
| 26 |
+
[
|
| 27 |
+
"--python-version=3.11",
|
| 28 |
+
"--install-types",
|
| 29 |
+
"--non-interactive",
|
| 30 |
+
"--ignore-missing-imports",
|
| 31 |
+
]
|
| 32 |
additional_dependencies:
|
| 33 |
- "typing-extensions>=4.8.0"
|
| 34 |
- repo: https://github.com/pre-commit/pre-commit-hooks
|
llmdataparser/__init__.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# llmdataparser/__init__.py
|
| 2 |
+
from typing import Type
|
| 3 |
+
|
| 4 |
+
from .base_parser import DatasetParser
|
| 5 |
+
from .mmlu_parser import MMLUDatasetParser
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ParserRegistry:
|
| 9 |
+
"""
|
| 10 |
+
Registry to keep track of available parsers and provide them on request.
|
| 11 |
+
"""
|
| 12 |
+
|
| 13 |
+
_registry: dict = {}
|
| 14 |
+
|
| 15 |
+
@classmethod
|
| 16 |
+
def register_parser(cls, name: str, parser_class: Type[DatasetParser]) -> None:
|
| 17 |
+
cls._registry[name.lower()] = parser_class
|
| 18 |
+
|
| 19 |
+
@classmethod
|
| 20 |
+
def get_parser(cls, name: str, **kwargs) -> Type[DatasetParser]:
|
| 21 |
+
parser_class = cls._registry.get(name.lower())
|
| 22 |
+
if parser_class is None:
|
| 23 |
+
raise ValueError(f"Parser '{name}' is not registered.")
|
| 24 |
+
return parser_class(**kwargs)
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def list_parsers(cls) -> list[str]:
|
| 28 |
+
"""Returns a list of available parser names."""
|
| 29 |
+
return list(cls._registry.keys())
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
# Register parsers
|
| 33 |
+
ParserRegistry.register_parser("mmlu", MMLUDatasetParser)
|
llmdataparser/base_parser.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from functools import lru_cache
|
| 4 |
+
from typing import Any, Generic, TypeVar
|
| 5 |
+
|
| 6 |
+
import datasets
|
| 7 |
+
|
| 8 |
+
# Define the generic type variable
|
| 9 |
+
T = TypeVar("T", bound="ParseEntry")
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
@dataclass(frozen=True)
|
| 13 |
+
class ParseEntry:
|
| 14 |
+
"""A simple base class for entries, customizable by each dataset parser."""
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class DatasetParser(ABC, Generic[T]):
|
| 18 |
+
"""
|
| 19 |
+
Abstract base class defining the interface for all dataset parsers.
|
| 20 |
+
"""
|
| 21 |
+
|
| 22 |
+
def __init__(self):
|
| 23 |
+
self._parsed_data: list[T] = []
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
def load(self, **kwargs: Any) -> None:
|
| 27 |
+
pass
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def parse(self, split_names: str | list[str] | None = None, **kwargs: Any) -> None:
|
| 31 |
+
"""
|
| 32 |
+
Parse the loaded dataset into self._parsed_data.
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def get_parsed_data(self) -> list[T]:
|
| 37 |
+
if not hasattr(self, "_parsed_data") or not self._parsed_data:
|
| 38 |
+
raise ValueError("Parsed data has not been initialized.")
|
| 39 |
+
return self._parsed_data
|
| 40 |
+
|
| 41 |
+
@abstractmethod
|
| 42 |
+
def process_entry(self, row: dict[str, Any]) -> T:
|
| 43 |
+
pass
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
# Base class for Hugging Face datasets
|
| 47 |
+
class HuggingFaceDatasetParser(DatasetParser[T]):
|
| 48 |
+
"""
|
| 49 |
+
Base class for parsers that use datasets from Hugging Face.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
_data_source: str # Class variable for the dataset name
|
| 53 |
+
|
| 54 |
+
def __init__(self):
|
| 55 |
+
self.raw_data = None
|
| 56 |
+
self.task_names = []
|
| 57 |
+
super().__init__()
|
| 58 |
+
|
| 59 |
+
def get_task_names(self) -> list[str]:
|
| 60 |
+
return self.task_names
|
| 61 |
+
|
| 62 |
+
@staticmethod
|
| 63 |
+
@lru_cache(maxsize=3)
|
| 64 |
+
def load_dataset_cached(
|
| 65 |
+
data_source: str, config_name: str = "default", **kwargs: Any
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
Cached static method to load a dataset from Hugging Face.
|
| 69 |
+
"""
|
| 70 |
+
return datasets.load_dataset(data_source, config_name, **kwargs)
|
| 71 |
+
|
| 72 |
+
def load(
|
| 73 |
+
self,
|
| 74 |
+
data_source: str | None = None,
|
| 75 |
+
config_name: str = "all",
|
| 76 |
+
trust_remote_code: bool = True,
|
| 77 |
+
split: str | None = None,
|
| 78 |
+
**kwargs: Any,
|
| 79 |
+
) -> None:
|
| 80 |
+
"""
|
| 81 |
+
Load the dataset using the Hugging Face datasets library.
|
| 82 |
+
"""
|
| 83 |
+
# Use class-level data_source if not provided
|
| 84 |
+
data_source = data_source or self._data_source
|
| 85 |
+
if not data_source:
|
| 86 |
+
raise ValueError("The 'data_source' class variable must be defined.")
|
| 87 |
+
|
| 88 |
+
# Call the cached static method
|
| 89 |
+
self.raw_data = self.load_dataset_cached(
|
| 90 |
+
data_source,
|
| 91 |
+
config_name=config_name,
|
| 92 |
+
trust_remote_code=trust_remote_code,
|
| 93 |
+
split=split,
|
| 94 |
+
**kwargs,
|
| 95 |
+
)
|
| 96 |
+
self.task_names = list(self.raw_data.keys())
|
| 97 |
+
print(
|
| 98 |
+
f"Loaded dataset with {len(self.task_names)} tasks: {', '.join(self.task_names)}."
|
| 99 |
+
)
|
| 100 |
+
# Additional common initialization can be added here
|
pyproject.toml
CHANGED
|
@@ -49,11 +49,13 @@ profile = "black"
|
|
| 49 |
line_length = 88
|
| 50 |
known_first_party = ["llmdataparser"]
|
| 51 |
|
|
|
|
| 52 |
[tool.flake8]
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
]
|
|
|
|
| 57 |
|
| 58 |
[tool.ruff]
|
| 59 |
line-length = 88
|
|
|
|
| 49 |
line_length = 88
|
| 50 |
known_first_party = ["llmdataparser"]
|
| 51 |
|
| 52 |
+
# .flake8
|
| 53 |
[tool.flake8]
|
| 54 |
+
ignore = ['E231', 'E241', "E501"]
|
| 55 |
+
per-file-ignores = [
|
| 56 |
+
'__init__.py:F401',
|
| 57 |
]
|
| 58 |
+
count = true
|
| 59 |
|
| 60 |
[tool.ruff]
|
| 61 |
line-length = 88
|