annabossler commited on
Commit
d3f816b
verified
1 Parent(s): 7441807

Update simulation_scripts_orbmol.py

Browse files
Files changed (1) hide show
  1. simulation_scripts_orbmol.py +233 -236
simulation_scripts_orbmol.py CHANGED
@@ -1,24 +1,11 @@
1
- # simulation_scripts_orbmol.py
2
  """
3
- Minimal FAIRChem-like simulation helpers for OrbMol (local inference).
4
-
5
- Usage from app.py:
6
- from simulation_scripts_orbmol import (
7
- load_orbmol_model,
8
- validate_ase_atoms,
9
- run_md_simulation,
10
- run_relaxation_simulation,
11
- atoms_to_xyz,
12
- last_frame_xyz_from_traj,
13
- )
14
  """
15
 
16
  from __future__ import annotations
17
  import os
18
  import tempfile
19
  from pathlib import Path
20
- from typing import Tuple
21
-
22
  import numpy as np
23
  import ase
24
  import ase.io
@@ -35,16 +22,13 @@ from ase.md.nose_hoover_chain import NoseHooverChainNVT
35
  from orb_models.forcefield import pretrained
36
  from orb_models.forcefield.calculator import ORBCalculator
37
 
38
-
39
  # -----------------------------
40
- # Global model (lazy singleton)
41
  # -----------------------------
42
  _model_calc: ORBCalculator | None = None
43
 
44
  def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> ORBCalculator:
45
- """
46
- Load OrbMol once and reuse the same calculator.
47
- """
48
  global _model_calc
49
  if _model_calc is None:
50
  orbff = pretrained.orb_v3_conservative_inf_omat(
@@ -54,182 +38,136 @@ def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> O
54
  _model_calc = ORBCalculator(orbff, device=device)
55
  return _model_calc
56
 
57
-
58
  # -----------------------------
59
- # Helpers
60
  # -----------------------------
61
- def atoms_to_xyz(atoms: ase.Atoms) -> str:
62
- """
63
- Convert ASE Atoms to an XYZ string for quick visualization.
64
- """
65
- lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"]
66
- for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()):
67
- lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
68
- return "\n".join(lines)
69
 
70
- def last_frame_xyz_from_traj(traj_path: str | Path) -> str:
71
- """
72
- Read the last frame of an ASE .traj and return it as XYZ string.
73
- """
74
- tr = Trajectory(str(traj_path))
75
- last = tr[-1]
76
- return atoms_to_xyz(last)
77
-
78
- def _center_atoms(atoms: ase.Atoms) -> None:
79
- """
80
- Center coordinates for nicer visualization (no effect on energies).
81
- """
82
- atoms.positions -= atoms.get_center_of_mass()
83
- if atoms.cell is not None and np.array(atoms.cell).any():
84
- cell_center = atoms.get_cell().sum(axis=0) / 2
85
- atoms.positions += cell_center
86
 
87
- def _string_looks_like_xyz(text: str) -> bool:
88
- """
89
- Heur铆stica simple para detectar si un input es un XYZ en texto.
90
- """
91
- if not isinstance(text, str):
92
- return False
93
- lines = [l for l in text.strip().splitlines() if l.strip()]
94
- if len(lines) < 2:
95
- return False
96
- # primera l铆nea: n煤mero de 谩tomos
97
  try:
98
- _ = int(lines[0].split()[0])
99
- return True
100
- except Exception:
101
- return False
102
 
103
- def _materialize_input_to_file(input_or_path: str | Path) -> Tuple[str, bool]:
104
- """
105
- Devuelve (file_path, is_temp). Si input es un string XYZ, lo guarda a un .xyz temporal.
106
- Si es una ruta existente, la devuelve tal cual.
107
- """
108
- # Caso: dict de Gradio File {'path': ...}
109
- if isinstance(input_or_path, dict) and "path" in input_or_path:
110
- p = input_or_path["path"]
111
- return p, False
112
-
113
- # Caso: Path o ruta existente
114
- if isinstance(input_or_path, (str, Path)) and os.path.exists(str(input_or_path)):
115
- return str(input_or_path), False
116
-
117
- # Caso: probablemente es un string XYZ
118
- if isinstance(input_or_path, str) and _string_looks_like_xyz(input_or_path):
119
- tf = tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False)
120
- tf.write(input_or_path)
121
- tf.flush()
122
- tf.close()
123
- return tf.name, True
124
-
125
- raise ValueError("Input must be an existing file path or a valid XYZ string.")
126
-
127
- def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase.Atoms:
128
- """
129
- Read & validate an ASE-compatible file; ensures uniform PBC and non-empty.
130
- Returns a centered Atoms object.
131
- """
132
- if not structure_file:
133
- raise ValueError("Missing input structure file path.")
134
- atoms = ase.io.read(str(structure_file))
135
-
136
- if len(atoms) == 0:
137
- raise ValueError("No atoms found in the input structure.")
138
 
