Spaces:
Sleeping
Sleeping
| """ | |
| Minimal FAIRChem-like simulation helpers for OrbMol (local inference). | |
| Usage from app.py: | |
| from simulation_scripts_orbmol import ( | |
| load_orbmol_model, | |
| validate_ase_atoms, | |
| run_md_simulation, | |
| run_relaxation_simulation, | |
| atoms_to_xyz, | |
| ) | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import ase | |
| import ase.io | |
| from ase import units | |
| from ase.io.trajectory import Trajectory | |
| from ase.optimize import LBFGS | |
| from ase.filters import FrechetCellFilter | |
| from ase.md import MDLogger | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.md.nose_hoover_chain import NoseHooverChainNVT | |
| # OrbMol | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| # ----------------------------- | |
| # Global model (lazy singleton) | |
| # ----------------------------- | |
| _model_calc: ORBCalculator | None = None | |
| def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> ORBCalculator: | |
| """ | |
| Load OrbMol once and reuse the same calculator. | |
| """ | |
| global _model_calc | |
| if _model_calc is None: | |
| # NOTE: orb_v3_conservative_inf_omat is the conservative Orb family entry point | |
| # used in OrbMol blog; works for molecules (aperiodic). | |
| orbff = pretrained.orb_v3_conservative_inf_omat( | |
| device=device, | |
| precision=precision, | |
| ) | |
| _model_calc = ORBCalculator(orbff, device=device) | |
| return _model_calc | |
| # ----------------------------- | |
| # Helpers | |
| # ----------------------------- | |
| def atoms_to_xyz(atoms: ase.Atoms) -> str: | |
| """ | |
| Convert ASE Atoms to an XYZ string for quick visualization. | |
| """ | |
| lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"] | |
| for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()): | |
| lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}") | |
| return "\n".join(lines) | |
| def _center_atoms(atoms: ase.Atoms) -> None: | |
| """ | |
| Center coordinates for nicer visualization (no effect on energies). | |
| """ | |
| atoms.positions -= atoms.get_center_of_mass() | |
| if atoms.cell is not None and atoms.cell.any(): | |
| cell_center = atoms.get_cell().sum(axis=0) / 2 | |
| atoms.positions += cell_center | |
| def validate_ase_atoms(structure_file: str | Path, max_atoms: int = 5000) -> ase.Atoms: | |
| """ | |
| Read & validate an ASE-compatible file; ensures uniform PBC and non-empty. | |
| Returns a centered Atoms object. | |
| """ | |
| if not structure_file: | |
| raise ValueError("Missing input structure file path.") | |
| atoms = ase.io.read(str(structure_file)) | |
| if len(atoms) == 0: | |
| raise ValueError("No atoms found in the input structure.") | |
| # Uniform PBC (all True or all False). Mixed PBC often breaks MD settings. | |
| pbc = np.array(atoms.pbc, dtype=bool) | |
| if not (pbc.all() or (~pbc).all()): | |
| raise ValueError(f"Mixed PBC {atoms.pbc} not supported. Set all True or all False.") | |
| if len(atoms) > max_atoms: | |
| raise ValueError( | |
| f"Structure has {len(atoms)} atoms, exceeding the limit of {max_atoms} for this demo." | |
| ) | |
| _center_atoms(atoms) | |
| return atoms | |
| # ----------------------------- | |
| # Molecular Dynamics (MD) | |
| # ----------------------------- | |
| def run_md_simulation( | |
| structure_file: str | Path, | |
| num_steps: int, | |
| num_prerelax_steps: int, | |
| md_timestep: float, # fs | |
| temperature_k: float, # K | |
| md_ensemble: str, # "NVE" or "NVT" | |
| total_charge: int, | |
| spin_multiplicity: int, | |
| explanation: str | None = None, | |
| ) -> tuple[str, str, str, str]: | |
| """ | |
| Run short MD using OrbMol. | |
| Returns: (traj_path, md_log_text, reproduction_script, explanation) | |
| """ | |
| traj_path = None | |
| md_log_path = None | |
| atoms = None | |
| try: | |
| atoms = validate_ase_atoms(structure_file) | |
| # Attach the calculator | |
| calc = load_orbmol_model() | |
| atoms.info["charge"] = int(total_charge) | |
| atoms.info["spin"] = int(spin_multiplicity) | |
| atoms.calc = calc | |
| # Output files | |
| with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf: | |
| traj_path = tf.name | |
| with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf: | |
| md_log_path = lf.name | |
| # Quick pre-relaxation to remove bad contacts | |
| opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path) | |
| if num_prerelax_steps > 0: | |
| opt.run(fmax=0.05, steps=int(num_prerelax_steps)) | |
| # Initialize velocities (double T after relaxation as in UMA demo) | |
| MaxwellBoltzmannDistribution(atoms, temperature_K=2 * float(temperature_k)) | |
| # Choose integrator/ensemble | |
| if md_ensemble.upper() == "NVT": | |
| dyn = NoseHooverChainNVT( | |
| atoms, | |
| timestep=float(md_timestep) * units.fs, | |
| temperature_K=float(temperature_k), | |
| tdamp=10 * float(md_timestep) * units.fs, | |
| ) | |
| else: | |
| dyn = VelocityVerlet(atoms, timestep=float(md_timestep) * units.fs) | |
| # Attach trajectory writer and MD logger | |
| traj = Trajectory(traj_path, "a", atoms) | |
| dyn.attach(traj.write, interval=1) | |
| dyn.attach( | |
| MDLogger( | |
| dyn, atoms, md_log_path, header=True, stress=False, peratom=True, mode="a" | |
| ), | |
| interval=10, | |
| ) | |
| # Run MD | |
| dyn.run(int(num_steps)) | |
| # Prepare reproduction script (using OrbMol locally) | |
| reproduction_script = f"""\ | |
| import ase.io | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.md.nose_hoover_chain import NoseHooverChainNVT | |
| from ase.optimize import LBFGS | |
| from ase.io.trajectory import Trajectory | |
| from ase.md import MDLogger | |
| from ase import units | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| atoms = ase.io.read('input_file.traj') # any ASE-readable file | |
| atoms.info['charge'] = {int(total_charge)} | |
| atoms.info['spin'] = {int(spin_multiplicity)} | |
| orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high') | |
| atoms.calc = ORBCalculator(orbff, device='cpu') | |
| opt = LBFGS(atoms, trajectory='relaxation_output.traj') | |
| opt.run(fmax=0.05, steps={int(num_prerelax_steps)}) | |
| MaxwellBoltzmannDistribution(atoms, temperature_K={float(temperature_k)}*2) | |
| ensemble = '{md_ensemble.upper()}' | |
| if ensemble == 'NVT': | |
| dyn = NoseHooverChainNVT(atoms, timestep={float(md_timestep)}*units.fs, | |
| temperature_K={float(temperature_k)}, tdamp=10*{float(md_timestep)}*units.fs) | |
| else: | |
| dyn = VelocityVerlet(atoms, timestep={float(md_timestep)}*units.fs) | |
| dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode='w'), interval=10) | |
| traj = Trajectory('md_output.traj', 'w', atoms) | |
| dyn.attach(traj.write, interval=1) | |
| dyn.run({int(num_steps)}) | |
| """ | |
| md_log_text = Path(md_log_path).read_text(encoding="utf-8", errors="ignore") | |
| if explanation is None: | |
| explanation = ( | |
| f"MD of {len(atoms)} atoms for {int(num_steps)} steps at " | |
| f"{float(temperature_k)} K, timestep {float(md_timestep)} fs, " | |
| f"ensemble {md_ensemble.upper()} (prerelax {int(num_prerelax_steps)} steps)." | |
| ) | |
| return traj_path, md_log_text, reproduction_script, explanation | |
| except Exception as e: | |
| # Bubble up a clean error | |
| raise RuntimeError(f"Error running MD: {e}") from e | |
| finally: | |
| # Detach calculator to free memory | |
| if atoms is not None and getattr(atoms, "calc", None) is not None: | |
| atoms.calc = None | |
| if md_log_path and not os.path.exists(md_log_path): | |
| md_log_path = None | |
| # (No deletion of traj/log here; the UI needs the files.) | |
| # ----------------------------- | |
| # Geometry optimization | |
| # ----------------------------- | |
| def run_relaxation_simulation( | |
| structure_file: str | Path, | |
| num_steps: int, | |
| fmax: float, # eV/Å | |
| total_charge: int, | |
| spin_multiplicity: int, | |
| relax_unit_cell: bool, | |
| explanation: str | None = None, | |
| ) -> tuple[str, str, str, str]: | |
| """ | |
| Run LBFGS relaxation (with optional cell relaxation). | |
| Returns: (traj_path, log_text, reproduction_script, explanation) | |
| """ | |
| traj_path = None | |
| opt_log_path = None | |
| atoms = None | |
| try: | |
| atoms = validate_ase_atoms(structure_file) | |
| calc = load_orbmol_model() | |
| atoms.info["charge"] = int(total_charge) | |
| atoms.info["spin"] = int(spin_multiplicity) | |
| atoms.calc = calc | |
| with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as tf: | |
| traj_path = tf.name | |
| with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as lf: | |
| opt_log_path = lf.name | |
| subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms | |
| optimizer = LBFGS(subject, trajectory=traj_path, logfile=opt_log_path) | |
| optimizer.run(fmax=float(fmax), steps=int(num_steps)) | |
| reproduction_script = f"""\ | |
| import ase.io | |
| from ase.optimize import LBFGS | |
| from ase.filters import FrechetCellFilter | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| atoms = ase.io.read('input_file.traj') | |
| atoms.info['charge'] = {int(total_charge)} | |
| atoms.info['spin'] = {int(spin_multiplicity)} | |
| orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high') | |
| atoms.calc = ORBCalculator(orbff, device='cpu') | |
| relax_unit_cell = {bool(relax_unit_cell)} | |
| subject = FrechetCellFilter(atoms) if relax_unit_cell else atoms | |
| optimizer = LBFGS(subject, trajectory='relaxation_output.traj') | |
| optimizer.run(fmax={float(fmax)}, steps={int(num_steps)}) | |
| """ | |
| log_text = Path(opt_log_path).read_text(encoding="utf-8", errors="ignore") | |
| if explanation is None: | |
| explanation = ( | |
| f"Relaxation of {len(atoms)} atoms for up to {int(num_steps)} steps " | |
| f"with fmax {float(fmax)} eV/Å (relax_cell={bool(relax_unit_cell)})." | |
| ) | |
| return traj_path, log_text, reproduction_script, explanation | |
| except Exception as e: | |
| raise RuntimeError(f"Error running relaxation: {e}") from e | |
| finally: | |
| if atoms is not None and getattr(atoms, "calc", None) is not None: | |
| atoms.calc = None | |