Spaces:
Running
Running
| """ | |
| Simulaciones OrbMol con interfaz estilo Facebook FAIRChem | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import tempfile | |
| from pathlib import Path | |
| import numpy as np | |
| import ase | |
| import ase.io | |
| from ase import units | |
| from ase.io.trajectory import Trajectory | |
| from ase.optimize import LBFGS | |
| from ase.filters import FrechetCellFilter | |
| from ase.md import MDLogger | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.md.nose_hoover_chain import NoseHooverChainNVT | |
| # OrbMol | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| # ----------------------------- | |
| # Global model | |
| # ----------------------------- | |
| _model_calc: ORBCalculator | None = None | |
| _current_task_name = None | |
| # Necesario para script de reproducci贸n | |
| model_name = { | |
| "OMol": "orb_v3_conservative_omol", | |
| "OMat": "orb_v3_conservative_inf_omat", | |
| "OMol-Direct": "orb_v3_direct_omol"} | |
| def load_orbmol_model(task_name,device: str = "cpu", precision: str = "float32-high") -> ORBCalculator: | |
| """Load OrbMol calculator, switches only if another model is required.""" | |
| global _model_calc, _current_task_name | |
| if _model_calc is None or _current_task_name != task_name: | |
| if task_name == "OMol": | |
| orbff= pretrained.orb_v3_conservative_omol( | |
| device=device, | |
| precision=precision, | |
| ) | |
| elif task_name == "OMat": | |
| orbff = pretrained.orb_v3_conservative_inf_omat( | |
| device=device, | |
| precision=precision, | |
| ) | |
| elif task_name == "OMol-Direct": | |
| orbff = pretrained.orb_v3_direct_omol( | |
| device=device, | |
| precision=precision, | |
| ) | |
| else: | |
| raise ValueError(f"Unknown task_name: {task_name}") | |
| _model_calc = ORBCalculator(orbff, device=device) | |
| _current_task_name = task_name | |
| return _model_calc | |
| # ----------------------------- | |
| # FUNCIONES ESTILO FACEBOOK - COPIADAS EXACTAS | |
| # ----------------------------- | |
| def load_check_ase_atoms(structure_file): | |
| """COPIA EXACTA de Facebook - valida y carga estructura""" | |
| if not structure_file: | |
| raise Exception("You need an input structure file to run a simulation!") | |
| try: | |
| atoms = ase.io.read(structure_file) | |
| if not (all(atoms.pbc) or np.all(~np.array(atoms.pbc))): | |
| raise Exception( | |
| "Mixed PBC are not supported yet - please set PBC all True or False in your structure before uploading" | |
| ) | |
| if len(atoms) == 0: | |
| raise Exception("Error: Structure file contains no atoms.") | |
| if len(atoms) > 2000: | |
| raise Exception( | |
| f"Error: Structure file contains {len(atoms)}, which is more than 2000 atoms. Please use a smaller structure for this demo, or run this on a local machine!" | |
| ) | |
| # Centrar para visualizaci贸n | |
| atoms.positions -= atoms.get_center_of_mass() | |
| cell_center = atoms.get_cell().sum(axis=0) / 2 | |
| atoms.positions += cell_center | |
| return atoms | |
| except Exception as e: | |
| raise Exception(f"Error loading structure with ASE: {str(e)}") | |
| def run_md_simulation( | |
| structure_file, | |
| num_steps, | |
| num_prerelax_steps, | |
| md_timestep, | |
| temperature_k, | |
| md_ensemble, | |
| task_name, | |
| total_charge=0, | |
| spin_multiplicity=1, | |
| explanation: str | None = None, | |
| oauth_token=None, # Ignorado | |
| progress=None, # Para compatibilidad Gradio | |
| ): | |
| """ | |
| MD simulation estilo Facebook pero con OrbMol | |
| """ | |
| temp_path = None | |
| traj_path = None | |
| md_log_path = None | |
| atoms = None | |
| try: | |
| # Cargar 谩tomos (igual que Facebook) | |
| atoms = load_check_ase_atoms(structure_file) | |
| # Configurar charge y spin | |
| atoms.info["charge"] = total_charge | |
| atoms.info["spin"] = spin_multiplicity | |
| # AQU脥 EL CAMBIO: OrbMol en lugar de HFEndpointCalculator | |
| calc = load_orbmol_model(task_name) | |
| atoms.calc = calc | |
| # Progress callback si existe | |
| interval = 1 | |
| steps = [0] | |
| expected_steps = num_steps + num_prerelax_steps | |
| def update_progress(): | |
| steps[-1] += interval | |
| if progress: | |
| progress(steps[-1] / expected_steps) | |
| # Archivos temporales (igual que Facebook) | |
| with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as traj_f: | |
| traj_path = traj_f.name | |
| with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as log_f: | |
| md_log_path = log_f.name | |
| # Pre-relaxaci贸n (igual que Facebook) | |
| opt = LBFGS(atoms, logfile=md_log_path, trajectory=traj_path) | |
| if progress: | |
| opt.attach(update_progress, interval=interval) | |
| opt.run(fmax=0.05, steps=num_prerelax_steps) | |
| # Velocidades (igual que Facebook - x2 despu茅s de relajaci贸n) | |
| MaxwellBoltzmannDistribution(atoms, temperature_K=temperature_k * 2) | |
| # Integrador (igual que Facebook) | |
| if md_ensemble == "NVE": | |
| dyn = VelocityVerlet(atoms, timestep=md_timestep * units.fs) | |
| elif md_ensemble == "NVT": | |
| dyn = NoseHooverChainNVT( | |
| atoms, | |
| timestep=md_timestep * units.fs, | |
| temperature_K=temperature_k, | |
| tdamp=10 * md_timestep * units.fs, | |
| ) | |
| # Trajectory y logging (igual que Facebook) | |
| traj = Trajectory(traj_path, "a", atoms) | |
| dyn.attach(traj.write, interval=1) | |
| if progress: | |
| dyn.attach(update_progress, interval=interval) | |
| dyn.attach( | |
| MDLogger( | |
| dyn, | |
| atoms, | |
| md_log_path, | |
| header=True, | |
| stress=False, | |
| peratom=True, | |
| mode="a", | |
| ), | |
| interval=10, | |
| ) | |
| # Ejecutar MD | |
| dyn.run(num_steps) | |
| # Script de reproducci贸n (estilo Facebook pero con OrbMol) | |
| reproduction_script = f""" | |
| import ase.io | |
| from ase.md.velocitydistribution import MaxwellBoltzmannDistribution | |
| from ase.md.verlet import VelocityVerlet | |
| from ase.optimize import LBFGS | |
| from ase.io.trajectory import Trajectory | |
| from ase.md import MDLogger | |
| from ase import units | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| # Read the atoms object from ASE read-able file | |
| atoms = ase.io.read('input_file.traj') | |
| # Set the total charge and spin multiplicity | |
| atoms.info["charge"] = {total_charge} | |
| atoms.info["spin"] = {spin_multiplicity} | |
| # Set up the OrbMol calculator | |
| orbff = pretrained.{model_name[task_name]}(device='cpu', precision='float32-high') | |
| atoms.calc = ORBCalculator(orbff, device='cpu') | |
| # Do a quick pre-relaxation to make sure the system is stable | |
| opt = LBFGS(atoms, trajectory="relaxation_output.traj") | |
| opt.run(fmax=0.05, steps={num_prerelax_steps}) | |
| # Initialize the velocity distribution; we set twice the temperature since we did a relaxation and | |
| # much of the kinetic energy will partition to the potential energy right away | |
| MaxwellBoltzmannDistribution(atoms, temperature_K={temperature_k}*2) | |
| # Initialize the integrator; NVE is shown here as an example | |
| dyn = VelocityVerlet(atoms, timestep={md_timestep} * units.fs) | |
| # Set up trajectory and MD logger | |
| dyn.attach(MDLogger(dyn, atoms, 'md.log', header=True, stress=False, peratom=True, mode="w"), interval=10) | |
| traj = Trajectory("md_output.traj", "w", atoms) | |
| dyn.attach(traj.write, interval=1) | |
| # Run the simulation! | |
| dyn.run({num_steps}) | |
| """ | |
| # Leer log | |
| with open(md_log_path, "r") as md_log_file: | |
| md_log = md_log_file.read() | |
| if explanation is None: | |
| explanation = f"MD simulation of {len(atoms)} atoms for {num_steps} steps with a timestep of {md_timestep} fs at {temperature_k} K in the {md_ensemble} ensemble using OrbMol. You submitted this simulation, so I hope you know what you're looking for or what it means!" | |
| return traj_path, md_log, reproduction_script, explanation | |
| except Exception as e: | |
| raise Exception( | |
| f"Error running MD simulation: {str(e)}. Please try again or report this error." | |
| ) | |
| finally: | |
| # Limpieza (igual que Facebook) | |
| if temp_path and os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| if md_log_path and os.path.exists(md_log_path): | |
| os.remove(md_log_path) | |
| if atoms is not None and getattr(atoms, "calc", None) is not None: | |
| atoms.calc = None | |
| def run_relaxation_simulation( | |
| structure_file, | |
| num_steps, | |
| fmax, | |
| task_name, | |
| total_charge: float = 0, | |
| spin_multiplicity: float = 1, | |
| relax_unit_cell=False, | |
| explanation: str | None = None, | |
| oauth_token=None, # Ignorado | |
| progress=None, | |
| ): | |
| """ | |
| Relaxation simulation estilo Facebook pero con OrbMol | |
| """ | |
| temp_path = None | |
| traj_path = None | |
| opt_log_path = None | |
| atoms = None | |
| try: | |
| # Cargar 谩tomos (igual que Facebook) | |
| atoms = load_check_ase_atoms(structure_file) | |
| # Configurar charge y spin | |
| atoms.info["charge"] = total_charge | |
| atoms.info["spin"] = spin_multiplicity | |
| # AQU脥 EL CAMBIO: OrbMol en lugar de HFEndpointCalculator | |
| calc = load_orbmol_model(task_name) | |
| atoms.calc = calc | |
| # Archivos temporales | |
| with tempfile.NamedTemporaryFile(suffix=".traj", delete=False) as traj_f: | |
| traj_path = traj_f.name | |
| with tempfile.NamedTemporaryFile(suffix=".log", delete=False) as log_f: | |
| opt_log_path = log_f.name | |
| # Optimizador (igual que Facebook) | |
| optimizer = LBFGS( | |
| FrechetCellFilter(atoms) if relax_unit_cell else atoms, | |
| trajectory=traj_path, | |
| logfile=opt_log_path, | |
| ) | |
| # Progress callback si existe | |
| if progress: | |
| interval = 1 | |
| steps = [0] | |
| def update_progress(steps): | |
| steps[-1] += interval | |
| progress(steps[-1] / num_steps) | |
| optimizer.attach(update_progress, interval=interval, steps=steps) | |
| # Ejecutar optimizaci贸n | |
| optimizer.run(fmax=fmax, steps=num_steps) | |
| # Script de reproducci贸n | |
| reproduction_script = f""" | |
| import ase.io | |
| from ase.optimize import LBFGS | |
| from ase.filters import FrechetCellFilter | |
| from orb_models.forcefield import pretrained | |
| from orb_models.forcefield.calculator import ORBCalculator | |
| # Read the atoms object from ASE read-able file | |
| atoms = ase.io.read('input_file.traj') | |
| # Set the total charge and spin multiplicity | |
| atoms.info["charge"] = {total_charge} | |
| atoms.info["spin"] = {spin_multiplicity} | |
| # Set up the OrbMol calculator | |
| orbff = pretrained.{model_name[task_name]}(device='cpu', precision='float32-high') | |
| atoms.calc = ORBCalculator(orbff, device='cpu') | |
| # Initialize the optimizer | |
| relax_unit_cell = {relax_unit_cell} | |
| optimizer = LBFGS(FrechetCellFilter(atoms) if relax_unit_cell else atoms, trajectory="relaxation_output.traj") | |
| # Run the optimization! | |
| optimizer.run(fmax={fmax}, steps={num_steps}) | |
| """ | |
| # Leer log | |
| with open(opt_log_path, "r") as opt_log_file: | |
| opt_log = opt_log_file.read() | |
| if explanation is None: | |
| explanation = f"Relaxation of {len(atoms)} atoms for {num_steps} steps with a force tolerance of {fmax} eV/脜 using OrbMol. You submitted this simulation, so I hope you know what you're looking for or what it means!" | |
| return traj_path, opt_log, reproduction_script, explanation | |
| except Exception as e: | |
| raise Exception( | |
| f"Error running relaxation: {str(e)}. Please try again or report this error." | |
| ) | |
| finally: | |
| # Limpieza (igual que Facebook) | |
| if temp_path and os.path.exists(temp_path): | |
| os.remove(temp_path) | |
| if opt_log_path and os.path.exists(opt_log_path): | |
| os.remove(opt_log_path) | |
| if atoms is not None and getattr(atoms, "calc", None) is not None: | |
| atoms.calc = None | |
| # ----------------------------- | |
| # Helper functions para compatibilidad | |
| # ----------------------------- | |
| def atoms_to_xyz(atoms: ase.Atoms) -> str: | |
| """Convert ASE Atoms to an XYZ string for quick visualization.""" | |
| lines = [str(len(atoms)), "generated by simulation_scripts_orbmol"] | |
| for s, (x, y, z) in zip(atoms.get_chemical_symbols(), atoms.get_positions()): | |
| lines.append(f"{s} {x:.6f} {y:.6f} {z:.6f}") | |
| return "\n".join(lines) | |
| def last_frame_xyz_from_traj(traj_path: str | Path) -> str: | |
| """Read the last frame of an ASE .traj and return it as XYZ string.""" | |
| tr = Trajectory(str(traj_path)) | |
| last = tr[-1] | |
| return atoms_to_xyz(last) | |
| # Funci贸n de validaci贸n simplificada (sin autenticaci贸n) | |
| def validate_ase_atoms_and_login(structure_file, login_button_value="", oauth_token=None): | |
| """Validaci贸n simplificada - sin login UMA""" | |
| if not structure_file: | |
| return (False, False, "Missing input structure!") | |
| if isinstance(structure_file, dict): | |
| structure_file = structure_file["path"] | |
| try: | |
| atoms = ase.io.read(structure_file) | |
| if len(atoms) == 0: | |
| return (False, False, "No atoms in the structure file!") | |
| elif not (all(atoms.pbc) or np.all(~np.array(atoms.pbc))): | |
| return (False, False, f"Mixed PBC {atoms.pbc} not supported!") | |
| elif len(atoms) > 2000: | |
| return (False, False, f"Too many atoms ({len(atoms)}), max 2000!") | |
| else: | |
| return (True, True, "Structure loaded successfully - ready for OrbMol simulation!") | |
| except Exception as e: | |
| return (False, False, f"Failed to load structure: {str(e)}") |