139
- # Uniform PBC (all True or all False). Mixed PBC often breaks MD settings.
140
- pbc = np.array(atoms.pbc, dtype=bool)
141
- if not (pbc.all() or (~pbc).all()):
142
- raise ValueError(f"Mixed PBC {atoms.pbc} not supported. Set all True or all False.")
143
 
144
- if len(atoms) > max_atoms:
145
- raise ValueError(
146
- f"Structure has {len(atoms)} atoms, exceeding the limit of {max_atoms} for this demo."
147
- )
148
 
149
- _center_atoms(atoms)
150
- return atoms
 
 
151
 
 
 
 
152
 
153
- # -----------------------------
154
- # Molecular Dynamics (MD)
155
- # -----------------------------
156
  def run_md_simulation(
157
- structure_file_or_xyz: str | Path,
158
- num_steps: int,
159
- num_prerelax_steps: int,
160
- md_timestep: float, # fs
161
- temperature_k: float, # K
162
- md_ensemble: str, # "NVE" or "NVT"
163
- total_charge: int,
164
- spin_multiplicity: int,
 
165
  explanation: str | None = None,
166
- ) -> tuple[str, str, str, str]:
 
 
167
  """
168
- Run short MD using OrbMol.
169
- Accepts a path or an XYZ string.
170
- Returns: (traj_path, md_log_text, reproduction_script, explanation)
171
  """
 
172
  traj_path = None
173
  md_log_path = None
174
  atoms = None
175
- realized_path = None
176
- is_temp = False
177
 
178
  try:
179
- # Permitir tanto ruta como string XYZ
180
- realized_path, is_temp = _materialize_input_to_file(structure_file_or_xyz)
181
- atoms = validate_ase_atoms(realized_path)
 
 
 
182
 
183
- # Attach the calculator
184
  calc = load_orbmol_model()
185
- atoms.info["charge"] = int(total_charge)
186
- atoms.info["spin"] = int(spin_multiplicity)
187
  atoms.calc = calc
188
 
189
- # Output files
190
- with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf:
191
- traj_path = tf.name
192
- with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf:
193
- md_log_path = lf.name
 
 
 
 
194
 
195
- # Quick pre-relaxation to remove bad contacts
 
 
 
 
 
 
196
  opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path)
197
- if int(num_prerelax_steps) > 0:
198
- opt.run(fmax=0.05, steps=int(num_prerelax_steps))
 
199
 
200
- # Initialize velocities (double T after relaxation as in UMA demo)
201
- MaxwellBoltzmannDistribution(atoms, temperature_K=2 * float(temperature_k))
202
 
203
- # Choose integrator/ensemble
204
- if md_ensemble.upper() == "NVT":
 
 
205
  dyn = NoseHooverChainNVT(
206
  atoms,
207
- timestep=float(md_timestep) * units.fs,
208
- temperature_K=float(temperature_k),
209
- tdamp=10 * float(md_timestep) * units.fs,
210
  )
211
- else:
212
- dyn = VelocityVerlet(atoms, timestep=float(md_timestep) * units.fs)
213
 
214
- # Attach trajectory writer and MD logger
215
  traj = Trajectory(traj_path, "a", atoms)
216
  dyn.attach(traj.write, interval=1)
 
 
217
  dyn.attach(
218
  MDLogger(
219
- dyn, atoms, md_log_path, header=True, stress=False, peratom=True, mode="a"
 
 
 
 
 
 
220
  ),
221
  interval=10,
222
  )
223
 
224
- # Run MD
225
- dyn.run(int(num_steps))
226
 
