update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import json | |
| import logging | |
| import os | |
| import shutil | |
| from itertools import islice | |
| from typing import Iterator, List, Optional, Sequence, Tuple | |
| from langchain.storage import create_kv_docstore | |
| from langchain_core.documents import Document as LCDocument | |
| from langchain_core.stores import BaseStore, ByteStore | |
| from pie_datasets import Dataset, DatasetDict | |
| from .pie_document_store import PieDocumentStore | |
| logger = logging.getLogger(__name__) | |
| class BasicPieDocumentStore(PieDocumentStore): | |
| """PIE Document store that uses a client to store and retrieve documents.""" | |
| def __init__( | |
| self, | |
| client: Optional[BaseStore[str, LCDocument]] = None, | |
| byte_store: Optional[ByteStore] = None, | |
| ): | |
| if byte_store is not None: | |
| client = create_kv_docstore(byte_store) | |
| elif client is None: | |
| raise Exception("You must pass a `byte_store` parameter.") | |
| self.client = client | |
| def mget(self, keys: Sequence[str]) -> List[LCDocument]: | |
| return self.client.mget(keys) | |
| def mset(self, items: Sequence[Tuple[str, LCDocument]]) -> None: | |
| self.client.mset(items) | |
| def mdelete(self, keys: Sequence[str]) -> None: | |
| self.client.mdelete(keys) | |
| def yield_keys(self, *, prefix: Optional[str] = None) -> Iterator[str]: | |
| return self.client.yield_keys(prefix=prefix) | |
| def _save_to_directory(self, path: str, batch_size: Optional[int] = None, **kwargs) -> None: | |
| all_doc_ids = [] | |
| all_metadata = [] | |
| pie_documents_path = os.path.join(path, "pie_documents") | |
| if os.path.exists(pie_documents_path): | |
| # remove existing directory | |
| logger.warning(f"Removing existing directory: {pie_documents_path}") | |
| shutil.rmtree(pie_documents_path) | |
| os.makedirs(pie_documents_path, exist_ok=True) | |
| doc_ids_iter = iter(self.client.yield_keys()) | |
| mode = "w" | |
| while batch_doc_ids := list(islice(doc_ids_iter, batch_size or 1000)): | |
| all_doc_ids.extend(batch_doc_ids) | |
| docs = self.client.mget(batch_doc_ids) | |
| pie_docs = [] | |
| for doc in docs: | |
| pie_doc = doc.metadata[self.METADATA_KEY_PIE_DOCUMENT] | |
| pie_docs.append(pie_doc) | |
| all_metadata.append( | |
| {k: v for k, v in doc.metadata.items() if k != self.METADATA_KEY_PIE_DOCUMENT} | |
| ) | |
| pie_dataset = Dataset.from_documents(pie_docs) | |
| DatasetDict({"train": pie_dataset}).to_json(path=pie_documents_path, mode=mode) | |
| mode = "a" # append after the first batch | |
| if len(all_doc_ids) > 0: | |
| doc_ids_path = os.path.join(path, "doc_ids.json") | |
| with open(doc_ids_path, "w") as f: | |
| json.dump(all_doc_ids, f) | |
| if len(all_metadata) > 0: | |
| metadata_path = os.path.join(path, "metadata.json") | |
| with open(metadata_path, "w") as f: | |
| json.dump(all_metadata, f) | |
| def _load_from_directory(self, path: str, **kwargs) -> None: | |
| pie_documents_path = os.path.join(path, "pie_documents") | |
| if not os.path.exists(pie_documents_path): | |
| logger.warning( | |
| f"Directory {pie_documents_path} does not exist, don't load any documents." | |
| ) | |
| return None | |
| pie_dataset = DatasetDict.from_json(data_dir=pie_documents_path) | |
| pie_docs = pie_dataset["train"] | |
| metadata_path = os.path.join(path, "metadata.json") | |
| if os.path.exists(metadata_path): | |
| with open(metadata_path, "r") as f: | |
| all_metadata = json.load(f) | |
| else: | |
| logger.warning(f"File {metadata_path} does not exist, don't load any metadata.") | |
| all_metadata = [{} for _ in pie_docs] | |
| docs = [ | |
| self.wrap(pie_doc, **metadata) for pie_doc, metadata in zip(pie_docs, all_metadata) | |
| ] | |
| doc_ids_path = os.path.join(path, "doc_ids.json") | |
| if os.path.exists(doc_ids_path): | |
| with open(doc_ids_path, "r") as f: | |
| all_doc_ids = json.load(f) | |
| else: | |
| logger.warning(f"File {doc_ids_path} does not exist, don't load any document ids.") | |
| all_doc_ids = [doc.id for doc in pie_docs] | |
| self.client.mset(zip(all_doc_ids, docs)) | |
| logger.info(f"Loaded {len(docs)} documents from {path} into docstore") | |