Spaces:
Running
Running
| import networkx as nx | |
| import matplotlib.pyplot as plt | |
| import jraph | |
| import jax.numpy as jnp | |
| from datasets import load_dataset | |
| import spacy | |
| import gradio as gr | |
| import en_core_web_sm | |
| dataset = load_dataset("gigant/tib_transcripts") | |
| nlp = en_core_web_sm.load() | |
| def dependency_parser(sentences): | |
| return [nlp(sentence) for sentence in sentences] | |
| def construct_dependency_graph(docs): | |
| """ | |
| docs is a list of outputs of the SpaCy dependency parser | |
| """ | |
| graphs = [] | |
| for doc in docs: | |
| nodes = [token.text for token in doc] | |
| senders = [] | |
| receivers = [] | |
| for token in doc: | |
| for child in token.children: | |
| senders.append(token.i) | |
| receivers.append(child.i) | |
| graphs.append({"nodes": nodes, "senders": senders, "receivers": receivers}) | |
| return graphs | |
| def to_jraph(graph): | |
| nodes = graph["nodes"] | |
| s = graph["senders"] | |
| r = graph["receivers"] | |
| # Define a three node graph, each node has an integer as its feature. | |
| node_features = jnp.array([0]*len(nodes)) | |
| # We will construct a graph for which there is a directed edge between each node | |
| # and its successor. We define this with `senders` (source nodes) and `receivers` | |
| # (destination nodes). | |
| senders = jnp.array(s) | |
| receivers = jnp.array(r) | |
| # We then save the number of nodes and the number of edges. | |
| # This information is used to make running GNNs over multiple graphs | |
| # in a GraphsTuple possible. | |
| n_node = jnp.array([len(nodes)]) | |
| n_edge = jnp.array([len(s)]) | |
| return jraph.GraphsTuple(nodes=node_features, senders=senders, receivers=receivers, | |
| edges=None, n_node=n_node, n_edge=n_edge, globals=None) | |
| def convert_jraph_to_networkx_graph(jraph_graph: jraph.GraphsTuple) -> nx.Graph: | |
| nodes, edges, receivers, senders, _, _, _ = jraph_graph | |
| nx_graph = nx.DiGraph() | |
| if nodes is None: | |
| for n in range(jraph_graph.n_node[0]): | |
| nx_graph.add_node(n) | |
| else: | |
| for n in range(jraph_graph.n_node[0]): | |
| nx_graph.add_node(n, node_feature=nodes[n]) | |
| if edges is None: | |
| for e in range(jraph_graph.n_edge[0]): | |
| nx_graph.add_edge(int(senders[e]), int(receivers[e])) | |
| else: | |
| for e in range(jraph_graph.n_edge[0]): | |
| nx_graph.add_edge( | |
| int(senders[e]), int(receivers[e]), edge_feature=edges[e]) | |
| return nx_graph | |
| def plot_graph_sentence(sentence): | |
| docs = dependency_parser([sentence]) | |
| graphs = construct_dependency_graph(docs) | |
| g = to_jraph(graphs[0]) | |
| nx_graph = convert_jraph_to_networkx_graph(g) | |
| pos = nx.spring_layout(nx_graph) | |
| plot = plt.figure(figsize=(6, 6)) | |
| nx.draw(nx_graph, pos=pos, labels={i: e for i,e in enumerate(graphs[0]["nodes"])}, with_labels = True, | |
| node_size=800, font_color='black', node_color="yellow") | |
| return plot | |
| def get_list_sentences(id): | |
| return gr.update(choices = dataset["train"][id]["transcript"].split(".")) | |
| with gr.Blocks() as demo: | |
| id = gr.Slider(maximum=len(dataset["train"]) - 1, label="Record #") | |
| sentence = gr.Dropdown(label="Transcript sentence", choices = dataset["train"][0]["transcript"].split("."), interactive = True) | |
| plot = gr.Plot(label="Dependency graph") | |
| id.change(get_list_sentences, id, sentence) | |
| sentence.change(plot_graph_sentence, sentence, plot) | |
| demo.launch() |