227
- # Prepare reproduction script (using OrbMol locally)
228
- reproduction_script = f"""\
229
  import ase.io
230
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
231
  from ase.md.verlet import VelocityVerlet
232
- from ase.md.nose_hoover_chain import NoseHooverChainNVT
233
  from ase.optimize import LBFGS
234
  from ase.io.trajectory import Trajectory
235
  from ase.md import MDLogger
@@ -237,134 +175,193 @@ from ase import units
237
  from orb_models.forcefield import pretrained
238
  from orb_models.forcefield.calculator import ORBCalculator
239
 
240
- atoms = ase.io.read('input_file.traj') # any ASE-readable file
241
- atoms.info['charge'] = {int(total_charge)}
242
- atoms.info['spin'] = {int(spin_multiplicity)}
243
-
 
 
244
  orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
245
  atoms.calc = ORBCalculator(orbff, device='cpu')
246
-
247
- opt = LBFGS(atoms, trajectory='relaxation_output.traj')
248
- opt.run(fmax=0.05, steps={int(num_prerelax_steps)})
249
-
250
- MaxwellBoltzmannDistribution(atoms, temperature_K={float(temperature_k)}*2)
251
-
252
- ensemble = '{md_ensemble.upper()}'
253
- if ensemble == 'NVT':
254
- dyn = NoseHooverChainNVT(atoms, timestep={float(md_timestep)}*units.fs,
255
- temperature_K={float(temperature_k)}, tdamp=10*{float(md_timestep)}*units.fs)
256
- else:
257
- dyn = VelocityVerlet(atoms, timestep={float(md_timestep)}*units.fs)
258
-
259
- dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode='w'), interval=10)
260
- traj = Trajectory('md_output.traj', 'w', atoms)
261
  dyn.attach(traj.write, interval=1)
262
- dyn.run({int(num_steps)})
263
- """
 
264
 
265
- md_log_text = Path(md_log_path).read_text(encoding="utf-8", errors="ignore")
 
 
266
 
267
  if explanation is None:
268
- explanation = (
269
- f"MD of {len(atoms)} atoms for {int(num_steps)} steps at "
270
- f"{float(temperature_k)} K, timestep {float(md_timestep)} fs, "
271
- f"ensemble {md_ensemble.upper()} (prerelax {int(num_prerelax_steps)} steps)."
272
- )
273
 
274
- return traj_path, md_log_text, reproduction_script, explanation
275
 
276
  except Exception as e:
277
- raise RuntimeError(f"Error running MD: {e}") from e
 
 
278
  finally:
279
- # Detach calculator to free memory
 
 
 
 
 
 
280
  if atoms is not None and getattr(atoms, "calc", None) is not None:
281
  atoms.calc = None
282
- # Limpieza del .xyz temporal si lo generamos nosotros
283
- if is_temp and realized_path and os.path.exists(realized_path):
284
- try:
285
- os.remove(realized_path)
286
- except Exception:
287
- pass
288
-
289
 
290
- # -----------------------------
291
- # Geometry optimization
292
- # -----------------------------
293
  def run_relaxation_simulation(
294
- structure_file_or_xyz: str | Path,
295
- num_steps: int,
296
- fmax: float, # eV/脜
297
- total_charge: int,
298
- spin_multiplicity: int,
299
- relax_unit_cell: bool,
 
300
  explanation: str | None = None,
301
- ) -> tuple[str, str, str, str]:
 
 
302
  """
303
- Run LBFGS relaxation (with optional cell relaxation).
304
- Accepts a path or an XYZ string.
305
- Returns: (traj_path, log_text, reproduction_script, explanation)
306
  """
 
307
  traj_path = None
308
  opt_log_path = None
309
  atoms = None
310
- realized_path = None
311
- is_temp = False
312
 
313
  try:
314
- realized_path, is_temp = _materialize_input_to_file(structure_file_or_xyz)
315
- atoms = validate_ase_atoms(realized_path)
 
 
 
 
316
 
 
317
  calc = load_orbmol_model()
318
- atoms.info["charge"] = int(total_charge)
319
- atoms.info["spin"] = int(spin_multiplicity)
320
  atoms.calc = calc
321
 
