Spaces:
Runtime error
Runtime error
| import logging | |
| from pathlib import Path | |
| from typing import List, Optional, Union | |
| from relik.common.utils import is_package_available | |
| from relik.inference.annotator import Relik | |
| if not is_package_available("fastapi"): | |
| raise ImportError( | |
| "FastAPI is not installed. Please install FastAPI with `pip install relik[serve]`." | |
| ) | |
| from fastapi import FastAPI, HTTPException | |
| if not is_package_available("ray"): | |
| raise ImportError( | |
| "Ray is not installed. Please install Ray with `pip install relik[serve]`." | |
| ) | |
| from ray import serve | |
| from relik.common.log import get_logger | |
| from relik.inference.serve.backend.utils import ( | |
| RayParameterManager, | |
| ServerParameterManager, | |
| ) | |
| from relik.retriever.data.utils import batch_generator | |
| logger = get_logger(__name__, level=logging.INFO) | |
| VERSION = {} # type: ignore | |
| with open( | |
| Path(__file__).parent.parent.parent.parent / "version.py", "r" | |
| ) as version_file: | |
| exec(version_file.read(), VERSION) | |
| # Env variables for server | |
| SERVER_MANAGER = ServerParameterManager() | |
| RAY_MANAGER = RayParameterManager() | |
| app = FastAPI( | |
| title="ReLiK", | |
| version=VERSION["VERSION"], | |
| description="ReLiK REST API", | |
| ) | |
| class RelikServer: | |
| def __init__( | |
| self, | |
| question_encoder: str, | |
| document_index: str, | |
| passage_encoder: Optional[str] = None, | |
| reader_encoder: Optional[str] = None, | |
| top_k: int = 100, | |
| retriver_device: str = "cpu", | |
| reader_device: str = "cpu", | |
| index_device: Optional[str] = None, | |
| precision: int = 32, | |
| index_precision: Optional[int] = None, | |
| use_faiss: bool = False, | |
| window_batch_size: int = 32, | |
| window_size: int = 32, | |
| window_stride: int = 16, | |
| split_on_spaces: bool = False, | |
| ): | |
| # parameters | |
| self.question_encoder = question_encoder | |
| self.passage_encoder = passage_encoder | |
| self.reader_encoder = reader_encoder | |
| self.document_index = document_index | |
| self.top_k = top_k | |
| self.retriver_device = retriver_device | |
| self.index_device = index_device or retriver_device | |
| self.reader_device = reader_device | |
| self.precision = precision | |
| self.index_precision = index_precision or precision | |
| self.use_faiss = use_faiss | |
| self.window_batch_size = window_batch_size | |
| self.window_size = window_size | |
| self.window_stride = window_stride | |
| self.split_on_spaces = split_on_spaces | |
| # log stuff for debugging | |
| logger.info("Initializing RelikServer with parameters:") | |
| logger.info(f"QUESTION_ENCODER: {self.question_encoder}") | |
| logger.info(f"PASSAGE_ENCODER: {self.passage_encoder}") | |
| logger.info(f"READER_ENCODER: {self.reader_encoder}") | |
| logger.info(f"DOCUMENT_INDEX: {self.document_index}") | |
| logger.info(f"TOP_K: {self.top_k}") | |
| logger.info(f"RETRIEVER_DEVICE: {self.retriver_device}") | |
| logger.info(f"READER_DEVICE: {self.reader_device}") | |
| logger.info(f"INDEX_DEVICE: {self.index_device}") | |
| logger.info(f"PRECISION: {self.precision}") | |
| logger.info(f"INDEX_PRECISION: {self.index_precision}") | |
| logger.info(f"WINDOW_BATCH_SIZE: {self.window_batch_size}") | |
| logger.info(f"SPLIT_ON_SPACES: {self.split_on_spaces}") | |
| self.relik = Relik( | |
| question_encoder=self.question_encoder, | |
| passage_encoder=self.passage_encoder, | |
| document_index=self.document_index, | |
| reader=self.reader_encoder, | |
| retriever_device=self.retriver_device, | |
| document_index_device=self.index_device, | |
| reader_device=self.reader_device, | |
| retriever_precision=self.precision, | |
| document_index_precision=self.index_precision, | |
| reader_precision=self.precision, | |
| ) | |
| # @serve.batch() | |
| async def handle_batch(self, documents: List[str]) -> List: | |
| return self.relik( | |
| documents, | |
| top_k=self.top_k, | |
| window_size=self.window_size, | |
| window_stride=self.window_stride, | |
| batch_size=self.window_batch_size, | |
| ) | |
| async def entities_endpoint( | |
| self, | |
| documents: Union[str, List[str]], | |
| ): | |
| try: | |
| # normalize input | |
| if isinstance(documents, str): | |
| documents = [documents] | |
| if document_topics is not None: | |
| if isinstance(document_topics, str): | |
| document_topics = [document_topics] | |
| assert len(documents) == len(document_topics) | |
| # get predictions for the retriever | |
| return await self.handle_batch(documents, document_topics) | |
| except Exception as e: | |
| # log the entire stack trace | |
| logger.exception(e) | |
| raise HTTPException(status_code=500, detail=f"Server Error: {e}") | |
| async def gerbil_endpoint(self, documents: Union[str, List[str]]): | |
| try: | |
| # normalize input | |
| if isinstance(documents, str): | |
| documents = [documents] | |
| # output list | |
| windows_passages = [] | |
| # split documents into windows | |
| document_windows = [ | |
| window | |
| for doc_id, document in enumerate(documents) | |
| for window in self.window_manager( | |
| self.tokenizer, | |
| document, | |
| window_size=self.window_size, | |
| stride=self.window_stride, | |
| doc_id=doc_id, | |
| ) | |
| ] | |
| # get text and topic from document windows and create new list | |
| model_inputs = [ | |
| (window.text, window.doc_topic) for window in document_windows | |
| ] | |
| # batch generator | |
| for batch in batch_generator( | |
| model_inputs, batch_size=self.window_batch_size | |
| ): | |
| text, text_pair = zip(*batch) | |
| batch_predictions = await self.handle_batch_retriever(text, text_pair) | |
| windows_passages.extend( | |
| [ | |
| [p.label for p in predictions] | |
| for predictions in batch_predictions | |
| ] | |
| ) | |
| # add passage to document windows | |
| for window, passages in zip(document_windows, windows_passages): | |
| # clean up passages (remove everything after first <def> tag if present) | |
| passages = [c.split(" <def>", 1)[0] for c in passages] | |
| window.window_candidates = passages | |
| # return document windows | |
| return document_windows | |
| except Exception as e: | |
| # log the entire stack trace | |
| logger.exception(e) | |
| raise HTTPException(status_code=500, detail=f"Server Error: {e}") | |
| server = RelikServer.bind(**vars(SERVER_MANAGER)) | |