annabossler Nekkrad commited on
Commit
94990ac
·
verified ·
1 Parent(s): 3f17e4b

multiple_model (#2)

Browse files

- Added support for multiple pretrained models (OMol,OMat,OMol-Direct) (ba742371795a1f34a0014d9deb3f3700df9d9292)


Co-authored-by: Davide Sarpa <Nekkrad@users.noreply.huggingface.co>

Files changed (2) hide show
  1. app.py +81 -23
  2. simulation_scripts_orbmol.py +37 -15
app.py CHANGED
@@ -6,6 +6,7 @@ import gradio as gr
6
  from ase.io import read, write
7
  from ase.io.trajectory import Trajectory
8
  from gradio_molecule3d import Molecule3D
 
9
  import hashlib
10
  import shutil
11
 
@@ -110,22 +111,22 @@ def prepare_input_for_viewer(structure_file):
110
  from orb_models.forcefield import pretrained
111
  from orb_models.forcefield.calculator import ORBCalculator
112
 
113
- _MODEL_CALC = None
114
- def _load_orbmol_calc():
115
- global _MODEL_CALC
116
- if _MODEL_CALC is None:
117
- orbff = pretrained.orb_v3_conservative_inf_omat(
118
- device="cpu", precision="float32-high"
119
- )
120
- _MODEL_CALC = ORBCalculator(orbff, device="cpu")
121
- return _MODEL_CALC
122
-
123
- def predict_molecule(structure_file, charge=0, spin_multiplicity=1):
124
  """
125
  Single Point Energy + fuerzas (OrbMol). Acepta archivos subidos.
126
  """
127
  try:
128
- calc = _load_orbmol_calc()
129
  if not structure_file:
130
  return "Error: Please upload a structure file", "Error", None
131
 
@@ -151,7 +152,7 @@ def predict_molecule(structure_file, charge=0, spin_multiplicity=1):
151
  # Preparar PDB para visualización
152
  pdb_file = prepare_input_for_viewer(file_path)
153
 
154
- return "\n".join(lines), "Calculation completed with OrbMol", pdb_file
155
  except Exception as e:
156
  import traceback; traceback.print_exc()
157
  return f"Error during calculation: {e}", "Error", None
@@ -163,7 +164,7 @@ from simulation_scripts_orbmol import (
163
  )
164
 
165
  # ==== Wrappers con debug y Molecule3D ====
166
- def md_wrapper(structure_file, charge, spin, steps, tempK, timestep_fs, ensemble):
167
  try:
168
  if not structure_file:
169
  return ("Error: Please upload a structure file", None, "", "", "", None)
@@ -178,6 +179,7 @@ def md_wrapper(structure_file, charge, spin, steps, tempK, timestep_fs, ensemble
178
  float(timestep_fs),
179
  float(tempK),
180
  "NVT" if ensemble == "NVT" else "NVE",
 
181
  int(charge),
182
  int(spin),
183
  )
@@ -194,7 +196,7 @@ def md_wrapper(structure_file, charge, spin, steps, tempK, timestep_fs, ensemble
194
  import traceback; traceback.print_exc()
195
  return (f"Error: {e}", None, "", "", "", None)
196
 
197
- def relax_wrapper(structure_file, steps, fmax, charge, spin, relax_cell):
198
  try:
199
  if not structure_file:
200
  return ("Error: Please upload a structure file", None, "", "", "", None)
@@ -206,6 +208,7 @@ def relax_wrapper(structure_file, steps, fmax, charge, spin, relax_cell):
206
  file_path,
207
  int(steps),
208
  float(fmax),
 
209
  int(charge),
210
  int(spin),
211
  bool(relax_cell),
@@ -238,9 +241,17 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
238
  file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
239
  file_count="single"
240
  )
 
 
 
 
 
 
 
241
  with gr.Row():
242
  charge_input = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
243
  spin_input = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
 
244
  run_spe = gr.Button("Run OrbMol Prediction", variant="primary")
245
 
246
  with gr.Column(variant="panel", min_width=500):
@@ -256,8 +267,18 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
256
  }
257
  ]
258
  )