322
- with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf:
323
- traj_path = tf.name
324
- with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf:
325
- opt_log_path = lf.name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
- subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms
328
- optimizer = LBFGS(subject, trajectory=traj_path, logfile=opt_log_path)
329
- optimizer.run(fmax=float(fmax), steps=int(num_steps))
330
 
331
- reproduction_script = f"""\
 
332
  import ase.io
333
  from ase.optimize import LBFGS
334
  from ase.filters import FrechetCellFilter
335
  from orb_models.forcefield import pretrained
336
  from orb_models.forcefield.calculator import ORBCalculator
337
 
 
338
  atoms = ase.io.read('input_file.traj')
339
- atoms.info['charge'] = {int(total_charge)}
340
- atoms.info['spin'] = {int(spin_multiplicity)}
341
-
 
342
  orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
343
  atoms.calc = ORBCalculator(orbff, device='cpu')
 
 
 
 
 
 
344
 
345
- relax_unit_cell = {bool(relax_unit_cell)}
346
- subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms
347
- optimizer = LBFGS(subject, trajectory='relaxation_output.traj')
348
- optimizer.run(fmax={float(fmax)}, steps={int(num_steps)})
349
- """
350
-
351
- log_text = Path(opt_log_path).read_text(encoding="utf-8", errors="ignore")
352
 
353
  if explanation is None:
354
- explanation = (
355
- f"Relaxation of {len(atoms)} atoms for up to {int(num_steps)} steps "
356
- f"with fmax {float(fmax)} eV/脜 (relax_cell={bool(relax_unit_cell)})."
357
- )
358
 
359
- return traj_path, log_text, reproduction_script, explanation
360
 
361
  except Exception as e:
362
- raise RuntimeError(f"Error running relaxation: {e}") from e
 
 
363
  finally:
 
 
 
 
 
364
  if atoms is not None and getattr(atoms, "calc", None) is not None:
365
  atoms.calc = None
366
- if is_temp and realized_path and os.path.exists(realized_path):
367
- try:
368
- os.remove(realized_path)
369
- except Exception:
370
- pass
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  """
2
+ Simulaciones OrbMol con interfaz estilo Facebook FAIRChem
 
 
 
 
 
 
 
 
 
 
3
  """
4
 
5
  from __future__ import annotations
6
  import os
7
  import tempfile
8
  from pathlib import Path
 
 
9
  import numpy as np
10
  import ase
11
  import ase.io
 
22
  from orb_models.forcefield import pretrained
23
  from orb_models.forcefield.calculator import ORBCalculator
24
 
 
25
  # -----------------------------
26
+ # Global model
27
  # -----------------------------
28
  _model_calc: ORBCalculator | None = None
29
 
30
  def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> ORBCalculator:
31
+ """Load OrbMol once and reuse the same calculator."""
 
 
32
  global _model_calc
33
  if _model_calc is None:
