Spaces:
Runtime error
Runtime error
| import os, sys | |
| import shutil | |
| import glob | |
| import torch | |
| import numpy as np | |
| import copy | |
| from itertools import groupby | |
| from operator import itemgetter | |
| import json | |
| import re | |
| import random | |
| import matplotlib.pyplot as plt | |
| import pandas as pd | |
| from tqdm import tqdm | |
| import random | |
| import Bio | |
| from icecream import ic | |
| DEVICE = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu') | |
| conversion = 'ARNDCQEGHILKMFPSTWYVX-' | |
| ### IF ADDING NEW POTENTIAL MAKE SURE TO ADD TO BOTTOM DICTIONARY ### | |
| # TEMPLATE CLASS | |
| class Potential: | |
| def get_gradients(seq): | |
| ''' | |
| EVERY POTENTIAL CLASS MUST RETURN GRADIENTS | |
| ''' | |
| sys.exit('ERROR POTENTIAL HAS NOT BEEN IMPLEMENTED') | |
| class AACompositionalBias(Potential): | |
| """ | |
| T = number of timesteps to set up diffuser with | |
| schedule = type of noise schedule to use linear, cosine, gaussian | |
| noise = type of ditribution to sample from; DEFAULT - normal_gaussian | |
| """ | |
| def __init__(self, args, features, potential_scale, DEVICE): | |
| self.L = features['L'] | |
| self.DEVICE = DEVICE | |
| self.frac_seq_to_weight = args['frac_seq_to_weight'] | |
| self.add_weight_every_n = args['add_weight_every_n'] | |
| self.aa_weights_json = args['aa_weights_json'] | |
| self.one_weight_per_position = args['one_weight_per_position'] | |
| self.aa_weight = args['aa_weight'] | |
| self.aa_spec = args['aa_spec'] | |
| self.aa_composition = args['aa_composition'] | |
| self.potential_scale = potential_scale | |
| self.aa_weights_to_add = [0 for l in range(21)] | |
| self.aa_max_potential = None | |
| if self.aa_weights_json != None: | |
| with open(self.aa_weights_json, 'r') as f: | |
| aa_weights = json.load(f) | |
| else: | |
| aa_weights = {} | |
| for k,v in aa_weights.items(): | |
| aa_weights_to_add[conversion.index(k)] = v | |
| aa_weights_to_add = [0 for l in range(21)] | |
| self.aa_weights_to_add = torch.tensor(aa_weights_to_add)[None].repeat(self.L,1).to(self.DEVICE, non_blocking=True) | |
| # BLOCK TO FIND OUT HOW YOU ARE LOOKING TO PROVIDE AA COMPOSITIONAL BIAS | |
| if self.add_weight_every_n > 1 or self.frac_seq_to_weight > 0: | |
| assert (self.add_weight_every_n > 1) ^ (self.frac_seq_to_weight > 0), 'use either --add_weight_every_n or --frac_seq_to_weight but not both' | |
| weight_mask = torch.zeros_like(self.aa_weights_to_add) | |
| if add_weight_every_n > 1: | |
| idxs_to_unmask = torch.arange(0,self.L,self.add_weight_every_n) | |
| else: | |
| indexs = np.arange(0,self.L).tolist() | |
| idxs_to_unmask = random.sample(indexs,int(self.frac_seq_to_weight*self.L)) | |
| idxs_to_unmask.sort() | |
| weight_mask[idxs_to_unmask,:] = 1 | |
| self.aa_weights_to_add *= weight_mask | |
| if one_weight_per_position: | |
| for p in range(self.aa_weights_to_add.shape[0]): | |
| where_ones = torch.where(self.aa_weights_to_add[p,:] > 0)[0].tolist() | |
| if len(where_ones) > 0: | |
| w_sample = random.sample(where_ones,1)[0] | |
| self.aa_weights_to_add[p,:w_sample] = 0 | |
| self.aa_weights_to_add[p,w_sample+1:] = 0 | |
| elif self.aa_spec != None: | |
| assert self.aa_weight != None, 'please specify --aa_weight' | |
| # Use specified repeat structure to bias sequence | |
| repeat_len = len(self.aa_spec) | |
| weight_split = [float(x) for x in self.aa_weight.split(',')] | |
| aa_idxs = [] | |
| for k,c in enumerate(self.aa_spec): | |
| if c != 'X': | |
| assert c in conversion, f'the letter you have chosen is not an amino acid: {c}' | |
| aa_idxs.append((k,conversion.index(c))) | |
| if len(self.aa_weight) > 1: | |
| assert len(aa_idxs) == len(weight_split), f'need to give same number of weights as AAs in weight spec' | |
| self.aa_weights_to_add = torch.zeros(self.L,21) | |
| for p,w in zip(aa_idxs,weight_split): | |
| x,a = p | |
| self.aa_weights_to_add[x,a] = w | |
| self.aa_weights_to_add = self.aa_weights_to_add[:repeat_len,:].repeat(self.L//repeat_len+1,1)[:self.L].to(self.DEVICE, non_blocking=True) | |
| elif self.aa_composition != None: | |
| self.aa_comp = [(x[0],float(x[1:])) for x in self.aa_composition.split(',')] | |
| self.aa_max_potential = 0 #just a place holder so not None | |
| assert sum([f for aa,f in self.aa_comp]) <= 1, f'total sequence fraction specified in aa_composition is > 1' | |
| else: | |
| sys.exit(f'You are missing an argument to use the aa_bias potential') | |
| def get_gradients(self, seq): | |
| ''' | |
| seq = L,21 | |
| return gradients to update the sequence with for the next pass | |
| ''' | |
| if self.aa_max_potential != None: | |
| soft_seq = torch.softmax(seq, dim=1) | |
| print('ADDING SOFTMAXED SEQUENCE POTENTIAL') | |
| aa_weights_to_add_list = [] | |
| for aa,f in self.aa_comp: | |
| aa_weights_to_add_copy = self.aa_weights_to_add.clone() | |
| soft_seq_tmp = soft_seq.clone().detach().requires_grad_(True) | |
| aa_idx = conversion.index(aa) | |
| # get top-k probability of logit to add to | |
| where_add = torch.topk(soft_seq_tmp[:,aa_idx], int(f*self.L))[1] | |
| # set up aa_potenital | |
| aa_potential = torch.zeros(21) | |
| aa_potential[conversion.index(aa)] = 1.0 | |
| aa_potential = aa_potential.repeat(self.L,1).to(self.DEVICE, non_blocking=True) | |
| # apply "loss" | |
| aa_comp_loss = torch.sum(torch.sum((aa_potential - soft_seq_tmp)**2, dim=1)**0.5) | |
| # get gradients | |
| aa_comp_loss.backward() | |
| update_grads = soft_seq_tmp.grad | |
| for k in range(self.L): | |
| if k in where_add: | |
| aa_weights_to_add_copy[k,:] = -update_grads[k,:]*self.potential_scale | |
| else: | |
| aa_weights_to_add_copy[k,:] = update_grads[k,:]*self.potential_scale | |
| aa_weights_to_add_list.append(aa_weights_to_add_copy) | |
| aa_weights_to_add_array = torch.stack((aa_weights_to_add_list)) | |
| self.aa_weights_to_add = torch.mean(aa_weights_to_add_array.float(), 0) | |
| return self.aa_weights_to_add | |
| class HydrophobicBias(Potential): | |
| """ | |
| Calculate loss with respect to soft_seq of the sequence hydropathy index (Kyte and Doolittle, 1986). | |
| T = number of timesteps to set up diffuser with | |
| schedule = type of noise schedule to use linear, cosine, gaussian | |
| noise = type of ditribution to sample from; DEFAULT - normal_gaussian | |
| """ | |
| def __init__(self, args, features, potential_scale, DEVICE): | |
| self.target_score = args['hydrophobic_score'] | |
| self.potential_scale = potential_scale | |
| self.loss_type = args['hydrophobic_loss_type'] | |
| print(f'USING {self.loss_type} LOSS TYPE...') | |
| # ----------------------------------------------------------------------- | |
| # ---------------------GRAVY index data structures----------------------- | |
| # ----------------------------------------------------------------------- | |
| # AA conversion | |
| self.alpha_1 = list("ARNDCQEGHILKMFPSTWYVX") | |
| # Dictionary to convert amino acids to their hyropathy index | |
| self.gravy_dict = {'C': 2.5, 'D': -3.5, 'S': -0.8, 'Q': -3.5, 'K': -3.9, | |
| 'I': 4.5, 'P': -1.6, 'T': -0.7, 'F': 2.8, 'N': -3.5, | |
| 'G': -0.4, 'H': -3.2, 'L': 3.8, 'R': -4.5, 'W': -0.9, | |
| 'A': 1.8, 'V':4.2, 'E': -3.5, 'Y': -1.3, 'M': 1.9, 'X': 0, '-': 0} | |
| self.gravy_list = [self.gravy_dict[a] for a in self.alpha_1] | |
| # ----------------------------------------------------------------------- | |
| # ----------------------------------------------------------------------- | |
| print(f'GUIDING SEQUENCES TO HAVE TARGET GRAVY SCORE OF: {self.target_score}') | |
| return None | |
| def get_gradients(self, seq): | |
| """ | |
| Calculate gradients with respect to GRAVY index of input seq. | |
| Uses a MSE loss. | |
| Arguments | |
| --------- | |
| seq : tensor | |
| L X 21 logits after saving seq_out from xt | |
| Returns | |
| ------- | |
| gradients : list of tensors | |
| gradients of soft_seq with respect to loss on partial_charge | |
| """ | |
| # Get GRAVY matrix based on length of seq | |
| gravy_matrix = torch.tensor(self.gravy_list)[None].repeat(seq.shape[0],1).to(DEVICE) | |
| # Get softmax of seq | |
| soft_seq = torch.softmax(seq,dim=-1).requires_grad_(requires_grad=True).to(DEVICE) | |
| # Calculate simple MSE loss on gravy_score | |
| if self.loss_type == 'simple': | |
| gravy_score = torch.mean(torch.sum(soft_seq*gravy_matrix,dim=-1), dim=0) | |
| loss = ((gravy_score - self.target_score)**2)**0.5 | |
| #print(f'LOSS: {loss}') | |
| # Take backward step | |
| loss.backward() | |
| # Get gradients from soft_seq | |
| self.gradients = soft_seq.grad | |
| # plt.imshow(self.gradients.cpu().detach().numpy()) | |
| # plt.colorbar() | |
| # plt.title('gradients') | |
| # Calculate MSE loss on gravy_score | |
| elif self.loss_type == 'complex': | |
| loss = torch.mean((torch.sum(soft_seq*gravy_matrix, dim = -1) - self.target_score)**2) | |
| #print(f'LOSS: {loss}') | |
| # Take backward step | |
| loss.backward() | |
| # Get gradients from soft_seq | |
| self.gradients = soft_seq.grad | |
| # plt.imshow(self.gradients.cpu().detach().numpy()) | |
| # plt.colorbar() | |
| # plt.title('gradients') | |
| return -self.gradients*self.potential_scale | |
| class ChargeBias(Potential): | |
| """ | |
| Calculate losses and get gradients with respect to soft_seq for the sequence charge at a given pH. | |
| T = number of timesteps to set up diffuser with | |
| schedule = type of noise schedule to use linear, cosine, gaussian | |
| noise = type of ditribution to sample from; DEFAULT - normal_gaussian | |
| """ | |
| def __init__(self, args, features, potential_scale, DEVICE): | |
| self.target_charge = args['target_charge'] | |
| self.pH = args['target_pH'] | |
| self.loss_type = args['charge_loss_type'] | |
| self.potential_scale = potential_scale | |
| self.L = features['L'] | |
| self.DEVICE = DEVICE | |
| # ----------------------------------------------------------------------- | |
| # ------------------------pI data structures----------------------------- | |
| # ----------------------------------------------------------------------- | |
| # pKa lists to account for every residue. | |
| pos_pKs_list = [[0.0, 12.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 5.98, 0.0, 0.0, 10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] | |
| neg_pKs_list = [[0.0, 0.0, 0.0, 4.05, 9.0, 0.0, 4.45, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 10.0, 0.0, 0.0]] | |
| cterm_pKs_list = [[0.0, 0.0, 0.0, 4.55, 0.0, 0.0, 4.75, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]] | |
| nterm_pKs_list = [[7.59, 0.0, 0.0, 0.0, 0.0, 0.0, 7.7, 0.0, 0.0, 0.0, 0.0, 0.0, 7.0, 0.0, 8.36, 6.93, 6.82, 0.0, 0.0, 7.44, 0.0]] | |
| # Convert pKa lists to tensors | |
| self.cterm_pKs = torch.tensor(cterm_pKs_list) | |
| self.nterm_pKs = torch.tensor(nterm_pKs_list) | |
| self.pos_pKs = torch.tensor(pos_pKs_list) | |
| self.neg_pKs = torch.tensor(neg_pKs_list) | |
| # Repeat charged pKs L - 2 times to populate in all non-terminal residue indices | |
| pos_pKs_repeat = self.pos_pKs.repeat(self.L - 2, 1) | |
| neg_pKs_repeat = self.neg_pKs.repeat(self.L - 2, 1) | |
| # Concatenate all pKs tensors with N-term and C-term pKas to get full L X 21 charge matrix | |
| self.pos_pKs_matrix = torch.cat((torch.zeros_like(self.nterm_pKs), pos_pKs_repeat, self.nterm_pKs)).to(DEVICE) | |
| self.neg_pKs_matrix = torch.cat((self.cterm_pKs, neg_pKs_repeat, torch.zeros_like(self.cterm_pKs))).to(DEVICE) | |
| # Get indices of positive, neutral, and negative residues | |
| self.cterm_charged_idx = torch.nonzero(self.cterm_pKs) | |
| self.cterm_neutral_idx = torch.nonzero(self.cterm_pKs == 0) | |
| self.nterm_charged_idx = torch.nonzero(self.nterm_pKs) | |
| self.nterm_neutral_idx = torch.nonzero(self.nterm_pKs == 0) | |
| self.pos_pKs_idx = torch.tensor([[1, 8, 11]]) | |
| self.neg_pKs_idx = torch.tensor([[3, 4, 6, 18]]) | |
| self.neutral_pKs_idx = torch.tensor([[0, 2, 5, 7, 9, 10, 12, 13, 14, 15, 16, 17, 19, 20]]) | |
| # ----------------------------------------------------------------------- | |
| # ----------------------------------------------------------------------- | |
| print(f'OPTIMIZING SEQUENCE TO HAVE CHARGE = {self.target_charge}\nAT pH = {self.pH}' ) | |
| def sum_tensor_indices(self, indices, tensor): | |
| total = 0 | |
| for idx in indices: | |
| i, j = idx[0], idx[1] | |
| total += tensor[i][j] | |
| return total | |
| def sum_tensor_indices_2(self, indices, tensor): | |
| # Create a tensor with the appropriate dimensions | |
| j = indices.clone().detach().long().to(self.DEVICE) | |
| # Select the values using advanced indexing and sum along dim=-1 | |
| row_sums = tensor[:, j].sum(dim=-1) | |
| # Reshape the result to an L x 1 tensor | |
| return row_sums.reshape(-1, 1).clone().detach() | |
| def make_table(self, L): | |
| """ | |
| Make table of all (positive, neutral, negative) charges -> (i, j, k) | |
| such that: | |
| i + j + k = L | |
| (1 * i) + (0 * j) + (-1 * k) = target_charge | |
| Arguments: | |
| L: int | |
| - length of sequence, defined as seq.shape[0] | |
| target_charge : float | |
| - Target charge for the sequence to be guided towards | |
| Returns: | |
| table: N x 3 tensor | |
| - All combinations of i, j, k such that the above conditions are satisfied | |
| """ | |
| table = [] | |
| for i in range(L): | |
| for j in range(L): | |
| for k in range(L): | |
| # Check that number of residues = L and that sum of charge (i - k) = target_charge | |
| # and that there are no 0 entries, as having no pos, no neg, or no neutral is not realistic | |
| if i+j+k == L and i-k == self.target_charge and i != 0 and j != 0 and k != 0: | |
| table.append([i,j,k]) | |
| return torch.tensor(np.array(table)) | |
| def classify_resis(self, seq): | |
| """ | |
| Classify each position in seq as either positive, neutral, or negative. | |
| Classification = max( [sum(positive residue logits), sum(neutral residue logits), sum(negative residue logits)] ) | |
| Arguments: | |
| seq: L x 21 tensor | |
| - sequence logits from the model | |
| Returns: | |
| charges: tensor | |
| - 1 x 3 tensor counting total # of each charge type in the input sequence | |
| - charges[0] = # positive residues | |
| - charges[1] = # neutral residues | |
| - charges[2] = # negative residues | |
| charge_classification: tensor | |
| - L x 1 tensor of each position's classification. 1 is positive, 0 is neutral, -1 is negative | |
| """ | |
| L = seq.shape[0] | |
| # Get softmax of seq | |
| soft_seq = torch.softmax(seq.clone(),dim=-1).requires_grad_(requires_grad=True).to(self.DEVICE) | |
| # Sum the softmax of all the positive and negative charges along dim = -1 (21 amino acids): | |
| # Sum across c-term pKs | |
| sum_cterm_charged = self.sum_tensor_indices(self.cterm_charged_idx, soft_seq).item() | |
| # print(f'SUM OF CTERM CHARGED RESIS: {sum_cterm_charged}') | |
| # print(type(sum_cterm_charged.item())) | |
| sum_cterm_neutral = self.sum_tensor_indices(self.cterm_neutral_idx, soft_seq).item() | |
| # print(f'SUM OF CTERM NEUTRAL RESIS: {sum_cterm_neutral}') | |
| # Classify c-term as negative or neutral | |
| cterm_max = max(sum_cterm_charged, sum_cterm_neutral) | |
| # print(f'CTERM MAX: {cterm_max}') | |
| if cterm_max == sum_cterm_charged: | |
| cterm_class = torch.tensor([[-1]]).to(self.DEVICE) | |
| else: | |
| cterm_class = torch.tensor([[0]]).to(self.DEVICE) | |
| # Prep cterm dataframe | |
| cterm_df = torch.tensor([[0, sum_cterm_neutral, sum_cterm_charged, cterm_max, cterm_class]]).to(self.DEVICE) | |
| # Sum across positive, neutral, and negative pKs | |
| sum_pos = self.sum_tensor_indices_2(self.pos_pKs_idx, soft_seq[1:L-1, ...]).to(self.DEVICE) | |
| # print(f'SUM POS: {sum_pos}') | |
| sum_neg = self.sum_tensor_indices_2(self.neg_pKs_idx, soft_seq[1:L-1, ...]).to(self.DEVICE) | |
| # print(f'SUM NEG: {sum_neg}') | |
| sum_neutral = self.sum_tensor_indices_2(self.neutral_pKs_idx, soft_seq[1:L-1, ...]).to(self.DEVICE) | |
| # print(f'SUM NEUTRAL: {sum_neutral}') | |
| # Classify non-terminal residues along dim = -1 | |
| middle_max, _ = torch.max(torch.stack((sum_pos, sum_neg, sum_neutral), dim=-1), dim=-1) | |
| middle_max = middle_max.to(self.DEVICE) | |
| # create an L x 1 tensor to store the result | |
| middle_class = torch.zeros((L - 2, 1), dtype=torch.long).to(self.DEVICE) | |
| # set the values of the result tensor based on which tensor had the maximum value | |
| middle_class[sum_neg == middle_max] = -1 | |
| middle_class[sum_neutral == middle_max] = 0 | |
| middle_class[sum_pos == middle_max] = 1 | |
| # Prepare df of all middle residue classifications and corresponding values | |
| middle_df = pd.DataFrame((torch.cat((sum_pos, sum_neutral, sum_neg, middle_max, middle_class), dim=-1)).detach().cpu().numpy()) | |
| middle_df.rename(columns={0: 'sum_pos', | |
| 1: 'sum_neutral', 2: 'sum_neg', 3: 'middle_max', 4: 'middle_classified'}, | |
| inplace=True, errors='raise') | |
| # Sum across n-term pKs | |
| sum_nterm_charged = self.sum_tensor_indices(self.nterm_charged_idx, soft_seq).to(self.DEVICE) | |
| # print(f'SUM OF NTERM CHARGED RESIS: {sum_nterm_charged}') | |
| sum_nterm_neutral = self.sum_tensor_indices(self.nterm_neutral_idx, soft_seq).to(self.DEVICE) | |
| # print(f'SUM OF NTERM NEUTRAL RESIS: {sum_nterm_neutral}') | |
| # Classify n-term as negative or neutral | |
| nterm_max = max(sum_nterm_charged, sum_nterm_neutral) | |
| if nterm_max == sum_nterm_charged: | |
| nterm_class = torch.tensor([[-1]]).to(self.DEVICE) | |
| else: | |
| nterm_class = torch.tensor([[0]]).to(self.DEVICE) | |
| nterm_df = torch.tensor([[sum_nterm_charged, sum_nterm_neutral, 0, nterm_max, nterm_class]]).to(self.DEVICE) | |
| # Prep data to be concatenated into output df | |
| middle_df_2 = (torch.cat((sum_pos, sum_neutral, sum_neg, middle_max, middle_class), dim=-1)).to(self.DEVICE) | |
| # Concat cterm, middle, and nterm data into one master df with all summed probs, max, and final classification | |
| full_tens_np = torch.cat((cterm_df, middle_df_2, nterm_df), dim = 0).detach().cpu().numpy() | |
| classification_df = pd.DataFrame(full_tens_np) | |
| classification_df.rename(columns={0: 'sum_pos', | |
| 1: 'sum_neutral', 2: 'sum_neg', 3: 'max', 4: 'classification'}, | |
| inplace=True, errors='raise') | |
| # Count number of positive, neutral, and negative charges that are stored in charge_classification as 1, 0, -1 respectively | |
| charge_classification = torch.cat((cterm_class, middle_class, nterm_class), dim = 0).to(self.DEVICE) | |
| charges = [torch.sum(charge_classification == 1).item(), torch.sum(charge_classification == 0).item(), torch.sum(charge_classification == -1).item()] | |
| # print('*'*100) | |
| # print(classification_df) | |
| return torch.tensor(charges), classification_df | |
| def get_target_charge_ratios(self, table, charges): | |
| """ | |
| Find closest distance between x, y, z in table and i, j, k in charges | |
| Arguments: | |
| table: N x 3 tensor of all combinations of positive, neutral, and negative charges that obey the conditions in make_table | |
| charges: 1 x 3 tensor | |
| - 1 x 3 tensor counting total # of each charge type in the input sequence | |
| - charges[0] = # positive residues | |
| - charges[1] = # neutral residues | |
| - charges[2] = # negative residues | |
| Returns: | |
| target_charge_tensor: tensor | |
| - 1 x 3 tensor of closest row in table that matches charges of input sequence | |
| """ | |
| # Compute the difference between the charges and each row of the table | |
| diff = table - charges | |
| # Compute the square of the Euclidean distance between the charges and each row of the table | |
| sq_distance = torch.sum(diff ** 2, dim=-1) | |
| # Find the index of the row with the smallest distance | |
| min_idx = torch.argmin(sq_distance) | |
| # Return the smallest distance and the corresponding row of the table | |
| target_charge_tensor = torch.sqrt(sq_distance[min_idx]), table[min_idx] | |
| #print(f'CLOSEST COMBINATION OF VALID RESIDUES: {target_charge_tensor[1]}') | |
| return target_charge_tensor[1] | |
| def draft_resis(self, classification_df, target_charge_tensor): | |
| """ | |
| Based on target_charge_tensor, draft the top (i, j, k) positive, neutral, and negative positions from | |
| charge_classification and return the idealized guided_charge_classification. | |
| guided_charge_classification will determine whether the gradients should be positive or negative | |
| Draft pick algorithm for determining gradient guided_charge_classification: | |
| 1) Define how many positive, negative, and neutral charges are needed | |
| 2) Current charge being drafted = sign of target charge, otherwise opposite charge | |
| 3) From the classification_df of the currently sampled sequence, choose the position with the highest probability of being current_charge | |
| 4) Make that residue +1, 0, or -1 in guided_charge_classification to dictate the sign of gradients | |
| 5) Keep drafting that residue charge until it is used up, then move to the next type | |
| Arguments: | |
| classification_df: tensor | |
| - L x 1 tensor of each position's classification. 1 is positive, 0 is neutral, -1 is negative | |
| target_charge_tensor: tensor | |
| - 1 x 3 tensor of closest row in table that matches charges of input sequence | |
| Returns: | |
| guided_charge_classification: L x 1 tensor | |
| - L x 1 tensor populated with 1 = positive, 0 = neutral, -1 = negative | |
| - in get_gradients, multiply the gradients by guided_charge_classification to determine which direction | |
| the gradients should guide toward based on the current sequence distribution and the target charge | |
| """ | |
| charge_dict = {'pos': 0, 'neutral': 0, 'neg': 0} | |
| # Define the target number of positive, neutral, and negative charges | |
| charge_dict['pos'] = target_charge_tensor[0].detach().clone() | |
| charge_dict['neutral'] = target_charge_tensor[1].detach().clone() | |
| charge_dict['neg'] = target_charge_tensor[2].detach().clone() | |
| # Determine which charge to start drafting | |
| if self.target_charge > 0: | |
| start_charge = 'pos' | |
| elif self.target_charge < 0: | |
| start_charge = 'neg' | |
| else: | |
| start_charge = 'neutral' | |
| # Initialize guided_charge_classification | |
| guided_charge_classification = torch.zeros((classification_df.shape[0], 1)) | |
| # Start drafting | |
| draft_charge = start_charge | |
| while charge_dict[draft_charge] > 0: | |
| # Find the residue with the max probability for the current draft charge | |
| max_residue_idx = classification_df.loc[:, ['sum_' + draft_charge]].idxmax()[0] | |
| # print(max_residue_idx[0]) | |
| # print(type(max_residue_idx)) | |
| #print(f'MAX RESIDUE INDEX for {draft_charge}: {max_residue_idx}') | |
| # Populate guided_charge_classification with the appropriate charge | |
| if draft_charge == 'pos': | |
| guided_charge_classification[max_residue_idx] = 1 | |
| elif draft_charge == 'neg': | |
| guided_charge_classification[max_residue_idx] = -1 | |
| else: | |
| guided_charge_classification[max_residue_idx] = 0 | |
| # Remove selected row from classification_df | |
| classification_df = classification_df.drop(max_residue_idx) | |
| # print(classification_df) | |
| # Update charges dictionary | |
| charge_dict[draft_charge] -= 1 | |
| #print(f'{charge_dict[draft_charge]} {draft_charge} residues left to draft...') | |
| # Switch to the other charged residue if the starting charge has been depleted | |
| if charge_dict[draft_charge] == 0: | |
| if draft_charge == start_charge: | |
| draft_charge = 'neg' if start_charge == 'pos' else 'pos' | |
| elif draft_charge == 'neg': | |
| draft_charge = 'pos' | |
| elif draft_charge == 'pos': | |
| draft_charge = 'neg' | |
| else: | |
| draft_charge = 'neutral' | |
| return guided_charge_classification.requires_grad_() | |
| def get_gradients(self, seq):#, guided_charge_classification): | |
| """ | |
| Calculate gradients with respect to SEQUENCE CHARGE at pH. | |
| Uses a MSE loss. | |
| Arguments | |
| --------- | |
| seq : tensor | |
| L X 21 logits after saving seq_out from xt | |
| Returns | |
| ------- | |
| gradients : list of tensors | |
| gradients of soft_seq with respect to loss on partial_charge | |
| """ | |
| # Get softmax of seq | |
| # soft_seq = torch.softmax(seq.clone(),dim=-1).requires_grad_(requires_grad=True).to(DEVICE) | |
| soft_seq = torch.softmax(seq,dim=-1).requires_grad_(requires_grad=True).to(DEVICE) | |
| # Get partial positive charges only for titratable residues | |
| pos_charge = torch.where(self.pos_pKs_matrix != 0, ((1) / (((10) ** ((self.pH) - self.pos_pKs_matrix)) + (1.0))), (0.0)) | |
| neg_charge = torch.where(self.neg_pKs_matrix != 0, ((1) / (((10) ** (self.neg_pKs_matrix - (self.pH))) + (1.0))), (0.0)) | |
| # partial_charge = torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True)) | |
| if self.loss_type == 'simple': | |
| # Calculate net partial charge of soft_seq | |
| partial_charge = torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True)) | |
| print(f'CURRENT PARTIAL CHARGE: {partial_charge.item()}') | |
| # Calculate MSE loss on partial_charge | |
| loss = ((partial_charge - self.target_charge)**2)**0.5 | |
| #print(f'LOSS: {loss}') | |
| # Take backward step | |
| loss.backward() | |
| # Get gradients from soft_seq | |
| self.gradients = soft_seq.grad | |
| # plt.imshow(self.gradients) | |
| # plt.colorbar() | |
| # plt.title('gradients') | |
| elif self.loss_type == 'simple2': | |
| # Calculate net partial charge of soft_seq | |
| # partial_charge = torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True)) | |
| print(f'CURRENT PARTIAL CHARGE: {partial_charge.item()}') | |
| # Calculate MSE loss on partial_charge | |
| loss = (((torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True))) | |
| - self.target_charge)**2)**0.5 | |
| #print(f'LOSS: {loss}') | |
| # Take backward step | |
| loss.backward() | |
| # Get gradients from soft_seq | |
| self.gradients = soft_seq.grad | |
| # plt.imshow(self.gradients) | |
| # plt.colorbar() | |
| # plt.title('gradients') | |
| elif self.loss_type == 'complex': | |
| # Preprocessing using method functions | |
| table = self.make_table(seq.shape[0]) | |
| charges, classification_df = self.classify_resis(seq) | |
| target_charge_tensor = self.get_target_charge_ratios(table, charges) | |
| guided_charge_classification = self.draft_resis(classification_df, target_charge_tensor) | |
| # Calculate net partial charge of soft_seq | |
| soft_partial_charge = (soft_seq*(pos_charge - neg_charge)) | |
| # print(f'SOFT PARTIAL CHARGE SHAPE: {soft_partial_charge.shape}') | |
| # Define partial charge as the sum of softmax * partial charge matrix | |
| partial_charge = torch.sum(soft_partial_charge, dim=-1).requires_grad_() | |
| #print(partial_charge) | |
| # partial_charge = torch.sum((soft_seq*(pos_charge - neg_charge)).requires_grad_(requires_grad=True)) | |
| print(f'CURRENT PARTIAL CHARGE: {partial_charge.sum().item()}') | |
| # print(f'DIFFERENCE BETWEEN TARGET CHARGES AND CURRENT CHARGES: {((guided_charge_classification.to(self.DEVICE) - partial_charge.unsqueeze(1).to(self.DEVICE))**2)**0.5}') | |
| # Calculate loss on partial_charge | |
| loss = torch.mean(((guided_charge_classification.to(self.DEVICE) - partial_charge.unsqueeze(1).to(self.DEVICE))**2)**0.5) | |
| # loss = torch.mean((guided_charge_classification.to(self.DEVICE) - partial_charge.to(self.DEVICE))**2) | |
| #print(f'LOSS: {loss}') | |
| # Take backward step | |
| loss.backward() | |
| # Get gradients from soft_seq | |
| self.gradients = soft_seq.grad | |
| # print(f'GUIDED CHARGE CLASSIFICATION SHAPE: {guided_charge_classification.shape}') | |
| # print(f'PARTIAL CHARGE SHAPE: {partial_charge.unsqueeze(1).shape}') | |
| # print(partial_charge) | |
| # fig, ax = plt.subplots(1,2, dpi=200) | |
| # ax[0].imshow((partial_charge.unsqueeze(1)).detach().numpy()) | |
| # ax[0].set_title('soft_seq partial charge') | |
| # ax[1].imshow(self.gradients)#.detach().numpy()) | |
| # ax[1].set_title('gradients') | |
| # print(seq) | |
| return -self.gradients*self.potential_scale | |
| class PSSMbias(Potential): | |
| def __init__(self, args, features, potential_scale, DEVICE): | |
| self.features = features | |
| self.args = args | |
| self.potential_scale = potential_scale | |
| self.DEVICE = DEVICE | |
| self.PSSM = np.loadtxt(args['PSSM'], delimiter=",", dtype=float) | |
| self.PSSM = torch.from_numpy(self.PSSM).to(self.DEVICE) | |
| def get_gradients(self, seq): | |
| print(seq.shape) | |
| return self.PSSM*self.potential_scale | |
| ### ADD NEW POTENTIALS INTO LIST DOWN BELOW ### | |
| POTENTIALS = {'aa_bias':AACompositionalBias, 'charge':ChargeBias, 'hydrophobic':HydrophobicBias, 'PSSM':PSSMbias} | |