annabossler commited on
Commit
72ea7e1
·
verified ·
1 Parent(s): 0ade04a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +61 -58
app.py CHANGED
@@ -4,21 +4,17 @@ import numpy as np
4
  import gradio as gr
5
 
6
  from ase.io import read
7
- from ase.optimize import LBFGS
8
- from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
9
- from ase.md.verlet import VelocityVerlet
10
  from ase.io.trajectory import Trajectory
11
- from ase import units
12
 
13
- # Visualizador 3D “pro” si está disponible
14
  try:
15
  from gradio_molecule3d import Molecule3D
16
  HAVE_MOL3D = True
17
  except Exception:
18
  HAVE_MOL3D = False
19
 
20
- # --- Helpers de visualización: fallback 3Dmol.js si no hay Molecule3D ---
21
  def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
 
22
  traj = Trajectory(traj_path)
23
  xyz_frames = []
24
  for atoms in traj:
@@ -53,22 +49,27 @@ def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
53
  """
54
  return html
55
 
56
- # --- OrbMol directo para SPE ---
 
57
  from orb_models.forcefield import pretrained
58
  from orb_models.forcefield.calculator import ORBCalculator
59
 
60
- _model_calc = None
61
  def _load_orbmol_calc():
62
- global _model_calc
63
- if _model_calc is None:
64
- orbff = pretrained.orb_v3_conservative_inf_omat(device="cpu", precision="float32-high")
65
- _model_calc = ORBCalculator(orbff, device="cpu")
66
- return _model_calc
 
 
 
67
 
68
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
 
69
  try:
70
  calc = _load_orbmol_calc()
71
- if not xyz_content.strip():
72
  return "Error: Please enter XYZ coordinates", "Error"
73
 
74
  with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f:
@@ -97,20 +98,21 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
97
  except Exception as e:
98
  return f"Error during calculation: {e}", "Error"
99
 
100
- # --- Importa las rutinas FAIRChem-like para MD/Relax (ya soportan string XYZ o ruta) ---
 
101
  from simulation_scripts_orbmol import (
102
  run_md_simulation,
103
  run_relaxation_simulation,
104
  last_frame_xyz_from_traj,
105
  )
106
 
107
- # Wrappers para conectar outputs de Gradio correctamente (string XYZ / HTML, file, logs...)
108
  def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
 
109
  try:
110
  traj_path, log_text, script_text, explanation = run_md_simulation(
111
- xyz_content, # acepta string XYZ
112
  int(steps),
113
- 20, # pre-relax steps como en la UI de UMA
114
  float(timestep_fs),
115
  float(tempK),
116
  "NVT" if ensemble == "NVT" else "NVE",
@@ -118,28 +120,28 @@ def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
118
  int(spin),
119
  )
120
  status = f"MD completed: {int(steps)} steps at {int(tempK)} K ({ensemble})"
121
- # Viewer
122
  if HAVE_MOL3D:
123
- xyz_final = last_frame_xyz_from_traj(traj_path) # value para Molecule3D
124
- viewer_value = xyz_final
125
  html_value = None
126
  else:
127
  viewer_value = None
128
  html_value = traj_to_html(traj_path)
129
 
130
  return (
131
- status, # MD Status
132
- viewer_value, # Molecule3D value (o None)
133
- html_value, # HTML fallback (o None)
134
- traj_path, # File download
135
- log_text, # Log
136
- script_text, # Script
137
- explanation, # Explanation
138
  )
139
  except Exception as e:
140
  return (f"Error: {e}", None, None, None, "", "", "")
141
 
142
  def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
 
143
  try:
144
  traj_path, log_text, script_text, explanation = run_relaxation_simulation(
145
  xyz_content,
@@ -150,6 +152,7 @@ def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
150
  bool(relax_cell),
151
  )
152
  status = f"Relaxation finished (≤ {int(steps)} steps, fmax={float(fmax)} eV/Å)"
 
153
  if HAVE_MOL3D:
154
  viewer_value = last_frame_xyz_from_traj(traj_path)
155
  html_value = None
@@ -158,20 +161,19 @@ def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
158
  html_value = traj_to_html(traj_path)
159
 
160
  return (
161
- status,
162
- viewer_value,
163
- html_value,
164
- traj_path,
165
- log_text,
166
- script_text,
167
- explanation,
168
  )
169
  except Exception as e:
170
  return (f"Error: {e}", None, None, None, "", "", "")
171
 
172
- # ------------------------
173
- # Ejemplos rápidos
174
- # ------------------------
175
  examples = [
176
  ["""2
