|
|
import json |
|
|
import os |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
import esm |
|
|
from tqdm import tqdm |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
BATCH_SIZE = 32 |
|
|
MAX_SEQ_LEN = 50 |
|
|
MIN_SEQ_LEN = 2 |
|
|
CANONICAL_AA = set('ACDEFGHIKLMNPQRSTVWY') |
|
|
|
|
|
print(f"Using device: {DEVICE}") |
|
|
if torch.cuda.is_available(): |
|
|
print(f"GPU: {torch.cuda.get_device_name()}") |
|
|
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") |
|
|
|
|
|
|
|
|
def read_peptides_json(json_file): |
|
|
""" |
|
|
Read and filter sequences from the all_peptides_data.json file. |
|
|
Extracts sequences from both main peptides and their monomers. |
|
|
Filters: |
|
|
- Only canonical 20 AAs |
|
|
- Sequence length between MIN_SEQ_LEN and MAX_SEQ_LEN |
|
|
- Non-empty sequences |
|
|
Returns: |
|
|
List of (seq_id, sequence) tuples. |
|
|
""" |
|
|
print(f"Loading peptides from {json_file}...") |
|
|
with open(json_file, 'r') as f: |
|
|
data = json.load(f) |
|
|
|
|
|
seqs = [] |
|
|
processed_ids = set() |
|
|
|
|
|
for item in tqdm(data, desc="Processing peptides"): |
|
|
|
|
|
if 'sequence' in item and item['sequence']: |
|
|
seq = item['sequence'].upper().strip() |
|
|
if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and |
|
|
all(aa in CANONICAL_AA for aa in seq)): |
|
|
seq_id = f"main_{item.get('id', 'unk')}" |
|
|
if seq_id not in processed_ids: |
|
|
seqs.append((seq_id, seq)) |
|
|
processed_ids.add(seq_id) |
|
|
|
|
|
|
|
|
if 'monomers' in item and item['monomers']: |
|
|
for monomer in item['monomers']: |
|
|
if 'sequence' in monomer and monomer['sequence']: |
|
|
seq = monomer['sequence'].upper().strip() |
|
|
if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and |
|
|
all(aa in CANONICAL_AA for aa in seq)): |
|
|
seq_id = f"monomer_{monomer.get('id', 'unk')}" |
|
|
if seq_id not in processed_ids: |
|
|
seqs.append((seq_id, seq)) |
|
|
processed_ids.add(seq_id) |
|
|
|
|
|
print(f"Found {len(seqs)} valid sequences") |
|
|
return seqs |
|
|
|
|
|
@torch.no_grad() |
|
|
def get_per_residue_embeddings(model, alphabet, sequences, batch_size=BATCH_SIZE): |
|
|
""" |
|
|
Compute per-residue ESM-2 embeddings for a list of (id, seq). |
|
|
Pads or truncates each embedding to shape [MAX_SEQ_LEN, D]. |
|
|
Returns a dict {seq_id: tensor[MAX_SEQ_LEN, D]} on CPU. |
|
|
""" |
|
|
model.eval() |
|
|
converter = alphabet.get_batch_converter() |
|
|
embeddings = {} |
|
|
|
|
|
print(f"Computing embeddings for {len(sequences)} sequences...") |
|
|
for i in tqdm(range(0, len(sequences), batch_size), desc="Computing embeddings"): |
|
|
batch = sequences[i:i+batch_size] |
|
|
labels, seqs = zip(*batch) |
|
|
_, _, tokens = converter(batch) |
|
|
tokens = tokens.to(DEVICE) |
|
|
|
|
|
out = model(tokens, repr_layers=[33], return_contacts=False) |
|
|
reps = out['representations'][33] |
|
|
|
|
|
for idx, sid in enumerate(labels): |
|
|
seq = seqs[idx] |
|
|
L = len(seq) |
|
|
|
|
|
emb = reps[idx, 1:1+L, :] |
|
|
if L < MAX_SEQ_LEN: |
|
|
pad_len = MAX_SEQ_LEN - L |
|
|
emb = F.pad(emb, (0, 0, 0, pad_len)) |
|
|
elif L > MAX_SEQ_LEN: |
|
|
emb = emb[:MAX_SEQ_LEN, :] |
|
|
embeddings[sid] = emb.cpu() |
|
|
|
|
|
return embeddings |
|
|
|
|
|
def save_embeddings_for_compressor(embeddings, output_dir="/data2/edwardsun/flow_project/peptide_embeddings"): |
|
|
""" |
|
|
Save embeddings in a format compatible with the compressor. |
|
|
Creates both individual files and a combined tensor. |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
print(f"Saving individual embeddings to {output_dir}/...") |
|
|
for seq_id, emb in tqdm(embeddings.items(), desc="Saving individual files"): |
|
|
torch.save(emb, os.path.join(output_dir, f"{seq_id}.pt")) |
|
|
|
|
|
|
|
|
print("Creating combined tensor...") |
|
|
all_embeddings = [] |
|
|
seq_ids = [] |
|
|
|
|
|
for seq_id, emb in embeddings.items(): |
|
|
all_embeddings.append(emb) |
|
|
seq_ids.append(seq_id) |
|
|
|
|
|
|
|
|
combined_embeddings = torch.stack(all_embeddings) |
|
|
|
|
|
|
|
|
combined_path = os.path.join(output_dir, "all_peptide_embeddings.pt") |
|
|
torch.save(combined_embeddings, combined_path) |
|
|
|
|
|
|
|
|
seq_ids_path = os.path.join(output_dir, "sequence_ids.json") |
|
|
with open(seq_ids_path, 'w') as f: |
|
|
json.dump(seq_ids, f, indent=2) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"num_sequences": len(embeddings), |
|
|
"embedding_dim": combined_embeddings.shape[-1], |
|
|
"max_seq_len": MAX_SEQ_LEN, |
|
|
"device_used": str(DEVICE), |
|
|
"model_name": "esm2_t33_650M_UR50D" |
|
|
} |
|
|
metadata_path = os.path.join(output_dir, "metadata.json") |
|
|
with open(metadata_path, 'w') as f: |
|
|
json.dump(metadata, f, indent=2) |
|
|
|
|
|
print(f"Saved combined embeddings: {combined_path}") |
|
|
print(f"Combined tensor shape: {combined_embeddings.shape}") |
|
|
print(f"Memory usage: {combined_embeddings.element_size() * combined_embeddings.nelement() / 1e6:.1f} MB") |
|
|
|
|
|
return combined_path |
|
|
|
|
|
def create_compressor_dataset(embeddings, output_dir="/data2/edwardsun/flow_project/compressor_dataset"): |
|
|
""" |
|
|
Create a dataset format specifically for the compressor training. |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
all_embeddings = torch.stack(list(embeddings.values())) |
|
|
|
|
|
|
|
|
np_path = os.path.join(output_dir, "peptide_embeddings.npy") |
|
|
np.save(np_path, all_embeddings.numpy()) |
|
|
|
|
|
|
|
|
torch_path = os.path.join(output_dir, "peptide_embeddings.pt") |
|
|
torch.save(all_embeddings, torch_path) |
|
|
|
|
|
print(f"Created compressor dataset:") |
|
|
print(f" Shape: {all_embeddings.shape}") |
|
|
print(f" Numpy: {np_path}") |
|
|
print(f" Torch: {torch_path}") |
|
|
|
|
|
return torch_path |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
|
|
|
print("Loading ESM-2 model...") |
|
|
model_name = 'esm2_t33_650M_UR50D' |
|
|
model, alphabet = esm.pretrained.load_model_and_alphabet(model_name) |
|
|
model = model.to(DEVICE) |
|
|
print(f"Loaded {model_name}") |
|
|
|
|
|
|
|
|
json_file = 'all_peptides_data.json' |
|
|
sequences = read_peptides_json(json_file) |
|
|
print(f"Loaded {len(sequences)} valid sequences from {json_file}") |
|
|
|
|
|
if len(sequences) == 0: |
|
|
print("No valid sequences found. Exiting.") |
|
|
exit(1) |
|
|
|
|
|
|
|
|
embeddings = get_per_residue_embeddings(model, alphabet, sequences) |
|
|
|
|
|
|
|
|
print("\nSaving embeddings...") |
|
|
|
|
|
|
|
|
combined_path = save_embeddings_for_compressor(embeddings) |
|
|
|
|
|
|
|
|
compressor_path = create_compressor_dataset(embeddings) |
|
|
|
|
|
print(f"\n✓ Successfully processed {len(embeddings)} peptide sequences") |
|
|
print(f"✓ Embeddings saved and ready for compressor training") |
|
|
print(f"✓ Use '{compressor_path}' in your compressor.py file") |
|
|
|
|
|
|
|
|
sample_emb = next(iter(embeddings.values())) |
|
|
print(f"\nEmbedding statistics:") |
|
|
print(f" Individual embedding shape: {sample_emb.shape}") |
|
|
print(f" Embedding dimension: {sample_emb.shape[-1]}") |
|
|
print(f" Data type: {sample_emb.dtype}") |
|
|
|