annabossler commited on
Commit
bdf6036
·
verified ·
1 Parent(s): c926c74

Create simulation_scripts_orbmol.py

Browse files
Files changed (1) hide show
  1. simulation_scripts_orbmol.py +305 -0
simulation_scripts_orbmol.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Minimal FAIRChem-like simulation helpers for OrbMol (local inference).
3
+
4
+ Usage from app.py:
5
+ from simulation_scripts_orbmol import (
6
+ load_orbmol_model,
7
+ validate_ase_atoms,
8
+ run_md_simulation,
9
+ run_relaxation_simulation,
10
+ atoms_to_xyz,
11
+ )
12
+ """
13
+
14
+ from __future__ import annotations
15
+ import os
16
+ import tempfile
17
+ from pathlib import Path
18
+
19
+ import numpy as np
20
+ import ase
21
+ import ase.io
22
+ from ase import units
23
+ from ase.io.trajectory import Trajectory
24
+ from ase.optimize import LBFGS
25
+ from ase.filters import FrechetCellFilter
26
+ from ase.md import MDLogger
27
+ from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
28
+ from ase.md.verlet import VelocityVerlet
29
+ from ase.md.nose_hoover_chain import NoseHooverChainNVT
30
+
31
+ # OrbMol
32
+ from orb_models.forcefield import pretrained
33
+ from orb_models.forcefield.calculator import ORBCalculator
34
+
35
+
36
+ # -----------------------------
37
+ # Global model (lazy singleton)
38
+ # -----------------------------
39
+ _model_calc: ORBCalculator | None = None
40
+
41
+ def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> ORBCalculator:
42
+ """
43
+ Load OrbMol once and reuse the same calculator.
44
+ """
45
+ global _model_calc
46
+ if _model_calc is None:
47
+ # NOTE: orb_v3_conservative_inf_omat is the conservative Orb family entry point
48
+ # used in OrbMol blog; works for molecules (aperiodic).
49
+ orbff = pretrained.orb_v3_conservative_inf_omat(
50
+ device=device,
51
+ precision=precision,
52
+ )
53
+ _model_calc = ORBCalculator(orbff, device=device)
54
+ return _model_calc
55
+
56
+
57
+ # -----------------------------
58
+ # Helpers
59
+ # -----------------------------
60
+ def atoms_to_xyz(atoms: ase.Atoms) -> str:
61
+ """
62
+ Convert ASE Atoms to an XYZ string for quick visualization.
63
+ """
64
+ lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"]
65
+ for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()):
66
+ lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
67
+ return "\n".join(lines)
68
+
69
+ def _center_atoms(atoms: ase.Atoms) -> None:
70
+ """
71
+ Center coordinates for nicer visualization (no effect on energies).
72
+ """
73
+ atoms.positions -= atoms.get_center_of_mass()
74
+ if atoms.cell is not None and atoms.cell.any():
75
+ cell_center = atoms.get_cell().sum(axis=0) / 2
76
+ atoms.positions += cell_center
77
+
78
+ def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase.Atoms:
79
+ """
80
+ Read & validate an ASE-compatible file; ensures uniform PBC and non-empty.
81
+ Returns a centered Atoms object.
82
+ """
83
+ if not structure_file:
84
+ raise ValueError("Missing input structure file path.")
85
+ atoms = ase.io.read(str(structure_file))
86
+
87
+ if len(atoms) == 0:
88
+ raise ValueError("No atoms found in the input structure.")
89
+
90
+ # Uniform PBC (all True or all False). Mixed PBC often breaks MD settings.
91
+ pbc = np.array(atoms.pbc, dtype=bool)
92
+ if not (pbc.all() or (~pbc).all()):
93
+ raise ValueError(f"Mixed PBC {atoms.pbc} not supported. Set all True or all False.")
94
+
95
+ if len(atoms) > max_atoms:
96
+ raise ValueError(
97
+ f"Structure has {len(atoms)} atoms, exceeding the limit of {max_atoms} for this demo."
98
+ )
99
+
100
+ _center_atoms(atoms)
101
+ return atoms
102
+
103
+
104
+ # -----------------------------
105
+ # Molecular Dynamics (MD)
106
+ # -----------------------------
107
+ def run_md_simulation(
108
+ structure_file: str | Path,
109
+ num_steps: int,
110
+ num_prerelax_steps: int,
111
+ md_timestep: float, # fs
112
+ temperature_k: float, # K
113
+ md_ensemble: str, # "NVE" or "NVT"
114
+ total_charge: int,
115
+ spin_multiplicity: int,
116
+ explanation: str | None = None,
117
+ ) -> tuple[str, str, str, str]:
118
+ """
119
+ Run short MD using OrbMol.
120
+ Returns: (traj_path, md_log_text, reproduction_script, explanation)
121
+ """
122
+ traj_path = None
123
+ md_log_path = None
124
+ atoms = None
125
+
126
+ try:
127
+ atoms = validate_ase_atoms(structure_file)
128
+
129
+ # Attach the calculator
130
+ calc = load_orbmol_model()
131
+ atoms.info["charge"] = int(total_charge)
132
+ atoms.info["spin"] = int(spin_multiplicity)
133
+ atoms.calc = calc
134
+
135
+ # Output files
136
+ with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf:
137
+ traj_path = tf.name
138
+ with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf:
139
+ md_log_path = lf.name
140
+
141
+ # Quick pre-relaxation to remove bad contacts
142
+ opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path)
143
+ if num_prerelax_steps > 0:
144
+ opt.run(fmax=0.05, steps=int(num_prerelax_steps))
145
+
146
+ # Initialize velocities (double T after relaxation as in UMA demo)
147
+ MaxwellBoltzmannDistribution(atoms, temperature_K=2 * float(temperature_k))
148
+
149
+ # Choose integrator/ensemble
150
+ if md_ensemble.upper() == "NVT":
151
+ dyn = NoseHooverChainNVT(
152
+ atoms,
153
+ timestep=float(md_timestep) * units.fs,
154
+ temperature_K=float(temperature_k),
155
+ tdamp=10 * float(md_timestep) * units.fs,
156
+ )
157
+ else:
158
+ dyn = VelocityVerlet(atoms, timestep=float(md_timestep) * units.fs)
159
+
160
+ # Attach trajectory writer and MD logger
161
+ traj = Trajectory(traj_path, "a", atoms)
162
+ dyn.attach(traj.write, interval=1)
163
+ dyn.attach(
164
+ MDLogger(
165
+ dyn, atoms, md_log_path, header=True, stress=False, peratom=True, mode="a"
166
+ ),
167
+ interval=10,
168
+ )
169
+
170
+ # Run MD
171
+ dyn.run(int(num_steps))
172
+
173
+ # Prepare reproduction script (using OrbMol locally)
174
+ reproduction_script = f"""\
175
+ import ase.io
176
+ from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
177
+ from ase.md.verlet import VelocityVerlet
178
+ from ase.md.nose_hoover_chain import NoseHooverChainNVT
179
+ from ase.optimize import LBFGS
180
+ from ase.io.trajectory import Trajectory
181
+ from ase.md import MDLogger
182
+ from ase import units
183
+ from orb_models.forcefield import pretrained
184
+ from orb_models.forcefield.calculator import ORBCalculator
185
+
186
+ atoms = ase.io.read('input_file.traj') # any ASE-readable file
187
+ atoms.info['charge'] = {int(total_charge)}
188
+ atoms.info['spin'] = {int(spin_multiplicity)}
189
+
190
+ orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
191
+ atoms.calc = ORBCalculator(orbff, device='cpu')
192
+
193
+ opt = LBFGS(atoms, trajectory='relaxation_output.traj')
194
+ opt.run(fmax=0.05, steps={int(num_prerelax_steps)})
195
+
196
+ MaxwellBoltzmannDistribution(atoms, temperature_K={float(temperature_k)}*2)
197
+
198
+ ensemble = '{md_ensemble.upper()}'
199
+ if ensemble == 'NVT':
200
+ dyn = NoseHooverChainNVT(atoms, timestep={float(md_timestep)}*units.fs,
201
+ temperature_K={float(temperature_k)}, tdamp=10*{float(md_timestep)}*units.fs)
202
+ else:
203
+ dyn = VelocityVerlet(atoms, timestep={float(md_timestep)}*units.fs)
204
+
205
+ dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode='w'), interval=10)
206
+ traj = Trajectory('md_output.traj', 'w', atoms)
207
+ dyn.attach(traj.write, interval=1)
208
+ dyn.run({int(num_steps)})
209
+ """
210
+
211
+ md_log_text = Path(md_log_path).read_text(encoding="utf-8", errors="ignore")
212
+
213
+ if explanation is None:
214
+ explanation = (
215
+ f"MD of {len(atoms)} atoms for {int(num_steps)} steps at "
216
+ f"{float(temperature_k)} K, timestep {float(md_timestep)} fs, "
217
+ f"ensemble {md_ensemble.upper()} (prerelax {int(num_prerelax_steps)} steps)."
218
+ )
219
+
220
+ return traj_path, md_log_text, reproduction_script, explanation
221
+
222
+ except Exception as e:
223
+ # Bubble up a clean error
224
+ raise RuntimeError(f"Error running MD: {e}") from e
225
+ finally:
226
+ # Detach calculator to free memory
227
+ if atoms is not None and getattr(atoms, "calc", None) is not None:
228
+ atoms.calc = None
229
+ if md_log_path and not os.path.exists(md_log_path):
230
+ md_log_path = None
231
+ # (No deletion of traj/log here; the UI needs the files.)
232
+
233
+
234
+ # -----------------------------
235
+ # Geometry optimization
236
+ # -----------------------------
237
+ def run_relaxation_simulation(
238
+ structure_file: str | Path,
239
+ num_steps: int,
240
+ fmax: float, # eV/Å
241
+ total_charge: int,
242
+ spin_multiplicity: int,
243
+ relax_unit_cell: bool,
244
+ explanation: str | None = None,
245
+ ) -> tuple[str, str, str, str]:
246
+ """
247
+ Run LBFGS relaxation (with optional cell relaxation).
248
+ Returns: (traj_path, log_text, reproduction_script, explanation)
249
+ """
250
+ traj_path = None
251
+ opt_log_path = None
252
+ atoms = None
253
+
254
+ try:
255
+ atoms = validate_ase_atoms(structure_file)
256
+
257
+ calc = load_orbmol_model()
258
+ atoms.info["charge"] = int(total_charge)
259
+ atoms.info["spin"] = int(spin_multiplicity)
260
+ atoms.calc = calc
261
+
262
+ with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf:
263
+ traj_path = tf.name
264
+ with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf:
265
+ opt_log_path = lf.name
266
+
267
+ subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms
268
+ optimizer = LBFGS(subject, trajectory=traj_path, logfile=opt_log_path)
269
+ optimizer.run(fmax=float(fmax), steps=int(num_steps))
270
+
271
+ reproduction_script = f"""\
272
+ import ase.io
273
+ from ase.optimize import LBFGS
274
+ from ase.filters import FrechetCellFilter
275
+ from orb_models.forcefield import pretrained
276
+ from orb_models.forcefield.calculator import ORBCalculator
277
+
278
+ atoms = ase.io.read('input_file.traj')
279
+ atoms.info['charge'] = {int(total_charge)}
280
+ atoms.info['spin'] = {int(spin_multiplicity)}
281
+
282
+ orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
283
+ atoms.calc = ORBCalculator(orbff, device='cpu')
284
+
285
+ relax_unit_cell = {bool(relax_unit_cell)}
286
+ subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms
287
+ optimizer = LBFGS(subject, trajectory='relaxation_output.traj')
288
+ optimizer.run(fmax={float(fmax)}, steps={int(num_steps)})
289
+ """
290
+
291
+ log_text = Path(opt_log_path).read_text(encoding="utf-8", errors="ignore")
292
+
293
+ if explanation is None:
294
+ explanation = (
295
+ f"Relaxation of {len(atoms)} atoms for up to {int(num_steps)} steps "
296
+ f"with fmax {float(fmax)} eV/Å (relax_cell={bool(relax_unit_cell)})."
297
+ )
298
+
299
+ return traj_path, log_text, reproduction_script, explanation
300
+
301
+ except Exception as e:
302
+ raise RuntimeError(f"Error running relaxation: {e}") from e
303
+ finally:
304
+ if atoms is not None and getattr(atoms, "calc", None) is not None:
305
+ atoms.calc = None