annabossler commited on
Commit
c926c74
·
verified ·
1 Parent(s): e72bc1b

Update app.py

Browse files

separate logic (orbmol and visualization )

Files changed (1) hide show
  1. app.py +61 -133
app.py CHANGED
@@ -10,12 +10,12 @@ from ase.md.verlet import VelocityVerlet
10
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
11
  from ase.io.trajectory import Trajectory
12
 
13
- # py3Dmol es opcional; si no está instalado haremos fallback a 3Dmol.js
14
  try:
15
- import py3Dmol # noqa: F401
16
- HAVE_PY3DMOL = True
17
  except Exception:
18
- HAVE_PY3DMOL = False
19
 
20
  from orb_models.forcefield import pretrained
21
  from orb_models.forcefield.calculator import ORBCalculator
@@ -67,16 +67,13 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
67
  forces = atoms.get_forces() # eV/Å
68
 
69
  lines = []
70
- lines.append(f"Total Energy: {energy:.6f} eV")
71
- lines.append("")
72
  lines.append("Atomic Forces:")
73
  for i, f in enumerate(forces):
74
- lines.append(
75
- f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Å"
76
- )
77
  max_force = float(np.max(np.linalg.norm(forces, axis=1)))
78
- lines.append("")
79
- lines.append(f"Max Force: {max_force:.4f} eV/Å")
80
 
81
  os.unlink(xyz_file)
82
  return "\n".join(lines), "Calculation completed with OrbMol"
@@ -86,17 +83,11 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
86
 
87
 
88
  # =========================
89
- # Trajectory → HTML 3D
90
- # (funciona con o sin py3Dmol)
91
  # =========================
92
  def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
93
- """
94
- Genera un bloque HTML con 3Dmol para animar una lista de frames en XYZ.
95
- Sin dependencias de JupyterLab.
96
- """
97
  traj = Trajectory(traj_path)
98
  xyz_frames = []
99
-
100
  for atoms in traj:
101
  symbols = atoms.get_chemical_symbols()
102
  coords = atoms.get_positions()
@@ -105,37 +96,25 @@ def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
105
  parts.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
106
  xyz_frames.append("\n".join(parts))
107
 
108
- # HTML+JS autopackage
109
  html = f"""
110
- <div id="viewer_md" style="width:{width}px; height:{height}px; position:relative;"></div>
111
  <script src="https://3dmol.org/build/3Dmol-min.js"></script>
112
  <script>
113
  (function() {{
114
- var element = document.getElementById('viewer_md');
115
- if (!element) return;
116
- var viewer = $3Dmol.createViewer(element, {{backgroundColor: 'white'}});
117
  var frames = {xyz_frames!r};
118
  var i = 0;
119
-
120
  function show(i) {{
121
  viewer.clear();
122
- viewer.addModel(frames[i], 'xyz');
123
- viewer.setStyle({{}}, {{stick: {{}} }});
124
  viewer.zoomTo();
125
  viewer.render();
126
  }}
127
-
128
- if (frames.length === 0) {{
129
- element.innerHTML = '<div style="padding:8px;color:#555">Empty trajectory</div>';
130
- }} else if (frames.length === 1) {{
131
- show(0);
132
- }} else {{
133
- show(0);
134
- setInterval(function() {{
135
- i = (i + 1) % frames.length;
136
- show(i);
137
- }}, {int(interval_ms)});
138
- }}
139
  }})();
140
  </script>
141
  """
