File size: 8,094 Bytes
370f342 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 |
import json
import os
import torch
import torch.nn.functional as F
import esm
from tqdm import tqdm
import numpy as np
# ---------------- Configuration ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32 # increased for GPU efficiency
MAX_SEQ_LEN = 50 # max sequence length for AMPs
MIN_SEQ_LEN = 2 # minimum length for filtering
CANONICAL_AA = set('ACDEFGHIKLMNPQRSTVWY')
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ---------------- Sequence Loading ----------------
def read_peptides_json(json_file):
"""
Read and filter sequences from the all_peptides_data.json file.
Extracts sequences from both main peptides and their monomers.
Filters:
- Only canonical 20 AAs
- Sequence length between MIN_SEQ_LEN and MAX_SEQ_LEN
- Non-empty sequences
Returns:
List of (seq_id, sequence) tuples.
"""
print(f"Loading peptides from {json_file}...")
with open(json_file, 'r') as f:
data = json.load(f)
seqs = []
processed_ids = set()
for item in tqdm(data, desc="Processing peptides"):
# Process main peptide sequence
if 'sequence' in item and item['sequence']:
seq = item['sequence'].upper().strip()
if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and
all(aa in CANONICAL_AA for aa in seq)):
seq_id = f"main_{item.get('id', 'unk')}"
if seq_id not in processed_ids:
seqs.append((seq_id, seq))
processed_ids.add(seq_id)
# Process monomer sequences
if 'monomers' in item and item['monomers']:
for monomer in item['monomers']:
if 'sequence' in monomer and monomer['sequence']:
seq = monomer['sequence'].upper().strip()
if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and
all(aa in CANONICAL_AA for aa in seq)):
seq_id = f"monomer_{monomer.get('id', 'unk')}"
if seq_id not in processed_ids:
seqs.append((seq_id, seq))
processed_ids.add(seq_id)
print(f"Found {len(seqs)} valid sequences")
return seqs
@torch.no_grad()
def get_per_residue_embeddings(model, alphabet, sequences, batch_size=BATCH_SIZE):
"""
Compute per-residue ESM-2 embeddings for a list of (id, seq).
Pads or truncates each embedding to shape [MAX_SEQ_LEN, D].
Returns a dict {seq_id: tensor[MAX_SEQ_LEN, D]} on CPU.
"""
model.eval()
converter = alphabet.get_batch_converter()
embeddings = {}
print(f"Computing embeddings for {len(sequences)} sequences...")
for i in tqdm(range(0, len(sequences), batch_size), desc="Computing embeddings"):
batch = sequences[i:i+batch_size]
labels, seqs = zip(*batch)
_, _, tokens = converter(batch)
tokens = tokens.to(DEVICE)
out = model(tokens, repr_layers=[33], return_contacts=False)
reps = out['representations'][33] # [B, L+2, D]
for idx, sid in enumerate(labels):
seq = seqs[idx]
L = len(seq)
# take per-residue embeddings and pad/truncate
emb = reps[idx, 1:1+L, :] # Remove CLS and EOS tokens
if L < MAX_SEQ_LEN:
pad_len = MAX_SEQ_LEN - L
emb = F.pad(emb, (0, 0, 0, pad_len))
elif L > MAX_SEQ_LEN:
emb = emb[:MAX_SEQ_LEN, :]
embeddings[sid] = emb.cpu()
return embeddings
def save_embeddings_for_compressor(embeddings, output_dir="/data2/edwardsun/flow_project/peptide_embeddings"):
"""
Save embeddings in a format compatible with the compressor.
Creates both individual files and a combined tensor.
"""
os.makedirs(output_dir, exist_ok=True)
# Save individual embeddings
print(f"Saving individual embeddings to {output_dir}/...")
for seq_id, emb in tqdm(embeddings.items(), desc="Saving individual files"):
torch.save(emb, os.path.join(output_dir, f"{seq_id}.pt"))
# Create and save combined tensor for compressor
print("Creating combined tensor...")
all_embeddings = []
seq_ids = []
for seq_id, emb in embeddings.items():
all_embeddings.append(emb)
seq_ids.append(seq_id)
# Stack all embeddings
combined_embeddings = torch.stack(all_embeddings) # [N, MAX_SEQ_LEN, D]
# Save combined tensor
combined_path = os.path.join(output_dir, "all_peptide_embeddings.pt")
torch.save(combined_embeddings, combined_path)
# Save sequence IDs for reference
seq_ids_path = os.path.join(output_dir, "sequence_ids.json")
with open(seq_ids_path, 'w') as f:
json.dump(seq_ids, f, indent=2)
# Save metadata
metadata = {
"num_sequences": len(embeddings),
"embedding_dim": combined_embeddings.shape[-1],
"max_seq_len": MAX_SEQ_LEN,
"device_used": str(DEVICE),
"model_name": "esm2_t33_650M_UR50D"
}
metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"Saved combined embeddings: {combined_path}")
print(f"Combined tensor shape: {combined_embeddings.shape}")
print(f"Memory usage: {combined_embeddings.element_size() * combined_embeddings.nelement() / 1e6:.1f} MB")
return combined_path
def create_compressor_dataset(embeddings, output_dir="/data2/edwardsun/flow_project/compressor_dataset"):
"""
Create a dataset format specifically for the compressor training.
"""
os.makedirs(output_dir, exist_ok=True)
# Stack all embeddings
all_embeddings = torch.stack(list(embeddings.values()))
# Save as numpy array for easy loading
np_path = os.path.join(output_dir, "peptide_embeddings.npy")
np.save(np_path, all_embeddings.numpy())
# Save as torch tensor
torch_path = os.path.join(output_dir, "peptide_embeddings.pt")
torch.save(all_embeddings, torch_path)
print(f"Created compressor dataset:")
print(f" Shape: {all_embeddings.shape}")
print(f" Numpy: {np_path}")
print(f" Torch: {torch_path}")
return torch_path
# ---------------- Main Execution ----------------
if __name__ == '__main__':
# 1. Load model & tokenizer
print("Loading ESM-2 model...")
model_name = 'esm2_t33_650M_UR50D'
model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
model = model.to(DEVICE)
print(f"Loaded {model_name}")
# 2. Read and filter sequences from peptides JSON
json_file = 'all_peptides_data.json'
sequences = read_peptides_json(json_file)
print(f"Loaded {len(sequences)} valid sequences from {json_file}")
if len(sequences) == 0:
print("No valid sequences found. Exiting.")
exit(1)
# 3. Compute per-residue embeddings
embeddings = get_per_residue_embeddings(model, alphabet, sequences)
# 4. Save embeddings in multiple formats
print("\nSaving embeddings...")
# Save individual files and combined tensor
combined_path = save_embeddings_for_compressor(embeddings)
# Create compressor-specific dataset
compressor_path = create_compressor_dataset(embeddings)
print(f"\n✓ Successfully processed {len(embeddings)} peptide sequences")
print(f"✓ Embeddings saved and ready for compressor training")
print(f"✓ Use '{compressor_path}' in your compressor.py file")
# Show some statistics
sample_emb = next(iter(embeddings.values()))
print(f"\nEmbedding statistics:")
print(f" Individual embedding shape: {sample_emb.shape}")
print(f" Embedding dimension: {sample_emb.shape[-1]}")
print(f" Data type: {sample_emb.dtype}")
|