259
-
260
- run_spe.click(predict_molecule, [xyz_input, charge_input, spin_input], [spe_out, spe_status, spe_viewer])
 
 
 
 
 
 
 
 
 
 
261
 
262
  # -------- MD --------
263
  with gr.Tab("Molecular Dynamics"):
@@ -271,6 +292,13 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
271
  file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
272
  file_count="single"
273
  )
 
 
 
 
 
 
 
274
  with gr.Row():
275
  charge_md = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
276
  spin_md = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
@@ -303,10 +331,21 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
303
  md_log = gr.Textbox(label="Log", interactive=False, lines=15, max_lines=25)
304
  md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
305
  md_explain = gr.Markdown()
306
-
 
 
 
 
 
 
 
 
 
 
 
307
  run_md_btn.click(
308
  md_wrapper,
309
- inputs=[xyz_md, charge_md, spin_md, steps_md, temp_md, timestep_md, ensemble_md],
310
  outputs=[md_status, md_traj, md_log, md_script, md_explain, md_viewer],
311
  )
312
 
@@ -322,8 +361,16 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
322
  file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
323
  file_count="single"
324
  )
325
- steps_rlx = gr.Slider(minimum=1, maximum=2000, value=300, step=1, label="Max Steps")
326
- fmax_rlx = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Fmax (eV/Å)")
 
 
 
 
 
 
 
 
327
  with gr.Row():
328
  charge_rlx = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
329
  spin_rlx = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin")
@@ -352,14 +399,25 @@ with gr.Blocks(theme=gr.themes.Ocean(), title="OrbMol Demo") as demo:
352
  rlx_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
353
  rlx_explain = gr.Markdown()
354
 
 
 
 
 
 
 
 
 
 
 
 
355
  run_rlx_btn.click(
356
  relax_wrapper,
357
- inputs=[xyz_rlx, steps_rlx, fmax_rlx, charge_rlx, spin_rlx, relax_cell],
358
  outputs=[rlx_status, rlx_traj, rlx_log, rlx_script, rlx_explain, rlx_viewer],
359
  )
360
 
361
  print("Starting OrbMol model loading…")
362
- _ = _load_orbmol_calc()
363
 
364
  if __name__ == "__main__":
365
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
 
6
  from ase.io import read, write
7
  from ase.io.trajectory import Trajectory
8
  from gradio_molecule3d import Molecule3D
9
+ from simulation_scripts_orbmol import load_orbmol_model # CORREGIDO
10
  import hashlib
11
  import shutil
12
 
 
111
  from orb_models.forcefield import pretrained
112
  from orb_models.forcefield.calculator import ORBCalculator
113
 
114
+ #_MODEL_CALC = None
115
+ #def _load_orbmol_calc():
116
+ # global _MODEL_CALC
117
+ # if _MODEL_CALC is None:
118
+ # orbff = pretrained.orb_v3_conservative_inf_omat(
119
+ # device="cpu", precision="float32-high"
120
+ # )
121
+ # _MODEL_CALC = ORBCalculator(orbff, device="cpu")
122
+ # return _MODEL_CALC
123
+
124
+ def predict_molecule(structure_file,task_name,charge=0, spin_multiplicity=1):
125
  """
126
  Single Point Energy + fuerzas (OrbMol). Acepta archivos subidos.
127
  """
128
  try:
129
+ calc = load_orbmol_model(task_name)
130
  if not structure_file:
131
  return "Error: Please upload a structure file", "Error", None
132
 
 
152
  # Preparar PDB para visualización
153
  pdb_file = prepare_input_for_viewer(file_path)
154
 
155
+ return "\n".join(lines), f"Calculation completed with {task_name}", pdb_file
156
  except Exception as e:
157
  import traceback; traceback.print_exc()
158
  return f"Error during calculation: {e}", "Error", None
 
164
  )
165
 
166
  # ==== Wrappers con debug y Molecule3D ====
167
+ def md_wrapper(structure_file,task_name, charge, spin, steps, tempK, timestep_fs, ensemble):
168
  try:
169
  if not structure_file:
170
  return ("Error: Please upload a structure file", None, "", "", "", None)
 
179
  float(timestep_fs),
180
  float(tempK),
181
  "NVT" if ensemble == "NVT" else "NVE",
