Spaces:
Sleeping
Sleeping
| import os | |
| os.system("pip install networkx") | |
| os.system("pip install Cython") | |
| os.system("pip install benepar") | |
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import jraph | |
| import jax.numpy as jnp | |
| from datasets import load_dataset | |
| import spacy | |
| import gradio as gr | |
| import en_core_web_trf | |
| import numpy as np | |
| import benepar | |
| import re | |
| dataset = load_dataset("gigant/tib_transcripts") | |
| nlp = en_core_web_trf.load() | |
| benepar.download('benepar_en3') | |
| nlp.add_pipe('benepar', config={'model': 'benepar_en3'}) | |
| def parse_tree(sentence): | |
| stack = [] # or a `collections.deque()` object, which is a little faster | |
| top = items = [] | |
| for token in filter(None, re.compile(r'(?:([()])|\s+)').split(sentence)): | |
| if token == '(': | |
| stack.append(items) | |
| items.append([]) | |
| items = items[-1] | |
| elif token == ')': | |
| if not stack: | |
| raise ValueError("Unbalanced parentheses") | |
| items = stack.pop() | |
| else: | |
| items.append(token) | |
| if stack: | |
| raise ValueError("Unbalanced parentheses") | |
| return top | |
| class Tree(): | |
| def __init__(self, name, children): | |
| self.children = children | |
| self.name = name | |
| self.id = None | |
| def set_id_rec(self, id=0): | |
| self.id = id | |
| last_id=id | |
| for child in self.children: | |
| last_id = child.set_id_rec(id=last_id+1) | |
| return last_id | |
| def set_all_ids(self): | |
| self.set_id_rec(0) | |
| def print_tree(self, level=0): | |
| to_print = f'|{"-" * level} {self.name} ({self.id})' | |
| for child in self.children: | |
| to_print += f"\n{child.print_tree(level + 1)}" | |
| return to_print | |
| def __str__(self): | |
| return self.print_tree(0) | |
| def get_list_nodes(self): | |
| return [self.name] + [_ for child in self.children for _ in child.get_list_nodes()] | |
| def rec_const_parsing(list_nodes): | |
| if isinstance(list_nodes, list): | |
| name, children = list_nodes[0], list_nodes[1:] | |
| else: | |
| name, children = list_nodes, [] | |
| return Tree(name, [rec_const_parsing(child) for i, child in enumerate(children)]) | |
| def tree_to_graph(t): | |
| senders = [] | |
| receivers = [] | |
| for child in t.children: | |
| senders.append(t.id) | |
| receivers.append(child.id) | |
| s_rec, r_rec = tree_to_graph(child) | |
| senders.extend(s_rec) | |
| receivers.extend(r_rec) | |
| return senders, receivers | |
| def construct_constituency_graph(docs): | |
| doc = docs[0] | |
| sent = list(doc.sents)[0] | |
| print(sent._.parse_string) | |
| t = rec_const_parsing(parse_tree(sent._.parse_string)[0]) | |
| t.set_all_ids() | |
| senders, receivers = tree_to_graph(t) | |
| nodes = t.get_list_nodes() | |
| graphs = [{"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": {}}] | |
| return graphs | |
| def half_circle_layout(n_nodes, sentence_node=True): | |
| pos = {} | |
| for i_node in range(n_nodes - 1): | |
| pos[i_node] = ((- np.cos(i_node * np.pi/(n_nodes - 1))), 0.5 * (-np.sin(i_node * np.pi/(n_nodes - 1)))) | |
| pos[n_nodes - 1] = (0, -0.25) | |
| return pos | |
| def get_adjacency_matrix(jraph_graph: jraph.GraphsTuple): | |
| nodes, edges, receivers, senders, _, _, _ = jraph_graph | |
| adj_mat = jnp.zeros((len(nodes), len(nodes))) | |
| for i in range(len(receivers)): | |
| adj_mat = adj_mat.at[senders[i], receivers[i]].set(1) | |
| return adj_mat | |
| def dependency_parser(sentences): | |
| return [nlp(sentence) for sentence in sentences] | |
| def construct_dependency_graph(docs): | |
| """ | |
| docs is a list of outputs of the SpaCy dependency parser | |
| """ | |
| graphs = [] | |
| for doc in docs: | |
| nodes = [token.text for token in doc] | |
| senders = [] | |
| receivers = [] | |
| edge_labels = {} | |
| for token in doc: | |
| for child in token.children: | |
| senders.append(child.i) | |
| receivers.append(token.i) | |
| edge_labels[(token.i, child.i)] = child.dep_ | |
| graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels}) | |
| return graphs | |
| def construct_both_graph(docs): | |
| """ | |
| docs is a list of outputs of the SpaCy dependency parser | |
| """ | |
| graphs = [] | |
| for doc in docs: | |
| nodes = [token.text for token in doc] | |
| nodes.append("Sentence") | |
| senders = [token.i for token in doc][:-1] | |
| senders.extend([token.i for token in doc][1:]) | |
| receivers = [token.i for token in doc][1:] | |
| receivers.extend([token.i for token in doc][:-1]) | |
| edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]} | |
| for token in doc[:-1]: | |
| edge_labels[(token.i + 1, token.i)] = "previous" | |
| for node in range(len(nodes) - 1): | |
| senders.append(node) | |
| receivers.append(len(nodes) - 1) | |
| edge_labels[(node, len(nodes) - 1)] = "in" | |
| for token in doc: | |
| for child in token.children: | |
| senders.append(child.i) | |
| receivers.append(token.i) | |
| edge_labels[(token.i, child.i)] = child.dep_ | |
| graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels}) | |
| return graphs | |
| def construct_structural_graph(docs): | |
| graphs = [] | |
| for doc in docs: | |
| nodes = [token.text for token in doc] | |
| nodes.append("Sentence") | |
| senders = [token.i for token in doc][:-1] | |
| senders.extend([token.i for token in doc][1:]) | |
| receivers = [token.i for token in doc][1:] | |
| receivers.extend([token.i for token in doc][:-1]) | |
| edge_labels = {(token.i, token.i + 1): "next" for token in doc[:-1]} | |
| for token in doc[:-1]: | |
| edge_labels[(token.i + 1, token.i)] = "previous" | |
| for node in range(len(nodes) - 1): | |
| senders.append(node) | |
| receivers.append(len(nodes) - 1) | |
| edge_labels[(node, len(nodes) - 1)] = "in" | |
| graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers, "edge_labels": edge_labels}) | |
| return graphs | |
| def to_jraph(graph): | |
| nodes = graph["nodes"] | |
| s = graph["senders"] | |
| r = graph["receivers"] | |
| # Define a three node graph, each node has an integer as its feature. | |
| node_features = jnp.array([0]*len(nodes)) | |
| # We will construct a graph for which there is a directed edge between each node | |
| # and its successor. We define this with `senders` (source nodes) and `receivers` | |
| # (destination nodes). | |
| senders = jnp.array(s) | |
| receivers = jnp.array(r) | |
| # We then save the number of nodes and the number of edges. | |
| # This information is used to make running GNNs over multiple graphs | |
| # in a GraphsTuple possible. | |
| n_node = jnp.array([len(nodes)]) | |
| n_edge = jnp.array([len(s)]) | |
| return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, | |
| edges=None, n_node=n_node, n_edge=n_edge, globals=None) | |
| def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph: | |
| nodes, edges, receivers, senders, _, _, _ = jraph_graph | |
| nx_graph = nx.DiGraph() | |
| if nodes is None: | |
| for n in range(jraph_graph.n_node[0]): | |
| nx_graph.add_node(n) | |
| else: | |
| for n in range(jraph_graph.n_node[0]): | |
| nx_graph.add_node(n, node_feature=nodes[n]) | |
| if edges is None: | |
| for e in range(jraph_graph.n_edge[0]): | |
| nx_graph.add_edge(int(senders[e]), int(receivers[e])) | |
| else: | |
| for e in range(jraph_graph.n_edge[0]): | |
| nx_graph.add_edge( | |
| int(senders[e]), int(receivers[e]), edge_feature=edges[e]) | |
| return nx_graph | |
| def plot_graph_sentence(sentence, graph_type="constituency"): | |
| # sentences = dataset["train"][0]["abstract"].split(".") | |
| docs = dependency_parser([sentence]) | |
| if graph_type == "dependency": | |
| graphs = construct_dependency_graph(docs) | |
| elif graph_type == "structural": | |
| graphs = construct_structural_graph(docs) | |
| elif graph_type == "structural+dependency": | |
| graphs = construct_both_graph(docs) | |
| elif graph_type == "constituency": | |
| graphs = construct_constituency_graph(docs) | |
| g = to_jraph(graphs[0]) | |
| adj_mat = get_adjacency_matrix(g) | |
| nx_graph = convert_jraph_to_networkx_graph(g) | |
| pos = half_circle_layout(len(graphs[0]["nodes"])) | |
| if graph_type == "constituency": | |
| pos = nx.planar_layout(nx_graph) | |
| plot = plt.figure(figsize=(12, 6)) | |
| nx.draw(nx_graph, pos=pos, | |
| labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, | |
| with_labels = True, edge_color="blue", | |
| # connectionstyle="arc3,rad=0.1", | |
| node_size=1000, font_color='black', node_color="yellow") | |
| nx.draw_networkx_edge_labels( | |
| nx_graph, pos=pos, | |
| edge_labels=graphs[0]["edge_labels"], | |
| font_color='red' | |
| ) | |
| adj_mat_plot, ax = plt.subplots(figsize=(6, 6)) | |
| ax.matshow(adj_mat) | |
| return [gr.update(value=plot), gr.update(value=adj_mat_plot)] | |
| def get_list_sentences(id): | |
| id = int(min(id, len(dataset["train"]) - 1)) | |
| return gr.update(choices = dataset["train"][id]["transcript"].split(".")) | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| graph_type = gr.Dropdown(label="Graph type", choices=["structural", "dependency", "structural+dependency", "constituency"], value="structural+dependency", interactive = True) | |
| with gr.Tab("From transcript"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| id = gr.Number(label="Transcript") | |
| with gr.Column(scale=3): | |
| sentence_transcript = gr.Dropdown(label="Sentence", choices = dataset["train"][0]["transcript"].split(".")[1:], interactive = True) | |
| with gr.Tab("Type sentence"): | |
| with gr.Row(): | |
| sentence_typed = gr.Textbox(label="Sentence", interactive = True) | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| plot_graph = gr.Plot(label="Word graph") | |
| with gr.Column(): | |
| plot_adj = gr.Plot(label="Word graph adjacency matrix") | |
| id.change(get_list_sentences, id, sentence_transcript) | |
| sentence_transcript.change(plot_graph_sentence, [sentence_transcript, graph_type], [plot_graph, plot_adj]) | |
| sentence_typed.change(plot_graph_sentence, [sentence_typed, graph_type], [plot_graph, plot_adj]) | |
| demo.launch() |