Spaces:
Running
Running
| from typing import List, Tuple, Dict | |
| from gliner import GLiNER | |
| import gradio as gr | |
| model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5").to("cpu") | |
| text = """ | |
| Dr. Paul Hammond, a renowned neurologist at Johns Hopkins University, has recently published a paper in the prestigious journal "Nature Neuroscience". | |
| His research focuses on a rare genetic mutation, found in less than 0.01% of the population, that appears to prevent the development of Alzheimer's disease. | |
| Collaborating with researchers at the University of California, San Francisco, the team is now working to understand the mechanism by which this mutation confers its protective effect. | |
| Funded by the National Institutes of Health, their research could potentially open new avenues for Alzheimer's treatment. | |
| """ | |
| def process(relation: str, text: str, distance_threshold: str, pairs_filter: str, labels: str) -> str: | |
| entity_labels: List[str] = [x.strip() for x in labels.split(",") if x.strip()] | |
| if not entity_labels: | |
| return "Error: provide Labels (comma-separated)." | |
| pairs: List[Tuple[str, str]] = [] | |
| for p in pairs_filter.split(","): | |
| if "->" in p: | |
| a, b = p.split("->", 1) | |
| a, b = a.strip(), b.strip() | |
| if a and b: | |
| pairs.append((a, b)) | |
| dist = None | |
| if distance_threshold and distance_threshold.strip().isdigit(): | |
| dist = int(distance_threshold.strip()) | |
| ents_raw = model.predict_entities(text, entity_labels, threshold=0.5) | |
| entities: List[Dict] = [{ | |
| "label": e["label"], | |
| "text": e["text"], | |
| "start": int(e["start"]), | |
| "end": int(e["end"]), | |
| "score": float(e.get("score", 0.0)), | |
| } for e in ents_raw] | |
| entities.sort(key=lambda x: (x["start"], x["end"], x["label"])) | |
| by_label: Dict[str, List[Dict]] = {} | |
| for e in entities: | |
| by_label.setdefault(e["label"], []).append(e) | |
| rels: List[Dict] = [] | |
| for s_lbl, t_lbl in pairs: | |
| sources = by_label.get(s_lbl, []) | |
| targets = by_label.get(t_lbl, []) | |
| for s in sources: | |
| for t in targets: | |
| if s["end"] <= t["start"]: | |
| d = t["start"] - s["end"] | |
| elif t["end"] <= s["start"]: | |
| d = s["start"] - t["end"] | |
| else: | |
| d = 0 | |
| if dist is not None and d > dist: | |
| continue | |
| cs, ce = min(s["start"], t["start"]), max(s["end"], t["end"]) | |
| chunk = text[cs:ce] | |
| rel_label = f"{s_lbl} <> {relation}" | |
| try: | |
| hit = model.predict_entities(chunk, [rel_label], threshold=0.5) | |
| except Exception: | |
| hit = [] | |
| if hit: | |
| rels.append({ | |
| "relation": relation, | |
| "source": {"text": s["text"], "label": s["label"], "start": s["start"], "end": s["end"]}, | |
| "target": {"text": t["text"], "label": t["label"], "start": t["start"], "end": t["end"]}, | |
| "score": float(hit[0].get("score", 0.0)), | |
| "distance": int(d), | |
| }) | |
| if not rels: | |
| return "No relations found" | |
| rels.sort(key=lambda r: (r["relation"], r["source"]["start"], r["target"]["start"])) | |
| lines = [ | |
| f"{r['source']['text']} ({r['source']['label']}) -> {r['relation']} -> {r['target']['text']} ({r['target']['label']})" | |
| for r in rels | |
| ] | |
| return "\n".join(lines) | |
| relation_e_examples = [ | |
| [ | |
| "worked at", | |
| text, | |
| "", | |
| "scientist -> university, scientist -> other", | |
| "scientist, university, city, research, journal" | |
| ] | |
| ] | |
| with gr.Blocks(title="Open Information Extracting") as relation_e_interface: | |
| relation = gr.Textbox(label="Relation", placeholder="Enter relation you want to extract here") | |
| input_text = gr.Textbox(label="Text input", placeholder="Enter your text here") | |
| labels = gr.Textbox(label="Labels", placeholder="Enter your labels here (comma separated)", scale=2) | |
| pairs_filter = gr.Textbox(label="Pairs Filter", placeholder="It specifies possible members of relations by their entity labels. Write as: source -> target,..") | |
| distance_threshold = gr.Textbox(label="Distance Threshold", placeholder="It specifies the max distance in characters between spans in the text") | |
| output = gr.Textbox(label="Predicted Relation") | |
| submit_btn = gr.Button("Submit") | |
| examples = gr.Examples( | |
| relation_e_examples, | |
| fn=process, | |
| inputs=[relation, input_text, distance_threshold, pairs_filter, labels], | |
| outputs=output, | |
| cache_examples=True | |
| ) | |
| theme = gr.themes.Base() | |
| input_text.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output) | |
| labels.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output) | |
| pairs_filter.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output) | |
| submit_btn.click(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output) | |
| distance_threshold.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output) | |
| if __name__ == "__main__": | |
| relation_e_interface.launch() | |