import torch import numpy as np import os import glob import random import matplotlib import imageio matplotlib.use('Agg') import matplotlib.pyplot as plt from analysis.molecule_builder import get_bond_order ############## ### Files #### ###########--> def save_xyz_file(path, one_hot, positions, atom_decoder, id_from=0, name='molecule', batch_mask=None): try: os.makedirs(path) except OSError: pass if batch_mask is None: batch_mask = torch.zeros(len(one_hot)) for batch_i in torch.unique(batch_mask): cur_batch_mask = (batch_mask == batch_i) n_atoms = int(torch.sum(cur_batch_mask).item()) f = open(path + name + '_' + "%03d.xyz" % (batch_i + id_from), "w") f.write("%d\n\n" % n_atoms) atoms = torch.argmax(one_hot[cur_batch_mask], dim=1) batch_pos = positions[cur_batch_mask] for atom_i in range(n_atoms): atom = atoms[atom_i] atom = atom_decoder[atom] f.write("%s %.9f %.9f %.9f\n" % (atom, batch_pos[atom_i, 0], batch_pos[atom_i, 1], batch_pos[atom_i, 2])) f.close() def load_molecule_xyz(file, dataset_info): with open(file, encoding='utf8') as f: n_atoms = int(f.readline()) one_hot = torch.zeros(n_atoms, len(dataset_info['atom_decoder'])) positions = torch.zeros(n_atoms, 3) f.readline() atoms = f.readlines() for i in range(n_atoms): atom = atoms[i].split(' ') atom_type = atom[0] one_hot[i, dataset_info['atom_encoder'][atom_type]] = 1 position = torch.Tensor([float(e) for e in atom[1:]]) positions[i, :] = position return positions, one_hot def load_xyz_files(path, shuffle=True): files = glob.glob(path + "/*.xyz") if shuffle: random.shuffle(files) return files # <----######## ### Files #### ############## def draw_sphere(ax, x, y, z, size, color, alpha): u = np.linspace(0, 2 * np.pi, 100) v = np.linspace(0, np.pi, 100) xs = size * np.outer(np.cos(u), np.sin(v)) ys = size * np.outer(np.sin(u), np.sin(v)) * 0.8 # Correct for matplotlib. zs = size * np.outer(np.ones(np.size(u)), np.cos(v)) # for i in range(2): # ax.plot_surface(x+random.randint(-5,5), y+random.randint(-5,5), z+random.randint(-5,5), rstride=4, cstride=4, color='b', linewidth=0, alpha=0.5) ax.plot_surface(x + xs, y + ys, z + zs, rstride=2, cstride=2, color=color, linewidth=0, alpha=alpha) # # calculate vectors for "vertical" circle # a = np.array([-np.sin(elev / 180 * np.pi), 0, np.cos(elev / 180 * np.pi)]) # b = np.array([0, 1, 0]) # b = b * np.cos(rot) + np.cross(a, b) * np.sin(rot) + a * np.dot(a, b) * ( # 1 - np.cos(rot)) # ax.plot(np.sin(u), np.cos(u), 0, color='k', linestyle='dashed') # horiz_front = np.linspace(0, np.pi, 100) # ax.plot(np.sin(horiz_front), np.cos(horiz_front), 0, color='k') # vert_front = np.linspace(np.pi / 2, 3 * np.pi / 2, 100) # ax.plot(a[0] * np.sin(u) + b[0] * np.cos(u), b[1] * np.cos(u), # a[2] * np.sin(u) + b[2] * np.cos(u), color='k', linestyle='dashed') # ax.plot(a[0] * np.sin(vert_front) + b[0] * np.cos(vert_front), # b[1] * np.cos(vert_front), # a[2] * np.sin(vert_front) + b[2] * np.cos(vert_front), color='k') # # ax.view_init(elev=elev, azim=0) def plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, dataset_info): # draw_sphere(ax, 0, 0, 0, 1) # draw_sphere(ax, 1, 1, 1, 1) x = positions[:, 0] y = positions[:, 1] z = positions[:, 2] # Hydrogen, Carbon, Nitrogen, Oxygen, Flourine # ax.set_facecolor((1.0, 0.47, 0.42)) colors_dic = np.array(dataset_info['colors_dic']) radius_dic = np.array(dataset_info['radius_dic']) area_dic = 1500 * radius_dic ** 2 # areas_dic = sizes_dic * sizes_dic * 3.1416 areas = area_dic[atom_type] radii = radius_dic[atom_type] colors = colors_dic[atom_type] if spheres_3d: for i, j, k, s, c in zip(x, y, z, radii, colors): draw_sphere(ax, i.item(), j.item(), k.item(), 0.7 * s, c, alpha) else: ax.scatter(x, y, z, s=areas, alpha=0.9 * alpha, c=colors) # , linewidths=2, edgecolors='#FFFFFF') for i in range(len(x)): for j in range(i + 1, len(x)): p1 = np.array([x[i], y[i], z[i]]) p2 = np.array([x[j], y[j], z[j]]) dist = np.sqrt(np.sum((p1 - p2) ** 2)) atom1, atom2 = dataset_info['atom_decoder'][atom_type[i]], \ dataset_info['atom_decoder'][atom_type[j]] s = (atom_type[i], atom_type[j]) draw_edge_int = get_bond_order(dataset_info['atom_decoder'][s[0]], dataset_info['atom_decoder'][s[1]], dist) line_width = 2 draw_edge = draw_edge_int > 0 if draw_edge: if draw_edge_int == 4: linewidth_factor = 1.5 else: # linewidth_factor = draw_edge_int # Prop to number of # edges. linewidth_factor = 1 ax.plot([x[i], x[j]], [y[i], y[j]], [z[i], z[j]], linewidth=line_width * linewidth_factor, c=hex_bg_color, alpha=alpha) def plot_data3d(positions, atom_type, dataset_info, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False, bg='black', alpha=1.): black = (0, 0, 0) white = (1, 1, 1) hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.set_aspect('auto') ax.view_init(elev=camera_elev, azim=camera_azim) if bg == 'black': ax.set_facecolor(black) else: ax.set_facecolor(white) # ax.xaxis.pane.set_edgecolor('#D0D0D0') ax.xaxis.pane.set_alpha(0) ax.yaxis.pane.set_alpha(0) ax.zaxis.pane.set_alpha(0) ax._axis3don = False if bg == 'black': ax.w_xaxis.line.set_color("black") else: ax.w_xaxis.line.set_color("white") plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, dataset_info) # if 'qm9' in dataset_info['name']: max_value = positions.abs().max().item() # axis_lim = 3.2 axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) ax.set_xlim(-axis_lim, axis_lim) ax.set_ylim(-axis_lim, axis_lim) ax.set_zlim(-axis_lim, axis_lim) # elif dataset_info['name'] == 'geom': # max_value = positions.abs().max().item() # # # axis_lim = 3.2 # axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) # ax.set_xlim(-axis_lim, axis_lim) # ax.set_ylim(-axis_lim, axis_lim) # ax.set_zlim(-axis_lim, axis_lim) # elif dataset_info['name'] == 'pdbbind': # max_value = positions.abs().max().item() # # # axis_lim = 3.2 # axis_lim = min(40, max(max_value / 1.5 + 0.3, 3.2)) # ax.set_xlim(-axis_lim, axis_lim) # ax.set_ylim(-axis_lim, axis_lim) # ax.set_zlim(-axis_lim, axis_lim) # else: # raise ValueError(dataset_info['name']) dpi = 120 if spheres_3d else 50 if save_path is not None: plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) if spheres_3d: img = imageio.imread(save_path) img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') imageio.imsave(save_path, img_brighter) else: plt.show() plt.close() def plot_data3d_uncertainty( all_positions, all_atom_types, dataset_info, camera_elev=0, camera_azim=0, save_path=None, spheres_3d=False, bg='black', alpha=1.): black = (0, 0, 0) white = (1, 1, 1) hex_bg_color = '#FFFFFF' if bg == 'black' else '#666666' from mpl_toolkits.mplot3d import Axes3D fig = plt.figure() ax = fig.add_subplot(projection='3d') ax.set_aspect('auto') ax.view_init(elev=camera_elev, azim=camera_azim) if bg == 'black': ax.set_facecolor(black) else: ax.set_facecolor(white) # ax.xaxis.pane.set_edgecolor('#D0D0D0') ax.xaxis.pane.set_alpha(0) ax.yaxis.pane.set_alpha(0) ax.zaxis.pane.set_alpha(0) ax._axis3don = False if bg == 'black': ax.w_xaxis.line.set_color("black") else: ax.w_xaxis.line.set_color("white") for i in range(len(all_positions)): positions = all_positions[i] atom_type = all_atom_types[i] plot_molecule(ax, positions, atom_type, alpha, spheres_3d, hex_bg_color, dataset_info) if 'qm9' in dataset_info['name']: max_value = all_positions[0].abs().max().item() # axis_lim = 3.2 axis_lim = min(40, max(max_value + 0.3, 3.2)) ax.set_xlim(-axis_lim, axis_lim) ax.set_ylim(-axis_lim, axis_lim) ax.set_zlim(-axis_lim, axis_lim) elif dataset_info['name'] == 'geom': max_value = all_positions[0].abs().max().item() # axis_lim = 3.2 axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) ax.set_xlim(-axis_lim, axis_lim) ax.set_ylim(-axis_lim, axis_lim) ax.set_zlim(-axis_lim, axis_lim) elif dataset_info['name'] == 'pdbbind': max_value = all_positions[0].abs().max().item() # axis_lim = 3.2 axis_lim = min(40, max(max_value / 2 + 0.3, 3.2)) ax.set_xlim(-axis_lim, axis_lim) ax.set_ylim(-axis_lim, axis_lim) ax.set_zlim(-axis_lim, axis_lim) else: raise ValueError(dataset_info['name']) dpi = 120 if spheres_3d else 50 if save_path is not None: plt.savefig(save_path, bbox_inches='tight', pad_inches=0.0, dpi=dpi) if spheres_3d: img = imageio.imread(save_path) img_brighter = np.clip(img * 1.4, 0, 255).astype('uint8') imageio.imsave(save_path, img_brighter) else: plt.show() plt.close() def plot_grid(): import matplotlib.pyplot as plt from mpl_toolkits.axes_grid1 import ImageGrid im1 = np.arange(100).reshape((10, 10)) im2 = im1.T im3 = np.flipud(im1) im4 = np.fliplr(im2) fig = plt.figure(figsize=(10., 10.)) grid = ImageGrid(fig, 111, # similar to subplot(111) nrows_ncols=(6, 6), # creates 2x2 grid of axes axes_pad=0.1, # pad between axes in inch. ) for ax, im in zip(grid, [im1, im2, im3, im4]): # Iterating over the grid returns the Axes. ax.imshow(im) plt.show() def visualize(path, dataset_info, max_num=25, wandb=None, spheres_3d=False): files = load_xyz_files(path)[0:max_num] for file in files: positions, one_hot = load_molecule_xyz(file, dataset_info) atom_type = torch.argmax(one_hot, dim=1).numpy() dists = torch.cdist(positions.unsqueeze(0), positions.unsqueeze(0)).squeeze(0) dists = dists[dists > 0] # print("Average distance between atoms", dists.mean().item()) plot_data3d(positions, atom_type, dataset_info=dataset_info, save_path=file[:-4] + '.png', spheres_3d=spheres_3d) if wandb is not None: path = file[:-4] + '.png' # Log image(s) im = plt.imread(path) wandb.log({'molecule': [wandb.Image(im, caption=path)]}) def visualize_chain(path, dataset_info, wandb=None, spheres_3d=False, mode="chain"): files = load_xyz_files(path) files = sorted(files) save_paths = [] for i in range(len(files)): file = files[i] positions, one_hot = load_molecule_xyz(file, dataset_info=dataset_info) atom_type = torch.argmax(one_hot, dim=1).numpy() fn = file[:-4] + '.png' plot_data3d(positions, atom_type, dataset_info=dataset_info, save_path=fn, spheres_3d=spheres_3d, alpha=1.0) save_paths.append(fn) imgs = [imageio.imread(fn) for fn in save_paths] dirname = os.path.dirname(save_paths[0]) gif_path = dirname + '/output.gif' print(f'Creating gif with {len(imgs)} images') # Add the last frame 10 times so that the final result remains temporally. # imgs.extend([imgs[-1]] * 10) imageio.mimsave(gif_path, imgs, subrectangles=True) if wandb is not None: wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) def visualize_chain_uncertainty( path, dataset_info, wandb=None, spheres_3d=False, mode="chain"): files = load_xyz_files(path) files = sorted(files) save_paths = [] for i in range(len(files)): if i + 2 == len(files): break file = files[i] file2 = files[i + 1] file3 = files[i + 2] positions, one_hot, _ = load_molecule_xyz(file, dataset_info=dataset_info) positions2, one_hot2, _ = load_molecule_xyz( file2, dataset_info=dataset_info) positions3, one_hot3, _ = load_molecule_xyz( file3, dataset_info=dataset_info) all_positions = torch.stack([positions, positions2, positions3], dim=0) one_hot = torch.stack([one_hot, one_hot2, one_hot3], dim=0) all_atom_type = torch.argmax(one_hot, dim=2).numpy() fn = file[:-4] + '.png' plot_data3d_uncertainty( all_positions, all_atom_type, dataset_info=dataset_info, save_path=fn, spheres_3d=spheres_3d, alpha=0.5) save_paths.append(fn) imgs = [imageio.imread(fn) for fn in save_paths] dirname = os.path.dirname(save_paths[0]) gif_path = dirname + '/output.gif' print(f'Creating gif with {len(imgs)} images') # Add the last frame 10 times so that the final result remains temporally. # imgs.extend([imgs[-1]] * 10) imageio.mimsave(gif_path, imgs, subrectangles=True) if wandb is not None: wandb.log({mode: [wandb.Video(gif_path, caption=gif_path)]}) if __name__ == '__main__': # plot_grid() import qm9.dataset as dataset from configs.datasets_config import qm9_with_h, geom_with_h matplotlib.use('macosx') task = "visualize_molecules" task_dataset = 'geom' if task_dataset == 'qm9': dataset_info = qm9_with_h class Args: batch_size = 1 num_workers = 0 filter_n_atoms = None datadir = 'qm9/temp' dataset = 'qm9' remove_h = False cfg = Args() dataloaders, charge_scale = dataset.retrieve_dataloaders(cfg) for i, data in enumerate(dataloaders['train']): positions = data['positions'].view(-1, 3) positions_centered = positions - positions.mean(dim=0, keepdim=True) one_hot = data['one_hot'].view(-1, 5).type(torch.float32) atom_type = torch.argmax(one_hot, dim=1).numpy() plot_data3d( positions_centered, atom_type, dataset_info=dataset_info, spheres_3d=True) elif task_dataset == 'geom': files = load_xyz_files('outputs/data') matplotlib.use('macosx') for file in files: x, one_hot, _ = load_molecule_xyz(file, dataset_info=geom_with_h) positions = x.view(-1, 3) positions_centered = positions - positions.mean(dim=0, keepdim=True) one_hot = one_hot.view(-1, 16).type(torch.float32) atom_type = torch.argmax(one_hot, dim=1).numpy() mask = (x == 0).sum(1) != 3 positions_centered = positions_centered[mask] atom_type = atom_type[mask] plot_data3d( positions_centered, atom_type, dataset_info=geom_with_h, spheres_3d=False) else: raise ValueError(dataset)