Spaces:
				
			
			
	
			
			
					
		Running
		
			on 
			
			CPU Upgrade
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
			on 
			
			CPU Upgrade
	better error handling
Browse files- evaluator.py +43 -11
    	
        evaluator.py
    CHANGED
    
    | @@ -16,6 +16,11 @@ from fairchem.data.omol.modules.evaluator import ( | |
| 16 | 
             
                unoptimized_spin_gap,
         | 
| 17 | 
             
            )
         | 
| 18 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 19 | 
             
            OMOL_EVAL_FUNCTIONS = {
         | 
| 20 | 
             
                "Ligand pocket": ligand_pocket,
         | 
| 21 | 
             
                "Ligand strain": ligand_strain,
         | 
| @@ -66,8 +71,13 @@ def reorder(ref: np.ndarray, to_reorder: np.ndarray) -> np.ndarray: | |
| 66 |  | 
| 67 | 
             
            def get_order(path_submission: Path, path_annotations: Path):
         | 
| 68 |  | 
| 69 | 
            -
                 | 
| 70 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 71 |  | 
| 72 | 
             
                with np.load(path_annotations, allow_pickle=True) as data:
         | 
| 73 | 
             
                    annotations_ids = data["ids"]
         | 
| @@ -86,6 +96,10 @@ def get_order(path_submission: Path, path_annotations: Path): | |
| 86 | 
             
                    )
         | 
| 87 | 
             
                    raise Exception(f"IDs don't match.\n{details}")
         | 
| 88 |  | 
|  | |
|  | |
|  | |
|  | |
| 89 | 
             
                return reorder(annotations_ids, submission_ids)
         | 
| 90 |  | 
| 91 |  | 
| @@ -96,10 +110,17 @@ def s2ef_metrics( | |
| 96 | 
             
            ) -> Dict[str, float]:
         | 
| 97 | 
             
                order = get_order(submission_filename, annotations_path)
         | 
| 98 |  | 
| 99 | 
            -
                 | 
| 100 | 
            -
                     | 
| 101 | 
            -
             | 
| 102 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 |  | 
| 104 | 
             
                if len(set(np.where(np.isinf(energy))[0])) != 0:
         | 
| 105 | 
             
                    inf_energy_ids = list(set(np.where(np.isinf(energy))[0]))
         | 