177
  Hydrogen molecule
@@ -191,21 +193,19 @@ H -0.3630 -0.5133 0.8887
191
  H -0.3630 -0.5133 -0.8887""", 0, 1],
192
  ]
193
 
194
- # ------------------------
195
- # UI Gradio
196
- # ------------------------
197
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
198
  with gr.Tabs():
199
- # ========== SPE ==========
200
  with gr.Tab("Single Point Energy"):
201
  with gr.Row():
202
  with gr.Column(scale=2):
203
  gr.Markdown("# OrbMol Demo — Quantum-Accurate Molecular Predictions")
204
  gr.Markdown(
205
- "Predict **energies** and **forces** with OrbMol (OMol25). "
206
- "Supports **charge** and **spin multiplicity**."
207
  )
208
-
209
  xyz_input = gr.Textbox(
210
  label="XYZ Coordinates",
211
  placeholder="Paste XYZ here...",
@@ -237,17 +237,17 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
237
  )
238
  with gr.Accordion("Benchmarks", open=False):
239
  gr.Markdown(
240
- "* Near-DFT accuracy on **GMTKN55** and **Wiggle150**\n"
241
- "* Accurate **protein–ligand** interaction energies (PLA15)\n"
242
- "* Stable long MD on biomolecules grandes"
243
  )
244
  with gr.Accordion("Disclaimers", open=False):
245
  gr.Markdown(
246
- "* Validate for your use case\n"
247
- "* Consider training **level of theory** and intended domain"
248
  )
249
 
250
- # ========== MD ==========
251
  with gr.Tab("Molecular Dynamics"):
252
  with gr.Row():
253
  with gr.Column(scale=2):
@@ -256,7 +256,7 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
256
  charge_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
257
  spin_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
258
  with gr.Row():
259
- steps_md = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps")
260
  temp_md = gr.Slider(value=300, minimum=10, maximum=1500, step=10, label="Temperature (K)")
261
  with gr.Row():
262
  timestep_md = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
@@ -265,15 +265,16 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
265
 
266
  with gr.Column(variant="panel", min_width=520):
267
  md_status = gr.Textbox(label="MD Status", interactive=False)
 
268
  if HAVE_MOL3D:
269
  md_viewer = Molecule3D(label="Trajectory Viewer")
270
- md_html = gr.HTML(visible=False) # placeholder para consistencia de outputs
271
  else:
272
  md_viewer = gr.Textbox(visible=False) # placeholder
273
  md_html = gr.HTML()
274
 
275
  md_traj = gr.File(label="Trajectory (.traj)", interactive=False)
276
- md_log = gr.Code(label="Log", language="text", interactive=False, lines=15, max_lines=25)
277
  md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
278
  md_explain = gr.Markdown()
279
 
@@ -283,12 +284,12 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
283
  outputs=[md_status, md_viewer, md_html, md_traj, md_log, md_script, md_explain],
284
  )
285
 
286
- # ========== Relaxation ==========
287
  with gr.Tab("Relaxation / Optimization"):
288
  with gr.Row():
289
  with gr.Column(scale=2):
290
  xyz_rlx = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here...")
291
- steps_rlx = gr.Slider(value=300, minimum=1, maximum=1000, step=1, label="Max Steps")
292
  fmax_rlx = gr.Slider(value=0.05, minimum=0.001, maximum=0.5, step=0.001, label="Fmax (eV/Å)")
293
  with gr.Row():
294
  charge_rlx = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
@@ -298,14 +299,16 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
298
 
299
  with gr.Column(variant="panel", min_width=520):
300
  rlx_status = gr.Textbox(label="Status", interactive=False)
 
301
  if HAVE_MOL3D:
302
  rlx_viewer = Molecule3D(label="Final Structure")
303
  rlx_html = gr.HTML(visible=False)
304
  else:
305
  rlx_viewer = gr.Textbox(visible=False)
306
  rlx_html = gr.HTML()
 
307
  rlx_traj = gr.File(label="Trajectory (.traj)", interactive=False)
308
- rlx_log = gr.Code(label="Log", language="text", interactive=False, lines=15, max_lines=25)
309
  rlx_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
310
  rlx_explain = gr.Markdown()
311
 
@@ -316,7 +319,7 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
316
  )
317
 
318
  print("Starting OrbMol model loading…")
319
- _load_orbmol_calc()
320
 
321
  if __name__ == "__main__":
322
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
4
  import gradio as gr
5
 
6
  from ase.io import read
 
 
 
7
  from ase.io.trajectory import Trajectory
 
8
 
9
+ # ==== Visualizador 3D ====
10
  try:
11
  from gradio_molecule3d import Molecule3D
12
  HAVE_MOL3D = True
13
  except Exception:
14
  HAVE_MOL3D = False
15
 
 
16
  def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
17
+ """Fallback de visualización con 3Dmol.js si no hay Molecule3D."""
18
  traj = Trajectory(traj_path)
19
  xyz_frames = []
20
  for atoms in traj:
 
49
  """