34
  orbff = pretrained.orb_v3_conservative_inf_omat(
 
38
  _model_calc = ORBCalculator(orbff, device=device)
39
  return _model_calc
40
 
 
41
  # -----------------------------
42
+ # FUNCIONES ESTILO FACEBOOK - COPIADAS EXACTAS
43
  # -----------------------------
 
 
 
 
 
 
 
 
44
 
45
+ def load_check_ase_atoms(structure_file):
46
+ """COPIA EXACTA de Facebook - valida y carga estructura"""
47
+ if not structure_file:
48
+ raise Exception("You need an input structure file to run a simulation!")
 
 
 
 
 
 
 
 
 
 
 
 
49
 
 
 
 
 
 
 
 
 
 
 
50
  try:
51
+ atoms = ase.io.read(structure_file)
 
 
 
52
 
53
+ if not (all(atoms.pbc) or np.all(~np.array(atoms.pbc))):
54
+ raise Exception(
55
+ "Mixed PBC are not supported yet - please set PBC all True or False in your structure before uploading"
56
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
+ if len(atoms) == 0:
59
+ raise Exception("Error: Structure file contains no atoms.")
 
 
60
 
61
+ if len(atoms) > 2000:
62
+ raise Exception(
63
+ f"Error: Structure file contains {len(atoms)}, which is more than 2000 atoms. Please use a smaller structure for this demo, or run this on a local machine!"
64
+ )
65
 
66
+ # Centrar para visualizaci贸n
67
+ atoms.positions -= atoms.get_center_of_mass()
68
+ cell_center = atoms.get_cell().sum(axis=0) / 2
69
+ atoms.positions += cell_center
70
 
71
+ return atoms
72
+ except Exception as e:
73
+ raise Exception(f"Error loading structure with ASE: {str(e)}")
74
 
 
 
 
75
  def run_md_simulation(
76
+ structure_file,
77
+ num_steps,
78
+ num_prerelax_steps,
79
+ md_timestep,
80
+ temperature_k,
81
+ md_ensemble,
82
+ task_name="OMol", # Siempre OMol para OrbMol
83
+ total_charge=0,
84
+ spin_multiplicity=1,
85
  explanation: str | None = None,
86
+ oauth_token=None, # Ignorado
87
+ progress=None, # Para compatibilidad Gradio
88
+ ):
89
  """
90
+ MD simulation estilo Facebook pero con OrbMol
 
 
91
  """
92
+ temp_path = None
93
  traj_path = None
94
  md_log_path = None
95
  atoms = None
 
 
96
 
97
  try:
98
+ # Cargar 谩tomos (igual que Facebook)
99
+ atoms = load_check_ase_atoms(structure_file)
100
+
101
+ # Configurar charge y spin
102
+ atoms.info["charge"] = total_charge
103
+ atoms.info["spin"] = spin_multiplicity
104
 
105
+ # AQU脥 EL CAMBIO: OrbMol en lugar de HFEndpointCalculator
106
  calc = load_orbmol_model()
 
 
107
  atoms.calc = calc
108
 
109
+ # Progress callback si existe
110
+ interval = 1
111
+ steps = [0]
112
+ expected_steps = num_steps + num_prerelax_steps
113
+
114
+ def update_progress():
115
+ steps[-1] += interval
116
+ if progress:
117
+ progress(steps[-1] / expected_steps)
118
 
119
+ # Archivos temporales (igual que Facebook)
120
+ with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as traj_f:
121
+ traj_path = traj_f.name
122
+ with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as log_f:
123
+ md_log_path = log_f.name
124
+
125
+ # Pre-relaxaci贸n (igual que Facebook)
126
  opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path)
127
+ if progress:
128
+ opt.attach(update_progress, interval=interval)
129
+ opt.run(fmax=0.05, steps=num_prerelax_steps)
130
 
131
+ # Velocidades (igual que Facebook - x2 despu茅s de relajaci贸n)
132
+ MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_k * 2)
133
 
134
+ # Integrador (igual que Facebook)
135
+ if md_ensemble == "NVE":
136
+ dyn = VelocityVerlet(atoms, timestep=md_timestep * units.fs)
137
+ elif md_ensemble == "NVT":
138
  dyn = NoseHooverChainNVT(
139
  atoms,
140
+ timestep=md_timestep * units.fs,
141
+ temperature_K=temperature_k,
142
+ tdamp=10 * md_timestep * units.fs,
143
  )
 
 
144
 
145
+ # Trajectory y logging (igual que Facebook)
146
  traj = Trajectory(traj_path, "a", atoms)
147
  dyn.attach(traj.write, interval=1)
148
+ if progress:
149
+ dyn.attach(update_progress, interval=interval)
150
  dyn.attach(
151
  MDLogger(
152
+ dyn,
153
+ atoms,
154
+ md_log_path,
155
+ header=True,
156
+ stress=False,
157
+ peratom=True,
158
+ mode="a",
159
  ),
160
  interval=10,
161
  )
162
 
163
+ # Ejecutar MD
164
+ dyn.run(num_steps)
165
 