| @@ -129,10 +150,12 @@ def s2ef_metrics( | |
| 129 |  | 
| 130 | 
             
                    forces_mae = 0
         | 
| 131 | 
             
                    natoms = 0
         | 
| 132 | 
            -
                    for sub_forces, sub_target_forces in zip( | 
|  | |
|  | |
| 133 | 
             
                        forces_mae += np.sum(np.abs(sub_target_forces - sub_forces))
         | 
| 134 | 
             
                        natoms += sub_forces.shape[0]
         | 
| 135 | 
            -
                    forces_mae /=  | 
| 136 |  | 
| 137 | 
             
                    metrics[f"{subset}_forces_mae"] = forces_mae
         | 
| 138 |  | 
| @@ -144,8 +167,12 @@ def omol_evaluations( | |
| 144 | 
             
                submission_filename: Path,
         | 
| 145 | 
             
                eval_type: str,
         | 
| 146 | 
             
            ) -> Dict[str, float]:
         | 
| 147 | 
            -
                 | 
| 148 | 
            -
                     | 
|  | |
|  | |
|  | |
|  | |
| 149 | 
             
                with open(annotations_path) as f:
         | 
| 150 | 
             
                    annotations_data = json.load(f)
         | 
| 151 |  | 
| @@ -159,6 +186,11 @@ def omol_evaluations( | |
| 159 | 
             
                        f"Missing entries in submission: {missing}\n"
         | 
| 160 | 
             
                        f"Unexpected entries in submission: {unexpected}"
         | 
| 161 | 
             
                    )
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
| 162 | 
             
                eval_fn = OMOL_EVAL_FUNCTIONS.get(eval_type)
         | 
| 163 | 
             
                metrics = eval_fn(annotations_data, submission_data)
         | 
| 164 | 
             
                return metrics
         | 
| @@ -190,4 +222,4 @@ def evaluate( | |
| 190 | 
             
                else:
         | 
| 191 | 
             
                    raise ValueError(f"Unknown eval_type: {eval_type}")
         | 
| 192 |  | 
| 193 | 
            -
                return metrics
         | 
|  | |
| 16 | 
             
                unoptimized_spin_gap,
         | 
| 17 | 
             
            )
         | 
| 18 |  | 
| 19 | 
            +
             | 
| 20 | 
            +
            class SubmissionLoadError(Exception):
         | 
| 21 | 
            +
                """Raised if unable to load the submission file."""
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
             
            OMOL_EVAL_FUNCTIONS = {
         | 
| 25 | 
             
                "Ligand pocket": ligand_pocket,
         | 
| 26 | 
             
                "Ligand strain": ligand_strain,
         | 
|  | |
| 71 |  | 
| 72 | 
             
            def get_order(path_submission: Path, path_annotations: Path):
         | 
| 73 |  | 
| 74 | 
            +
                try:
         | 
| 75 | 
            +
                    with np.load(path_submission) as data:
         | 
| 76 | 
            +
                        submission_ids = data["ids"]
         | 
| 77 | 
            +
                except Exception as e:
         | 
| 78 | 
            +
                    raise SubmissionLoadError(
         | 
| 79 | 
            +
                        f"Error loading submission file. 'ids' must not be object types."
         | 
| 80 | 
            +
                    ) from e
         | 
| 81 |  | 
| 82 | 
             
                with np.load(path_annotations, allow_pickle=True) as data:
         | 
| 83 | 
             
                    annotations_ids = data["ids"]
         | 
|  | |
| 96 | 
             
                    )
         | 
| 97 | 
             
                    raise Exception(f"IDs don't match.\n{details}")
         | 
| 98 |  | 
| 99 | 
            +
                assert len(submission_ids) == len(
         | 
| 100 | 
            +
                    submission_set
         | 
| 101 | 
            +
                ), "Duplicate IDs found in submission."
         | 
| 102 | 
            +
             | 
| 103 | 
             
                return reorder(annotations_ids, submission_ids)
         | 
| 104 |  | 
| 105 |  | 
|  | |
| 110 | 
             
            ) -> Dict[str, float]:
         | 
| 111 | 
             
                order = get_order(submission_filename, annotations_path)
         | 
| 112 |  | 
| 113 | 
            +
                try:
         | 
| 114 | 
            +
                    with np.load(submission_filename) as data:
         | 
| 115 | 
            +
                        forces = data["forces"]
         | 
| 116 | 
            +
                        energy = data["energy"][order]
         | 
| 117 | 
            +
                        forces = np.array(
         | 
| 118 | 
            +
                            np.split(forces, np.cumsum(data["natoms"])[:-1]), dtype=object
         | 
| 119 | 
            +
                        )[order]
         | 
| 120 | 
            +
                except Exception as e:
         | 
| 121 | 
            +
                    raise SubmissionLoadError(
         | 
| 122 | 
            +
                        f"Error loading submission data. Make sure you concatenated your forces and there are no object types."
         | 
| 123 | 
            +
                    ) from e
         | 
| 124 |  | 
| 125 | 
             
                if len(set(np.where(np.isinf(energy))[0])) != 0:
         | 
| 126 | 
             
                    inf_energy_ids = list(set(np.where(np.isinf(energy))[0]))
         | 
|  | |
| 150 |  | 
| 151 | 
             
                    forces_mae = 0
         | 
| 152 | 
             
                    natoms = 0
         | 
| 153 | 
            +
                    for sub_forces, sub_target_forces in zip(
         | 
| 154 | 
            +
                        forces[subset_mask], target_forces[subset_mask]
         | 
| 155 | 
            +
                    ):
         | 
| 156 | 
             
                        forces_mae += np.sum(np.abs(sub_target_forces - sub_forces))
         | 
| 157 | 
             
                        natoms += sub_forces.shape[0]
         | 
| 158 | 
            +
                    forces_mae /= 3 * natoms
         | 
| 159 |  | 
| 160 | 
             
                    metrics[f"{subset}_forces_mae"] = forces_mae
         | 
| 161 |  | 
|  | |
| 167 | 
             
                submission_filename: Path,
         | 
| 168 | 
             
                eval_type: str,
         | 
| 169 | 
             
            ) -> Dict[str, float]:
         | 
| 170 | 
            +
                try:
         | 
| 171 | 
            +
                    with open(submission_filename) as f:
         | 
| 172 | 
            +
                        submission_data = json.load(f)
         | 
| 173 | 
            +
                except Exception as e:
         | 
| 174 | 
            +
                    raise SubmissionLoadError(f"Error loading submission file") from e
         | 
| 175 | 
            +
             | 
| 176 | 
             
                with open(annotations_path) as f:
         | 
| 177 | 
             
                    annotations_data = json.load(f)
         | 
| 178 |  | 
|  | |
| 186 | 
             
                        f"Missing entries in submission: {missing}\n"
         | 
| 187 | 
             
                        f"Unexpected entries in submission: {unexpected}"
         | 
| 188 | 
             
                    )
         | 
| 189 | 
            +
             | 
| 190 | 
            +
                assert len(submission_entries) == len(
         | 
| 191 | 
            +
                    submission_data
         | 
| 192 | 
            +
                ), "Duplicate entries found in submission."
         | 
| 193 | 
            +
             | 
| 194 | 
             
                eval_fn = OMOL_EVAL_FUNCTIONS.get(eval_type)
         | 
| 195 | 
             
                metrics = eval_fn(annotations_data, submission_data)
         | 
| 196 | 
             
                return metrics
         | 
|  | |
| 222 | 
             
                else:
         | 
| 223 | 
             
                    raise ValueError(f"Unknown eval_type: {eval_type}")
         | 
| 224 |  | 
| 225 | 
            +
                return metrics
         | 

