Spaces:
Runtime error
Runtime error
| from typing import List, Optional | |
| import torch | |
| import streamlit as st | |
| import pandas as pd | |
| import random | |
| import time | |
| import logging | |
| import shutil | |
| from json import JSONDecodeError | |
| import os | |
| from transformers import AutoModelForSequenceClassification, AutoTokenizer, AutoConfig | |
| from haystack import Document | |
| from haystack.document_stores import FAISSDocumentStore | |
| from haystack.modeling.utils import initialize_device_settings | |
| from haystack.nodes import EmbeddingRetriever | |
| from haystack.pipelines import Pipeline | |
| from haystack.nodes.base import BaseComponent | |
| from haystack.schema import Document | |
| from config import ( | |
| RETRIEVER_TOP_K, | |
| RETRIEVER_MODEL, | |
| NLI_MODEL, | |
| ) | |
| class EntailmentChecker(BaseComponent): | |
| """ | |
| This node checks the entailment between every document content and the statement. | |
| It enrichs the documents metadata with entailment informations. | |
| It also returns aggregate entailment information. | |
| """ | |
| outgoing_edges = 1 | |
| def __init__( | |
| self, | |
| model_name_or_path: str = "roberta-large-mnli", | |
| model_version: Optional[str] = None, | |
| tokenizer: Optional[str] = None, | |
| use_gpu: bool = True, | |
| batch_size: int = 100, | |
| entailment_contradiction_consideration: float = 0.7, | |
| entailment_contradiction_threshold: float = 0.95 | |
| ): | |
| """ | |
| Load a Natural Language Inference model from Transformers. | |
| :param model_name_or_path: Directory of a saved model or the name of a public model. | |
| See https://huggingface.co/models for full list of available models. | |
| :param model_version: The version of model to use from the HuggingFace model hub. Can be tag name, branch name, or commit hash. | |
| :param tokenizer: Name of the tokenizer (usually the same as model) | |
| :param use_gpu: Whether to use GPU (if available). | |
| :param batch_size: Number of Documents to be processed at a time. | |
| :param entailment_contradiction_threshold: Only consider sentences that have entailment or contradiction score greater than this param. | |
| """ | |
| super().__init__() | |
| self.devices, _ = initialize_device_settings(use_cuda=use_gpu, multi_gpu=False) | |
| tokenizer = tokenizer or model_name_or_path | |
| self.tokenizer = AutoTokenizer.from_pretrained(tokenizer) | |
| self.model = AutoModelForSequenceClassification.from_pretrained( | |
| pretrained_model_name_or_path=model_name_or_path, revision=model_version | |
| ) | |
| self.batch_size = batch_size | |
| self.entailment_contradiction_threshold = entailment_contradiction_threshold | |
| self.entailment_contradiction_consideration = entailment_contradiction_consideration | |
| self.model.to(str(self.devices[0])) | |
| id2label = AutoConfig.from_pretrained(model_name_or_path).id2label | |
| self.labels = [id2label[k].lower() for k in sorted(id2label)] | |
| if "entailment" not in self.labels: | |
| raise ValueError("The model config must contain entailment value in the id2label dict.") | |
| def run(self, query: str, documents: List[Document]): | |
| scores, agg_con, agg_neu, agg_ent = 0, 0, 0, 0 | |
| premise_batch = [doc.content for doc in documents] | |
| hypothesis_batch = [query] * len(documents) | |
| entailment_info_batch = self.get_entailment_batch( | |
| premise_batch=premise_batch, hypothesis_batch=hypothesis_batch | |
| ) | |
| considered_documents = [] | |
| for i, (doc, entailment_info) in enumerate(zip(documents, entailment_info_batch)): | |
| doc.meta["entailment_info"] = entailment_info | |
| con, neu, ent = ( | |
| entailment_info["contradiction"], | |
| entailment_info["neutral"], | |
| entailment_info["entailment"], | |
| ) | |
| if (con > self.entailment_contradiction_consideration) or (ent > self.entailment_contradiction_consideration): | |
| considered_documents.append(doc) | |
| agg_con += con | |
| agg_neu += neu | |
| agg_ent += ent | |
| scores += 1 | |
| if max(agg_con, agg_ent)/scores > self.entailment_contradiction_threshold: | |
| break | |
| # if in the first documents there is a strong evidence of entailment/contradiction, | |
| # there is no need to consider less relevant documents | |
| if scores > 0: | |
| aggregate_entailment_info = { | |
| "contradiction": round(agg_con / scores, 2), | |
| "neutral": round(agg_neu / scores, 2), | |
| "entailment": round(agg_ent / scores, 2), | |
| } | |
| entailment_checker_result = { | |
| "documents": considered_documents, | |
| "aggregate_entailment_info": aggregate_entailment_info, | |
| } | |
| else: | |
| aggregate_entailment_info = { | |
| "contradiction": 0, | |
| "neutral": 0, | |
| "entailment": 0, | |
| } | |
| entailment_checker_result = { | |
| "documents": considered_documents, | |
| "aggregate_entailment_info": aggregate_entailment_info, | |
| } | |
| return entailment_checker_result, "output_1" | |
| def run_batch(self, queries: List[str], documents: List[Document]): | |
| entailment_checker_result_batch = [] | |
| entailment_info_batch = self.get_entailment_batch(premise_batch=documents, hypothesis_batch=queries) | |
| for doc, entailment_info in zip(documents, entailment_info_batch): | |
| doc.meta["entailment_info"] = entailment_info | |
| aggregate_entailment_info = { | |
| "contradiction": round(entailment_info["contradiction"] / doc.score), | |
| "neutral": round(entailment_info["neutral"] / doc.score), | |
| "entailment": round(entailment_info["entailment"] / doc.score), | |
| } | |
| entailment_checker_result_batch.append( | |
| { | |
| "documents": [doc], | |
| "aggregate_entailment_info": aggregate_entailment_info, | |
| } | |
| ) | |
| return entailment_checker_result_batch, "output_1" | |
| def get_entailment_dict(self, probs): | |
| return {k.lower(): v for k, v in zip(self.labels, probs)} | |
| def get_entailment_batch(self, premise_batch: List[str], hypothesis_batch: List[str]): | |
| formatted_texts = [ | |
| f"{premise}{self.tokenizer.sep_token}{hypothesis}" | |
| for premise, hypothesis in zip(premise_batch, hypothesis_batch) | |
| ] | |
| with torch.inference_mode(): | |
| inputs = self.tokenizer(formatted_texts, return_tensors="pt", padding=True, truncation=True).to( | |
| self.devices[0] | |
| ) | |
| out = self.model(**inputs) | |
| logits = out.logits | |
| probs_batch = torch.nn.functional.softmax(logits, dim=-1).detach().cpu().numpy() | |
| return [self.get_entailment_dict(probs) for probs in probs_batch] | |
| # cached to make index and models load only at start | |
| def start_haystack(): | |
| """ | |
| load document store, retriever, entailment checker and create pipeline | |
| """ | |
| shutil.copy("./data/final_faiss_document_store.db", ".") | |
| document_store = FAISSDocumentStore( | |
| faiss_index_path=f"./data/my_faiss_index.faiss", | |
| faiss_config_path=f"./data/my_faiss_index.json", | |
| ) | |
| print(f"Index size: {document_store.get_document_count()}") | |
| retriever = EmbeddingRetriever( | |
| document_store=document_store, | |
| embedding_model=RETRIEVER_MODEL | |
| ) | |
| entailment_checker = EntailmentChecker( | |
| model_name_or_path=NLI_MODEL, | |
| use_gpu=False, | |
| ) | |
| pipe = Pipeline() | |
| pipe.add_node(component=retriever, name="retriever", inputs=["Query"]) | |
| pipe.add_node(component=entailment_checker, name="ec", inputs=["retriever"]) | |
| return pipe | |
| pipe = start_haystack() | |
| def check_statement(statement: str, retriever_top_k: int = 5): | |
| """Run query and verify statement""" | |
| params = {"retriever": {"top_k": retriever_top_k}} | |
| return pipe.run(statement, params=params) | |
| def set_state_if_absent(key, value): | |
| if key not in st.session_state: | |
| st.session_state[key] = value | |
| # Small callback to reset the interface in case the text of the question changes | |
| def reset_results(*args): | |
| st.session_state.answer = None | |
| st.session_state.results = None | |
| st.session_state.raw_json = None | |
| def create_df_for_relevant_snippets(docs): | |
| """ | |
| Create a dataframe that contains all relevant snippets. | |
| """ | |
| if len(docs) == 0: | |
| return "Não foram encontradas informações na base de sentenças verdadeiras" | |
| rows = [] | |
| for doc in docs: | |
| row = { | |
| "Content": doc.content, | |
| "con": f"{doc.meta['entailment_info']['contradiction']:.2f}", | |
| "neu": f"{doc.meta['entailment_info']['neutral']:.2f}", | |
| "ent": f"{doc.meta['entailment_info']['entailment']:.2f}", | |
| } | |
| rows.append(row) | |
| df = pd.DataFrame(rows) | |
| df["Content"] = df["Content"].str.wrap(75) | |
| df = df.style.apply(highlight_cols) | |
| return df | |
| def highlight_cols(s): | |
| coldict = {"con": "#FFA07A", "neu": "#E5E4E2", "ent": "#a9d39e"} | |
| if s.name in coldict.keys(): | |
| return ["background-color: {}".format(coldict[s.name])] * len(s) | |
| return [""] * len(s) | |
| def main(): | |
| # Persistent state | |
| set_state_if_absent("statement", "") | |
| set_state_if_absent("answer", "") | |
| set_state_if_absent("results", None) | |
| set_state_if_absent("raw_json", None) | |
| st.write("# Verificação de Sentenças sobre Amazônia Azul") | |
| st.write() | |
| st.markdown( | |
| """ | |
| ##### Insira uma sentença sobre a amazônia azul. | |
| """ | |
| ) | |
| # Search bar | |
| statement = st.text_input( | |
| "", max_chars=100, on_change=reset_results | |
| ) | |
| st.markdown("<style>.stButton button {width:100%;}</style>", unsafe_allow_html=True) | |
| run_pressed = st.button("Run") | |
| run_query = ( | |
| run_pressed or statement != st.session_state.statement | |
| ) | |
| # Get results for query | |
| if run_query and statement: | |
| time_start = time.time() | |
| reset_results() | |
| st.session_state.statement = statement | |
| with st.spinner(" Procurando a Similaridade no banco de sentenças..."): | |
| try: | |
| st.session_state.results = check_statement(statement, RETRIEVER_TOP_K) | |
| print(f"S: {statement}") | |
| time_end = time.time() | |
| print(time.strftime("%Y-%m-%d %H:%M:%S", time.gmtime())) | |
| print(f"elapsed time: {time_end - time_start}") | |
| except JSONDecodeError as je: | |
| st.error( | |
| "👓 Erro na document store." | |
| ) | |
| return | |
| except Exception as e: | |
| logging.exception(e) | |
| st.error("🐞 Erro Genérico.") | |
| return | |
| # Display results | |
| if st.session_state.results: | |
| docs = st.session_state.results["documents"] | |
| agg_entailment_info = st.session_state.results["aggregate_entailment_info"] | |
| st.markdown(f"###### Aggregate entailment information:") | |
| st.write(agg_entailment_info) | |
| st.markdown(f"###### Most Relevant snippets:") | |
| df = create_df_for_relevant_snippets(docs) | |
| if isinstance(df, str): | |
| st.markdown(df) | |
| else: | |
| st.dataframe(df) | |
| main() |