annabossler commited on
Commit
e72bc1b
·
verified ·
1 Parent(s): 0fc2a04

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +240 -104
app.py CHANGED
@@ -1,5 +1,4 @@
1
  import gradio as gr
2
- import torch
3
  import numpy as np
4
  import tempfile
5
  import os
@@ -10,22 +9,28 @@ from ase.optimize import LBFGS
10
  from ase.md.verlet import VelocityVerlet
11
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
12
  from ase.io.trajectory import Trajectory
13
- import py3Dmol
 
 
 
 
 
 
14
 
15
  from orb_models.forcefield import pretrained
16
  from orb_models.forcefield.calculator import ORBCalculator
17
 
18
 
19
- # -----------------------------
20
- # Global OrbMol model
21
- # -----------------------------
22
  model_calc = None
23
 
24
  def load_orbmol_model():
25
- """Load OrbMol model once"""
26
  global model_calc
27
  if model_calc is None:
28
  try:
 
29
  orbff = pretrained.orb_v3_conservative_inf_omat(
30
  device="cpu",
31
  precision="float32-high"
@@ -33,14 +38,14 @@ def load_orbmol_model():
33
  model_calc = ORBCalculator(orbff, device="cpu")
34
  print("OrbMol model loaded successfully")
35
  except Exception as e:
36
- print(f"Error loading OrbMol model: {e}")
37
  model_calc = None
38
  return model_calc
39
 
40
 
41
- # -----------------------------
42
- # Single-point calculation
43
- # -----------------------------
44
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
45
  try:
46
  calc = load_orbmol_model()
@@ -50,7 +55,7 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
50
  if not xyz_content.strip():
51
  return "Error: Please enter XYZ coordinates", ""
52
 
53
- with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f:
54
  f.write(xyz_content)
55
  xyz_file = f.name
56
 
@@ -58,45 +63,90 @@ def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
58
  atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
59
  atoms.calc = calc
60
 
61
- energy = atoms.get_potential_energy()
62
- forces = atoms.get_forces()
63
 
64
- result = f"Total Energy: {energy:.6f} eV\n\nAtomic Forces:\n"
 
 
 
65
  for i, f in enumerate(forces):
66
- result += f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Å\n"
67
-
68
- max_force = np.max(np.linalg.norm(forces, axis=1))
69
- result += f"\nMax Force: {max_force:.4f} eV/Å"
 
 
70
 
71
  os.unlink(xyz_file)
72
- return result, "Calculation completed with OrbMol"
 
73
  except Exception as e:
74
  return f"Error during calculation: {str(e)}", "Error"
75
 
76
 
77
- # -----------------------------
78
- # Trajectory → HTML
79
- # -----------------------------
80
- def traj_to_html(traj_file):
81
- traj = Trajectory(traj_file)
82
- view = py3Dmol.view(width=400, height=400)
 
 
 
 
 
 
83
  for atoms in traj:
84
  symbols = atoms.get_chemical_symbols()
85
- xyz = atoms.get_positions()
86
- mol = ""
87
- for s, (x, y, z) in zip(symbols, xyz):
88
- mol += f"{s} {x} {y} {z}\n"
89
- view.addModel(mol, "xyz")
90
- view.setStyle({"stick": {}})
91
- view.zoomTo()
92
- view.animate({"loop": "forward"})
93
- return view.render()
94
-
95
-
96
- # -----------------------------
97
- # MD simulation
98
- # -----------------------------
99
- def run_md(xyz_content, charge=0, spin_multiplicity=1, steps=100, temperature=300, timestep=1.0):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  try:
101
  calc = load_orbmol_model()
102
  if calc is None:
@@ -105,7 +155,8 @@ def run_md(xyz_content, charge=0, spin_multiplicity=1, steps=100, temperature=30
105
  if not xyz_content.strip():
106
  return "Error: Please enter XYZ coordinates", ""
107
 
108
- with tempfile.NamedTemporaryFile(mode='w', suffix='.xyz', delete=False) as f:
 
109
  f.write(xyz_content)
110
  xyz_file = f.name
111
 
@@ -113,99 +164,184 @@ def run_md(xyz_content, charge=0, spin_multiplicity=1, steps=100, temperature=30
113
  atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
114
  atoms.calc = calc
115
 
116
- # Pre-relaxation
117
- opt = LBFGS(atoms)
118
  opt.run(fmax=0.05, steps=20)
119
 
120
- # Initialize velocities
121
- MaxwellBoltzmannDistribution(atoms, temperature_K=2 * temperature)
122
 
123
- # Run MD
124
- dyn = VelocityVerlet(atoms, timestep=timestep * units.fs)
125
- traj_file = tempfile.NamedTemporaryFile(suffix=".traj", delete=False)
126
- traj = Trajectory(traj_file.name, "w", atoms)
 
 
 
127
  dyn.attach(traj.write, interval=1)
128
- dyn.run(steps)
129
 
130
- html = traj_to_html(traj_file.name)
 
 
 
 
 
 
 
 
 
 
 
 
131
 
132
- os.unlink(xyz_file)
133
- return f"MD completed: {steps} steps at {temperature} K", html
134
  except Exception as e:
135
  return f"Error during MD simulation: {str(e)}", ""
136
 
137
 
138
- # -----------------------------
139
- # Examples
140
- # -----------------------------
141
  examples = [
142
  ["""2
143
  Hydrogen molecule
144
  H 0.0 0.0 0.0
145
  H 0.0 0.0 0.74""", 0, 1],
 
146
  ["""3
147
- Water molecule
148
  O 0.0000 0.0000 0.0000
149
  H 0.7571 0.0000 0.5864
150
  H -0.7571 0.0000 0.5864""", 0, 1],
 
151
  ["""5
152
  Methane
153
  C 0.0000 0.0000 0.0000
154
  H 1.0890 0.0000 0.0000
155
  H -0.3630 1.0267 0.0000
156
  H -0.3630 -0.5133 0.8887
157
- H -0.3630 -0.5133 -0.8887""", 0, 1]
158
  ]
159
 
160
 
161
- # -----------------------------
162
- # Gradio UI
163
- # -----------------------------
164
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
165
 
166
- with gr.Tab("Single Point Energy"):
167
- with gr.Row():
168
- with gr.Column(scale=2):
169
- with gr.Column(variant="panel"):
170
- gr.Markdown("# OrbMol Demo - Quantum-Accurate Molecular Predictions")
171
- xyz_input = gr.Textbox(label="XYZ Coordinates", lines=12, placeholder="Paste XYZ here")
172
- with gr.Row():
173
- charge_input = gr.Slider(value=0, label="Charge", minimum=-10, maximum=10, step=1)
174
- spin_input = gr.Slider(value=1, maximum=11, minimum=1, step=1, label="Spin Multiplicity")
175
- predict_btn = gr.Button("Run OrbMol Prediction", variant="primary", size="lg")
176
- with gr.Column(variant="panel", min_width=500):
177
- gr.Markdown("## Results")
178
- results_output = gr.Textbox(label="Energy & Forces", lines=15, interactive=False)
179
- status_output = gr.Textbox(label="Status", interactive=False, max_lines=1)
180
-
181
- gr.Examples(examples=examples, inputs=[xyz_input, charge_input, spin_input])
182
-
183
- predict_btn.click(
184
- predict_molecule,
185
- inputs=[xyz_input, charge_input, spin_input],
186
- outputs=[results_output, status_output]
187
- )
188
-
189
- with gr.Tab("Molecular Dynamics"):
190
- with gr.Row():
191
- with gr.Column(scale=2):
192
- xyz_input_md = gr.Textbox(label="XYZ Coordinates", lines=12)
193
- charge_input_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
194
- spin_input_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
195
- steps_input = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps")
196
- temp_input = gr.Slider(value=300, minimum=10, maximum=1000, step=10, label="Temperature (K)")
197
- timestep_input = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
198
- run_md_btn = gr.Button("Run MD Simulation", variant="primary")
199
- md_status = gr.Textbox(label="MD Status", lines=2)
200
- with gr.Column(scale=1, variant="panel"):
201
- gr.Markdown("## MD Visualization")
202
- md_view = gr.HTML()
203
-
204
- run_md_btn.click(
205
- run_md,
206
- inputs=[xyz_input_md, charge_input_md, spin_input_md, steps_input, temp_input, timestep_input],
207
- outputs=[md_status, md_view],
208
- )
 
 
 
209
 
210
 
211
  print("Starting OrbMol model loading...")
 
1
  import gradio as gr
 
2
  import numpy as np
3
  import tempfile
4
  import os
 
9
  from ase.md.verlet import VelocityVerlet
10
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
11
  from ase.io.trajectory import Trajectory
12
+
13
+ # py3Dmol es opcional; si no está instalado haremos fallback a 3Dmol.js
14
+ try:
15
+ import py3Dmol # noqa: F401
16
+ HAVE_PY3DMOL = True
17
+ except Exception:
18
+ HAVE_PY3DMOL = False
19
 
20
  from orb_models.forcefield import pretrained
21
  from orb_models.forcefield.calculator import ORBCalculator
22
 
23
 
24
+ # =========================
25
+ # OrbMol global model
26
+ # =========================
27
  model_calc = None
28
 
29
  def load_orbmol_model():
 
30
  global model_calc
31
  if model_calc is None:
32
  try:
33
+ print("Loading OrbMol model...")
34
  orbff = pretrained.orb_v3_conservative_inf_omat(
35
  device="cpu",
36
  precision="float32-high"
 
38
  model_calc = ORBCalculator(orbff, device="cpu")
39
  print("OrbMol model loaded successfully")
40
  except Exception as e:
41
+ print(f"Error loading model: {e}")
42
  model_calc = None
43
  return model_calc
44
 
45
 
46
+ # =========================
47
+ # Single-point (SPE)
48
+ # =========================
49
  def predict_molecule(xyz_content, charge=0, spin_multiplicity=1):
50
  try:
51
  calc = load_orbmol_model()
 
55
  if not xyz_content.strip():
56
  return "Error: Please enter XYZ coordinates", ""
57
 
58
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f:
59
  f.write(xyz_content)
60
  xyz_file = f.name
61
 
 
63
  atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
64
  atoms.calc = calc
65
 
66
+ energy = atoms.get_potential_energy() # eV
67
+ forces = atoms.get_forces() # eV/Å
68
 
69
+ lines = []
70
+ lines.append(f"Total Energy: {energy:.6f} eV")
71
+ lines.append("")
72
+ lines.append("Atomic Forces:")
73
  for i, f in enumerate(forces):
74
+ lines.append(
75
+ f"Atom {i+1}: [{f[0]:.4f}, {f[1]:.4f}, {f[2]:.4f}] eV/Å"
76
+ )
77
+ max_force = float(np.max(np.linalg.norm(forces, axis=1)))
78
+ lines.append("")
79
+ lines.append(f"Max Force: {max_force:.4f} eV/Å")
80
 
81
  os.unlink(xyz_file)
82
+ return "\n".join(lines), "Calculation completed with OrbMol"
83
+
84
  except Exception as e:
85
  return f"Error during calculation: {str(e)}", "Error"
86
 
87
 
88
+ # =========================
89
+ # Trajectory → HTML 3D
90
+ # (funciona con o sin py3Dmol)
91
+ # =========================
92
+ def traj_to_html(traj_path, width=520, height=520, interval_ms=200):
93
+ """
94
+ Genera un bloque HTML con 3Dmol para animar una lista de frames en XYZ.
95
+ Sin dependencias de JupyterLab.
96
+ """
97
+ traj = Trajectory(traj_path)
98
+ xyz_frames = []
99
+
100
  for atoms in traj:
101
  symbols = atoms.get_chemical_symbols()
102
+ coords = atoms.get_positions()
103
+ parts = [str(len(symbols)), "frame"]
104
+ for s, (x, y, z) in zip(symbols, coords):
105
+ parts.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
106
+ xyz_frames.append("\n".join(parts))
107
+
108
+ # HTML+JS autopackage
109
+ html = f"""
110
+ <div id="viewer_md" style="width:{width}px; height:{height}px; position:relative;"></div>
111
+ <script src="https://3dmol.org/build/3Dmol-min.js"></script>
112
+ <script>
113
+ (function() {{
114
+ var element = document.getElementById('viewer_md');
115
+ if (!element) return;
116
+ var viewer = $3Dmol.createViewer(element, {{backgroundColor: 'white'}});
117
+ var frames = {xyz_frames!r};
118
+ var i = 0;
119
+
120
+ function show(i) {{
121
+ viewer.clear();
122
+ viewer.addModel(frames[i], 'xyz');
123
+ viewer.setStyle({{}}, {{stick: {{}} }});
124
+ viewer.zoomTo();
125
+ viewer.render();
126
+ }}
127
+
128
+ if (frames.length === 0) {{
129
+ element.innerHTML = '<div style="padding:8px;color:#555">Empty trajectory</div>';
130
+ }} else if (frames.length === 1) {{
131
+ show(0);
132
+ }} else {{
133
+ show(0);
134
+ setInterval(function() {{
135
+ i = (i + 1) % frames.length;
136
+ show(i);
137
+ }}, {int(interval_ms)});
138
+ }}
139
+ }})();
140
+ </script>
141
+ """
142
+ return html
143
+
144
+
145
+ # =========================
146
+ # MD with OrbMol
147
+ # =========================
148
+ def run_md(xyz_content, charge=0, spin_multiplicity=1,
149
+ steps=100, temperature=300, timestep=1.0):
150
  try:
151
  calc = load_orbmol_model()
152
  if calc is None:
 
155
  if not xyz_content.strip():
156
  return "Error: Please enter XYZ coordinates", ""
157
 
158
+ # Leer estructura
159
+ with tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) as f:
160
  f.write(xyz_content)
161
  xyz_file = f.name
162
 
 
164
  atoms.info = {"charge": int(charge), "spin": int(spin_multiplicity)}
165
  atoms.calc = calc
166
 
167
+ # Pre-relajación ligera
168
+ opt = LBFGS(atoms, logfile=None)
169
  opt.run(fmax=0.05, steps=20)
170
 
171
+ # Velocidades (2*T por partición de energía)
172
+ MaxwellBoltzmannDistribution(atoms, temperature_K=2*float(temperature))
173
 
174
+ # Integrador NVE
175
+ dyn = VelocityVerlet(atoms, timestep=float(timestep) * units.fs)
176
+
177
+ # Trayectoria temporal
178
+ tf = tempfile.NamedTemporaryFile(suffix=".traj", delete=False)
179
+ tf.close()
180
+ traj = Trajectory(tf.name, "w", atoms)
181
  dyn.attach(traj.write, interval=1)
 
182
 
183
+ # Ejecutar MD
184
+ dyn.run(int(steps))
185
+
186
+ # Visualización 3D
187
+ html = traj_to_html(tf.name)
188
+
189
+ # Limpieza archivo xyz temporal
190
+ try:
191
+ os.unlink(xyz_file)
192
+ except Exception:
193
+ pass
194
+
195
+ return f"MD completed: {int(steps)} steps at {int(temperature)} K", html
196
 
 
 
197
  except Exception as e:
198
  return f"Error during MD simulation: {str(e)}", ""
199
 
200
 
201
+ # =========================
202
+ # Ejemplos (tu set)
203
+ # =========================
204
  examples = [
205
  ["""2
206
  Hydrogen molecule
207
  H 0.0 0.0 0.0
208
  H 0.0 0.0 0.74""", 0, 1],
209
+
210
  ["""3
211
+ Water molecule
212
  O 0.0000 0.0000 0.0000
213
  H 0.7571 0.0000 0.5864
214
  H -0.7571 0.0000 0.5864""", 0, 1],
215
+
216
  ["""5
217
  Methane
218
  C 0.0000 0.0000 0.0000
219
  H 1.0890 0.0000 0.0000
220
  H -0.3630 1.0267 0.0000
221
  H -0.3630 -0.5133 0.8887
222
+ H -0.3630 -0.5133 -0.8887""", 0, 1],
223
  ]
224
 
225
 
226
+ # =========================
227
+ # Gradio UI
228
+ # =========================
229
  with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
230
+ with gr.Tabs():
231
+ # -------- Tab 1: tu layout original --------
232
+ with gr.Tab("Single Point Energy"):
233
+ with gr.Row():
234
+ with gr.Column(scale=2):
235
+ with gr.Column(variant="panel"):
236
+ gr.Markdown("# OrbMol Demo - Quantum-Accurate Molecular Predictions")
237
+
238
+ gr.Markdown(
239
+ "OrbMol is a neural network potential trained on the OMol25 dataset. "
240
+ "Predict energies and forces with configurable charge and spin."
241
+ )
242
+
243
+ gr.Markdown("## Simulation inputs")
244
+ with gr.Column(variant="panel"):
245
+ gr.Markdown("### Input molecular structure")
246
+
247
+ xyz_input = gr.Textbox(
248
+ label="XYZ Coordinates",
249
+ placeholder=(
250
+ "3\nWater molecule\n"
251
+ "O 0.0000 0.0000 0.0000\n"
252
+ "H 0.7571 0.0000 0.5864\n"
253
+ "H -0.7571 0.0000 0.5864"
254
+ ),
255
+ lines=12,
256
+ info="Paste XYZ coordinates of your molecule here",
257
+ )
258
+
259
+ gr.Markdown("OMol-specific settings for total charge and spin multiplicity")
260
+ with gr.Row():
261
+ charge_input = gr.Slider(
262
+ value=0, label="Total Charge", minimum=-10, maximum=10, step=1
263
+ )
264
+ spin_input = gr.Slider(
265
+ value=1, maximum=11, minimum=1, step=1, label="Spin Multiplicity"
266
+ )
267
+
268
+ predict_btn = gr.Button("Run OrbMol Prediction", variant="primary", size="lg")
269
+
270
+ with gr.Column(variant="panel", elem_id="results", min_width=500):
271
+ gr.Markdown("## OrbMol Prediction Results")
272
+
273
+ results_output = gr.Textbox(
274
+ label="Energy & Forces",
275
+ lines=15,
276
+ interactive=False,
277
+ info="OrbMol energy and force predictions",
278
+ )
279
+
280
+ status_output = gr.Textbox(
281
+ label="Status",
282
+ interactive=False,
283
+ max_lines=1,
284
+ )
285
+
286
+ gr.Markdown("### Examples")
287
+ gr.Examples(
288
+ examples=examples,
289
+ inputs=[xyz_input, charge_input, spin_input],
290
+ label="Click any example to load it",
291
+ )
292
+
293
+ predict_btn.click(
294
+ predict_molecule,
295
+ inputs=[xyz_input, charge_input, spin_input],
296
+ outputs=[results_output, status_output],
297
+ )
298
 
299
+ with gr.Sidebar(open=True):
300
+ gr.Markdown("## Learn more about OrbMol")
301
+ with gr.Accordion("What is OrbMol?", open=False):
302
+ gr.Markdown(
303
+ "* OrbMol is a neural network potential for molecular property prediction\n"
304
+ "* Built on the Orb-v3 architecture and trained on OMol25 dataset\n"
305
+ "* Supports configurable charge and spin multiplicity"
306
+ )
307
+ with gr.Accordion("Model Disclaimers", open=False):
308
+ gr.Markdown(
309
+ "* Validate results for your use case\n"
310
+ "* Consider limitations of the training level of theory"
311
+ )
312
+ with gr.Accordion("Open source packages", open=False):
313
+ gr.Markdown(
314
+ "* Model code: orbital-materials/orb-models\n"
315
+ "* This demo uses ASE and Gradio"
316
+ )
317
+
318
+ # -------- Tab 2: MD + Visualización 3D --------
319
+ with gr.Tab("Molecular Dynamics"):
320
+ with gr.Row():
321
+ with gr.Column(scale=2):
322
+ with gr.Column(variant="panel"):
323
+ xyz_input_md = gr.Textbox(
324
+ label="XYZ Coordinates",
325
+ lines=12,
326
+ placeholder="Paste XYZ here",
327
+ )
328
+ charge_input_md = gr.Slider(value=0, minimum=-10, maximum=10, step=1, label="Charge")
329
+ spin_input_md = gr.Slider(value=1, minimum=1, maximum=11, step=1, label="Spin Multiplicity")
330
+ steps_input = gr.Slider(value=100, minimum=10, maximum=1000, step=10, label="Steps")
331
+ temp_input = gr.Slider(value=300, minimum=10, maximum=1000, step=10, label="Temperature (K)")
332
+ timestep_input = gr.Slider(value=1.0, minimum=0.1, maximum=5.0, step=0.1, label="Timestep (fs)")
333
+ run_md_btn = gr.Button("Run MD Simulation", variant="primary")
334
+
335
+ with gr.Column(variant="panel", min_width=520):
336
+ gr.Markdown("## MD Visualization")
337
+ md_status = gr.Textbox(label="MD Status", lines=2, interactive=False)
338
+ md_view = gr.HTML()
339
+
340
+ run_md_btn.click(
341
+ run_md,
342
+ inputs=[xyz_input_md, charge_input_md, spin_input_md, steps_input, temp_input, timestep_input],
343
+ outputs=[md_status, md_view],
344
+ )
345
 
346
 
347
  print("Starting OrbMol model loading...")