Spaces:
Running
on
Zero
Running
on
Zero
| import spaces | |
| import gradio as gr | |
| from phi3_instruct_graph import MODEL_LIST, Phi3InstructGraph | |
| import rapidjson | |
| from pyvis.network import Network | |
| import networkx as nx | |
| import spacy | |
| from spacy import displacy | |
| from spacy.tokens import Span | |
| import random | |
| import time | |
| # Set up the theme and styling | |
| CUSTOM_CSS = """ | |
| .gradio-container { | |
| font-family: 'Inter', 'Segoe UI', Roboto, sans-serif; | |
| } | |
| .gr-prose h1 { | |
| font-size: 2.5rem !important; | |
| margin-bottom: 0.5rem !important; | |
| background: linear-gradient(90deg, #4338ca, #a855f7); | |
| -webkit-background-clip: text; | |
| -webkit-text-fill-color: transparent; | |
| } | |
| .gr-prose h2 { | |
| font-size: 1.8rem !important; | |
| margin-top: 1rem !important; | |
| } | |
| .info-box { | |
| padding: 1rem; | |
| border-radius: 0.5rem; | |
| background-color: #f3f4f6; | |
| margin-bottom: 1rem; | |
| border-left: 4px solid #6366f1; | |
| } | |
| .language-badge { | |
| display: inline-block; | |
| padding: 0.25rem 0.5rem; | |
| border-radius: 9999px; | |
| font-size: 0.75rem; | |
| font-weight: 600; | |
| background-color: #e0e7ff; | |
| color: #4338ca; | |
| margin-right: 0.5rem; | |
| margin-bottom: 0.5rem; | |
| } | |
| .footer { | |
| text-align: center; | |
| margin-top: 2rem; | |
| padding-top: 1rem; | |
| border-top: 1px solid #e2e8f0; | |
| font-size: 0.875rem; | |
| color: #64748b; | |
| } | |
| """ | |
| # Color utilities | |
| def get_random_light_color(): | |
| r = random.randint(150, 255) | |
| g = random.randint(150, 255) | |
| b = random.randint(150, 255) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| # Text processing helper | |
| def handle_text(text): | |
| return " ".join(text.split()) | |
| # Core extraction function | |
| def extract(text, model): | |
| model = Phi3InstructGraph(model=model) | |
| try: | |
| result = model.extract(text) | |
| return rapidjson.loads(result) | |
| except Exception as e: | |
| raise gr.Error(f"π¨ Extraction failed: {str(e)}") | |
| def find_token_indices(doc, substring, text): | |
| result = [] | |
| start_index = text.find(substring) | |
| while start_index != -1: | |
| end_index = start_index + len(substring) | |
| start_token = None | |
| end_token = None | |
| for token in doc: | |
| if token.idx == start_index: | |
| start_token = token.i | |
| if token.idx + len(token) == end_index: | |
| end_token = token.i + 1 | |
| if start_token is not None and end_token is not None: | |
| result.append({ | |
| "start": start_token, | |
| "end": end_token | |
| }) | |
| # Search for next occurrence | |
| start_index = text.find(substring, end_index) | |
| return result | |
| def create_custom_entity_viz(data, full_text): | |
| nlp = spacy.blank("xx") | |
| doc = nlp(full_text) | |
| spans = [] | |
| colors = {} | |
| for node in data["nodes"]: | |
| entity_spans = find_token_indices(doc, node["id"], full_text) | |
| for dataentity in entity_spans: | |
| start = dataentity["start"] | |
| end = dataentity["end"] | |
| if start < len(doc) and end <= len(doc): | |
| # Check for overlapping spans | |
| overlapping = any(s.start < end and start < s.end for s in spans) | |
| if not overlapping: | |
| span = Span(doc, start, end, label=node["type"]) | |
| spans.append(span) | |
| if node["type"] not in colors: | |
| colors[node["type"]] = get_random_light_color() | |
| doc.set_ents(spans, default="unmodified") | |
| doc.spans["sc"] = spans | |
| options = { | |
| "colors": colors, | |
| "ents": list(colors.keys()), | |
| "style": "ent", | |
| "manual": True | |
| } | |
| html = displacy.render(doc, style="span", options=options) | |
| # Add custom styling to the entity visualization | |
| styled_html = f""" | |
| <div style="border-radius: 0.5rem; padding: 1rem; background-color: white; | |
| border: 1px solid #e2e8f0; box-shadow: 0 1px 3px 0 rgba(0, 0, 0, 0.1);"> | |
| <div style="margin-bottom: 0.75rem; font-weight: 500; color: #4b5563;"> | |
| Entity types found: | |
| {' '.join([f'<span style="display: inline-block; margin-right: 0.5rem; margin-bottom: 0.5rem; padding: 0.25rem 0.5rem; border-radius: 9999px; font-size: 0.75rem; background-color: {colors[entity_type]}; color: #1e293b;">{entity_type}</span>' for entity_type in colors.keys()])} | |
| </div> | |
| {html} | |
| </div> | |
| """ | |
| return styled_html | |
| def create_graph(json_data): | |
| G = nx.DiGraph() # Using DiGraph for directed graph | |
| # Add nodes | |
| for node in json_data['nodes']: | |
| G.add_node(node['id'], | |
| title=f"{node['type']}: {node['detailed_type']}", | |
| group=node['type']) # Group nodes by type | |
| # Add edges | |
| for edge in json_data['edges']: | |
| G.add_edge(edge['from'], edge['to'], title=edge['label'], label=edge['label']) | |
| # Create network visualization | |
| nt = Network( | |
| width="100%", | |
| height="600px", | |
| directed=True, | |
| notebook=False, | |
| bgcolor="#fafafa", | |
| font_color="#1e293b" | |
| ) | |
| # Configure network | |
| nt.from_nx(G) | |
| nt.barnes_hut( | |
| gravity=-3000, | |
| central_gravity=0.3, | |
| spring_length=150, | |
| spring_strength=0.001, | |
| damping=0.09, | |
| overlap=0, | |
| ) | |
| # Create color groups for node types | |
| node_types = {node['type'] for node in json_data['nodes']} | |
| colors = {} | |
| for i, node_type in enumerate(node_types): | |
| hue = (i * 137) % 360 # Golden ratio to distribute colors | |
| colors[node_type] = f"hsl({hue}, 70%, 70%)" | |
| # Customize nodes | |
| for node in nt.nodes: | |
| node_data = next((n for n in json_data['nodes'] if n['id'] == node['id']), None) | |
| if node_data: | |
| node_type = node_data['type'] | |
| node['color'] = colors.get(node_type, "#bfdbfe") | |
| node['shape'] = 'dot' | |
| node['size'] = 20 | |
| node['borderWidth'] = 2 | |
| node['borderWidthSelected'] = 4 | |
| node['font'] = {'size': 14, 'color': '#1e293b', 'face': 'Inter, Arial'} | |
| # Customize edges | |
| for edge in nt.edges: | |
| edge['color'] = {'color': '#94a3b8', 'highlight': '#6366f1', 'hover': '#818cf8'} | |
| edge['width'] = 1.5 | |
| edge['selectionWidth'] = 2 | |
| edge['hoverWidth'] = 2 | |
| edge['arrows'] = {'to': {'enabled': True, 'type': 'arrow'}} | |
| edge['smooth'] = {'type': 'continuous', 'roundness': 0.2} | |
| edge['font'] = {'size': 12, 'color': '#4b5563', 'face': 'Inter, Arial', 'strokeWidth': 2, 'strokeColor': '#ffffff'} | |
| # Generate HTML | |
| html = nt.generate_html() | |
| html = html.replace("'", '"') | |
| html = html.replace('height: 600px;', 'height: 600px; border-radius: 8px;') | |
| return f"""<iframe style="width: 100%; height: 620px; margin: 0 auto; border-radius: 8px; box-shadow: 0 4px 6px -1px rgba(0, 0, 0, 0.1), 0 2px 4px -1px rgba(0, 0, 0, 0.06);" | |
| name="result" allow="midi; geolocation; microphone; camera; display-capture; encrypted-media;" | |
| sandbox="allow-modals allow-forms allow-scripts allow-same-origin allow-popups | |
| allow-top-navigation-by-user-activation allow-downloads" allowfullscreen="" | |
| allowpaymentrequest="" frameborder="0" srcdoc='{html}'></iframe>""" | |
| def process_and_visualize(text, model, progress=gr.Progress()): | |
| if not text or not model: | |
| raise gr.Error("β οΈ Please provide both text and model") | |
| # Progress updates | |
| progress(0.1, "Initializing...") | |
| time.sleep(0.2) # Small delay for UI feedback | |
| # Extract graph | |
| progress(0.2, "Extracting knowledge graph...") | |
| json_data = extract(text, model) | |
| # Entity visualization | |
| progress(0.6, "Identifying entities...") | |
| entities_viz = create_custom_entity_viz(json_data, text) | |
| # Graph visualization | |
| progress(0.8, "Building graph visualization...") | |
| graph_html = create_graph(json_data) | |
| # Statistics | |
| entity_types = {} | |
| for node in json_data['nodes']: | |
| entity_type = node['type'] | |
| if entity_type in entity_types: | |
| entity_types[entity_type] += 1 | |
| else: | |
| entity_types[entity_type] = 1 | |
| stats_html = f""" | |
| <div class="info-box"> | |
| <h3 style="margin-top: 0;">π Extraction Results</h3> | |
| <p>β Successfully extracted <b>{len(json_data['nodes'])}</b> entities and <b>{len(json_data['edges'])}</b> relationships.</p> | |
| <div> | |
| <h4>Entity Types:</h4> | |
| <div> | |
| {''.join([f'<span class="language-badge">{entity_type}: {count}</span>' for entity_type, count in entity_types.items()])} | |
| </div> | |
| </div> | |
| </div> | |
| """ | |
| progress(1.0, "Done!") | |
| return graph_html, entities_viz, json_data, stats_html | |
| def language_info(): | |
| return """ | |
| <div class="info-box"> | |
| <h3 style="margin-top: 0;">π Multilingual Support</h3> | |
| <p>This application supports text analysis in multiple languages, including:</p> | |
| <div> | |
| <span class="language-badge">English π¬π§</span> | |
| <span class="language-badge">Korean π°π·</span> | |
| <span class="language-badge">Spanish πͺπΈ</span> | |
| <span class="language-badge">French π«π·</span> | |
| <span class="language-badge">German π©πͺ</span> | |
| <span class="language-badge">Japanese π―π΅</span> | |
| <span class="language-badge">Chinese π¨π³</span> | |
| <span class="language-badge">And more...</span> | |
| </div> | |
| </div> | |
| """ | |
| def tips_html(): | |
| return """ | |
| <div class="info-box"> | |
| <h3 style="margin-top: 0;">π‘ Tips for Best Results</h3> | |
| <ul> | |
| <li>Use clear, descriptive sentences with well-defined relationships</li> | |
| <li>Include specific entities, events, dates, and locations for better extraction</li> | |
| <li>Longer texts provide more context for relationship identification</li> | |
| <li>Try different models to compare extraction results</li> | |
| </ul> | |
| </div> | |
| """ | |
| # Examples in multiple languages | |
| EXAMPLES = [ | |
| [handle_text("""Legendary rock band Aerosmith has officially announced their retirement from touring after 54 years, citing | |
| lead singer Steven Tyler's unrecoverable vocal cord injury. | |
| The decision comes after months of unsuccessful treatment for Tyler's fractured larynx, | |
| which he suffered in September 2023.""")], | |
| [handle_text("""Pop star Justin Timberlake, 43, had his driver's license suspended by a New York judge during a virtual | |
| court hearing on August 2, 2024. The suspension follows Timberlake's arrest for driving while intoxicated (DWI) | |
| in Sag Harbor on June 18. Timberlake, who is currently on tour in Europe, | |
| pleaded not guilty to the charges.""")], | |
| [handle_text("""μΈκ³μ μΈ κΈ°μ κΈ°μ μΌμ±μ μλ μλ‘μ΄ μΈκ³΅μ§λ₯ κΈ°λ° μ€λ§νΈν°μ μ¬ν΄ νλ°κΈ°μ μΆμν μμ μ΄λΌκ³ λ°ννλ€. | |
| μ΄ μ€λ§νΈν°μ νμ¬ κ°λ° μ€μΈ κ°€λμ μ리μ¦μ μ΅μ μμΌλ‘, κ°λ ₯ν AI κΈ°λ₯κ³Ό νμ μ μΈ μΉ΄λ©λΌ μμ€ν μ νμ¬ν κ²μΌλ‘ μλ €μ‘λ€. | |
| μΌμ±μ μμ CEOλ μ΄λ² μ μ νμ΄ μ€λ§νΈν° μμ₯μ μλ‘μ΄ νμ μ κ°μ Έμ¬ κ²μ΄λΌκ³ μ λ§νλ€.""")], | |
| [handle_text("""νκ΅ μν 'κΈ°μμΆ©'μ 2020λ μμΉ΄λ°λ―Έ μμμμμ μνμ, κ°λ μ, κ°λ³Έμ, κ΅μ μνμ λ± 4κ° λΆλ¬Έμ μμνλ©° μμ¬λ₯Ό μλ‘ μΌλ€. | |
| λ΄μ€νΈ κ°λ μ΄ μ°μΆν μ΄ μνλ νκ΅ μν μ΅μ΄λ‘ μΉΈ μνμ ν©κΈμ’ λ €μλ μμνμΌλ©°, μ μΈκ³μ μΌλ‘ μμ²λ ν₯νκ³Ό | |
| νλ¨μ νΈνμ λ°μλ€.""")] | |
| ] | |
| # Main UI | |
| with gr.Blocks(css=CUSTOM_CSS, title="π§ Phi-3 Knowledge Graph Explorer") as demo: | |
| # Header | |
| gr.Markdown("# π§ Phi-3 Knowledge Graph Explorer") | |
| gr.Markdown("### β¨ Extract and visualize knowledge graphs from text in any language") | |
| with gr.Row(): | |
| with gr.Column(scale=2): | |
| input_text = gr.TextArea( | |
| label="π Enter your text", | |
| placeholder="Paste or type your text here...", | |
| lines=10 | |
| ) | |
| with gr.Row(): | |
| input_model = gr.Dropdown( | |
| MODEL_LIST, | |
| label="π€ Model", | |
| value=MODEL_LIST[0] if MODEL_LIST else None, | |
| info="Select the model to use for extraction" | |
| ) | |
| with gr.Column(): | |
| submit_button = gr.Button("π Extract & Visualize", variant="primary") | |
| clear_button = gr.Button("π Clear", variant="secondary") | |
| # Multilingual support info | |
| gr.HTML(language_info()) | |
| # Examples section | |
| gr.Examples( | |
| examples=EXAMPLES, | |
| inputs=input_text, | |
| label="π Example Texts (English & Korean)" | |
| ) | |
| # Tips | |
| gr.HTML(tips_html()) | |
| with gr.Column(scale=3): | |
| # Stats output | |
| stats_output = gr.HTML(label="") | |
| # Tabs for different visualizations | |
| with gr.Tabs(): | |
| with gr.TabItem("π Knowledge Graph"): | |
| output_graph = gr.HTML() | |
| with gr.TabItem("π·οΈ Entity Recognition"): | |
| output_entity_viz = gr.HTML() | |
| with gr.TabItem("π JSON Data"): | |
| output_json = gr.JSON() | |
| # Footer | |
| gr.HTML(""" | |
| <div class="footer"> | |
| <p>π Powered by Phi-3 Instruct Graph | Created by Emergent Methods</p> | |
| <p>Β© 2025 | Knowledge Graph Explorer</p> | |
| </div> | |
| """) | |
| # Set up event handlers | |
| submit_button.click( | |
| fn=process_and_visualize, | |
| inputs=[input_text, input_model], | |
| outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
| ) | |
| clear_button.click( | |
| fn=lambda: [None, None, None, ""], | |
| inputs=[], | |
| outputs=[output_graph, output_entity_viz, output_json, stats_output] | |
| ) | |
| # Launch the app | |
| demo.launch(share=False) |