182
+ str(task_name),
183
  int(charge),
184
  int(spin),
185
  )
 
196
  import traceback; traceback.print_exc()
197
  return (f"Error: {e}", None, "", "", "", None)
198
 
199
+ def relax_wrapper(structure_file,task_name, steps, fmax, charge, spin, relax_cell):
200
  try:
201
  if not structure_file:
202
  return ("Error: Please upload a structure file", None, "", "", "", None)
 
208
  file_path,
209
  int(steps),
210
  float(fmax),
211
+ str(task_name),
212
  int(charge),
213
  int(spin),
214
  bool(relax_cell),
 
241
  file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
242
  file_count="single"
243
  )
244
+ with gr.Row():
245
+ task_name = gr.Radio(
246
+ ["OMol", "OMat", "OMol-Direct"],
247
+ value="OMol",
248
+ label="Model Type",
249
+ info="Choose the OrbMol model variant for the calculation."
250
+ )
251
  with gr.Row():
252
  charge_input = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
253
  spin_input = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
254
+
255
  run_spe = gr.Button("Run OrbMol Prediction", variant="primary")
256
 
257
  with gr.Column(variant="panel", min_width=500):
 
267
  }
268
  ]
269
  )
270
+ # Charge and Spin are only applicable to OMol and OMol-Direct
271
+ task_name.input(
272
+ lambda x: (
273
+ (gr.Number(visible=True), gr.Number(visible=True))
274
+ if x == "OMol" or x == "OMol-Direct"
275
+ else (gr.Number(visible=False), gr.Number(visible=False))
276
+ ),
277
+ [task_name],
278
+ [charge_input, spin_input],
279
+ )
280
+
281
+ run_spe.click(predict_molecule, [xyz_input, task_name,charge_input, spin_input], [spe_out, spe_status, spe_viewer])
282
 
283
  # -------- MD --------
284
  with gr.Tab("Molecular Dynamics"):
 
292
  file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
293
  file_count="single"
294
  )
295
+ with gr.Row():
296
+ task_name = gr.Radio(
297
+ ["OMol", "OMat", "OMol-Direct"],
298
+ value="OMol",
299
+ label="Model Type",
300
+ info="Choose the OrbMol model variant for the calculation."
301
+ )
302
  with gr.Row():
303
  charge_md = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
304
  spin_md = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin Multiplicity")
 
331
  md_log = gr.Textbox(label="Log", interactive=False, lines=15, max_lines=25)
332
  md_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
333
  md_explain = gr.Markdown()
334
+
335
+ # Charge and Spin are only applicable to OMol and OMol-Direct
336
+ task_name.input(
337
+ lambda x: (
338
+ (gr.Number(visible=True), gr.Number(visible=True))
339
+ if x == "OMol" or x == "OMol-Direct"
340
+ else (gr.Number(visible=False), gr.Number(visible=False))
341
+ ),
342
+ [task_name],
343
+ [charge_md, spin_md],
344
+ )
345
+
346
  run_md_btn.click(
347
  md_wrapper,
348
+ inputs=[xyz_md, task_name,charge_md, spin_md, steps_md, temp_md, timestep_md, ensemble_md],
349
  outputs=[md_status, md_traj, md_log, md_script, md_explain, md_viewer],
350
  )
351
 
 
361
  file_types=[".xyz", ".pdb", ".cif", ".traj", ".mol", ".sdf"],
362
  file_count="single"
363
  )
364
+ with gr.Row():
365
+ task_name = gr.Radio(
366
+ ["OMol", "OMat", "OMol-Direct"],
367
+ value="OMol",
368
+ label="Model Type",
369
+ info="Choose the OrbMol model variant for the calculation."
370
+ )
371
+ with gr.Row():
372
+ steps_rlx = gr.Slider(minimum=1, maximum=2000, value=300, step=1, label="Max Steps")
373
+ fmax_rlx = gr.Slider(minimum=0.001, maximum=0.5, value=0.05, step=0.001, label="Fmax (eV/Å)")
374
  with gr.Row():
375
  charge_rlx = gr.Slider(minimum=-10, maximum=10, value=0, step=1, label="Charge")
