""" Simulaciones OrbMol con interfaz estilo Facebook FAIRChem """ from __future__ import annotations import os import tempfile from pathlib import Path import numpy as np import ase import ase.io from ase import units from ase.io.trajectory import Trajectory from ase.optimize import LBFGS from ase.filters import FrechetCellFilter from ase.md import MDLogger from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.md.verlet import VelocityVerlet from ase.md.nose_hoover_chain import NoseHooverChainNVT # OrbMol from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator # ----------------------------- # Global model # ----------------------------- _model_calc: ORBCalculator | None = None _current_task_name = None # Necesario para script de reproducción model_name = { "OMol": "orb_v3_conservative_omol", "OMat": "orb_v3_conservative_inf_omat", "OMol-Direct": "orb_v3_direct_omol"} def load_orbmol_model(task_name,device: str = "cpu", precision: str = "float32-high") -> ORBCalculator: """Load OrbMol calculator, switches only if another model is required.""" global _model_calc, _current_task_name if _model_calc is None or _current_task_name != task_name: if task_name == "OMol": orbff= pretrained.orb_v3_conservative_omol( device=device, precision=precision, ) elif task_name == "OMat": orbff = pretrained.orb_v3_conservative_inf_omat( device=device, precision=precision, ) elif task_name == "OMol-Direct": orbff = pretrained.orb_v3_direct_omol( device=device, precision=precision, ) else: raise ValueError(f"Unknown task_name: {task_name}") _model_calc = ORBCalculator(orbff, device=device) _current_task_name = task_name return _model_calc # ----------------------------- # FUNCIONES ESTILO FACEBOOK - COPIADAS EXACTAS # ----------------------------- def load_check_ase_atoms(structure_file): """COPIA EXACTA de Facebook - valida y carga estructura""" if not structure_file: raise Exception("You need an input structure file to run a simulation!") try: atoms = ase.io.read(structure_file) if not (all(atoms.pbc) or np.all(~np.array(atoms.pbc))): raise Exception( "Mixed PBC are not supported yet - please set PBC all True or False in your structure before uploading" ) if len(atoms) == 0: raise Exception("Error: Structure file contains no atoms.") if len(atoms) > 2000: raise Exception( f"Error: Structure file contains {len(atoms)}, which is more than 2000 atoms. Please use a smaller structure for this demo, or run this on a local machine!" ) # Centrar para visualización atoms.positions -= atoms.get_center_of_mass() cell_center = atoms.get_cell().sum(axis=0) / 2 atoms.positions += cell_center return atoms except Exception as e: raise Exception(f"Error loading structure with ASE: {str(e)}") def run_md_simulation( structure_file, num_steps, num_prerelax_steps, md_timestep, temperature_k, md_ensemble, task_name, total_charge=0, spin_multiplicity=1, explanation: str | None = None, oauth_token=None, # Ignorado progress=None, # Para compatibilidad Gradio ): """ MD simulation estilo Facebook pero con OrbMol """ temp_path = None traj_path = None md_log_path = None atoms = None try: # Cargar átomos (igual que Facebook) atoms = load_check_ase_atoms(structure_file) # Configurar charge y spin atoms.info["charge"] = total_charge atoms.info["spin"] = spin_multiplicity # AQUÍ EL CAMBIO: OrbMol en lugar de HFEndpointCalculator calc = load_orbmol_model(task_name) atoms.calc = calc # Progress callback si existe interval = 1 steps = [0] expected_steps = num_steps + num_prerelax_steps def update_progress(): steps[-1] += interval if progress: progress(steps[-1] / expected_steps) # Archivos temporales (igual que Facebook) with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as traj_f: traj_path = traj_f.name with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as log_f: md_log_path = log_f.name # Pre-relaxación (igual que Facebook) opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path) if progress: opt.attach(update_progress, interval=interval) opt.run(fmax=0.05, steps=num_prerelax_steps) # Velocidades (igual que Facebook - x2 después de relajación) MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_k * 2) # Integrador (igual que Facebook) if md_ensemble == "NVE": dyn = VelocityVerlet(atoms, timestep=md_timestep * units.fs) elif md_ensemble == "NVT": dyn = NoseHooverChainNVT( atoms, timestep=md_timestep * units.fs, temperature_K=temperature_k, tdamp=10 * md_timestep * units.fs, ) # Trajectory y logging (igual que Facebook) traj = Trajectory(traj_path, "a", atoms) dyn.attach(traj.write, interval=1) if progress: dyn.attach(update_progress, interval=interval) dyn.attach( MDLogger( dyn, atoms, md_log_path, header=True, stress=False, peratom=True, mode="a", ), interval=10, ) # Ejecutar MD dyn.run(num_steps) # Script de reproducción (estilo Facebook pero con OrbMol) reproduction_script = f""" import ase.io from ase.md.velocitydistribution import MaxwellBoltzmannDistribution from ase.md.verlet import VelocityVerlet from ase.optimize import LBFGS from ase.io.trajectory import Trajectory from ase.md import MDLogger from ase import units from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator # Read the atoms object from ASE read-able file atoms = ase.io.read('input_file.traj') # Set the total charge and spin multiplicity atoms.info["charge"] = {total_charge} atoms.info["spin"] = {spin_multiplicity} # Set up the OrbMol calculator orbff = pretrained.{model_name[task_name]}(device='cpu', precision='float32-high') atoms.calc = ORBCalculator(orbff, device='cpu') # Do a quick pre-relaxation to make sure the system is stable opt = LBFGS(atoms, trajectory="relaxation_output.traj") opt.run(fmax=0.05, steps={num_prerelax_steps}) # Initialize the velocity distribution; we set twice the temperature since we did a relaxation and # much of the kinetic energy will partition to the potential energy right away MaxwellBoltzmannDistribution(atoms, temperature_K={temperature_k}*2) # Initialize the integrator; NVE is shown here as an example dyn = VelocityVerlet(atoms, timestep={md_timestep} * units.fs) # Set up trajectory and MD logger dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode="w"), interval=10) traj = Trajectory("md_output.traj", "w", atoms) dyn.attach(traj.write, interval=1) # Run the simulation! dyn.run({num_steps}) """ # Leer log with open(md_log_path, "r") as md_log_file: md_log = md_log_file.read() if explanation is None: explanation = f"MD simulation of {len(atoms)} atoms for {num_steps} steps with a timestep of {md_timestep} fs at {temperature_k} K in the {md_ensemble} ensemble using OrbMol. You submitted this simulation, so I hope you know what you're looking for or what it means!" return traj_path, md_log, reproduction_script, explanation except Exception as e: raise Exception( f"Error running MD simulation: {str(e)}. Please try again or report this error." ) finally: # Limpieza (igual que Facebook) if temp_path and os.path.exists(temp_path): os.remove(temp_path) if md_log_path and os.path.exists(md_log_path): os.remove(md_log_path) if atoms is not None and getattr(atoms, "calc", None) is not None: atoms.calc = None def run_relaxation_simulation( structure_file, num_steps, fmax, task_name, total_charge: float = 0, spin_multiplicity: float = 1, relax_unit_cell=False, explanation: str | None = None, oauth_token=None, # Ignorado progress=None, ): """ Relaxation simulation estilo Facebook pero con OrbMol """ temp_path = None traj_path = None opt_log_path = None atoms = None try: # Cargar átomos (igual que Facebook) atoms = load_check_ase_atoms(structure_file) # Configurar charge y spin atoms.info["charge"] = total_charge atoms.info["spin"] = spin_multiplicity # AQUÍ EL CAMBIO: OrbMol en lugar de HFEndpointCalculator calc = load_orbmol_model(task_name) atoms.calc = calc # Archivos temporales with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as traj_f: traj_path = traj_f.name with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as log_f: opt_log_path = log_f.name # Optimizador (igual que Facebook) optimizer = LBFGS( FrechetCellFilter(atoms) if relax_unit_cell else atoms, trajectory=traj_path, logfile=opt_log_path, ) # Progress callback si existe if progress: interval = 1 steps = [0] def update_progress(steps): steps[-1] += interval progress(steps[-1] / num_steps) optimizer.attach(update_progress, interval=interval, steps=steps) # Ejecutar optimización optimizer.run(fmax=fmax, steps=num_steps) # Script de reproducción reproduction_script = f""" import ase.io from ase.optimize import LBFGS from ase.filters import FrechetCellFilter from orb_models.forcefield import pretrained from orb_models.forcefield.calculator import ORBCalculator # Read the atoms object from ASE read-able file atoms = ase.io.read('input_file.traj') # Set the total charge and spin multiplicity atoms.info["charge"] = {total_charge} atoms.info["spin"] = {spin_multiplicity} # Set up the OrbMol calculator orbff = pretrained.{model_name[task_name]}(device='cpu', precision='float32-high') atoms.calc = ORBCalculator(orbff, device='cpu') # Initialize the optimizer relax_unit_cell = {relax_unit_cell} optimizer = LBFGS(FrechetCellFilter(atoms) if relax_unit_cell else atoms, trajectory="relaxation_output.traj") # Run the optimization! optimizer.run(fmax={fmax}, steps={num_steps}) """ # Leer log with open(opt_log_path, "r") as opt_log_file: opt_log = opt_log_file.read() if explanation is None: explanation = f"Relaxation of {len(atoms)} atoms for {num_steps} steps with a force tolerance of {fmax} eV/Å using OrbMol. You submitted this simulation, so I hope you know what you're looking for or what it means!" return traj_path, opt_log, reproduction_script, explanation except Exception as e: raise Exception( f"Error running relaxation: {str(e)}. Please try again or report this error." ) finally: # Limpieza (igual que Facebook) if temp_path and os.path.exists(temp_path): os.remove(temp_path) if opt_log_path and os.path.exists(opt_log_path): os.remove(opt_log_path) if atoms is not None and getattr(atoms, "calc", None) is not None: atoms.calc = None # ----------------------------- # Helper functions para compatibilidad # ----------------------------- def atoms_to_xyz(atoms: ase.Atoms) -> str: """Convert ASE Atoms to an XYZ string for quick visualization.""" lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"] for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()): lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}") return "\n".join(lines) def last_frame_xyz_from_traj(traj_path: str | Path) -> str: """Read the last frame of an ASE .traj and return it as XYZ string.""" tr = Trajectory(str(traj_path)) last = tr[-1] return atoms_to_xyz(last) # Función de validación simplificada (sin autenticación) def validate_ase_atoms_and_login(structure_file, login_button_value="", oauth_token=None): """Validación simplificada - sin login UMA""" if not structure_file: return (False, False, "Missing input structure!") if isinstance(structure_file, dict): structure_file = structure_file["path"] try: atoms = ase.io.read(structure_file) if len(atoms) == 0: return (False, False, "No atoms in the structure file!") elif not (all(atoms.pbc) or np.all(~np.array(atoms.pbc))): return (False, False, f"Mixed PBC {atoms.pbc} not supported!") elif len(atoms) > 2000: return (False, False, f"Too many atoms ({len(atoms)}), max 2000!") else: return (True, True, "Structure loaded successfully - ready for OrbMol simulation!") except Exception as e: return (False, False, f"Failed to load structure: {str(e)}")