@@ -168,38 +147,39 @@ def run_md(xyz_content, charge=0, spin_multiplicity=1,
168
  opt = LBFGS(atoms, logfile=None)
169
  opt.run(fmax=0.05, steps=20)
170
 
171
- # Velocidades (2*T por partición de energía)
172
  MaxwellBoltzmannDistribution(atoms, temperature_K=2*float(temperature))
173
 
174
- # Integrador NVE
175
  dyn = VelocityVerlet(atoms, timestep=float(timestep) * units.fs)
176
 
177
- # Trayectoria temporal
178
  tf = tempfile.NamedTemporaryFile(suffix=".traj", delete=False)
179
  tf.close()
180
  traj = Trajectory(tf.name, "w", atoms)
181
  dyn.attach(traj.write, interval=1)
182
-
183
- # Ejecutar MD
184
  dyn.run(int(steps))
185
 
186
- # Visualización 3D
187
- html = traj_to_html(tf.name)
 
 
 
 
 
 
 
188
 
189
- # Limpieza archivo xyz temporal
190
  try:
191
  os.unlink(xyz_file)
192
  except Exception:
193
  pass
194
 
195
- return f"MD completed: {int(steps)} steps at {int(temperature)} K", html
196
 
197
  except Exception as e:
198
  return f"Error during MD simulation: {str(e)}", ""
199
 
200
 
201
  # =========================
202
- # Ejemplos (tu set)
203
  # =========================
204
  examples = [
205
  ["""2
@@ -228,67 +208,31 @@ H -0.3630 -0.5133 -0.8887""", 0, 1],
228
  # =========================
229
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
230
  with gr.Tabs():
231
- # -------- Tab 1: tu layout original --------
232
  with gr.Tab("Single Point Energy"):
233
  with gr.Row():
234
  with gr.Column(scale=2):
235
  with gr.Column(variant="panel"):
236
- gr.Markdown("# OrbMol Demo - Quantum-Accurate Molecular Predictions")
 
237
 
238
- gr.Markdown(
239
- "OrbMol is a neural network potential trained on the OMol25 dataset. "
240
- "Predict energies and forces with configurable charge and spin."
 
241
  )
242
 
243
- gr.Markdown("## Simulation inputs")
244
- with gr.Column(variant="panel"):
245
- gr.Markdown("### Input molecular structure")
246
-
247
- xyz_input = gr.Textbox(
248
- label="XYZ Coordinates",
249
- placeholder=(
250
- "3\nWater molecule\n"
251
- "O 0.0000 0.0000 0.0000\n"
252
- "H 0.7571 0.0000 0.5864\n"
253
- "H -0.7571 0.0000 0.5864"
254
- ),
255
- lines=12,
256
- info="Paste XYZ coordinates of your molecule here",
257
- )
258
-
259
- gr.Markdown("OMol-specific settings for total charge and spin multiplicity")
260
- with gr.Row():
261
- charge_input = gr.Slider(
262
- value=0, label="Total Charge", minimum=-10, maximum=10, step=1
263
- )
264
- spin_input = gr.Slider(
265
- value=1, maximum=11, minimum=1, step=1, label="Spin Multiplicity"
266
- )
267
-
268
- predict_btn = gr.Button("Run OrbMol Prediction", variant="primary", size="lg")
269
-
270
- with gr.Column(variant="panel", elem_id="results", min_width=500):
271
- gr.Markdown("## OrbMol Prediction Results")
272
-
273
- results_output = gr.Textbox(
274
- label="Energy & Forces",
275
- lines=15,
276
- interactive=False,
277
- info="OrbMol energy and force predictions",
278
- )
279
-
280
- status_output = gr.Textbox(
281
- label="Status",
282
- interactive=False,
283
- max_lines=1,
284
- )
285
-
286
- gr.Markdown("### Examples")
287
- gr.Examples(
288
- examples=examples,
289
- inputs=[xyz_input, charge_input, spin_input],
290
- label="Click any example to load it",
291
- )
292
 
293
  predict_btn.click(
294
  predict_molecule,
@@ -299,43 +243,27 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
299
  with gr.Sidebar(open=True):
300
  gr.Markdown("## Learn more about OrbMol")
301
  with gr.Accordion("What is OrbMol?", open=False):
302
- gr.Markdown(
303
- "* OrbMol is a neural network potential for molecular property prediction\n"
304
- "* Built on the Orb-v3 architecture and trained on OMol25 dataset\n"
305
- "* Supports configurable charge and spin multiplicity"
306
- )
307
- with gr.Accordion("Model Disclaimers", open=False):
308
- gr.Markdown(
309
- "* Validate results for your use case\n"
310
- "* Consider limitations of the training level of theory"
311
- )
312
- with gr.Accordion("Open source packages", open=False):
313
- gr.Markdown(
314
- "* Model code: orbital-materials/orb-models\n"
315
- "* This demo uses ASE and Gradio"
316
- )
317
-
318
- # -------- Tab 2: MD + Visualización 3D --------
319
  with gr.Tab("Molecular Dynamics"):
320
  with gr.Row():
321
  with gr.Column(scale=2):
322
- with gr.Column(variant="panel"):
323
- xyz_input_md = gr.Textbox(
324
- label="XYZ Coordinates",
325
- lines=12,
326
- placeholder="Paste XYZ here",
327
- )
328
- charge_input_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
329
- spin_input_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
330
- steps_input = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps")
331
- temp_input = gr.Slider(value=300, minimum=10, maximum=1000, step=10, label="Temperature (K)")
332
- timestep_input = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
333
- run_md_btn = gr.Button("Run MD Simulation", variant="primary")
334
 
335
  with gr.Column(variant="panel", min_width=520):
336
- gr.Markdown("## MD Visualization")
337
  md_status = gr.Textbox(label="MD Status", lines=2, interactive=False)
338
- md_view = gr.HTML()
339
 
340
  run_md_btn.click(
341
  run_md,
 
10
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
11
  from ase.io.trajectory import Trajectory
12
 
13
+ # Intentar importar Molecule3D para vista 3D nativa
14
  try:
15
+ from gradio_molecule3d import Molecule3D
16
+ HAVE_MOL3D = True
17
  except Exception:
18
+ HAVE_MOL3D = False
19
 
20
  from orb_models.forcefield import pretrained
21
  from orb_models.forcefield.calculator import ORBCalculator
 
67
  forces = atoms.get_forces() # eV/Å
68
 
69
  lines = []
70
+ lines.append(f"Total Energy: {energy:.6f} eV\n")
 
71
  lines.append("Atomic Forces:")
72
  for i, f in enumerate(forces):
73
+ lines.append(f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Å")
74
+
 
75
  max_force = float(np.max(np.linalg.norm(forces, axis=1)))
76
+ lines.append(f"\nMax Force: {max_force:.4f} eV/Å")
 
77
 
78
  os.unlink(xyz_file)
79
  return "\n".join(lines), "Calculation completed with OrbMol"
 
83
 
84
 
85
  # =========================
86
+ # Trajectory → HTML 3D fallback
 
87
  # =========================
88
  def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
 
 
 
 
89
  traj = Trajectory(traj_path)
90
  xyz_frames = []
 
91
  for atoms in traj:
92
  symbols = atoms.get_chemical_symbols()
93
  coords = atoms.get_positions()
 
96
  parts.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
97
  xyz_frames.append("\n".join(parts))
98
 
 
99
  html = f"""
100
+ <div id="viewer_md" style="width:{width}px; height:{height}px;"></div>
101
  <script src="https://3dmol.org/build/3Dmol-min.js"></script>
102
  <script>
103
  (function() {{
104
+ var viewer = $3Dmol.createViewer("viewer_md", {{backgroundColor: 'white'}});
 
 
105
  var frames = {xyz_frames!r};
106
  var i = 0;
 
107
  function show(i) {{
108
  viewer.clear();
109
+ viewer.addModel(frames[i], "xyz");
110
+ viewer.setStyle({{}}, {{stick: {{}}}});
111
  viewer.zoomTo();
112
  viewer.render();
113
  }}
114
+ if(frames.length>0) show(0);
115
+ if(frames.length>1) setInterval(function(){{
116
+ i=(i+1)%frames.length; show(i);
117
+ }}, {int(interval_ms)});
 
 
 
 
 
 
 
 
118
  }})();
119
  </script>
120
  """
 
147
  opt = LBFGS(atoms, logfile=None)
148
  opt.run(fmax=0.05, steps=20)
149
 
 
150
  MaxwellBoltzmannDistribution(atoms, temperature_K=2*float(temperature))
151
 
 
152
  dyn = VelocityVerlet(atoms, timestep=float(timestep) * units.fs)
153
 
 
154
  tf = tempfile.NamedTemporaryFile(suffix=".traj", delete=False)
155
  tf.close()
156
  traj = Trajectory(tf.name, "w", atoms)
157
  dyn.attach(traj.write, interval=1)
 
 
158
  dyn.run(int(steps))
159
 
160
+ if HAVE_MOL3D:
161
+ # Mostrar último frame en Molecule3D
162
+ last = traj[-1]
163
+ mol_xyz = f"{len(last)}\nFinal frame\n"
164
+ for s, (x, y, z) in zip(last.get_chemical_symbols(), last.get_positions()):
165
+ mol_xyz += f"{s} {x:.6f} {y:.6f} {z:.6f}\n"
166
+ view = Molecule3D(value=mol_xyz, label="Final Frame (XYZ)")
167
+ else:
168
+ view = traj_to_html(tf.name)
169
 
 
170
  try:
171
  os.unlink(xyz_file)
172
  except Exception:
173
  pass
174
 
175
+ return f"MD completed: {int(steps)} steps at {int(temperature)} K", view
176
 
177
  except Exception as e:
178
  return f"Error during MD simulation: {str(e)}", ""
179
 
180
 
181
  # =========================
182
+ # Ejemplos
183
  # =========================
184
  examples = [
185
  ["""2
 
208
  # =========================
209
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
210
  with gr.Tabs():
211
+ # -------- Tab 1: Single Point --------
212
  with gr.Tab("Single Point Energy"):
213
  with gr.Row():
214
  with gr.Column(scale=2):
215
  with gr.Column(variant="panel"):
216
+ gr.Markdown("# OrbMol Demo - Quantum-Accurate Predictions")
217
+ gr.Markdown("OrbMol is a neural network potential trained on the OMol25 dataset.")
218
 
219
+ xyz_input = gr.Textbox(
220
+ label="XYZ Coordinates",
221
+ placeholder="Paste XYZ here...",
222
+ lines=12,
223
  )
224
 
225
+ with gr.Row():
226
+ charge_input = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
227
+ spin_input = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
228
+
229
+ predict_btn = gr.Button("Run OrbMol Prediction", variant="primary")
230
+
231
+ with gr.Column(variant="panel", min_width=500):
232
+ results_output = gr.Textbox(label="Energy & Forces", lines=15, interactive=False)
233
+ status_output = gr.Textbox(label="Status", interactive=False, max_lines=1)
234
+
235
+ gr.Examples(examples=examples, inputs=[xyz_input, charge_input, spin_input])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  predict_btn.click(
238
  predict_molecule,
 
243
  with gr.Sidebar(open=True):
244
  gr.Markdown("## Learn more about OrbMol")
245
  with gr.Accordion("What is OrbMol?", open=False):
246
+ gr.Markdown("* Neural network potential for molecules\n* Built on Orb-v3, trained on OMol25\n* Supports charge and spin")
247
+ with gr.Accordion("Benchmarks", open=False):
248
+ gr.Markdown("* <1 kcal/mol error on Wiggle150\n* Accurate protein–ligand binding energies\n* Stable MD on biomolecules >20k atoms")
249
+ with gr.Accordion("Disclaimers", open=False):
250
+ gr.Markdown("* Validate results for your use case\n* Training level of theory may limit accuracy")
251
+
252
+ # -------- Tab 2: MD --------
 
 
 
 
 
 
 
 
 
 
253
  with gr.Tab("Molecular Dynamics"):
254
  with gr.Row():
255
  with gr.Column(scale=2):
256
+ xyz_input_md = gr.Textbox(label="XYZ Coordinates", lines=12)
257
+ charge_input_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
258
+ spin_input_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
259
+ steps_input = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps")
260
+ temp_input = gr.Slider(value=300, minimum=10, maximum=1000, step=10, label="Temperature (K)")
261
+ timestep_input = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
262
+ run_md_btn = gr.Button("Run MD Simulation", variant="primary")
 
 
 
 
 
263
 
264
  with gr.Column(variant="panel", min_width=520):
 
265
  md_status = gr.Textbox(label="MD Status", lines=2, interactive=False)
266
+ md_view = gr.HTML() if not HAVE_MOL3D else Molecule3D(label="Trajectory Viewer")
267
 
268
  run_md_btn.click(
269
  run_md,