376
  spin_rlx = gr.Slider(minimum=1, maximum=11, value=1, step=1, label="Spin")
 
399
  rlx_script = gr.Code(label="Reproduction Script", language="python", interactive=False, lines=20, max_lines=30)
400
  rlx_explain = gr.Markdown()
401
 
402
+ # Charge and Spin are only applicable to OMol and OMol-Direct
403
+ task_name.input(
404
+ lambda x: (
405
+ (gr.Number(visible=True), gr.Number(visible=True))
406
+ if x == "OMol" or x == "OMol-Direct"
407
+ else (gr.Number(visible=False), gr.Number(visible=False))
408
+ ),
409
+ [task_name],
410
+ [charge_rlx, spin_rlx],
411
+ )
412
+
413
  run_rlx_btn.click(
414
  relax_wrapper,
415
+ inputs=[xyz_rlx,task_name, steps_rlx, fmax_rlx, charge_rlx, spin_rlx, relax_cell],
416
  outputs=[rlx_status, rlx_traj, rlx_log, rlx_script, rlx_explain, rlx_viewer],
417
  )
418
 
419
  print("Starting OrbMol model loading…")
420
+ #_ = load_orbmol_model(task_name)
421
 
422
  if __name__ == "__main__":
423
  demo.launch(server_name="0.0.0.0", server_port=7860, show_error=True)
simulation_scripts_orbmol.py CHANGED
@@ -26,16 +26,37 @@ from orb_models.forcefield.calculator import ORBCalculator
26
  # Global model
27
  # -----------------------------
28
  _model_calc: ORBCalculator | None = None
