Spaces:
				
			
			
	
			
			
					
		Running
		
			on 
			
			CPU Upgrade
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
			on 
			
			CPU Upgrade
	| """ | |
| Copyright (c) Meta Platforms, Inc. and affiliates. | |
| This source code is licensed under the MIT license found in the | |
| LICENSE file in the root directory of this source tree. | |
| """ | |
| import hashlib | |
| import os | |
| from pathlib import Path | |
| import ase | |
| import backoff | |
| import gradio as gr | |
| import huggingface_hub as hf_hub | |
| import requests | |
| from ase.calculators.calculator import Calculator | |
| from ase.db.core import now | |
| from ase.db.row import AtomsRow | |
| from ase.io.jsonio import decode, encode | |
| from requests.exceptions import HTTPError | |
| def hash_save_file(atoms: ase.Atoms, task_name, path: Path | str): | |
| atoms = atoms.copy() | |
| atoms.info["task_name"] = task_name | |
| atoms.write( | |
| Path(path) | |
| / f"{hashlib.md5(atoms_to_json(atoms).encode('utf-8')).hexdigest()}.traj" | |
| ) | |
| return | |
| def validate_uma_access(oauth_token): | |
| try: | |
| hf_hub.HfApi().auth_check(repo_id="facebook/UMA", token=oauth_token.token) | |
| return True | |
| except (hf_hub.errors.HfHubHTTPError, AttributeError): | |
| return False | |
| class HFEndpointCalculator(Calculator): | |
| # A simple calculator that uses the Hugging Face Inference Endpoints to run | |
| implemented_properties = ["energy", "free_energy", "stress", "forces"] | |
| def __init__( | |
| self, | |
| atoms, | |
| endpoint_url, | |
| oauth_token, | |
| task_name, | |
| example=False, | |
| *args, | |
| **kwargs, | |
| ): | |
| # If we have an example structure, we don't need to check for authentication | |
| # Otherwise, we need to check if the user is authenticated and has gated access to the UMA models | |
| if not example: | |
| if validate_uma_access(oauth_token): | |
| try: | |
| hash_save_file(atoms, task_name, "/data/custom_inputs/") | |
| except FileNotFoundError: | |
| pass | |
| else: | |
| raise gr.Error( | |
| "You need to log in to HF and have gated model access to UMA before running your own simulations!" | |
| ) | |
| self.url = endpoint_url | |
| self.token = os.environ["HF_TOKEN"] | |
| self.atoms = atoms | |
| self.task_name = task_name | |
| super().__init__(*args, **kwargs) | |
| def _post_with_backoff(url, headers, payload): | |
| response = requests.post(url, headers=headers, json=payload) | |
| response.raise_for_status() | |
| return response | |
| def calculate(self, atoms, properties, system_changes): | |
| Calculator.calculate(self, atoms, properties, system_changes) | |
| task_name = self.task_name.lower() | |
| payload = { | |
| "inputs": atoms_to_json(atoms, data=atoms.info), | |
| "properties": properties, | |
| "system_changes": system_changes, | |
| "task_name": task_name, | |
| } | |
| headers = { | |
| "Accept": "application/json", | |
| "Authorization": f"Bearer {self.token}", | |
| "Content-Type": "application/json", | |
| } | |
| try: | |
| response = self._post_with_backoff(self.url, headers, payload) | |
| response_dict = response.json() | |
| except HTTPError as error: | |
| hash_save_file(atoms, task_name, "/data/custom_inputs/errors/") | |
| raise gr.Error( | |
| f"Backend failure during your calculation; if you have continued issues please file an issue in the main FAIR chemistry repo (https://github.com/facebookresearch/fairchem).\n{error}" | |
| ) | |
| # Load the response and store the results in the calc and atoms object | |
| response_dict = decode(response_dict) | |
| self.results = response_dict["results"] | |
| atoms.info = response_dict["info"] | |
| def atoms_to_json(atoms, data=None): | |
| # Similar to ase.db.jsondb | |
| mtime = now() | |
| row = AtomsRow(atoms) | |
| row.ctime = mtime | |
| dct = {} | |
| for key in row.__dict__: | |
| if key[0] == "_" or key in row._keys or key == "id": | |
| continue | |
| dct[key] = row[key] | |
| dct["mtime"] = mtime | |
| if data: | |
| dct["data"] = data | |
| else: | |
| dct["data"] = {} | |
| constraints = row.get("constraints") | |
| if constraints: | |
| dct["constraints"] = constraints | |
| return encode(dct) | |
