Spaces:
Running
Running
| from torch import nn | |
| import torch | |
| from .encoder import GotenNet | |
| from .utils import get_symmetric_displacement, BatchedPeriodicDistance, ACT_CLASS_MAPPING | |
| #from torch_scatter import scatter | |
| class NodeInvariantReadout(nn.Module): | |
| def __init__(self, in_channels, num_residues, hidden_channels, out_channels, activation): | |
| super().__init__() | |
| self.linears = nn.ModuleList([nn.Linear(in_channels, out_channels) for _ in range(num_residues - 1)]) | |
| # Define the nonlinear layer for the last layer's output | |
| self.non_linear = nn.Sequential( | |
| nn.Linear(in_channels, hidden_channels), | |
| ACT_CLASS_MAPPING[activation](), | |
| nn.Linear(hidden_channels, out_channels), | |
| ) | |
| def forward(self, embedding_0): | |
| layer_outputs = embedding_0.squeeze(2) # [n_nodes, in_channels, num_residues] | |
| processed_outputs = [] | |
| for i, linear in enumerate(self.linears): | |
| processed_outputs.append(linear(layer_outputs[:, :, i])) | |
| processed_outputs.append(self.non_linear(layer_outputs[:, :, -1])) | |
| output = torch.stack(processed_outputs, dim=0).sum(dim=0).squeeze(-1) | |
| return output | |
| class PosEGNN(nn.Module): | |
| def __init__(self, config): | |
| super().__init__() | |
| self.distance = BatchedPeriodicDistance(config["encoder"]["cutoff"]) | |
| self.encoder = GotenNet(**config["encoder"]) | |
| self.readout = NodeInvariantReadout(**config["decoder"]) | |
| self.register_buffer("e0_mean", torch.tensor(config["e0_mean"])) | |
| self.register_buffer("atomic_res_total_mean", torch.tensor(config["atomic_res_total_mean"])) | |
| self.register_buffer("atomic_res_total_std", torch.tensor(config["atomic_res_total_std"])) | |
| def forward(self, data): | |
| data.pos.requires_grad_(True) | |
| data.pos, data.box, data.displacements = get_symmetric_displacement(data.pos, data.box, data.num_graphs, data.batch) | |
| data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec, data.cutoff_shifts_idx = self.distance( | |
| data.pos, data.box, data.batch | |
| ) | |
| embedding_dict = self.encoder(data.z, data.pos, data.cutoff_edge_index, data.cutoff_edge_distance, data.cutoff_edge_vec) | |
| return embedding_dict | |
| def compute_properties(self, data, compute_stress = True): | |
| output = {} | |
| embedding_dict = self.forward(data) | |
| embedding_0 = embedding_dict["embedding_0"] | |
| # Compute energy | |
| node_e_res = self.readout(embedding_0) | |
| node_e_res = node_e_res * self.atomic_res_total_std + self.atomic_res_total_mean | |
| total_e_res = scatter(src=node_e_res, index=data["batch"], dim=0, reduce="sum") | |
| node_e0 = self.e0_mean[data.z] | |
| total_e0 = scatter(src=node_e0, index=data["batch"], dim=0, reduce="sum") | |
| total_energy = total_e0 + total_e_res | |
| output["total_energy"] = total_energy | |
| # Compute gradients | |
| if compute_stress: | |
| inputs = [data.pos, data.displacements] | |
| compute_stress = True | |
| else: | |
| inputs = [data.pos] | |
| grad_outputs = torch.autograd.grad( | |
| outputs=[total_energy], | |
| inputs=inputs, | |
| grad_outputs=[torch.ones_like(total_energy)], | |
| retain_graph=self.training, | |
| create_graph=self.training, | |
| ) | |
| # Get forces and stresses | |
| if compute_stress: | |
| force, virial = grad_outputs | |
| stress = virial / torch.det(data.box).abs().view(-1, 1, 1) | |
| stress = torch.where(torch.abs(stress) < 1e10, stress, torch.zeros_like(stress)) | |
| output["force"] = -force | |
| output["stress"] = -stress | |
| else: | |
| force = grad_outputs[0] | |
| output["force"] = -force | |
| return output |