29
-
30
- def load_orbmol_model(device: str = "cpu", precision: str = "float32-high") -> ORBCalculator:
31
- """Load OrbMol once and reuse the same calculator."""
32
- global _model_calc
33
- if _model_calc is None:
34
- orbff = pretrained.orb_v3_conservative_inf_omat(
35
- device=device,
36
- precision=precision,
37
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  _model_calc = ORBCalculator(orbff, device=device)
 
39
  return _model_calc
40
 
41
  # -----------------------------
@@ -79,7 +100,7 @@ def run_md_simulation(
79
  md_timestep,
80
  temperature_k,
81
  md_ensemble,
82
- task_name="OMol", # Siempre OMol para OrbMol
83
  total_charge=0,
84
  spin_multiplicity=1,
85
  explanation: str | None = None,
@@ -103,7 +124,7 @@ def run_md_simulation(
103
  atoms.info["spin"] = spin_multiplicity
104
 
105
  # AQUÍ EL CAMBIO: OrbMol en lugar de HFEndpointCalculator
106
- calc = load_orbmol_model()
107
  atoms.calc = calc
108
 
109
  # Progress callback si existe
@@ -164,6 +185,7 @@ def run_md_simulation(
164
  dyn.run(num_steps)
165
 
166
  # Script de reproducción (estilo Facebook pero con OrbMol)
 
167
  reproduction_script = f"""
168
  import ase.io
169
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
@@ -181,7 +203,7 @@ atoms = ase.io.read('input_file.traj')
181
  atoms.info["charge"] = {total_charge}
182
  atoms.info["spin"] = {spin_multiplicity}
183
  # Set up the OrbMol calculator
184
- orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
185
  atoms.calc = ORBCalculator(orbff, device='cpu')
186
  # Do a quick pre-relaxation to make sure the system is stable
187
  opt = LBFGS(atoms, trajectory="relaxation_output.traj")
@@ -227,7 +249,7 @@ def run_relaxation_simulation(
227
  structure_file,
228
  num_steps,
229
  fmax,
230
- task_name="OMol", # Siempre OMol para OrbMol
231
  total_charge: float = 0,
232
  spin_multiplicity: float = 1,
233
  relax_unit_cell=False,
@@ -252,7 +274,7 @@ def run_relaxation_simulation(
252
  atoms.info["spin"] = spin_multiplicity
253
 
254
  # AQUÍ EL CAMBIO: OrbMol en lugar de HFEndpointCalculator
255
- calc = load_orbmol_model()
256
  atoms.calc = calc
257
 
258
  # Archivos temporales
@@ -294,7 +316,7 @@ atoms = ase.io.read('input_file.traj')
294
  atoms.info["charge"] = {total_charge}
295
  atoms.info["spin"] = {spin_multiplicity}
296
  # Set up the OrbMol calculator
297
- orbff = pretrained.orb_v3_conservative_inf_omat(device='cpu', precision='float32-high')
298
  atoms.calc = ORBCalculator(orbff, device='cpu')
299
  # Initialize the optimizer
300
  relax_unit_cell = {relax_unit_cell}
 
26
  # Global model
27
  # -----------------------------
28
  _model_calc: ORBCalculator | None = None
29
+ _current_task_name = None
30
+ # Necesario para script de reproducción
31
+ model_name = {
32
+ "OMol": "orb_v3_conservative_omol",
33
+ "OMat": "orb_v3_conservative_inf_omat",
34
+ "OMol-Direct": "orb_v3_direct_omol"}
35
+
36
+
37
+ def load_orbmol_model(task_name,device: str = "cpu", precision: str = "float32-high") -> ORBCalculator:
38
+ """Load OrbMol calculator, switches only if another model is required."""
39
+ global _model_calc, _current_task_name
40
+ if _model_calc is None or _current_task_name != task_name:
41
+ if task_name == "OMol":
42
+ orbff= pretrained.orb_v3_conservative_omol(
43
+ device=device,
44
+ precision=precision,
45
+ )
46
+ elif task_name == "OMat":
47
+ orbff = pretrained.orb_v3_conservative_inf_omat(
48
+ device=device,
49
+ precision=precision,
50
+ )
51
+ elif task_name == "OMol-Direct":
52
+ orbff = pretrained.orb_v3_direct_omol(
53
+ device=device,
54
+ precision=precision,
55
+ )
56
+ else:
57
+ raise ValueError(f"Unknown task_name: {task_name}")
58
  _model_calc = ORBCalculator(orbff, device=device)
59
+ _current_task_name = task_name
60
  return _model_calc
61
 
62
  # -----------------------------
 
100
  md_timestep,
101
  temperature_k,
102
  md_ensemble,
103
+ task_name,
104
  total_charge=0,
105
  spin_multiplicity=1,
106
  explanation: str | None = None,
 
124
  atoms.info["spin"] = spin_multiplicity
125
 
126
  # AQUÍ EL CAMBIO: OrbMol en lugar de HFEndpointCalculator
127
+ calc = load_orbmol_model(task_name)
128
  atoms.calc = calc
129
 
130
  # Progress callback si existe
 
185
  dyn.run(num_steps)
186
 
187
  # Script de reproducción (estilo Facebook pero con OrbMol)
188
+
189
  reproduction_script = f"""
190
  import ase.io
191
  from ase.md.velocitydistribution import MaxwellBoltzmannDistribution
 
203
  atoms.info["charge"] = {total_charge}
204
  atoms.info["spin"] = {spin_multiplicity}
205
  # Set up the OrbMol calculator
206
+ orbff = pretrained.{model_name[task_name]}(device='cpu', precision='float32-high')
207
  atoms.calc = ORBCalculator(orbff, device='cpu')
208
  # Do a quick pre-relaxation to make sure the system is stable
209
  opt = LBFGS(atoms, trajectory="relaxation_output.traj")
 
249
  structure_file,
250
  num_steps,
251
  fmax,
252
+ task_name,
253
  total_charge: float = 0,
254
  spin_multiplicity: float = 1,
255
  relax_unit_cell=False,
 
274
  atoms.info["spin"] = spin_multiplicity
275
 
276
  # AQUÍ EL CAMBIO: OrbMol en lugar de HFEndpointCalculator
277
+ calc = load_orbmol_model(task_name)
278
  atoms.calc = calc
279
 
280
  # Archivos temporales
 
316
  atoms.info["charge"] = {total_charge}
317
  atoms.info["spin"] = {spin_multiplicity}
318
  # Set up the OrbMol calculator
319
+ orbff = pretrained.{model_name[task_name]}(device='cpu', precision='float32-high')
320
  atoms.calc = ORBCalculator(orbff, device='cpu')
321
  # Initialize the optimizer
322
  relax_unit_cell = {relax_unit_cell}