GLiNER_HandyLab / interfaces /relation_e.py
alexandrlukashov's picture
fixed demo
1d354d1 verified
from typing import List, Tuple, Dict
from gliner import GLiNER
import gradio as gr
model = GLiNER.from_pretrained("knowledgator/gliner-multitask-large-v0.5").to("cpu")
text = """
Dr. Paul Hammond, a renowned neurologist at Johns Hopkins University, has recently published a paper in the prestigious journal "Nature Neuroscience".
His research focuses on a rare genetic mutation, found in less than 0.01% of the population, that appears to prevent the development of Alzheimer's disease.
Collaborating with researchers at the University of California, San Francisco, the team is now working to understand the mechanism by which this mutation confers its protective effect.
Funded by the National Institutes of Health, their research could potentially open new avenues for Alzheimer's treatment.
"""
def process(relation: str, text: str, distance_threshold: str, pairs_filter: str, labels: str) -> str:
entity_labels: List[str] = [x.strip() for x in labels.split(",") if x.strip()]
if not entity_labels:
return "Error: provide Labels (comma-separated)."
pairs: List[Tuple[str, str]] = []
for p in pairs_filter.split(","):
if "->" in p:
a, b = p.split("->", 1)
a, b = a.strip(), b.strip()
if a and b:
pairs.append((a, b))
dist = None
if distance_threshold and distance_threshold.strip().isdigit():
dist = int(distance_threshold.strip())
ents_raw = model.predict_entities(text, entity_labels, threshold=0.5)
entities: List[Dict] = [{
"label": e["label"],
"text": e["text"],
"start": int(e["start"]),
"end": int(e["end"]),
"score": float(e.get("score", 0.0)),
} for e in ents_raw]
entities.sort(key=lambda x: (x["start"], x["end"], x["label"]))
by_label: Dict[str, List[Dict]] = {}
for e in entities:
by_label.setdefault(e["label"], []).append(e)
rels: List[Dict] = []
for s_lbl, t_lbl in pairs:
sources = by_label.get(s_lbl, [])
targets = by_label.get(t_lbl, [])
for s in sources:
for t in targets:
if s["end"] <= t["start"]:
d = t["start"] - s["end"]
elif t["end"] <= s["start"]:
d = s["start"] - t["end"]
else:
d = 0
if dist is not None and d > dist:
continue
cs, ce = min(s["start"], t["start"]), max(s["end"], t["end"])
chunk = text[cs:ce]
rel_label = f"{s_lbl} <> {relation}"
try:
hit = model.predict_entities(chunk, [rel_label], threshold=0.5)
except Exception:
hit = []
if hit:
rels.append({
"relation": relation,
"source": {"text": s["text"], "label": s["label"], "start": s["start"], "end": s["end"]},
"target": {"text": t["text"], "label": t["label"], "start": t["start"], "end": t["end"]},
"score": float(hit[0].get("score", 0.0)),
"distance": int(d),
})
if not rels:
return "No relations found"
rels.sort(key=lambda r: (r["relation"], r["source"]["start"], r["target"]["start"]))
lines = [
f"{r['source']['text']} ({r['source']['label']}) -> {r['relation']} -> {r['target']['text']} ({r['target']['label']})"
for r in rels
]
return "\n".join(lines)
relation_e_examples = [
[
"worked at",
text,
"",
"scientist -> university, scientist -> other",
"scientist, university, city, research, journal"
]
]
with gr.Blocks(title="Open Information Extracting") as relation_e_interface:
relation = gr.Textbox(label="Relation", placeholder="Enter relation you want to extract here")
input_text = gr.Textbox(label="Text input", placeholder="Enter your text here")
labels = gr.Textbox(label="Labels", placeholder="Enter your labels here (comma separated)", scale=2)
pairs_filter = gr.Textbox(label="Pairs Filter", placeholder="It specifies possible members of relations by their entity labels. Write as: source -> target,..")
distance_threshold = gr.Textbox(label="Distance Threshold", placeholder="It specifies the max distance in characters between spans in the text")
output = gr.Textbox(label="Predicted Relation")
submit_btn = gr.Button("Submit")
examples = gr.Examples(
relation_e_examples,
fn=process,
inputs=[relation, input_text, distance_threshold, pairs_filter, labels],
outputs=output,
cache_examples=True
)
theme = gr.themes.Base()
input_text.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
labels.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
pairs_filter.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
submit_btn.click(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
distance_threshold.submit(fn=process, inputs=[relation, input_text, distance_threshold, pairs_filter, labels], outputs=output)
if __name__ == "__main__":
relation_e_interface.launch()