Spaces:
Running
Running
| # simulation_scripts_orbmol.py | |
| """ | |
| Minimal FAIRChem-like simulation helpers for OrbMol (local inference). | |
| Usage from app.py: | |
| from simulation_scripts_orbmol import ( | |
| load_orbmol_model, | |
| validate_ase_atoms, | |
| run_md_simulation, | |
| run_relaxation_simulation, | |
| atoms_to_xyz, | |
| last_frame_xyz_from_traj, | |
| ) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| from typing import Tuple | |
| import numpy as np | |
| import ase | |
| import ase.io | |
| from ase import units | |
| from ase.io.trajectory import Trajectory | |
| from ase.optimize import LBFGS | |
| from ase.filters import FrechetCellFilter | |
| from ase.md import MDLogger | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.md.nose_hoover_chain import NoseHooverChainNVT | |
| # OrbMol | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| # ----------------------------- | |
| # Global model (lazy singleton) | |
| # ----------------------------- | |
| _model_calc: ORBCalculator | None = None | |
| def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> ORBCalculator: | |
| """ | |
| Load OrbMol once and reuse the same calculator. | |
| """ | |
| global _model_calc | |
| if _model_calc is None: | |
| orbff = pretrained.orb_v3_conservative_inf_omat( | |
| device=device, | |
| precision=precision, | |
| ) | |
| _model_calc = ORBCalculator(orbff, device=device) | |
| return _model_calc | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def atoms_to_xyz(atoms: ase.Atoms) -> str: | |
| """ | |
| Convert ASE Atoms to an XYZ string for quick visualization. | |
| """ | |
| lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"] | |
| for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()): | |
| lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}") | |
| return "\n".join(lines) | |
| def last_frame_xyz_from_traj(traj_path: str | Path) -> str: | |
| """ | |
| Read the last frame of an ASE .traj and return it as XYZ string. | |
| """ | |
| tr = Trajectory(str(traj_path)) | |
| last = tr[-1] | |
| return atoms_to_xyz(last) | |
| def _center_atoms(atoms: ase.Atoms) -> None: | |
| """ | |
| Center coordinates for nicer visualization (no effect on energies). | |
| """ | |
| atoms.positions -= atoms.get_center_of_mass() | |
| if atoms.cell is not None and np.array(atoms.cell).any(): | |
| cell_center = atoms.get_cell().sum(axis=0) / 2 | |
| atoms.positions += cell_center | |
| def _string_looks_like_xyz(text: str) -> bool: | |
| """ | |
| Heurística simple para detectar si un input es un XYZ en texto. | |
| """ | |
| if not isinstance(text, str): | |
| return False | |
| lines = [l for l in text.strip().splitlines() if l.strip()] | |
| if len(lines) < 2: | |
| return False | |
| # primera línea: número de átomos | |
| try: | |
| _ = int(lines[0].split()[0]) | |
| return True | |
| except Exception: | |
| return False | |
| def _materialize_input_to_file(input_or_path: str | Path) -> Tuple[str, bool]: | |
| """ | |
| Devuelve (file_path, is_temp). Si input es un string XYZ, lo guarda a un .xyz temporal. | |
| Si es una ruta existente, la devuelve tal cual. | |
| """ | |
| # Caso: dict de Gradio File {'path': ...} | |
| if isinstance(input_or_path, dict) and "path" in input_or_path: | |
| p = input_or_path["path"] | |
| return p, False | |
| # Caso: Path o ruta existente | |
| if isinstance(input_or_path, (str, Path)) and os.path.exists(str(input_or_path)): | |
| return str(input_or_path), False | |
| # Caso: probablemente es un string XYZ | |
| if isinstance(input_or_path, str) and _string_looks_like_xyz(input_or_path): | |
| tf = tempfile.NamedTemporaryFile(mode="w", suffix=".xyz", delete=False) | |
| tf.write(input_or_path) | |
| tf.flush() | |
| tf.close() | |
| return tf.name, True | |
| raise ValueError("Input must be an existing file path or a valid XYZ string.") | |
| def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase.Atoms: | |
| """ | |
| Read & validate an ASE-compatible file; ensures uniform PBC and non-empty. | |
| Returns a centered Atoms object. | |
| """ | |
| if not structure_file: | |
| raise ValueError("Missing input structure file path.") | |
| atoms = ase.io.read(str(structure_file)) | |
| if len(atoms) == 0: | |
| raise ValueError("No atoms found in the input structure.") | |
| # Uniform PBC (all True or all False). Mixed PBC often breaks MD settings. | |
| pbc = np.array(atoms.pbc, dtype=bool) | |
| if not (pbc.all() or (~pbc).all()): | |
| raise ValueError(f"Mixed PBC {atoms.pbc} not supported. Set all True or all False.") | |
| if len(atoms) > max_atoms: | |
| raise ValueError( | |
| f"Structure has {len(atoms)} atoms, exceeding the limit of {max_atoms} for this demo." | |
| ) | |
| _center_atoms(atoms) | |
| return atoms | |
| # ----------------------------- | |
| # Molecular Dynamics (MD) | |
| # ----------------------------- | |
| def run_md_simulation( | |
| structure_file_or_xyz: str | Path, | |
| num_steps: int, | |
| num_prerelax_steps: int, | |
| md_timestep: float, # fs | |
| temperature_k: float, # K | |
| md_ensemble: str, # "NVE" or "NVT" | |
| total_charge: int, | |
| spin_multiplicity: int, | |
| explanation: str | None = None, | |
| ) -> tuple[str, str, str, str]: | |
| """ | |
| Run short MD using OrbMol. | |
| Accepts a path or an XYZ string. | |
| Returns: (traj_path, md_log_text, reproduction_script, explanation) | |
| """ | |
| traj_path = None | |
| md_log_path = None | |
| atoms = None | |
| realized_path = None | |
| is_temp = False | |
| try: | |
| # Permitir tanto ruta como string XYZ | |
| realized_path, is_temp = _materialize_input_to_file(structure_file_or_xyz) | |
| atoms = validate_ase_atoms(realized_path) | |
| # Attach the calculator | |
| calc = load_orbmol_model() | |
| atoms.info["charge"] = int(total_charge) | |
| atoms.info["spin"] = int(spin_multiplicity) | |
| atoms.calc = calc | |
| # Output files | |
| with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf: | |
| traj_path = tf.name | |
| with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf: | |
| md_log_path = lf.name | |
| # Quick pre-relaxation to remove bad contacts | |
| opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path) | |
| if int(num_prerelax_steps) > 0: | |
| opt.run(fmax=0.05, steps=int(num_prerelax_steps)) | |
| # Initialize velocities (double T after relaxation as in UMA demo) | |
| MaxwellBoltzmannDistribution(atoms, temperature_K=2 * float(temperature_k)) | |
| # Choose integrator/ensemble | |
| if md_ensemble.upper() == "NVT": | |
| dyn = NoseHooverChainNVT( | |
| atoms, | |
| timestep=float(md_timestep) * units.fs, | |
| temperature_K=float(temperature_k), | |
| tdamp=10 * float(md_timestep) * units.fs, | |
| ) | |
| else: | |
| dyn = VelocityVerlet(atoms, timestep=float(md_timestep) * units.fs) | |
| # Attach trajectory writer and MD logger | |
| traj = Trajectory(traj_path, "a", atoms) | |
| dyn.attach(traj.write, interval=1) | |
| dyn.attach( | |
| MDLogger( | |
| dyn, atoms, md_log_path, header=True, stress=False, peratom=True, mode="a" | |
| ), | |
| interval=10, | |
| ) | |
| # Run MD | |
| dyn.run(int(num_steps)) | |
| # Prepare reproduction script (using OrbMol locally) | |
| reproduction_script = f"""\ | |
| import ase.io | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.md.nose_hoover_chain import NoseHooverChainNVT | |
| from ase.optimize import LBFGS | |
| from ase.io.trajectory import Trajectory | |
| from ase.md import MDLogger | |
| from ase import units | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| atoms = ase.io.read('input_file.traj') # any ASE-readable file | |
| atoms.info['charge'] = {int(total_charge)} | |
| atoms.info['spin'] = {int(spin_multiplicity)} | |
| orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high') | |
| atoms.calc = ORBCalculator(orbff, device='cpu') | |
| opt = LBFGS(atoms, trajectory='relaxation_output.traj') | |
| opt.run(fmax=0.05, steps={int(num_prerelax_steps)}) | |
| MaxwellBoltzmannDistribution(atoms, temperature_K={float(temperature_k)}*2) | |
| ensemble = '{md_ensemble.upper()}' | |
| if ensemble == 'NVT': | |
| dyn = NoseHooverChainNVT(atoms, timestep={float(md_timestep)}*units.fs, | |
| temperature_K={float(temperature_k)}, tdamp=10*{float(md_timestep)}*units.fs) | |
| else: | |
| dyn = VelocityVerlet(atoms, timestep={float(md_timestep)}*units.fs) | |
| dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode='w'), interval=10) | |
| traj = Trajectory('md_output.traj', 'w', atoms) | |
| dyn.attach(traj.write, interval=1) | |
| dyn.run({int(num_steps)}) | |
| """ | |
| md_log_text = Path(md_log_path).read_text(encoding="utf-8", errors="ignore") | |
| if explanation is None: | |
| explanation = ( | |
| f"MD of {len(atoms)} atoms for {int(num_steps)} steps at " | |
| f"{float(temperature_k)} K, timestep {float(md_timestep)} fs, " | |
| f"ensemble {md_ensemble.upper()} (prerelax {int(num_prerelax_steps)} steps)." | |
| ) | |
| return traj_path, md_log_text, reproduction_script, explanation | |
| except Exception as e: | |
| raise RuntimeError(f"Error running MD: {e}") from e | |
| finally: | |
| # Detach calculator to free memory | |
| if atoms is not None and getattr(atoms, "calc", None) is not None: | |
| atoms.calc = None | |
| # Limpieza del .xyz temporal si lo generamos nosotros | |
| if is_temp and realized_path and os.path.exists(realized_path): | |
| try: | |
| os.remove(realized_path) | |
| except Exception: | |
| pass | |
| # ----------------------------- | |
| # Geometry optimization | |
| # ----------------------------- | |
| def run_relaxation_simulation( | |
| structure_file_or_xyz: str | Path, | |
| num_steps: int, | |
| fmax: float, # eV/Å | |
| total_charge: int, | |
| spin_multiplicity: int, | |
| relax_unit_cell: bool, | |
| explanation: str | None = None, | |
| ) -> tuple[str, str, str, str]: | |
| """ | |
| Run LBFGS relaxation (with optional cell relaxation). | |
| Accepts a path or an XYZ string. | |
| Returns: (traj_path, log_text, reproduction_script, explanation) | |
| """ | |
| traj_path = None | |
| opt_log_path = None | |
| atoms = None | |
| realized_path = None | |
| is_temp = False | |
| try: | |
| realized_path, is_temp = _materialize_input_to_file(structure_file_or_xyz) | |
| atoms = validate_ase_atoms(realized_path) | |
| calc = load_orbmol_model() | |
| atoms.info["charge"] = int(total_charge) | |
| atoms.info["spin"] = int(spin_multiplicity) | |
| atoms.calc = calc | |
| with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf: | |
| traj_path = tf.name | |
| with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf: | |
| opt_log_path = lf.name | |
| subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms | |
| optimizer = LBFGS(subject, trajectory=traj_path, logfile=opt_log_path) | |
| optimizer.run(fmax=float(fmax), steps=int(num_steps)) | |
| reproduction_script = f"""\ | |
| import ase.io | |
| from ase.optimize import LBFGS | |
| from ase.filters import FrechetCellFilter | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| atoms = ase.io.read('input_file.traj') | |
| atoms.info['charge'] = {int(total_charge)} | |
| atoms.info['spin'] = {int(spin_multiplicity)} | |
| orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high') | |
| atoms.calc = ORBCalculator(orbff, device='cpu') | |
| relax_unit_cell = {bool(relax_unit_cell)} | |
| subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms | |
| optimizer = LBFGS(subject, trajectory='relaxation_output.traj') | |
| optimizer.run(fmax={float(fmax)}, steps={int(num_steps)}) | |
| """ | |
| log_text = Path(opt_log_path).read_text(encoding="utf-8", errors="ignore") | |
| if explanation is None: | |
| explanation = ( | |
| f"Relaxation of {len(atoms)} atoms for up to {int(num_steps)} steps " | |
| f"with fmax {float(fmax)} eV/Å (relax_cell={bool(relax_unit_cell)})." | |
| ) | |
| return traj_path, log_text, reproduction_script, explanation | |
| except Exception as e: | |
| raise RuntimeError(f"Error running relaxation: {e}") from e | |
| finally: | |
| if atoms is not None and getattr(atoms, "calc", None) is not None: | |
| atoms.calc = None | |
| if is_temp and realized_path and os.path.exists(realized_path): | |
| try: | |
| os.remove(realized_path) | |
| except Exception: | |
| pass | |