166
+ # Script de reproducci贸n (estilo Facebook pero con OrbMol)
167
+ reproduction_script = f"""
168
  import ase.io
169
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
170
  from ase.md.verlet import VelocityVerlet
 
171
  from ase.optimize import LBFGS
172
  from ase.io.trajectory import Trajectory
173
  from ase.md import MDLogger
 
175
  from orb_models.forcefield import pretrained
176
  from orb_models.forcefield.calculator import ORBCalculator
177
 
178
+ # Read the atoms object from ASE read-able file
179
+ atoms = ase.io.read('input_file.traj')
180
+ # Set the total charge and spin multiplicity
181
+ atoms.info["charge"] = {total_charge}
182
+ atoms.info["spin"] = {spin_multiplicity}
183
+ # Set up the OrbMol calculator
184
  orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
185
  atoms.calc = ORBCalculator(orbff, device='cpu')
186
+ # Do a quick pre-relaxation to make sure the system is stable
187
+ opt = LBFGS(atoms, trajectory="relaxation_output.traj")
188
+ opt.run(fmax=0.05, steps={num_prerelax_steps})
189
+ # Initialize the velocity distribution; we set twice the temperature since we did a relaxation and
190
+ # much of the kinetic energy will partition to the potential energy right away
191
+ MaxwellBoltzmannDistribution(atoms, temperature_K={temperature_k}*2)
192
+ # Initialize the integrator; NVE is shown here as an example
193
+ dyn = VelocityVerlet(atoms, timestep={md_timestep} * units.fs)
194
+ # Set up trajectory and MD logger
195
+ dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode="w"), interval=10)
196
+ traj = Trajectory("md_output.traj", "w", atoms)
 
 
 
 
197
  dyn.attach(traj.write, interval=1)
198
+ # Run the simulation!
199
+ dyn.run({num_steps})
200
+ """
201
 
202
+ # Leer log
203
+ with open(md_log_path, "r") as md_log_file:
204
+ md_log = md_log_file.read()
205
 
206
  if explanation is None:
207
+ explanation = f"MD simulation of {len(atoms)} atoms for {num_steps} steps with a timestep of {md_timestep} fs at {temperature_k} K in the {md_ensemble} ensemble using OrbMol. You submitted this simulation, so I hope you know what you're looking for or what it means!"
 
 
 
 
208
 
209
+ return traj_path, md_log, reproduction_script, explanation
210
 
211
  except Exception as e:
212
+ raise Exception(
213
+ f"Error running MD simulation: {str(e)}. Please try again or report this error."
214
+ )
215
  finally:
216
+ # Limpieza (igual que Facebook)
217
+ if temp_path and os.path.exists(temp_path):
218
+ os.remove(temp_path)
219
+
220
+ if md_log_path and os.path.exists(md_log_path):
221
+ os.remove(md_log_path)
222
+
223
  if atoms is not None and getattr(atoms, "calc", None) is not None:
224
  atoms.calc = None
 
 
 
 
 
 
 
225
 
 
 
 
226
  def run_relaxation_simulation(
227
+ structure_file,
228
+ num_steps,
229
+ fmax,
230
+ task_name="OMol", # Siempre OMol para OrbMol
231
+ total_charge: float = 0,
232
+ spin_multiplicity: float = 1,
233
+ relax_unit_cell=False,
234
  explanation: str | None = None,
235
+ oauth_token=None, # Ignorado
236
+ progress=None,
237
+ ):
238
  """
239
+ Relaxation simulation estilo Facebook pero con OrbMol
 
 
240
  """
241
+ temp_path = None
242
  traj_path = None
243
  opt_log_path = None
244
  atoms = None
 
 
245
 
246
  try:
247
+ # Cargar 谩tomos (igual que Facebook)
248
+ atoms = load_check_ase_atoms(structure_file)
249
+
250
+ # Configurar charge y spin
251
+ atoms.info["charge"] = total_charge
252
+ atoms.info["spin"] = spin_multiplicity
253
 
254
+ # AQU脥 EL CAMBIO: OrbMol en lugar de HFEndpointCalculator
255
  calc = load_orbmol_model()
 
 
256
  atoms.calc = calc
257
 
258
+ # Archivos temporales
259
+ with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as traj_f:
260
+ traj_path = traj_f.name
261
+ with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as log_f:
262
+ opt_log_path = log_f.name
263
+
264
+ # Optimizador (igual que Facebook)
265
+ optimizer = LBFGS(
266
+ FrechetCellFilter(atoms) if relax_unit_cell else atoms,
267
+ trajectory=traj_path,
268
+ logfile=opt_log_path,
269
+ )
270
+
271
+ # Progress callback si existe
272
+ if progress:
273
+ interval = 1
274
+ steps = [0]
275
+ def update_progress(steps):
276
+ steps[-1] += interval
277
+ progress(steps[-1] / num_steps)
278
+ optimizer.attach(update_progress, interval=interval, steps=steps)
279
 
