Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -136,8 +136,36 @@ def graphs_from_smiles(smiles_list):
|
|
| 136 |
tf.ragged.constant(bond_features_list, dtype=tf.float32),
|
| 137 |
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
|
| 138 |
)
|
| 139 |
-
|
| 140 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 141 |
def MPNNDataset(X, y, batch_size=32, shuffle=False):
|
| 142 |
dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
|
| 143 |
if shuffle:
|
|
|
|
| 136 |
tf.ragged.constant(bond_features_list, dtype=tf.float32),
|
| 137 |
tf.ragged.constant(pair_indices_list, dtype=tf.int64),
|
| 138 |
)
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def prepare_batch(x_batch, y_batch):
|
| 142 |
+
"""Merges (sub)graphs of batch into a single global (disconnected) graph
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
atom_features, bond_features, pair_indices = x_batch
|
| 146 |
+
|
| 147 |
+
# Obtain number of atoms and bonds for each graph (molecule)
|
| 148 |
+
num_atoms = atom_features.row_lengths()
|
| 149 |
+
num_bonds = bond_features.row_lengths()
|
| 150 |
+
|
| 151 |
+
# Obtain partition indices (molecule_indicator), which will be used to
|
| 152 |
+
# gather (sub)graphs from global graph in model later on
|
| 153 |
+
molecule_indices = tf.range(len(num_atoms))
|
| 154 |
+
molecule_indicator = tf.repeat(molecule_indices, num_atoms)
|
| 155 |
+
|
| 156 |
+
# Merge (sub)graphs into a global (disconnected) graph. Adding 'increment' to
|
| 157 |
+
# 'pair_indices' (and merging ragged tensors) actualizes the global graph
|
| 158 |
+
gather_indices = tf.repeat(molecule_indices[:-1], num_bonds[1:])
|
| 159 |
+
increment = tf.cumsum(num_atoms[:-1])
|
| 160 |
+
increment = tf.pad(tf.gather(increment, gather_indices), [(num_bonds[0], 0)])
|
| 161 |
+
pair_indices = pair_indices.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
|
| 162 |
+
pair_indices = pair_indices + increment[:, tf.newaxis]
|
| 163 |
+
atom_features = atom_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
|
| 164 |
+
bond_features = bond_features.merge_dims(outer_axis=0, inner_axis=1).to_tensor()
|
| 165 |
+
|
| 166 |
+
return (atom_features, bond_features, pair_indices, molecule_indicator), y_batch
|
| 167 |
+
|
| 168 |
+
|
| 169 |
def MPNNDataset(X, y, batch_size=32, shuffle=False):
|
| 170 |
dataset = tf.data.Dataset.from_tensor_slices((X, (y)))
|
| 171 |
if shuffle:
|