Spaces:
Runtime error
Runtime error
| from huggingface_hub import from_pretrained_keras | |
| import gradio as gr | |
| from rdkit import Chem, RDLogger | |
| from rdkit.Chem.Draw import MolsToGridImage | |
| import numpy as np | |
| import tensorflow as tf | |
| from tensorflow import keras | |
| import pandas as pd | |
| # Config | |
| class Featurizer: | |
| def __init__(self, allowable_sets): | |
| self.dim = 0 | |
| self.features_mapping = {} | |
| for k, s in allowable_sets.items(): | |
| s = sorted(list(s)) | |
| self.features_mapping[k] = dict(zip(s, range(self.dim, len(s) + self.dim))) | |
| self.dim += len(s) | |
| def encode(self, inputs): | |
| output = np.zeros((self.dim,)) | |
| for name_feature, feature_mapping in self.features_mapping.items(): | |
| feature = getattr(self, name_feature)(inputs) | |
| if feature not in feature_mapping: | |
| continue | |
| output[feature_mapping[feature]] = 1.0 | |
| return output | |
| class AtomFeaturizer(Featurizer): | |
| def __init__(self, allowable_sets): | |
| super().__init__(allowable_sets) | |
| def symbol(self, atom): | |
| return atom.GetSymbol() | |
| def n_valence(self, atom): | |
| return atom.GetTotalValence() | |
| def n_hydrogens(self, atom): | |
| return atom.GetTotalNumHs() | |
| def hybridization(self, atom): | |
| return atom.GetHybridization().name.lower() | |
| class BondFeaturizer(Featurizer): | |
| def __init__(self, allowable_sets): | |
| super().__init__(allowable_sets) | |
| self.dim += 1 | |
| def encode(self, bond): | |
| output = np.zeros((self.dim,)) | |
| if bond is None: | |
| output[-1] = 1.0 | |
| return output | |
| output = super().encode(bond) | |
| return output | |
| def bond_type(self, bond): | |
| return bond.GetBondType().name.lower() | |
| def conjugated(self, bond): | |
| return bond.GetIsConjugated() | |
| atom_featurizer = AtomFeaturizer( | |
| allowable_sets={ | |
| "symbol": {"B", "Br", "C", "Ca", "Cl", "F", "H", "I", "N", "Na", "O", "P", "S"}, | |
| "n_valence": {0, 1, 2, 3, 4, 5, 6}, | |
| "n_hydrogens": {0, 1, 2, 3, 4}, | |
| "hybridization": {"s", "sp", "sp2", "sp3"}, | |
| } | |
| ) | |
| bond_featurizer = BondFeaturizer( | |
| allowable_sets={ | |
| "bond_type": {"single", "double", "triple", "aromatic"}, | |
| "conjugated": {True, False}, | |
| } | |
| ) | |
| def molecule_from_smiles(smiles): | |
| # MolFromSmiles(m, sanitize=True) should be equivalent to | |
| # MolFromSmiles(m, sanitize=False) -> SanitizeMol(m) -> AssignStereochemistry(m, ...) | |
| molecule = Chem.MolFromSmiles(smiles, sanitize=False) | |
| # If sanitization is unsuccessful, catch the error, and try again without | |
| # the sanitization step that caused the error | |
| flag = Chem.SanitizeMol(molecule, catchErrors=True) | |
| if flag != Chem.SanitizeFlags.SANITIZE_NONE: | |
| Chem.SanitizeMol(molecule, sanitizeOps=Chem.SanitizeFlags.SANITIZE_ALL ^ flag) | |
| Chem.AssignStereochemistry(molecule, cleanIt=True, force=True) | |
| return molecule | |
| def graph_from_molecule(molecule): | |
| # Initialize graph | |
| atom_features = [] | |
| bond_features = [] | |
| pair_indices = [] | |
| for atom in molecule.GetAtoms(): | |
| atom_features.append(atom_featurizer.encode(atom)) | |
| # Add self-loops | |
| pair_indices.append([atom.GetIdx(), atom.GetIdx()]) | |
| bond_features.append(bond_featurizer.encode(None)) | |
| for neighbor in atom.GetNeighbors(): | |
| bond = molecule.GetBondBetweenAtoms(atom.GetIdx(), neighbor.GetIdx()) | |
| pair_indices.append([atom.GetIdx(), neighbor.GetIdx()]) | |
| bond_features.append(bond_featurizer.encode(bond)) | |
| return np.array(atom_features), np.array(bond_features), np.array(pair_indices) | |
| def graphs_from_smiles(smiles_list): | |
| # Initialize graphs | |
| atom_features_list = [] | |
| bond_features_list = [] | |
| pair_indices_list = [] | |
| for smiles in smiles_list: | |
| molecule = molecule_from_smiles(smiles) | |
| atom_features, bond_features, pair_indices = graph_from_molecule(molecule) | |
| atom_features_list.append(atom_features) | |
| bond_features_list.append(bond_features) | |
| pair_indices_list.append(pair_indices) | |
| # Convert lists to ragged tensors for tf.data.Dataset later on | |
| return ( | |
| tf.ragged.constant(atom_features_list, dtype=tf.float32), | |
| tf.ragged.constant(bond_features_list, dtype=tf.float32), | |
| tf.ragged.constant(pair_indices_list, dtype=tf.int64), | |
| ) | |
| def prepare_batch(x_batch, y_batch): | |
| """Merges (sub)graphs of batch into a single global (disconnected) graph | |
| """ | |
| atom_features, bond_features, pair_indices = x_batch | |
| # Obtain number of atoms and bonds for each graph (molecule) | |
| num_atoms = atom_features.row_lengths() | |
| num_bonds = bond_features.row_lengths() | |
| # Obtain partition indices (molecule_indicator), which will be used to | |
| # gather (sub)graphs from global graph in model later on | |
| molecule_indices = tf.range(len(num_atoms)) | |
| molecule_indicator = tf.repeat(molecule_indices, num_atoms) | |
| # Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to | |
| # 'pair_indices' (and merging ragged tensors) actualizes the global graph | |
| gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:]) | |
| increment = tf.cumsum(num_atoms[:-1]) | |
| increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)]) | |
| pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
| pair_indices = pair_indices + increment[:, tf.newaxis] | |
| atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
| bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor() | |
| return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch | |
| def MPNNDataset(X, y, batch_size=32, shuffle=False): | |
| dataset = tf.data.Dataset.from_tensor_slices((X, (y))) | |
| if shuffle: | |
| dataset = dataset.shuffle(1024) | |
| return dataset.batch(batch_size).map(prepare_batch, -1).prefetch(-1) | |
| model = from_pretrained_keras("keras-io/MPNN-for-molecular-property-prediction") | |
| def predict(smiles, label): | |
| molecules = [molecule_from_smiles(smiles)] | |
| input = graphs_from_smiles([smiles]) | |
| label = pd.Series([label]) | |
| test_dataset = MPNNDataset(input, label) | |
| y_pred = tf.squeeze(model.predict(test_dataset), axis=1) | |
| legends = [f"y_true/y_pred = {label[i]}/{y_pred[i]:.2f}" for i in range(len(label))] | |
| MolsToGridImage(molecules, molsPerRow=1, legends=legends, returnPNG=False, subImgSize=(650, 650)).save("img.png") | |
| return 'img.png' | |
| inputs = [ | |
| gr.Textbox(label='Smiles of molecular'), | |
| gr.Textbox(label='Molecular permeability') | |
| ] | |
| examples = [ | |
| ["CO/N=C(C(=O)N[C@H]1[C@H]2SCC(=C(N2C1=O)C(O)=O)C)/c3csc(N)n3", 0], | |
| ["[C@H]37[C@H]2[C@@]([C@](C(COC(C1=CC(=CC=C1)[S](O)(=O)=O)=O)=O)(O)[C@@H](C2)C)(C[C@@H]([C@@H]3[C@@]4(C(=CC5=C(C4)C=N[N]5C6=CC=CC=C6)C(=C7)C)C)O)C", 1], | |
| ["CNCCCC2(C)C(=O)N(c1ccccc1)c3ccccc23", 1], | |
| ["O.N[C@@H](C(=O)NC1C2CCC(=C(N2C1=O)C(O)=O)Cl)c3ccccc3", 0], | |
| ["[C@@]4([C@@]3([C@H]([C@H]2[C@@H]([C@@]1(C(=CC(=O)CC1)CC2)C)[C@H](C3)O)CC4)C)(C(COC(C)=O)=O)OC(CC)=O", 1], | |
| ["[C@]34([C@H](C2[C@@](F)([C@@]1(C(=CC(=O)C=C1)[C@@H](F)C2)C)[C@@H](O)C3)C[C@H]5OC(O[C@@]45C(=O)COC(=O)C6CC6)(C)C)C", 1] | |
| ] | |
| gr.Interface( | |
| fn=predict, | |
| title="Predict blood-brain barrier permeability of molecular", | |
| description = "Message-passing neural network (MPNN) for molecular property prediction", | |
| inputs=inputs, | |
| examples=examples, | |
| outputs="image", | |
| article = "Author: <a href=\"https://huggingface.co/vumichien\">Vu Minh Chien</a>. Based on the keras example from <a href=\"https://keras.io/examples/graph/mpnn-molecular-graphs/\">Alexander Kensert</a>", | |
| ).launch(debug=False, enable_queue=True) |