280
+ # Ejecutar optimizaci贸n
281
+ optimizer.run(fmax=fmax, steps=num_steps)
 
282
 
283
+ # Script de reproducci贸n
284
+ reproduction_script = f"""
285
  import ase.io
286
  from ase.optimize import LBFGS
287
  from ase.filters import FrechetCellFilter
288
  from orb_models.forcefield import pretrained
289
  from orb_models.forcefield.calculator import ORBCalculator
290
 
291
+ # Read the atoms object from ASE read-able file
292
  atoms = ase.io.read('input_file.traj')
293
+ # Set the total charge and spin multiplicity
294
+ atoms.info["charge"] = {total_charge}
295
+ atoms.info["spin"] = {spin_multiplicity}
296
+ # Set up the OrbMol calculator
297
  orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
298
  atoms.calc = ORBCalculator(orbff, device='cpu')
299
+ # Initialize the optimizer
300
+ relax_unit_cell = {relax_unit_cell}
301
+ optimizer = LBFGS(FrechetCellFilter(atoms) if relax_unit_cell else atoms, trajectory="relaxation_output.traj")
302
+ # Run the optimization!
303
+ optimizer.run(fmax={fmax}, steps={num_steps})
304
+ """
305
 
306
+ # Leer log
307
+ with open(opt_log_path, "r") as opt_log_file:
308
+ opt_log = opt_log_file.read()
 
 
 
 
309
 
310
  if explanation is None:
311
+ explanation = f"Relaxation of {len(atoms)} atoms for {num_steps} steps with a force tolerance of {fmax} eV/脜 using OrbMol. You submitted this simulation, so I hope you know what you're looking for or what it means!"
 
 
 
312
 
313
+ return traj_path, opt_log, reproduction_script, explanation
314
 
315
  except Exception as e:
316
+ raise Exception(
317
+ f"Error running relaxation: {str(e)}. Please try again or report this error."
318
+ )
319
  finally:
320
+ # Limpieza (igual que Facebook)
321
+ if temp_path and os.path.exists(temp_path):
322
+ os.remove(temp_path)
323
+ if opt_log_path and os.path.exists(opt_log_path):
324
+ os.remove(opt_log_path)
325
  if atoms is not None and getattr(atoms, "calc", None) is not None:
326
  atoms.calc = None
327
+
328
+ # -----------------------------
329
+ # Helper functions para compatibilidad
330
+ # -----------------------------
331
+
332
+ def atoms_to_xyz(atoms: ase.Atoms) -> str:
333
+ """Convert ASE Atoms to an XYZ string for quick visualization."""
334
+ lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"]
335
+ for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()):
336
+ lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}")
337
+ return "\n".join(lines)
338
+
339
+ def last_frame_xyz_from_traj(traj_path: str | Path) -> str:
340
+ """Read the last frame of an ASE .traj and return it as XYZ string."""
341
+ tr = Trajectory(str(traj_path))
342
+ last = tr[-1]
343
+ return atoms_to_xyz(last)
344
+
345
+ # Funci贸n de validaci贸n simplificada (sin autenticaci贸n)
346
+ def validate_ase_atoms_and_login(structure_file, login_button_value="", oauth_token=None):
347
+ """Validaci贸n simplificada - sin login UMA"""
348
+ if not structure_file:
349
+ return (False, False, "Missing input structure!")
350
+
351
+ if isinstance(structure_file, dict):
352
+ structure_file = structure_file["path"]
353
+
354
+ try:
355
+ atoms = ase.io.read(structure_file)
356
+
357
+ if len(atoms) == 0:
358
+ return (False, False, "No atoms in the structure file!")
359
+ elif not (all(atoms.pbc) or np.all(~np.array(atoms.pbc))):
360
+ return (False, False, f"Mixed PBC {atoms.pbc} not supported!")
361
+ elif len(atoms) > 2000:
362
+ return (False, False, f"Too many atoms ({len(atoms)}), max 2000!")
363
+ else:
364
+ return (True, True, "Structure loaded successfully - ready for OrbMol simulation!")
365
+
366
+ except Exception as e:
367
+ return (False, False, f"Failed to load structure: {str(e)}")