Atom Bioworks
commited on
Create mcts.py
Browse files
mcts.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import timeit
|
| 3 |
+
import torch
|
| 4 |
+
from utils import rna2vec
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
#Node
|
| 8 |
+
class Node:
|
| 9 |
+
#init
|
| 10 |
+
def __init__(self, letter="", parent=None, root=False, last=False, depth=0, states=8):
|
| 11 |
+
self.exploitation_score = 0 # Exploitaion score
|
| 12 |
+
self.visits = 1 #How many visits
|
| 13 |
+
self.letter = letter #This node's letter
|
| 14 |
+
self.parent = parent #This node's parent node
|
| 15 |
+
self.states = states #How many states in node
|
| 16 |
+
self.children = np.array([None for _ in range(self.states)]) #This node's children
|
| 17 |
+
self.children_stat = np.zeros(self.states, dtype=bool) #Which stat are expanded
|
| 18 |
+
self.root = root # Is root? boolean
|
| 19 |
+
self.last = last # Is last node?
|
| 20 |
+
self.depth = depth # My depth
|
| 21 |
+
self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"]
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
#next_node
|
| 25 |
+
def next_node(self, child=0): #Return next node
|
| 26 |
+
assert self.children_stat[child] == True, "No child in here."
|
| 27 |
+
|
| 28 |
+
return self.children[child]
|
| 29 |
+
|
| 30 |
+
#back_parent
|
| 31 |
+
def back_parent(self): #Go back to parent
|
| 32 |
+
return self.parent, letters_map[self.letter]
|
| 33 |
+
|
| 34 |
+
#generate_child
|
| 35 |
+
def generate_child(self, child=0, last=False): #Generate child
|
| 36 |
+
assert self.children_stat[child] == False, "Already tree generated child at here"
|
| 37 |
+
|
| 38 |
+
self.children[child] = Node(letter=self.letters[child], parent=self, last=last, depth=self.depth+1, states=self.states) #New node
|
| 39 |
+
self.children_stat[child] = True #Stat = True
|
| 40 |
+
|
| 41 |
+
return self.children[child]
|
| 42 |
+
|
| 43 |
+
#backpropagation
|
| 44 |
+
def backpropagation(self, score=0):
|
| 45 |
+
self.visits += 1 # +1 to visit
|
| 46 |
+
if self.root == True: # if root, then stop
|
| 47 |
+
return self.exploitation_score
|
| 48 |
+
|
| 49 |
+
else:
|
| 50 |
+
self.exploitation_score += score #Add score to exploitation score
|
| 51 |
+
return self.parent.backpropagation(score=score) #Backpropagation to parent node
|
| 52 |
+
|
| 53 |
+
#UCT
|
| 54 |
+
def UCT(self):
|
| 55 |
+
return (self.exploitation_score / self.visits) + np.sqrt(np.log(self.parent.visits) / (2 * self.visits)) #UCT score
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
#MCTS
|
| 59 |
+
class MCTS:
|
| 60 |
+
def __init__(self, target_encoded, depth=20, iteration=1000, states=8, target_protein="", device='cpu', esm_alphabet=None):
|
| 61 |
+
self.states = states #How many states
|
| 62 |
+
self.root = Node(letter="", parent=None, root=True, last=False, states=self.states) #root node
|
| 63 |
+
self.depth = depth #Maximum depth
|
| 64 |
+
self.iteration = iteration #iteration for expand
|
| 65 |
+
self.target_protein = target_protein #target protein's amino acid sequence
|
| 66 |
+
self.device = device
|
| 67 |
+
self.encoded_targetprotein = target_encoded
|
| 68 |
+
self.base = ""
|
| 69 |
+
self.candidate = ""
|
| 70 |
+
self.letters =["A_", "C_", "G_", "T_", "_A", "_C", "_G", "_T"]
|
| 71 |
+
self.esm_alphabet = esm_alphabet
|
| 72 |
+
self.nt_tokenizer = AutoTokenizer.from_pretrained("InstaDeepAI/nucleotide-transformer-v2-50m-multi-species", trust_remote_code=True)
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def make_candidate(self, classifier):
|
| 76 |
+
now = self.root
|
| 77 |
+
n = 0 # rounds
|
| 78 |
+
start_time = timeit.default_timer() #timer start
|
| 79 |
+
|
| 80 |
+
while len(self.base) < self.depth * 2: #If now is last node, then stop
|
| 81 |
+
n += 1
|
| 82 |
+
print(n, "round start!!!")
|
| 83 |
+
for _ in range(self.iteration):
|
| 84 |
+
now = self.select(classifier, now=now) #Select & Expand
|
| 85 |
+
|
| 86 |
+
terminate_time = timeit.default_timer()
|
| 87 |
+
time = terminate_time-start_time
|
| 88 |
+
|
| 89 |
+
base = self.find_best_subsequence() #Find best subsequence
|
| 90 |
+
self.base = base
|
| 91 |
+
|
| 92 |
+
# print("best subsequence:", base)
|
| 93 |
+
# print("Depth:", int(len(base)/2))
|
| 94 |
+
# print("%02d:%02d:%2f" % ((time//3600), (time//60)%60, time%60))
|
| 95 |
+
# print("=" * 80)
|
| 96 |
+
|
| 97 |
+
self.root = Node(letter="", parent=None, root=True, last=False, states=self.states, depth=len(self.base)/2)
|
| 98 |
+
now = self.root
|
| 99 |
+
|
| 100 |
+
self.candidate = self.base
|
| 101 |
+
|
| 102 |
+
return self.candidate
|
| 103 |
+
|
| 104 |
+
#selection
|
| 105 |
+
def select(self, classifier, now=None):
|
| 106 |
+
if now.depth == self.depth: #If last node, then stop
|
| 107 |
+
return self.root
|
| 108 |
+
|
| 109 |
+
next_node = 0
|
| 110 |
+
if np.sum(now.children_stat) == self.states: #If every child is expanded, then go to best child
|
| 111 |
+
best = 0
|
| 112 |
+
for i in range(self.states):
|
| 113 |
+
if best < now.children[i].UCT():
|
| 114 |
+
next_node = i
|
| 115 |
+
best = now.children[i].UCT()
|
| 116 |
+
|
| 117 |
+
else: #If not, then random
|
| 118 |
+
next_node = np.random.randint(0, self.states)
|
| 119 |
+
if now.children_stat[next_node] == False: #If selected child is not expanded, then expand and simulate
|
| 120 |
+
next_node = self.expand(classifier, child=next_node, now=now)
|
| 121 |
+
|
| 122 |
+
return self.root #start iteration at this node
|
| 123 |
+
|
| 124 |
+
return now.next_node(child=next_node)
|
| 125 |
+
|
| 126 |
+
#expand
|
| 127 |
+
def expand(self, classifier, child=None, now=None):
|
| 128 |
+
last = False
|
| 129 |
+
if now.depth == (self.depth-1): #If depth of this node is maximum depth -1, then next node is last
|
| 130 |
+
last = True
|
| 131 |
+
|
| 132 |
+
expanded_node = now.generate_child(child=child, last=last) #Expand
|
| 133 |
+
|
| 134 |
+
score = self.simulate(classifier, target=expanded_node) #Simulate
|
| 135 |
+
expanded_node.backpropagation(score=score) #Backporpagation
|
| 136 |
+
|
| 137 |
+
return child
|
| 138 |
+
|
| 139 |
+
#simulate
|
| 140 |
+
def simulate(self, classifier, target=None):
|
| 141 |
+
now = target #Target node
|
| 142 |
+
sim_seq = ""
|
| 143 |
+
|
| 144 |
+
while now.root != True: #Parent's letters
|
| 145 |
+
sim_seq = now.letter + sim_seq
|
| 146 |
+
now = now.parent
|
| 147 |
+
|
| 148 |
+
sim_seq = self.base + sim_seq
|
| 149 |
+
|
| 150 |
+
for i in range((self.depth * 2) - len(sim_seq)): #Random child letters
|
| 151 |
+
r = np.random.randint(0,self.states)
|
| 152 |
+
sim_seq += self.letters[r]
|
| 153 |
+
|
| 154 |
+
sim_seq = self.reconstruct(sim_seq)
|
| 155 |
+
scores = []
|
| 156 |
+
|
| 157 |
+
classifier.eval().to('cuda')
|
| 158 |
+
with torch.no_grad():
|
| 159 |
+
sim_seq = np.array([sim_seq])
|
| 160 |
+
|
| 161 |
+
apta_toks = self.nt_tokenizer.batch_encode_plus(sim_seq, return_tensors='pt', padding='max_length', max_length=275)['input_ids']
|
| 162 |
+
apta_attention_mask = apta_toks != self.nt_tokenizer.pad_token_id
|
| 163 |
+
prot_attention_mask = self.encoded_targetprotein != self.esm_alphabet.padding_idx
|
| 164 |
+
score, _, _, _ = classifier(apta_toks.to('cuda'), self.encoded_targetprotein.to('cuda'), apta_attention_mask.to('cuda'), prot_attention_mask.to('cuda'))
|
| 165 |
+
|
| 166 |
+
return score
|
| 167 |
+
|
| 168 |
+
#recommend
|
| 169 |
+
def get_candidate(self):
|
| 170 |
+
return self.reconstruct(self.candidate)
|
| 171 |
+
|
| 172 |
+
def find_best_subsequence(self):
|
| 173 |
+
now = self.root
|
| 174 |
+
stop = False
|
| 175 |
+
base = self.base
|
| 176 |
+
|
| 177 |
+
for _ in range((self.depth*2) - len(base)):
|
| 178 |
+
best = 0
|
| 179 |
+
next_node = 0
|
| 180 |
+
for j in range(self.states):
|
| 181 |
+
if now.children_stat[j] == True:
|
| 182 |
+
if best < now.children[j].UCT():
|
| 183 |
+
next_node = j
|
| 184 |
+
best = now.children[j].UCT()
|
| 185 |
+
|
| 186 |
+
now = now.next_node(child=next_node)
|
| 187 |
+
base += now.letter
|
| 188 |
+
|
| 189 |
+
# if current node has no expanded children, stop reconstructing.
|
| 190 |
+
if np.sum(now.children_stat) == 0:
|
| 191 |
+
break
|
| 192 |
+
|
| 193 |
+
return base
|
| 194 |
+
|
| 195 |
+
#reconstruct
|
| 196 |
+
def reconstruct(self, seq=""):
|
| 197 |
+
r_seq = ""
|
| 198 |
+
for i in range(0, len(seq), 2):
|
| 199 |
+
if seq[i] == '_':
|
| 200 |
+
r_seq = r_seq + seq[i+1]
|
| 201 |
+
else:
|
| 202 |
+
r_seq = seq[i] + r_seq
|
| 203 |
+
return r_seq
|
| 204 |
+
|
| 205 |
+
def reset(self):
|
| 206 |
+
self.base = ""
|
| 207 |
+
self.candidate = ""
|
| 208 |
+
self.root = Node(letter="", parent=None, root=True, last=False, states=self.states)
|