Spaces:
Runtime error
Runtime error
| import argparse | |
| import json | |
| import os | |
| import re | |
| import sys | |
| from http.server import BaseHTTPRequestHandler, HTTPServer | |
| from typing import Iterator, List, Optional, Tuple | |
| from relik.inference.annotator import Relik | |
| from relik.inference.data.objects import RelikOutput | |
| # sys.path += ['../'] | |
| sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "../"))) | |
| import logging | |
| logger = logging.getLogger(__name__) | |
| class GerbilAlbyManager: | |
| def __init__( | |
| self, | |
| annotator: Optional[Relik] = None, | |
| response_logger_dir: Optional[str] = None, | |
| ) -> None: | |
| self.annotator = annotator | |
| self.response_logger_dir = response_logger_dir | |
| self.predictions_counter = 0 | |
| self.labels_mapping = None | |
| def annotate(self, document: str): | |
| relik_output: RelikOutput = self.annotator(document) | |
| annotations = [(ss, se, l) for ss, se, l, _ in relik_output.labels] | |
| if self.labels_mapping is not None: | |
| return [ | |
| (ss, se, self.labels_mapping.get(l, l)) for ss, se, l in annotations | |
| ] | |
| return annotations | |
| def set_mapping_file(self, mapping_file_path: str): | |
| with open(mapping_file_path) as f: | |
| labels_mapping = json.load(f) | |
| self.labels_mapping = {v: k for k, v in labels_mapping.items()} | |
| def write_response_bundle( | |
| self, | |
| document: str, | |
| new_document: str, | |
| annotations: list, | |
| mapped_annotations: list, | |
| ) -> None: | |
| if self.response_logger_dir is None: | |
| return | |
| if not os.path.isdir(self.response_logger_dir): | |
| os.mkdir(self.response_logger_dir) | |
| with open( | |
| f"{self.response_logger_dir}/{self.predictions_counter}.json", "w" | |
| ) as f: | |
| out_json_obj = dict( | |
| document=document, | |
| new_document=new_document, | |
| annotations=annotations, | |
| mapped_annotations=mapped_annotations, | |
| ) | |
| out_json_obj["span_annotations"] = [ | |
| (ss, se, document[ss:se], label) for (ss, se, label) in annotations | |
| ] | |
| out_json_obj["span_mapped_annotations"] = [ | |
| (ss, se, new_document[ss:se], label) | |
| for (ss, se, label) in mapped_annotations | |
| ] | |
| json.dump(out_json_obj, f, indent=2) | |
| self.predictions_counter += 1 | |
| manager = GerbilAlbyManager() | |
| def preprocess_document(document: str) -> Tuple[str, List[Tuple[int, int]]]: | |
| pattern_subs = { | |
| "-LPR- ": " (", | |
| "-RPR-": ")", | |
| "\n\n": "\n", | |
| "-LRB-": "(", | |
| "-RRB-": ")", | |
| '","': ",", | |
| } | |
| document_acc = document | |
| curr_offset = 0 | |
| char2offset = [] | |
| matchings = re.finditer("({})".format("|".join(pattern_subs)), document) | |
| for span_matching in sorted(matchings, key=lambda x: x.span()[0]): | |
| span_start, span_end = span_matching.span() | |
| span_start -= curr_offset | |
| span_end -= curr_offset | |
| span_text = document_acc[span_start:span_end] | |
| span_sub = pattern_subs[span_text] | |
| document_acc = document_acc[:span_start] + span_sub + document_acc[span_end:] | |
| offset = len(span_text) - len(span_sub) | |
| curr_offset += offset | |
| char2offset.append((span_start + len(span_sub), curr_offset)) | |
| return document_acc, char2offset | |
| def map_back_annotations( | |
| annotations: List[Tuple[int, int, str]], char_mapping: List[Tuple[int, int]] | |
| ) -> Iterator[Tuple[int, int, str]]: | |
| def map_char(char_idx: int) -> int: | |
| current_offset = 0 | |
| for offset_idx, offset_value in char_mapping: | |
| if char_idx >= offset_idx: | |
| current_offset = offset_value | |
| else: | |
| break | |
| return char_idx + current_offset | |
| for ss, se, label in annotations: | |
| yield map_char(ss), map_char(se), label | |
| def annotate(document: str) -> List[Tuple[int, int, str]]: | |
| new_document, mapping = preprocess_document(document) | |
| logger.info("Mapping: " + str(mapping)) | |
| logger.info("Document: " + str(document)) | |
| annotations = [ | |
| (cs, ce, label.replace(" ", "_")) | |
| for cs, ce, label in manager.annotate(new_document) | |
| ] | |
| logger.info("New document: " + str(new_document)) | |
| mapped_annotations = ( | |
| list(map_back_annotations(annotations, mapping)) | |
| if len(mapping) > 0 | |
| else annotations | |
| ) | |
| logger.info( | |
| "Annotations: " | |
| + str([(ss, se, document[ss:se], ann) for ss, se, ann in mapped_annotations]) | |
| ) | |
| manager.write_response_bundle( | |
| document, new_document, mapped_annotations, annotations | |
| ) | |
| if not all( | |
| [ | |
| new_document[ss:se] == document[mss:mse] | |
| for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) | |
| ] | |
| ): | |
| diff_mappings = [ | |
| (new_document[ss:se], document[mss:mse]) | |
| for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) | |
| ] | |
| return None | |
| assert all( | |
| [ | |
| document[mss:mse] == new_document[ss:se] | |
| for (mss, mse, _), (ss, se, _) in zip(mapped_annotations, annotations) | |
| ] | |
| ), (mapped_annotations, annotations) | |
| return [(cs, ce - cs, label) for cs, ce, label in mapped_annotations] | |
| class GetHandler(BaseHTTPRequestHandler): | |
| def do_POST(self): | |
| content_length = int(self.headers["Content-Length"]) | |
| post_data = self.rfile.read(content_length) | |
| self.send_response(200) | |
| self.end_headers() | |
| doc_text = read_json(post_data) | |
| # try: | |
| response = annotate(doc_text) | |
| self.wfile.write(bytes(json.dumps(response), "utf-8")) | |
| return | |
| def read_json(post_data): | |
| data = json.loads(post_data.decode("utf-8")) | |
| # logger.info("received data:", data) | |
| text = data["text"] | |
| # spans = [(int(j["start"]), int(j["length"])) for j in data["spans"]] | |
| return text | |
| def parse_args() -> argparse.Namespace: | |
| parser = argparse.ArgumentParser() | |
| parser.add_argument("--relik-model-name", required=True) | |
| parser.add_argument("--responses-log-dir") | |
| parser.add_argument("--log-file", default="logs/logging.txt") | |
| parser.add_argument("--mapping-file") | |
| return parser.parse_args() | |
| def main(): | |
| args = parse_args() | |
| # init manager | |
| manager.response_logger_dir = args.responses_log_dir | |
| # manager.annotator = Relik.from_pretrained(args.relik_model_name) | |
| print("Debugging, not using you relik model but an hardcoded one.") | |
| manager.annotator = Relik( | |
| question_encoder="riccorl/relik-retriever-aida-blink-pretrain-omniencoder", | |
| document_index="riccorl/index-relik-retriever-aida-blink-pretrain-omniencoder", | |
| reader="relik/reader/models/relik-reader-deberta-base-new-data", | |
| window_size=32, | |
| window_stride=16, | |
| candidates_preprocessing_fn=(lambda x: x.split("<def>")[0].strip()), | |
| ) | |
| if args.mapping_file is not None: | |
| manager.set_mapping_file(args.mapping_file) | |
| port = 6654 | |
| server = HTTPServer(("localhost", port), GetHandler) | |
| logger.info(f"Starting server at http://localhost:{port}") | |
| # Create a file handler and set its level | |
| file_handler = logging.FileHandler(args.log_file) | |
| file_handler.setLevel(logging.DEBUG) | |
| # Create a log formatter and set it on the handler | |
| formatter = logging.Formatter( | |
| "%(asctime)s - %(name)s - %(levelname)s - %(message)s" | |
| ) | |
| file_handler.setFormatter(formatter) | |
| # Add the file handler to the logger | |
| logger.addHandler(file_handler) | |
| try: | |
| server.serve_forever() | |
| except KeyboardInterrupt: | |
| exit(0) | |
| if __name__ == "__main__": | |
| main() | |