Spaces:
Build error
Build error
| from datasets import load_dataset | |
| from transformers import DPRContextEncoderTokenizer, DPRContextEncoder | |
| from general_utils import embed_passages, embed_passages_haystack | |
| import faiss | |
| import argparse | |
| import os | |
| from haystack.nodes import DensePassageRetriever | |
| from haystack.document_stores import InMemoryDocumentStore | |
| os.environ["OMP_NUM_THREADS"] = "8" | |
| def create_faiss_index(args): | |
| minchars = 200 | |
| dims = 128 | |
| dpr = DensePassageRetriever( | |
| document_store=InMemoryDocumentStore(), | |
| query_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", | |
| passage_embedding_model="IIC/dpr-spanish-question_encoder-allqa-base", | |
| max_seq_len_query=64, | |
| max_seq_len_passage=256, | |
| batch_size=512, | |
| ) | |
| dataset = load_dataset( | |
| "IIC/spanish_biomedical_crawled_corpus", split="train" | |
| ) | |
| dataset = dataset.filter(lambda example: len(example["text"]) > minchars) | |
| def embed_passages_retrieval(examples): | |
| return embed_passages_haystack(dpr, examples) | |
| dataset = dataset.map(embed_passages_retrieval, batched=True, batch_size=8192) | |
| dataset.add_faiss_index( | |
| column="embeddings", | |
| string_factory="OPQ64_128,IVF4898,PQ64x4fsr", | |
| train_size=len(dataset), | |
| ) | |
| dataset.save_faiss_index("embeddings", args.index_file_name) | |
| if __name__ == "__main__": | |
| parser = argparse.ArgumentParser(description="Creates Faiss Wikipedia index file") | |
| parser.add_argument( | |
| "--ctx_encoder_name", | |
| default="IIC/dpr-spanish-passage_encoder-squades-base", | |
| help="Encoding model to use for passage encoding", | |
| ) | |
| parser.add_argument( | |
| "--index_file_name", | |
| default="dpr_index_bio_splitted.faiss", | |
| help="Faiss index file with passage embeddings", | |
| ) | |
| parser.add_argument( | |
| "--device", default="cuda:0", help="The device to index data on." | |
| ) | |
| main_args, _ = parser.parse_known_args() | |
| create_faiss_index(main_args) | |