Spaces:
Paused
Paused
| import os | |
| import gradio as gr | |
| import json | |
| from main import ChemEagle # 支持 API key 通过环境变量 | |
| from rdkit import Chem | |
| from rdkit.Chem import rdChemReactions | |
| from rdkit.Chem import Draw | |
| from rdkit.Chem import AllChem | |
| from rdkit.Chem.Draw import rdMolDraw2D | |
| import cairosvg | |
| import re | |
| import torch | |
| example_diagram = "examples/exp.png" | |
| rdkit_image = "examples/rdkit.png" | |
| # 解析 ChemEagle 返回的结构化数据 | |
| def parse_reactions(output_json): | |
| """ | |
| 解析 JSON 格式的反应数据并格式化输出,包含颜色定制。 | |
| """ | |
| if isinstance(output_json, str): | |
| reactions_data = json.loads(output_json) | |
| elif isinstance(output_json, dict): | |
| reactions_data = output_json # 转换 JSON 字符串为字典 | |
| reactions_list = reactions_data.get("reactions", []) | |
| detailed_output = [] | |
| smiles_output = [] | |
| for reaction in reactions_list: | |
| reaction_id = reaction.get("reaction_id", "Unknown ID") | |
| reactants = [r.get("smiles", "Unknown") for r in reaction.get("reactants", [])] | |
| conditions = [ | |
| f"<span style='color:red'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" | |
| for c in reaction.get("condition", []) | |
| ] | |
| conditions_1 = [ | |
| f"<span style='color:black'>{c.get('smiles', c.get('text', 'Unknown'))}[{c.get('role', 'Unknown')}]</span>" | |
| for c in reaction.get("condition", []) | |
| ] | |
| products = [f"<span style='color:orange'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] | |
| products_1 = [f"<span style='color:black'>{p.get('smiles', 'Unknown')}</span>" for p in reaction.get("products", [])] | |
| products_2 = [r.get("smiles", "Unknown") for r in reaction.get("products", [])] | |
| additional = reaction.get("additional_info", []) | |
| additional_str = [str(x) for x in additional if x is not None] | |
| tail = conditions_1 + additional_str | |
| tail_str = ", ".join(tail) | |
| # 构造反应的完整字符串,定制字体颜色 | |
| full_reaction = f"{'.'.join(reactants)}>>{'.'.join(products_1)} | {tail_str}" | |
| full_reaction = f"<span style='color:black'>{full_reaction}</span>" | |
| # 详细反应格式化输出 | |
| reaction_output = f"<b>Reaction: </b> {reaction_id}<br>" | |
| reaction_output += f" Reactants: <span style='color:blue'>{', '.join(reactants)}</span><br>" | |
| reaction_output += f" Conditions: {', '.join(conditions)}<br>" | |
| reaction_output += f" Products: {', '.join(products)}<br>" | |
| reaction_output += f" additional_info: {', '.join(additional_str)}<br>" | |
| reaction_output += f" <b>Full Reaction:</b> {full_reaction}<br>" | |
| reaction_output += "<br>" | |
| detailed_output.append(reaction_output) | |
| reaction_smiles = f"{'.'.join(reactants)}>>{'.'.join(products_2)}" | |
| smiles_output.append(reaction_smiles) | |
| return detailed_output, smiles_output | |
| # 核心处理函数,仅使用 API Key 和图像 | |
| def process_chem_image(api_key, image): | |
| # 设置 API Key 环境变量,供 ChemEagle 使用 | |
| os.environ["CHEMEAGLE_API_KEY"] = api_key | |
| # 保存上传图片 | |
| image_path = "temp_image.png" | |
| image.save(image_path) | |
| # 调用 ChemEagle(实现内部读取 os.getenv) | |
| chemeagle_result = ChemEagle(image_path) | |
| # 解析输出 | |
| detailed, smiles = parse_reactions(chemeagle_result) | |
| # 写出 JSON | |
| json_path = "output.json" | |
| with open(json_path, 'w') as jf: | |
| json.dump(chemeagle_result, jf, indent=2) | |
| # 返回 HTML、SMILES 合并文本、示意图、JSON 下载 | |
| return "\n\n".join(detailed), smiles, example_diagram, json_path | |
| # 构建 Gradio 界面 | |
| with gr.Blocks() as demo: | |
| gr.Markdown( | |
| """ | |
| <center><h1>ChemEagle: A Multi-Agent System for Multimodal Chemical Information Extraction</h1></center> | |
| Upload a multimodal reaction image and type your OpenAI API key to extract multimodal chemical information. | |
| """ | |
| ) | |
| with gr.Row(): | |
| # ———— 左侧:上传 + API Key + 按钮 ———— | |
| with gr.Column(scale=1): | |
| image_input = gr.Image(type="pil", label="Upload a multimodal reaction image") | |
| api_key_input = gr.Textbox( | |
| label="Your API-Key", | |
| placeholder="Type your OpenAI_API_KEY", | |
| type="password" | |
| ) | |
| with gr.Row(): | |
| clear_btn = gr.Button("Clear") | |
| run_btn = gr.Button("Run", elem_id="submit-btn") | |
| # ———— 中间:解析结果 + 示意图 ———— | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Parsed Reactions") | |
| reaction_output = gr.HTML(label="Detailed Reaction Output") | |
| gr.Markdown("### Schematic Diagram") | |
| schematic_diagram = gr.Image(value=example_diagram, label="Schematic Diagram") | |
| # ———— 右侧:SMILES 拆分 & RDKit 渲染 + JSON 下载 ———— | |
| with gr.Column(scale=1): | |
| gr.Markdown("### Machine-readable Output") | |
| smiles_output = gr.Textbox( | |
| label="Reaction SMILES", | |
| show_copy_button=True, | |
| interactive=False, | |
| visible=False | |
| ) | |
| # 使用gr.render修饰器绑定输入和渲染逻辑 | |
| def show_split(inputs): # 定义处理和展示分割文本的函数 | |
| if not inputs or isinstance(inputs, str) and inputs.strip() == "": # 检查输入文本是否为空 | |
| return gr.Textbox(label= "SMILES of Reaction i"), gr.Image(value=rdkit_image, label= "RDKit Image of Reaction i",height=100) | |
| else: | |
| # 假设输入是逗号分隔的 SMILES 字符串 | |
| smiles_list = inputs.split(",") | |
| smiles_list = [re.sub(r"^\s*\[?'?|'\]?\s*$", "", item) for item in smiles_list] | |
| components = [] # 初始化一个组件列表,用于存放每个 SMILES 对应的 Textbox 组件 | |
| for i, smiles in enumerate(smiles_list): | |
| smiles.replace('"', '').replace("'", "").replace("[", "").replace("]", "") | |
| rxn = rdChemReactions.ReactionFromSmarts(smiles, useSmiles=True) | |
| if rxn: | |
| new_rxn = AllChem.ChemicalReaction() | |
| for mol in rxn.GetReactants(): | |
| mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) | |
| new_rxn.AddReactantTemplate(mol) | |
| for mol in rxn.GetProducts(): | |
| mol = Chem.MolFromMolBlock(Chem.MolToMolBlock(mol)) | |
| new_rxn.AddProductTemplate(mol) | |
| rxn = new_rxn | |
| def atom_mapping_remover(rxn): | |
| for reactant in rxn.GetReactants(): | |
| for atom in reactant.GetAtoms(): | |
| atom.SetAtomMapNum(0) | |
| for product in rxn.GetProducts(): | |
| for atom in product.GetAtoms(): | |
| atom.SetAtomMapNum(0) | |
| return rxn | |
| atom_mapping_remover(rxn) | |
| reactant1 = rxn.GetReactantTemplate(0) | |
| print(reactant1.GetNumBonds) | |
| reactant2 = rxn.GetReactantTemplate(1) if rxn.GetNumReactantTemplates() > 1 else None | |
| if reactant1.GetNumBonds() > 0: | |
| bond_length_reference = Draw.MeanBondLength(reactant1) | |
| elif reactant2 and reactant2.GetNumBonds() > 0: | |
| bond_length_reference = Draw.MeanBondLength(reactant2) | |
| else: | |
| bond_length_reference = 1.0 | |
| drawer = rdMolDraw2D.MolDraw2DSVG(-1, -1) | |
| dopts = drawer.drawOptions() | |
| dopts.padding = 0.1 | |
| dopts.includeRadicals = True | |
| Draw.SetACS1996Mode(dopts, bond_length_reference*0.55) | |
| dopts.bondLineWidth = 1.5 | |
| drawer.DrawReaction(rxn) | |
| drawer.FinishDrawing() | |
| svg_content = drawer.GetDrawingText() | |
| svg_file = f"reaction{i+1}.svg" | |
| with open(svg_file, "w") as f: | |
| f.write(svg_content) | |
| png_file = f"reaction_{i+1}.png" | |
| cairosvg.svg2png(url=svg_file, write_to=png_file) | |
| components.append(gr.Textbox(value=smiles,label= f"SMILES of Reaction {i}", show_copy_button=True, interactive=False)) | |
| components.append(gr.Image(value=png_file,label= f"RDKit Image of Reaction {i}")) | |
| return components # 返回包含所有 SMILES Textbox 组件的列表 | |
| download_json = gr.File(label="Download JSON File") | |
| gr.Examples( | |
| examples=[ | |
| ["examples/reaction1.jpg", ""], | |
| ["examples/reaction2.png", ""], | |
| ["examples/reaction3.png", ""], | |
| ["examples/reaction4.png", ""], | |
| ], | |
| inputs=[image_input, api_key_input], | |
| outputs=[reaction_output, smiles_output, schematic_diagram, download_json], | |
| cache_examples=False, | |
| examples_per_page=4, | |
| ) | |
| # ———— 清空与运行 绑定 ———— | |
| clear_btn.click( | |
| lambda: (None, None, None, None, None), | |
| inputs=[], | |
| outputs=[image_input, api_key_input, reaction_output, smiles_output, download_json] | |
| ) | |
| run_btn.click( | |
| process_chem_image, | |
| inputs=[api_key_input, image_input], | |
| outputs=[reaction_output, smiles_output, schematic_diagram, download_json] | |
| ) | |
| # 自定义按钮样式 | |
| demo.css = """ | |
| #submit-btn { | |
| background-color: #FF914D; | |
| color: white; | |
| font-weight: bold; | |
| } | |
| """ | |
| demo.launch() |