FlowAMP / final_sequence_encoder.py
esunAI's picture
Initial FlowAMP upload: Complete project with all essential files
370f342
import json
import os
import torch
import torch.nn.functional as F
import esm
from tqdm import tqdm
import numpy as np
# ---------------- Configuration ----------------
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE = 32 # increased for GPU efficiency
MAX_SEQ_LEN = 50 # max sequence length for AMPs
MIN_SEQ_LEN = 2 # minimum length for filtering
CANONICAL_AA = set('ACDEFGHIKLMNPQRSTVWY')
print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
print(f"GPU: {torch.cuda.get_device_name()}")
print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")
# ---------------- Sequence Loading ----------------
def read_peptides_json(json_file):
"""
Read and filter sequences from the all_peptides_data.json file.
Extracts sequences from both main peptides and their monomers.
Filters:
- Only canonical 20 AAs
- Sequence length between MIN_SEQ_LEN and MAX_SEQ_LEN
- Non-empty sequences
Returns:
List of (seq_id, sequence) tuples.
"""
print(f"Loading peptides from {json_file}...")
with open(json_file, 'r') as f:
data = json.load(f)
seqs = []
processed_ids = set()
for item in tqdm(data, desc="Processing peptides"):
# Process main peptide sequence
if 'sequence' in item and item['sequence']:
seq = item['sequence'].upper().strip()
if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and
all(aa in CANONICAL_AA for aa in seq)):
seq_id = f"main_{item.get('id', 'unk')}"
if seq_id not in processed_ids:
seqs.append((seq_id, seq))
processed_ids.add(seq_id)
# Process monomer sequences
if 'monomers' in item and item['monomers']:
for monomer in item['monomers']:
if 'sequence' in monomer and monomer['sequence']:
seq = monomer['sequence'].upper().strip()
if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and
all(aa in CANONICAL_AA for aa in seq)):
seq_id = f"monomer_{monomer.get('id', 'unk')}"
if seq_id not in processed_ids:
seqs.append((seq_id, seq))
processed_ids.add(seq_id)
print(f"Found {len(seqs)} valid sequences")
return seqs
@torch.no_grad()
def get_per_residue_embeddings(model, alphabet, sequences, batch_size=BATCH_SIZE):
"""
Compute per-residue ESM-2 embeddings for a list of (id, seq).
Pads or truncates each embedding to shape [MAX_SEQ_LEN, D].
Returns a dict {seq_id: tensor[MAX_SEQ_LEN, D]} on CPU.
"""
model.eval()
converter = alphabet.get_batch_converter()
embeddings = {}
print(f"Computing embeddings for {len(sequences)} sequences...")
for i in tqdm(range(0, len(sequences), batch_size), desc="Computing embeddings"):
batch = sequences[i:i+batch_size]
labels, seqs = zip(*batch)
_, _, tokens = converter(batch)
tokens = tokens.to(DEVICE)
out = model(tokens, repr_layers=[33], return_contacts=False)
reps = out['representations'][33] # [B, L+2, D]
for idx, sid in enumerate(labels):
seq = seqs[idx]
L = len(seq)
# take per-residue embeddings and pad/truncate
emb = reps[idx, 1:1+L, :] # Remove CLS and EOS tokens
if L < MAX_SEQ_LEN:
pad_len = MAX_SEQ_LEN - L
emb = F.pad(emb, (0, 0, 0, pad_len))
elif L > MAX_SEQ_LEN:
emb = emb[:MAX_SEQ_LEN, :]
embeddings[sid] = emb.cpu()
return embeddings
def save_embeddings_for_compressor(embeddings, output_dir="/data2/edwardsun/flow_project/peptide_embeddings"):
"""
Save embeddings in a format compatible with the compressor.
Creates both individual files and a combined tensor.
"""
os.makedirs(output_dir, exist_ok=True)
# Save individual embeddings
print(f"Saving individual embeddings to {output_dir}/...")
for seq_id, emb in tqdm(embeddings.items(), desc="Saving individual files"):
torch.save(emb, os.path.join(output_dir, f"{seq_id}.pt"))
# Create and save combined tensor for compressor
print("Creating combined tensor...")
all_embeddings = []
seq_ids = []
for seq_id, emb in embeddings.items():
all_embeddings.append(emb)
seq_ids.append(seq_id)
# Stack all embeddings
combined_embeddings = torch.stack(all_embeddings) # [N, MAX_SEQ_LEN, D]
# Save combined tensor
combined_path = os.path.join(output_dir, "all_peptide_embeddings.pt")
torch.save(combined_embeddings, combined_path)
# Save sequence IDs for reference
seq_ids_path = os.path.join(output_dir, "sequence_ids.json")
with open(seq_ids_path, 'w') as f:
json.dump(seq_ids, f, indent=2)
# Save metadata
metadata = {
"num_sequences": len(embeddings),
"embedding_dim": combined_embeddings.shape[-1],
"max_seq_len": MAX_SEQ_LEN,
"device_used": str(DEVICE),
"model_name": "esm2_t33_650M_UR50D"
}
metadata_path = os.path.join(output_dir, "metadata.json")
with open(metadata_path, 'w') as f:
json.dump(metadata, f, indent=2)
print(f"Saved combined embeddings: {combined_path}")
print(f"Combined tensor shape: {combined_embeddings.shape}")
print(f"Memory usage: {combined_embeddings.element_size() * combined_embeddings.nelement() / 1e6:.1f} MB")
return combined_path
def create_compressor_dataset(embeddings, output_dir="/data2/edwardsun/flow_project/compressor_dataset"):
"""
Create a dataset format specifically for the compressor training.
"""
os.makedirs(output_dir, exist_ok=True)
# Stack all embeddings
all_embeddings = torch.stack(list(embeddings.values()))
# Save as numpy array for easy loading
np_path = os.path.join(output_dir, "peptide_embeddings.npy")
np.save(np_path, all_embeddings.numpy())
# Save as torch tensor
torch_path = os.path.join(output_dir, "peptide_embeddings.pt")
torch.save(all_embeddings, torch_path)
print(f"Created compressor dataset:")
print(f" Shape: {all_embeddings.shape}")
print(f" Numpy: {np_path}")
print(f" Torch: {torch_path}")
return torch_path
# ---------------- Main Execution ----------------
if __name__ == '__main__':
# 1. Load model & tokenizer
print("Loading ESM-2 model...")
model_name = 'esm2_t33_650M_UR50D'
model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
model = model.to(DEVICE)
print(f"Loaded {model_name}")
# 2. Read and filter sequences from peptides JSON
json_file = 'all_peptides_data.json'
sequences = read_peptides_json(json_file)
print(f"Loaded {len(sequences)} valid sequences from {json_file}")
if len(sequences) == 0:
print("No valid sequences found. Exiting.")
exit(1)
# 3. Compute per-residue embeddings
embeddings = get_per_residue_embeddings(model, alphabet, sequences)
# 4. Save embeddings in multiple formats
print("\nSaving embeddings...")
# Save individual files and combined tensor
combined_path = save_embeddings_for_compressor(embeddings)
# Create compressor-specific dataset
compressor_path = create_compressor_dataset(embeddings)
print(f"\n✓ Successfully processed {len(embeddings)} peptide sequences")
print(f"✓ Embeddings saved and ready for compressor training")
print(f"✓ Use '{compressor_path}' in your compressor.py file")
# Show some statistics
sample_emb = next(iter(embeddings.values()))
print(f"\nEmbedding statistics:")
print(f" Individual embedding shape: {sample_emb.shape}")
print(f" Embedding dimension: {sample_emb.shape[-1]}")
print(f" Data type: {sample_emb.dtype}")