50
  return html
51
 
52
+
53
+ # ==== OrbMol SPE directo (rápido) ====
54
  from orb_models.forcefield import pretrained
55
  from orb_models.forcefield.calculator import ORBCalculator
56
 
57
+ _MODEL_CALC = None
58
  def _load_orbmol_calc():
59
+ global _MODEL_CALC
60
+ if _MODEL_CALC is None:
61
+ orbff = pretrained.orb_v3_conservative_inf_omat(
62
+ device="cpu",
63
+ precision="float32-high"
64
+ )
65
+ _MODEL_CALC = ORBCalculator(orbff, device="cpu")
66
+ return _MODEL_CALC
67
 
68
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
69
+ """Single Point Energy/Forces con OrbMol (input en Textbox XYZ)."""
70
  try:
71
  calc = _load_orbmol_calc()
72
+ if not xyz_content or not xyz_content.strip():
73
  return "Error: Please enter XYZ coordinates", "Error"
74
 
75
  with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f:
 
98
  except Exception as e:
99
  return f"Error during calculation: {e}", "Error"
100
 
101
+
102
+ # ==== Simulación (estilo UMA) vía helpers locales ====
103
  from simulation_scripts_orbmol import (
104
  run_md_simulation,
105
  run_relaxation_simulation,
106
  last_frame_xyz_from_traj,
107
  )
108
 
 
109
  def md_wrapper(xyz_content, charge, spin, steps, tempK, timestep_fs, ensemble):
110
+ """Conecta Gradio con run_md_simulation y adapta outputs (viewer/HTML/file/log/script)."""
111
  try:
112
  traj_path, log_text, script_text, explanation = run_md_simulation(
113
+ xyz_content, # acepta string XYZ o ruta
114
  int(steps),
115
+ 20, # prerelax por defecto, como UMA
116
  float(timestep_fs),
117
  float(tempK),
118
  "NVT" if ensemble == "NVT" else "NVE",
 
120
  int(spin),
121
  )
122
  status = f"MD completed: {int(steps)} steps at {int(tempK)} K ({ensemble})"
123
+
124
  if HAVE_MOL3D:
125
+ viewer_value = last_frame_xyz_from_traj(traj_path) # string XYZ como value
 
126
  html_value = None
127
  else:
128
  viewer_value = None
129
  html_value = traj_to_html(traj_path)
130
 
131
  return (
132
+ status, # md_status
133
+ viewer_value, # md_viewer (Molecule3D value) o None
134
+ html_value, # md_html (fallback) o None
135
+ traj_path, # md_traj file
136
+ log_text, # md_log (Textbox)
137
+ script_text, # md_script (Code py)
138
+ explanation, # md_explain
139
  )
140
  except Exception as e:
141
  return (f"Error: {e}", None, None, None, "", "", "")
142
 
143
  def relax_wrapper(xyz_content, steps, fmax, charge, spin, relax_cell):
144
+ """Conecta Gradio con run_relaxation_simulation y adapta outputs."""
145
  try:
146
  traj_path, log_text, script_text, explanation = run_relaxation_simulation(
147
  xyz_content,
 
152
  bool(relax_cell),
153
  )
154
  status = f"Relaxation finished (≤ {int(steps)} steps, fmax={float(fmax)} eV/Å)"
155
+
156
  if HAVE_MOL3D:
157
  viewer_value = last_frame_xyz_from_traj(traj_path)
158
  html_value = None
 
161
  html_value = traj_to_html(traj_path)
162
 
163
  return (
164
+ status, # rlx_status
165
+ viewer_value, # rlx_viewer (Molecule3D value) o None
166
+ html_value, # rlx_html (fallback) o None
167
+ traj_path, # rlx_traj file
168
+ log_text, # rlx_log (Textbox)
169
+ script_text, # rlx_script (Code py)
170
+ explanation, # rlx_explain
171
  )
172
  except Exception as e:
173
  return (f"Error: {e}", None, None, None, "", "", "")
