orbmol / app.py
annabossler's picture
examples (#4)
7d8bb7d verified
raw
history blame
24.1 kB
import os
os.environ["TORCH_DYNAMO_DISABLE"] = "1"
import tempfile
import numpy as np
import gradio as gr
from ase.io import read, write
from ase.io.trajectory import Trajectory
from gradio_molecule3d import Molecule3D
from simulation_scripts_orbmol import load_orbmol_model, run_md_simulation, run_relaxation_simulation
import hashlib
from pathlib import Path
# ==== Configuración Molecule3D ====
DEFAULT_MOLECULAR_REPRESENTATIONS = [
{
"model": 0,
"chain": "",
"resname": "",
"style": "sphere",
"color": "Jmol",
"around": 0,
"byres": False,
"scale": 0.3,
},
{
"model": 0,
"chain": "",
"resname": "",
"style": "stick",
"color": "Jmol",
"around": 0,
"byres": False,
"scale": 0.2,
},
]
DEFAULT_MOLECULAR_SETTINGS = {
"backgroundColor": "white",
"orthographic": False,
"disableFog": False,
}
# ==== UI definition ====
# ==== SPE Inputs and Outputs ====
input_sp = gr.File(
label="Upload Structure File",
file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
file_count="single",
render=False
)
task_name_sp = gr.Radio(
["OMol", "OMat", "OMol-Direct"],
value="OMol",
label="Model Type",
render=False
)
total_charge_sp = gr.Slider(-10, 10, 0, step=1, label="Charge",render=False)
spin_multiplicity_sp = gr.Slider(1, 11, 1, step=1, label="Spin Multiplicity",render=False)
run_sp = gr.Button("Run OrbMol Prediction", variant="primary",render=False)
spe_out = gr.Textbox(label="Energy & Forces", lines=15, interactive=False,render=False)
spe_status = gr.Textbox(label="Status", interactive=False,render=False)
spe_viewer = Molecule3D(
label="Input Structure Viewer",
reps=DEFAULT_MOLECULAR_REPRESENTATIONS,
config=DEFAULT_MOLECULAR_SETTINGS,
render=False
)
#==== MD Inputs and Outputs ====
input_md = gr.File(
label="Upload Structure File",
file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
file_count="single",
render=False
)
task_name_md = gr.Radio(
["OMol", "OMat", "OMol-Direct"],
value="OMol",
label="Model Type",
render=False
)
charge_md = gr.Slider(-10, 10, 0, step=1, label="Charge",render=False)
spin_md = gr.Slider(1, 11, 1, step=1, label="Spin Multiplicity",render=False)
steps_md = gr.Slider(10, 2000, 100, step=10, label="Steps",render=False)
temp_md = gr.Slider(10, 1500, 300, step=10, label="Temperature (K)",render=False)
timestep_md = gr.Slider(0.1, 5.0, 1.0, step=0.1, label="Timestep (fs)",render=False)
ensemble_md = gr.Radio(["NVE", "NVT"], value="NVE", label="Ensemble",render=False)
run_md_btn = gr.Button("Run MD Simulation", variant="primary",render=False)
md_status = gr.Textbox(label="MD Status", interactive=False,render=False)
md_traj = gr.File(label="Trajectory (.traj)", interactive=False,render=False)
md_viewer = Molecule3D(
label="MD Result Viewer",
reps=DEFAULT_MOLECULAR_REPRESENTATIONS,
config=DEFAULT_MOLECULAR_SETTINGS,
render=False
)
md_log = gr.Textbox(label="Log", interactive=False, lines=15,render=False)
md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20,render=False)
md_explain = gr.Markdown(render=False)
#==== Relax Inputs and Outputs ====
input_rlx = gr.File(
label="Upload Structure File",
file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
file_count="single",
render=False
)
task_name_rlx = gr.Radio(
["OMol", "OMat", "OMol-Direct"],
value="OMol",
label="Model Type",
render=False
)
steps_rlx = gr.Slider(1, 2000, 300, step=1, label="Max Steps",render=False)
fmax_rlx = gr.Slider(0.001, 0.5, 0.05, step=0.001, label="Fmax (eV/Å)",render=False)
charge_rlx = gr.Slider(-10, 10, 0, step=1, label="Charge",render=False)
spin_rlx = gr.Slider(1, 11, 1, step=1, label="Spin",render=False)
relax_cell = gr.Checkbox(False, label="Relax Unit Cell",render=False)
run_rlx_btn = gr.Button("Run Optimization", variant="primary",render=False)
rlx_status = gr.Textbox(label="Status", interactive=False,render=False)
rlx_traj = gr.File(label="Trajectory (.traj)", interactive=False,render=False)
rlx_viewer = Molecule3D(
label="Optimized Structure Viewer",
reps=DEFAULT_MOLECULAR_REPRESENTATIONS,
config=DEFAULT_MOLECULAR_SETTINGS,
render=False
)
rlx_log = gr.Textbox(label="Log", interactive=False, lines=15,render=False)
rlx_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20,render=False)
rlx_explain = gr.Markdown(render=False)
# ==== Conversión a PDB para Molecule3D ====
def convert_to_pdb_for_viewer(file_path):
"""Convierte cualquier archivo a PDB para Molecule3D"""
if not file_path or not os.path.exists(file_path):
return None
try:
atoms = read(file_path)
cache_dir = os.path.join(tempfile.gettempdir(), "gradio")
os.makedirs(cache_dir, exist_ok=True)
pdb_path = os.path.join(cache_dir, f"mol_{hashlib.md5(file_path.encode()).hexdigest()[:12]}.pdb")
write(pdb_path, atoms, format="proteindatabank")
return pdb_path
except Exception as e:
print(f"Error converting to PDB: {e}")
return None
# ==== OrbMol SPE ====
def predict_molecule(structure_file, task_name, charge=0, spin_multiplicity=1):
"""Single Point Energy + fuerzas (OrbMol)"""
try:
calc = load_orbmol_model(task_name)
if not structure_file:
return "Error: Please upload a structure file", "Error", None
file_path = structure_file
if not os.path.exists(file_path):
return f"Error: File not found: {file_path}", "Error", None
if os.path.getsize(file_path) == 0:
return f"Error: Empty file: {file_path}", "Error", None
atoms = read(file_path)
if task_name in ["OMol", "OMol-Direct"]:
atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
atoms.calc = calc
energy = atoms.get_potential_energy()
forces = atoms.get_forces()
lines = [
f"Model: {task_name}",
f"Total Energy: {energy:.6f} eV",
"",
"Atomic Forces:"
]
for i, fc in enumerate(forces):
lines.append(f"Atom {i+1}: [{fc[0]:.4f}, {fc[1]:.4f}, {fc[2]:.4f}] eV/Å")
max_force = float(np.max(np.linalg.norm(forces, axis=1)))
lines += ["", f"Max Force: {max_force:.4f} eV/Å"]
pdb_file = convert_to_pdb_for_viewer(file_path)
return "\n".join(lines), f"Calculation completed with {task_name}", pdb_file
except Exception as e:
import traceback
traceback.print_exc()
return f"Error during calculation: {e}", "Error", None
# ==== Wrappers MD y Relax ====
def md_wrapper(structure_file, task_name, charge, spin, steps, tempK, timestep_fs, ensemble):
try:
if not structure_file:
return ("Error: Please upload a structure file", None, "", "", "", None)
traj_path, log_text, script_text, explanation = run_md_simulation(
structure_file,
int(steps),
20,
float(timestep_fs),
float(tempK),
"NVT" if ensemble == "NVT" else "NVE",
str(task_name),
int(charge),
int(spin),
)
status = f"MD completed: {int(steps)} steps at {int(tempK)} K ({ensemble})"
pdb_file = convert_to_pdb_for_viewer(traj_path)
return (status, traj_path, log_text, script_text, explanation, pdb_file)
except Exception as e:
import traceback
traceback.print_exc()
return (f"Error: {e}", None, "", "", "", None)
def relax_wrapper(structure_file, task_name, steps, fmax, charge, spin, relax_cell):
try:
if not structure_file:
return ("Error: Please upload a structure file", None, "", "", "", None)
traj_path, log_text, script_text, explanation = run_relaxation_simulation(
structure_file,
int(steps),
float(fmax),
str(task_name),
int(charge),
int(spin),
bool(relax_cell),
)
status = f"Relaxation finished (<={int(steps)} steps, fmax={float(fmax)} eV/Å)"
pdb_file = convert_to_pdb_for_viewer(traj_path)
return (status, traj_path, log_text, script_text, explanation, pdb_file)
except Exception as e:
import traceback
traceback.print_exc()
return (f"Error: {e}", None, "", "", "", None)
# ==== UI ====
with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
with gr.Tabs():
# -------- HOME TAB (NUEVA) --------
with gr.Tab("Home"):
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Learn more about OrbMol")
with gr.Accordion("What is OrbMol?", open=False):
gr.Markdown("""
OrbMol is a suite of quantum-accurate machine learning models for molecular predictions. Built on the **Orb-v3 architecture**, OrbMol provides fast and accurate calculations of energies, forces, and molecular properties at the level of advanced quantum chemistry methods.
""")
with gr.Accordion("Available Models", open=False):
gr.Markdown("""
**OMol** and **OMol-Direct**
- **Training dataset**: OMol25 (>100M calculations on small molecules, biomolecules, metal complexes, and electrolytes)
- **Level of theory**: ωB97M-V/def2-TZVPD with non-local dispersion; solvation treated explicitly
- **Inputs**: total charge & spin multiplicity
- **Applications**: biology, organic chemistry, protein folding, small-molecule drugs, organic liquids, homogeneous catalysis
- **Caveats**: trained only on aperiodic systems → periodic/inorganic cases may not work well
- **Difference**: OMol enforces energy–force consistency; OMol-Direct relaxes this for efficiency
**OMat**
- **Training dataset**: OMat24 (>100M inorganic calculations, from Materials Project, Alexandria, and far-from-equilibrium samples)
- **Level of theory**: PBE/PBE+U with Materials Project settings; VASP 54 pseudopotentials; no dispersion
- **Inputs**: No support for spin and charge
- **Applications**: inorganic discovery, photovoltaics, alloys, superconductors, electronic/optical materials
""")
with gr.Accordion("Supported File Formats", open=False):
gr.Markdown("""
OrbMol supports: `.xyz`, `.pdb`, `.cif`, `.traj`, `.mol`, `.sdf`
""")
with gr.Accordion("Resources & Support", open=False):
gr.Markdown("""
- [Orb-v3 paper](https://arxiv.org/abs/2504.06231)
- [Orb-Models GitHub repository](https://github.com/orbital-materials/orb-models)
""")
with gr.Column(scale=2):
gr.Image("logo_color_text.png",
show_share_button=False,
show_download_button=False,
show_label=False,
show_fullscreen_button=False)
gr.Markdown("# OrbMol — Quantum-Accurate Molecular Predictions")
gr.Markdown("""
Welcome to the OrbMol demo! Use the tabs above to access different functionalities:
1. **Single Point Energy**: Calculate energies and forces
2. **Molecular Dynamics**: Run MD simulations
3. **Relaxation / Optimization**: Optimize structures
Supported formats: `.xyz`, `.pdb`, `.cif`, `.traj`, `.mol`, `.sdf`
""")
# -------- SPE (IDÉNTICA A LA QUE FUNCIONABA) --------
with gr.Tab("Single Point Energy"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("# OrbMol: Single Point Energy")
gr.Markdown("Run a quantum-accurate single point energy calculation with OrbMol models.")
gr.Markdown("**Supported formats:** .xyz, .pdb, .cif, .traj, .mol, .sdf")
gr.Examples(
examples=[[str(Path(__file__).parent/ "./examples/10mer_b-DNA.pdb"),
"OMol",
0,
1,
"",
],
[str(Path(__file__).parent/ "./examples/PtO2.cif"),
"OMat",
0,
1,
"",
],
],
example_labels=[
"A 10-mer DNA (A tiny DNA fragment)",
"An inorganic crystal (PtO2)",
],
inputs=[
input_sp,
task_name_sp,
total_charge_sp,
spin_multiplicity_sp
],
outputs=[
spe_out,
spe_status,
spe_viewer
],
fn=predict_molecule,
run_on_click=True,
cache_examples=True,
label="Try an example!",
)
input_sp.render()
task_name_sp.render()
with gr.Row():
total_charge_sp.render()
spin_multiplicity_sp.render()
run_sp.render()
with gr.Column(variant="panel", min_width=500):
spe_out.render()
spe_status.render()
spe_viewer.render()
task_name_sp.change(
lambda x: (
gr.update(visible=x in ["OMol", "OMol-Direct"]),
gr.update(visible=x in ["OMol", "OMol-Direct"])
),
[task_name_sp],
[total_charge_sp, spin_multiplicity_sp]
)
run_sp.click(
predict_molecule,
[input_sp, task_name_sp, total_charge_sp, spin_multiplicity_sp],
[spe_out, spe_status, spe_viewer]
)
# -------- MD (IDÉNTICA A LA QUE FUNCIONABA) --------
with gr.Tab("Molecular Dynamics"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## OrbMol: Molecular Dynamics Simulation")
gr.Markdown("Run a quantum-accurate molecular dynamics simulation with OrbMol models.")
gr.Markdown("**Supported formats:** .xyz, .pdb, .cif, .traj, .mol, .sdf")
gr.Examples(
examples=[[str(Path(__file__).parent/ "./examples/Cyclohexane.sdf"),
"OMol",
0,
1,
1000,
300,
0.5,
"NVT",
"",
],
[str(Path(__file__).parent/ "./examples/Cyclohexane.sdf"),
"OMol-Direct",
0,
1,
1000,
300,
0.5,
"NVT",
"",
],
],
example_labels=[
"Cyclohexane with OMol",
"Cyclohexane with OMol-Direct",
],
inputs=[
input_md,
task_name_md,
charge_md,
spin_md,
steps_md,
temp_md,
timestep_md,
ensemble_md
],
outputs=[
md_status,
md_traj,
md_log,
md_script,
md_explain,
md_viewer
],
fn=md_wrapper,
run_on_click=True,
cache_examples=True,
label="Try an example!",
)
input_md.render()
task_name_md.render()
with gr.Row():
charge_md.render()
spin_md.render()
with gr.Row():
steps_md.render()
temp_md.render()
with gr.Row():
timestep_md.render()
ensemble_md.render()
run_md_btn.render()
with gr.Column(variant="panel", min_width=520):
md_status.render()
md_traj.render()
md_viewer.render()
md_log.render()
md_script.render()
md_explain.render()
task_name_md.change(
lambda x: (
gr.update(visible=x in ["OMol", "OMol-Direct"]),
gr.update(visible=x in ["OMol", "OMol-Direct"])
),
[task_name_md],
[charge_md, spin_md]
)
run_md_btn.click(
md_wrapper,
[input_md, task_name_md, charge_md, spin_md, steps_md, temp_md, timestep_md, ensemble_md],
[md_status, md_traj, md_log, md_script, md_explain, md_viewer]
)
# -------- Relax (IDÉNTICA A LA QUE FUNCIONABA) --------
with gr.Tab("Relaxation / Optimization"):
with gr.Row():
with gr.Column(scale=2):
gr.Markdown("## OrbMol: Structure Relaxation/Optimization")
gr.Markdown("Run a quantum-accurate structure relaxation with OrbMol models.")
gr.Markdown("**Supported formats:** .xyz, .pdb, .cif, .traj, .mol, .sdf")
gr.Examples(
examples=[[str(Path(__file__).parent/ "./examples/dioctyl_sebacate.xyz"),
"OMol",
1000,
0.05,
0,
1,
False,
"",
],
[str(Path(__file__).parent/ "./examples/Fe2O3.cif"),
"OMat",
1000,
0.05,
0,
1,
False,
"",
],
],
example_labels=[
"A branched Alkane (Dioctyl Sebacate)",
"An inorganic crystal (Fe2O3)",
],
inputs=[
input_rlx,
task_name_rlx,
steps_rlx,
fmax_rlx,
charge_rlx,
spin_rlx,
relax_cell
],
outputs=[
rlx_status,
rlx_traj,
rlx_log,
rlx_script,
rlx_explain,
rlx_viewer
],
fn=relax_wrapper,
run_on_click=True,
cache_examples=True,
label="Try an example!",
)
input_rlx.render()
task_name_rlx.render()
with gr.Row():
steps_rlx.render()
fmax_rlx.render()
with gr.Row():
charge_rlx.render()
spin_rlx.render()
relax_cell.render()
run_rlx_btn.render()
with gr.Column(variant="panel", min_width=520):
rlx_status.render()
rlx_traj.render()
rlx_viewer.render()
rlx_log.render()
rlx_script.render()
rlx_explain.render()
task_name_rlx.change(
lambda x: (
gr.update(visible=x in ["OMol", "OMol-Direct"]),
gr.update(visible=x in ["OMol", "OMol-Direct"])
),
[task_name_rlx],
[charge_rlx, spin_rlx]
)
run_rlx_btn.click(
relax_wrapper,
[input_rlx, task_name_rlx, steps_rlx, fmax_rlx, charge_rlx, spin_rlx, relax_cell],
[rlx_status, rlx_traj, rlx_log, rlx_script, rlx_explain, rlx_viewer]
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)