update from https://github.com/ArneBinder/argumentation-structure-identification/pull/529
d868d2e
verified
| import hydra | |
| import pyrootutils | |
| from omegaconf import DictConfig, OmegaConf, SCMode | |
| root = pyrootutils.setup_root( | |
| search_from=__file__, | |
| indicator=[".project-root"], | |
| pythonpath=True, | |
| dotenv=True, | |
| ) | |
| import json | |
| import logging | |
| import gradio as gr | |
| import torch | |
| import yaml | |
| from src.demo.annotation_utils import load_argumentation_model | |
| from src.demo.backend_utils import ( | |
| download_processed_documents, | |
| load_acl_anthology_venues, | |
| process_text_from_arxiv, | |
| process_uploaded_files, | |
| process_uploaded_pdf_files, | |
| render_annotated_document, | |
| upload_processed_documents, | |
| wrapped_add_annotated_pie_documents_from_dataset, | |
| wrapped_process_text, | |
| ) | |
| from src.demo.frontend_utils import ( | |
| change_tab, | |
| escape_regex, | |
| get_cell_for_fixed_column_from_df, | |
| open_accordion, | |
| open_accordion_with_stats, | |
| unescape_regex, | |
| ) | |
| from src.demo.rendering_utils import AVAILABLE_RENDER_MODES, HIGHLIGHT_SPANS_JS | |
| from src.demo.retriever_utils import ( | |
| get_document_as_dict, | |
| get_span_annotation, | |
| load_retriever, | |
| retrieve_all_relevant_spans, | |
| retrieve_all_similar_spans, | |
| retrieve_relevant_spans, | |
| retrieve_similar_spans, | |
| ) | |
| def load_yaml_config(path: str) -> str: | |
| with open(path, "r") as file: | |
| yaml_string = file.read() | |
| config = yaml.safe_load(yaml_string) | |
| return yaml.dump(config) | |
| def resolve_config(cfg) -> dict: | |
| return OmegaConf.to_container(cfg, resolve=True, structured_config_mode=SCMode.DICT) | |
| def main(cfg: DictConfig) -> None: | |
| # configure logging | |
| logging.basicConfig() | |
| # resolve everything in the config to prevent any issues with to json serialization etc. | |
| cfg = resolve_config(cfg) | |
| example_text = cfg["example_text"] | |
| default_device = "cuda:0" if torch.cuda.is_available() else "cpu" | |
| default_retriever_config_str = yaml.dump(cfg["retriever"]) | |
| default_argumentation_model_config_str = yaml.dump(cfg["argumentation_model"]) | |
| handle_parts_of_same = cfg["handle_parts_of_same"] | |
| default_arxiv_id = cfg["default_arxiv_id"] | |
| default_load_pie_dataset_kwargs_str = json.dumps( | |
| cfg["default_load_pie_dataset_kwargs"], indent=2 | |
| ) | |
| default_render_mode = cfg["default_render_mode"] | |
| if default_render_mode not in AVAILABLE_RENDER_MODES: | |
| raise ValueError( | |
| f"Invalid default render mode '{default_render_mode}'. " | |
| f"Choose one of {AVAILABLE_RENDER_MODES}." | |
| ) | |
| default_render_kwargs = cfg["default_render_kwargs"] | |
| # captions for better readability | |
| default_split_regex = cfg["default_split_regex"] | |
| # map from render mode to the corresponding caption | |
| render_mode2caption = { | |
| render_mode: cfg["render_mode_captions"].get(render_mode, render_mode) | |
| for render_mode in AVAILABLE_RENDER_MODES | |
| } | |
| render_caption2mode = {v: k for k, v in render_mode2caption.items()} | |
| default_min_similarity = cfg["default_min_similarity"] | |
| default_top_k = cfg["default_top_k"] | |
| default_min_score = cfg["default_min_score"] | |
| layer_caption_mapping = cfg["layer_caption_mapping"] | |
| relation_name_mapping = cfg["relation_name_mapping"] | |
| indexed_documents_label = "Indexed Documents" | |
| indexed_documents_caption2column = { | |
| "documents": "TOTAL", | |
| "ADUs": "num_adus", | |
| "Relations": "num_relations", | |
| } | |
| gr.Info("Loading models ...") | |
| argumentation_model = load_argumentation_model( | |
| config_str=default_argumentation_model_config_str, | |
| device=default_device, | |
| ) | |
| retriever = load_retriever( | |
| config_str=default_retriever_config_str, device=default_device, config_format="yaml" | |
| ) | |
| if cfg.get("pdf_fulltext_extractor"): | |
| gr.Info("Loading PDF fulltext extractor ...") | |
| pdf_fulltext_extractor = hydra.utils.instantiate(cfg["pdf_fulltext_extractor"]) | |
| else: | |
| pdf_fulltext_extractor = None | |
| with gr.Blocks() as demo: | |
| # wrap the pipeline and the embedding model/tokenizer in a tuple to avoid that it gets called | |
| # models_state = gr.State((argumentation_model, embedding_model)) | |
| argumentation_model_state = gr.State((argumentation_model,)) | |
| retriever_state = gr.State((retriever,)) | |
| with gr.Row(): | |
| with gr.Tabs() as left_tabs: | |
| with gr.Tab("User Input", id="user_input") as user_input_tab: | |
| doc_id = gr.Textbox( | |
| label="Document ID", | |
| value="user_input", | |
| ) | |
| doc_text = gr.Textbox( | |
| label="Text", | |
| lines=20, | |
| value=example_text, | |
| ) | |
| with gr.Accordion("Model Configuration", open=False): | |
| with gr.Accordion("argumentation structure", open=True): | |
| argumentation_model_config_str = gr.Code( | |
| language="yaml", | |
| label="Argumentation Model Configuration", | |
| value=default_argumentation_model_config_str, | |
| lines=len(default_argumentation_model_config_str.split("\n")), | |
| ) | |
| load_arg_model_btn = gr.Button("Load Argumentation Model") | |
| with gr.Accordion("retriever", open=True): | |
| retriever_config_str = gr.Code( | |
| language="yaml", | |
| label="Retriever Configuration", | |
| value=default_retriever_config_str, | |
| lines=len(default_retriever_config_str.split("\n")), | |
| ) | |
| load_retriever_btn = gr.Button("Load Retriever") | |
| device = gr.Textbox( | |
| label="Device (e.g. 'cuda' or 'cpu')", | |
| value=default_device, | |
| ) | |
| load_arg_model_btn.click( | |
| fn=lambda _argumentation_model_config_str, _device: ( | |
| load_argumentation_model( | |
| config_str=_argumentation_model_config_str, | |
| device=_device, | |
| ), | |
| ), | |
| inputs=[argumentation_model_config_str, device], | |
| outputs=argumentation_model_state, | |
| ) | |
| load_retriever_btn.click( | |
| fn=lambda _retriever_config, _device, _previous_retriever: ( | |
| load_retriever( | |
| config_str=_retriever_config, | |
| device=_device, | |
| previous_retriever=_previous_retriever[0], | |
| config_format="yaml", | |
| ), | |
| ), | |
| inputs=[retriever_config_str, device, retriever_state], | |
| outputs=retriever_state, | |
| ) | |
| split_regex_escaped = gr.Textbox( | |
| label="Regex to partition the text", | |
| placeholder="Regular expression pattern to split the text into partitions", | |
| value=escape_regex(default_split_regex), | |
| ) | |
| predict_btn = gr.Button("Analyse") | |
| with gr.Tab("Analysed Document", id="analysed_document") as analysed_document_tab: | |
| selected_document_id = gr.Textbox( | |
| label="Document ID", max_lines=1, interactive=False | |
| ) | |
| rendered_output = gr.HTML(label="Rendered Output") | |
| with gr.Accordion("Render Options", open=False): | |
| render_as = gr.Dropdown( | |
| label="Render with", | |
| choices=list(render_mode2caption.values()), | |
| value=render_mode2caption[default_render_mode], | |
| ) | |
| render_kwargs = gr.Code( | |
| language="json", | |
| label="Render Arguments", | |
| lines=len(json.dumps(default_render_kwargs, indent=2).split("\n")), | |
| value=json.dumps(default_render_kwargs, indent=2), | |
| ) | |
| render_btn = gr.Button("Re-render") | |
| with gr.Accordion("See plain result ...", open=False): | |
| get_document_json_btn = gr.Button("Fetch annotated document as JSON") | |
| document_json = gr.JSON(label="Model Output") | |
| with gr.Tabs() as right_tabs: | |
| with gr.Tab("Retrieval", id="retrieval") as retrieval_tab: | |
| with gr.Accordion( | |
| indexed_documents_label, open=False | |
| ) as processed_documents_accordion: | |
| processed_documents_df = gr.DataFrame( | |
| headers=["id", "num_adus", "num_relations"], | |
| interactive=False, | |
| elem_classes="df-docstore", | |
| ) | |
| gr.Markdown("Data Snapshot:") | |
| with gr.Row(): | |
| download_processed_documents_btn = gr.DownloadButton("Download") | |
| upload_processed_documents_btn = gr.UploadButton( | |
| "Upload", file_types=["file"] | |
| ) | |
| # currently not used | |
| # relation_types = set_relation_types( | |
| # argumentation_model_state.value[0], default=["supports_reversed", "contradicts_reversed"] | |
| # ) | |
| # Dummy textbox to hold the hover adu id. On click on the rendered output, | |
| # its content will be copied to selected_adu_id which will trigger the retrieval. | |
| hover_adu_id = gr.Textbox( | |
| label="ID (hover)", | |
| elem_id="hover_adu_id", | |
| interactive=False, | |
| visible=False, | |
| ) | |
| selected_adu_id = gr.Textbox( | |
| label="ID (selected)", | |
| elem_id="selected_adu_id", | |
| interactive=False, | |
| visible=False, | |
| ) | |
| selected_adu_text = gr.Textbox(label="Selected ADU", interactive=False) | |
| with gr.Accordion("Relevant ADUs from other documents", open=True): | |
| relevant_adus_df = gr.DataFrame( | |
| headers=[ | |
| "relation", | |
| "adu", | |
| "reference_adu", | |
| "doc_id", | |
| "sim_score", | |
| "rel_score", | |
| ], | |
| interactive=False, | |
| ) | |
| with gr.Accordion("Retrieval Configuration", open=False): | |
| min_similarity = gr.Slider( | |
| label="Minimum Similarity", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=default_min_similarity, | |
| ) | |
| top_k = gr.Slider( | |
| label="Top K", | |
| minimum=2, | |
| maximum=50, | |
| step=1, | |
| value=default_top_k, | |
| ) | |
| min_score = gr.Slider( | |
| label="Minimum Score", | |
| minimum=0.0, | |
| maximum=1.0, | |
| step=0.01, | |
| value=default_min_score, | |
| ) | |
| retrieve_similar_adus_btn = gr.Button( | |
| "Retrieve *similar* ADUs for *selected* ADU" | |
| ) | |
| similar_adus_df = gr.DataFrame( | |
| headers=["doc_id", "adu_id", "score", "text"], interactive=False | |
| ) | |
| retrieve_all_similar_adus_btn = gr.Button( | |
| "Retrieve *similar* ADUs for *all* ADUs in the document" | |
| ) | |
| all_similar_adus_df = gr.DataFrame( | |
| headers=["doc_id", "query_adu_id", "adu_id", "score", "text"], | |
| interactive=False, | |
| ) | |
| retrieve_all_relevant_adus_btn = gr.Button( | |
| "Retrieve *relevant* ADUs for *all* ADUs in the document" | |
| ) | |
| all_relevant_adus_df = gr.DataFrame( | |
| headers=["doc_id", "adu_id", "score", "text", "query_span_id"], | |
| interactive=False, | |
| ) | |
| all_relevant_adus_query_doc_id = gr.Textbox(visible=False) | |
| with gr.Tab("Import Documents", id="import_documents") as import_documents_tab: | |
| upload_btn = gr.UploadButton( | |
| "Batch Analyse Texts", | |
| file_types=["text"], | |
| file_count="multiple", | |
| ) | |
| upload_pdf_btn = gr.UploadButton( | |
| "Batch Analyse PDFs", | |
| # file_types=["pdf"], | |
| file_count="multiple", | |
| visible=pdf_fulltext_extractor is not None, | |
| ) | |
| enable_acl_venue_loading = ( | |
| pdf_fulltext_extractor is not None | |
| and cfg.get("acl_anthology_data_dir") is not None | |
| ) | |
| acl_anthology_venues = gr.Textbox( | |
| label="ACL Anthology Venues", | |
| value="wiesp", | |
| max_lines=1, | |
| visible=enable_acl_venue_loading, | |
| ) | |
| load_acl_anthology_venues_btn = gr.Button( | |
| "Import from ACL Anthology", | |
| variant="secondary", | |
| visible=enable_acl_venue_loading, | |
| ) | |
| with gr.Accordion("Import text from arXiv", open=False): | |
| arxiv_id = gr.Textbox( | |
| label="arXiv paper ID", | |
| placeholder=f"e.g. {default_arxiv_id}", | |
| max_lines=1, | |
| ) | |
| load_arxiv_only_abstract = gr.Checkbox(label="abstract only", value=False) | |
| load_arxiv_btn = gr.Button( | |
| "Load & Analyse from arXiv", variant="secondary" | |
| ) | |
| with gr.Accordion( | |
| "Import argument structure annotated PIE dataset", open=False | |
| ): | |
| load_pie_dataset_kwargs_str = gr.Code( | |
| language="json", | |
| label="Parameters for Loading the PIE Dataset", | |
| value=default_load_pie_dataset_kwargs_str, | |
| lines=len(default_load_pie_dataset_kwargs_str.split("\n")), | |
| ) | |
| load_pie_dataset_btn = gr.Button("Load & Embed PIE Dataset") | |
| render_event_kwargs = dict( | |
| fn=lambda _rendered_output, _retriever, _document_id, _render_as, _render_kwargs, _all_relevant_adus_df, _all_relevant_adus_query_doc_id: ( | |
| render_annotated_document( | |
| retriever=_retriever[0], | |
| document_id=_document_id, | |
| render_with=render_caption2mode[_render_as], | |
| render_kwargs_json=_render_kwargs, | |
| highlight_span_ids=( | |
| _all_relevant_adus_df["query_span_id"].tolist() | |
| if _document_id == _all_relevant_adus_query_doc_id | |
| else None | |
| ), | |
| ) | |
| if _document_id.strip() != "" | |
| else _rendered_output | |
| ), | |
| inputs=[ | |
| rendered_output, | |
| retriever_state, | |
| selected_document_id, | |
| render_as, | |
| render_kwargs, | |
| all_relevant_adus_df, | |
| all_relevant_adus_query_doc_id, | |
| ], | |
| outputs=rendered_output, | |
| ) | |
| show_overview_kwargs = dict( | |
| fn=lambda _retriever: _retriever[0].docstore.overview( | |
| layer_captions=layer_caption_mapping, use_predictions=True | |
| ), | |
| inputs=[retriever_state], | |
| outputs=[processed_documents_df], | |
| ) | |
| show_stats_kwargs = dict( | |
| fn=lambda _processed_documents_df: open_accordion_with_stats( | |
| _processed_documents_df, | |
| base_label=indexed_documents_label, | |
| caption2column=indexed_documents_caption2column, | |
| total_column="TOTAL", | |
| ), | |
| inputs=[processed_documents_df], | |
| outputs=[processed_documents_accordion], | |
| ) | |
| predict_btn.click( | |
| fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] | |
| ).then( | |
| fn=lambda _doc_text, _doc_id, _argumentation_model, _retriever, _split_regex_escaped: wrapped_process_text( | |
| text=_doc_text, | |
| doc_id=_doc_id, | |
| argumentation_model=_argumentation_model[0], | |
| retriever=_retriever[0], | |
| split_regex_escaped=( | |
| unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
| ), | |
| handle_parts_of_same=handle_parts_of_same, | |
| ), | |
| inputs=[ | |
| doc_text, | |
| doc_id, | |
| argumentation_model_state, | |
| retriever_state, | |
| split_regex_escaped, | |
| ], | |
| outputs=[selected_document_id], | |
| api_name="predict", | |
| ).success( | |
| **show_overview_kwargs | |
| ).success( | |
| **show_stats_kwargs | |
| ).success( | |
| **render_event_kwargs | |
| ) | |
| render_btn.click(**render_event_kwargs, api_name="render") | |
| load_arxiv_btn.click( | |
| fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] | |
| ).then( | |
| fn=lambda _arxiv_id, _load_arxiv_only_abstract, _argumentation_model, _retriever, _split_regex_escaped: process_text_from_arxiv( | |
| arxiv_id=_arxiv_id.strip() or default_arxiv_id, | |
| abstract_only=_load_arxiv_only_abstract, | |
| argumentation_model=_argumentation_model[0], | |
| retriever=_retriever[0], | |
| split_regex_escaped=( | |
| unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
| ), | |
| handle_parts_of_same=handle_parts_of_same, | |
| ), | |
| inputs=[ | |
| arxiv_id, | |
| load_arxiv_only_abstract, | |
| argumentation_model_state, | |
| retriever_state, | |
| split_regex_escaped, | |
| ], | |
| outputs=[selected_document_id], | |
| api_name="predict", | |
| ).success( | |
| **show_overview_kwargs | |
| ).success( | |
| **show_stats_kwargs | |
| ) | |
| load_pie_dataset_btn.click( | |
| fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
| ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
| fn=lambda _retriever, _load_pie_dataset_kwargs_str: wrapped_add_annotated_pie_documents_from_dataset( | |
| retriever=_retriever[0], | |
| verbose=True, | |
| layer_captions=layer_caption_mapping, | |
| **json.loads(_load_pie_dataset_kwargs_str), | |
| ), | |
| inputs=[retriever_state, load_pie_dataset_kwargs_str], | |
| outputs=[processed_documents_df], | |
| ).success( | |
| **show_stats_kwargs | |
| ) | |
| selected_document_id.change( | |
| fn=lambda: change_tab(analysed_document_tab.id), inputs=[], outputs=[left_tabs] | |
| ).then(**render_event_kwargs) | |
| get_document_json_btn.click( | |
| fn=lambda _retriever, _document_id: get_document_as_dict( | |
| retriever=_retriever[0], doc_id=_document_id | |
| ), | |
| inputs=[retriever_state, selected_document_id], | |
| outputs=[document_json], | |
| ) | |
| upload_btn.upload( | |
| fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
| ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
| fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_files( | |
| file_names=_file_names, | |
| argumentation_model=_argumentation_model[0], | |
| retriever=_retriever[0], | |
| split_regex_escaped=( | |
| unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
| ), | |
| handle_parts_of_same=handle_parts_of_same, | |
| layer_captions=layer_caption_mapping, | |
| ), | |
| inputs=[ | |
| upload_btn, | |
| argumentation_model_state, | |
| retriever_state, | |
| split_regex_escaped, | |
| ], | |
| outputs=[processed_documents_df], | |
| ).success( | |
| **show_stats_kwargs | |
| ) | |
| upload_pdf_btn.upload( | |
| fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
| ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
| fn=lambda _file_names, _argumentation_model, _retriever, _split_regex_escaped: process_uploaded_pdf_files( | |
| file_names=_file_names, | |
| argumentation_model=_argumentation_model[0], | |
| retriever=_retriever[0], | |
| split_regex_escaped=( | |
| unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
| ), | |
| handle_parts_of_same=handle_parts_of_same, | |
| layer_captions=layer_caption_mapping, | |
| pdf_fulltext_extractor=pdf_fulltext_extractor, | |
| ), | |
| inputs=[ | |
| upload_pdf_btn, | |
| argumentation_model_state, | |
| retriever_state, | |
| split_regex_escaped, | |
| ], | |
| outputs=[processed_documents_df], | |
| ).success( | |
| **show_stats_kwargs | |
| ) | |
| load_acl_anthology_venues_btn.click( | |
| fn=lambda: change_tab(retrieval_tab.id), inputs=[], outputs=[right_tabs] | |
| ).then(fn=open_accordion, inputs=[], outputs=[processed_documents_accordion]).then( | |
| fn=lambda _acl_anthology_venues, _argumentation_model, _retriever, _split_regex_escaped: load_acl_anthology_venues( | |
| pdf_fulltext_extractor=pdf_fulltext_extractor, | |
| venues=[venue.strip() for venue in _acl_anthology_venues.split(",")], | |
| argumentation_model=_argumentation_model[0], | |
| retriever=_retriever[0], | |
| split_regex_escaped=( | |
| unescape_regex(_split_regex_escaped) if _split_regex_escaped else None | |
| ), | |
| handle_parts_of_same=handle_parts_of_same, | |
| layer_captions=layer_caption_mapping, | |
| acl_anthology_data_dir=cfg.get("acl_anthology_data_dir"), | |
| pdf_output_dir=cfg.get("acl_anthology_pdf_dir"), | |
| ), | |
| inputs=[ | |
| acl_anthology_venues, | |
| argumentation_model_state, | |
| retriever_state, | |
| split_regex_escaped, | |
| ], | |
| outputs=[processed_documents_df], | |
| ).success( | |
| **show_stats_kwargs | |
| ) | |
| processed_documents_df.select( | |
| fn=get_cell_for_fixed_column_from_df, | |
| inputs=[processed_documents_df, gr.State("doc_id")], | |
| outputs=[selected_document_id], | |
| ) | |
| download_processed_documents_btn.click( | |
| fn=lambda _retriever: download_processed_documents( | |
| _retriever[0], file_name="processed_documents" | |
| ), | |
| inputs=[retriever_state], | |
| outputs=[download_processed_documents_btn], | |
| ) | |
| upload_processed_documents_btn.upload( | |
| fn=lambda file_name, _retriever: upload_processed_documents( | |
| file_name, retriever=_retriever[0], layer_captions=layer_caption_mapping | |
| ), | |
| inputs=[upload_processed_documents_btn, retriever_state], | |
| outputs=[processed_documents_df], | |
| ).success(**show_stats_kwargs) | |
| retrieve_relevant_adus_event_kwargs = dict( | |
| fn=lambda _retriever, _selected_adu_id, _min_similarity, _top_k, _min_score: retrieve_relevant_spans( | |
| retriever=_retriever[0], | |
| query_span_id=_selected_adu_id, | |
| k=_top_k, | |
| min_score=_min_score, | |
| score_threshold=_min_similarity, | |
| relation_label_mapping=relation_name_mapping, | |
| # columns=relevant_adus.headers | |
| ), | |
| inputs=[ | |
| retriever_state, | |
| selected_adu_id, | |
| min_similarity, | |
| top_k, | |
| min_score, | |
| ], | |
| outputs=[relevant_adus_df], | |
| ) | |
| relevant_adus_df.select( | |
| fn=get_cell_for_fixed_column_from_df, | |
| inputs=[relevant_adus_df, gr.State("doc_id")], | |
| outputs=[selected_document_id], | |
| ) | |
| selected_adu_id.change( | |
| fn=lambda _retriever, _selected_adu_id: get_span_annotation( | |
| retriever=_retriever[0], span_id=_selected_adu_id | |
| ), | |
| inputs=[retriever_state, selected_adu_id], | |
| outputs=[selected_adu_text], | |
| ).success(**retrieve_relevant_adus_event_kwargs) | |
| retrieve_similar_adus_btn.click( | |
| fn=lambda _retriever, _selected_adu_id, _min_similarity, _tok_k, _min_score: retrieve_similar_spans( | |
| retriever=_retriever[0], | |
| query_span_id=_selected_adu_id, | |
| k=_tok_k, | |
| min_score=_min_score, | |
| score_threshold=_min_similarity, | |
| ), | |
| inputs=[ | |
| retriever_state, | |
| selected_adu_id, | |
| min_similarity, | |
| top_k, | |
| min_score, | |
| ], | |
| outputs=[similar_adus_df], | |
| ) | |
| similar_adus_df.select( | |
| fn=get_cell_for_fixed_column_from_df, | |
| inputs=[similar_adus_df, gr.State("doc_id")], | |
| outputs=[selected_document_id], | |
| ) | |
| retrieve_all_similar_adus_btn.click( | |
| fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: retrieve_all_similar_spans( | |
| retriever=_retriever[0], | |
| query_doc_id=_document_id, | |
| k=_tok_k, | |
| min_score=_min_score, | |
| score_threshold=_min_similarity, | |
| query_span_id_column="query_span_id", | |
| ), | |
| inputs=[ | |
| retriever_state, | |
| selected_document_id, | |
| min_similarity, | |
| top_k, | |
| min_score, | |
| ], | |
| outputs=[all_similar_adus_df], | |
| ) | |
| retrieve_all_relevant_adus_btn.click( | |
| fn=lambda _retriever, _document_id, _min_similarity, _tok_k, _min_score: ( | |
| retrieve_all_relevant_spans( | |
| retriever=_retriever[0], | |
| query_doc_id=_document_id, | |
| k=_tok_k, | |
| min_score=_min_score, | |
| score_threshold=_min_similarity, | |
| query_span_id_column="query_span_id", | |
| query_span_text_column="query_span_text", | |
| ), | |
| _document_id, | |
| ), | |
| inputs=[ | |
| retriever_state, | |
| selected_document_id, | |
| min_similarity, | |
| top_k, | |
| min_score, | |
| ], | |
| outputs=[all_relevant_adus_df, all_relevant_adus_query_doc_id], | |
| ) | |
| all_relevant_adus_df.change(**render_event_kwargs) | |
| # select query span id from the "retrieve all" result data frames | |
| all_similar_adus_df.select( | |
| fn=get_cell_for_fixed_column_from_df, | |
| inputs=[all_similar_adus_df, gr.State("query_span_id")], | |
| outputs=[selected_adu_id], | |
| ) | |
| all_relevant_adus_df.select( | |
| fn=get_cell_for_fixed_column_from_df, | |
| inputs=[all_relevant_adus_df, gr.State("query_span_id")], | |
| outputs=[selected_adu_id], | |
| ) | |
| # argumentation_model_state.change( | |
| # fn=lambda _argumentation_model: set_relation_types(_argumentation_model[0]), | |
| # inputs=[argumentation_model_state], | |
| # outputs=[relation_types], | |
| # ) | |
| rendered_output.change(fn=None, js=HIGHLIGHT_SPANS_JS, inputs=[], outputs=[]) | |
| demo.launch() | |
| if __name__ == "__main__": | |
| main() | |