174
 
175
+
176
+ # ==== Ejemplos ====
 
177
  examples = [
178
  ["""2
179
  Hydrogen molecule
 
193
  H -0.3630 -0.5133 -0.8887""", 0, 1],
194
  ]
195
 
196
+
197
+ # ==== UI Gradio ====
 
198
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
199
  with gr.Tabs():
200
+ # --- Tab SPE ---
201
  with gr.Tab("Single Point Energy"):
202
  with gr.Row():
203
  with gr.Column(scale=2):
204
  gr.Markdown("# OrbMol Demo — Quantum-Accurate Molecular Predictions")
205
  gr.Markdown(
206
+ "Predict **energies** and **forces** with OrbMol (trained on **OMol25**, "
207
+ "ωB97M-V/def2-TZVPD). Supports **charge** and **spin multiplicity**."
208
  )
 
209
  xyz_input = gr.Textbox(
210
  label="XYZ Coordinates",
211
  placeholder="Paste XYZ here...",
 
237
  )
238
  with gr.Accordion("Benchmarks", open=False):
239
  gr.Markdown(
240
+ "* Strong results on **GMTKN55** y **Wiggle150**\n"
241
+ "* Accurate **protein–ligand** energies (PLA15)\n"
242
+ "* Stable MD en biomoléculas grandes"
243
  )
244
  with gr.Accordion("Disclaimers", open=False):
245
  gr.Markdown(
246
+ "* Verifica resultados para tu caso\n"
247
+ "* Considera el **nivel de teoría** de entrenamiento"
248
  )
249
 
250
+ # --- Tab MD ---
251
  with gr.Tab("Molecular Dynamics"):
252
  with gr.Row():
253
  with gr.Column(scale=2):
 
256
  charge_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
257
  spin_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
258
  with gr.Row():
259
+ steps_md = gr.Slider(value=100, minimum=10, maximum=2000, step=10, label="Steps")
260
  temp_md = gr.Slider(value=300, minimum=10, maximum=1500, step=10, label="Temperature (K)")
261
  with gr.Row():
262
  timestep_md = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
 
265
 
266
  with gr.Column(variant="panel", min_width=520):
267
  md_status = gr.Textbox(label="MD Status", interactive=False)
268
+
269
  if HAVE_MOL3D:
270
  md_viewer = Molecule3D(label="Trajectory Viewer")
271
+ md_html = gr.HTML(visible=False) # placeholder para layout consistente
272
  else:
273
  md_viewer = gr.Textbox(visible=False) # placeholder
274
  md_html = gr.HTML()
275
 
276
  md_traj = gr.File(label="Trajectory (.traj)", interactive=False)
277
+ md_log = gr.Textbox(label="Log", interactive=False, lines=15, max_lines=25) # <- FIX Code->Textbox
278
  md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
279
  md_explain = gr.Markdown()
280
 
 
284
  outputs=[md_status, md_viewer, md_html, md_traj, md_log, md_script, md_explain],
285
  )
286
 
287
+ # --- Tab Relaxation ---
288
  with gr.Tab("Relaxation / Optimization"):
289
  with gr.Row():
290
  with gr.Column(scale=2):
291
  xyz_rlx = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here...")
292
+ steps_rlx = gr.Slider(value=300, minimum=1, maximum=2000, step=1, label="Max Steps")
293
  fmax_rlx = gr.Slider(value=0.05, minimum=0.001, maximum=0.5, step=0.001, label="Fmax (eV/Å)")
294
  with gr.Row():
295
  charge_rlx = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
 
299
 
300
  with gr.Column(variant="panel", min_width=520):
301
  rlx_status = gr.Textbox(label="Status", interactive=False)
302
+
303
  if HAVE_MOL3D:
304
  rlx_viewer = Molecule3D(label="Final Structure")
305
  rlx_html = gr.HTML(visible=False)
306
  else:
307
  rlx_viewer = gr.Textbox(visible=False)
308
  rlx_html = gr.HTML()
309
+
310
  rlx_traj = gr.File(label="Trajectory (.traj)", interactive=False)
311
+ rlx_log = gr.Textbox(label="Log", interactive=False, lines=15, max_lines=25) # <- FIX Code->Textbox
312
  rlx_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
313
  rlx_explain = gr.Markdown()
314
 
 
319
  )
320
 
321
  print("Starting OrbMol model loading…")
322
+ _LOAD = _load_orbmol_calc()
323
 
324
  if __name__ == "__main__":
325
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)