import json import os import torch import torch.nn.functional as F import esm from tqdm import tqdm import numpy as np # ---------------- Configuration ---------------- DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") BATCH_SIZE = 32 # increased for GPU efficiency MAX_SEQ_LEN = 50 # max sequence length for AMPs MIN_SEQ_LEN = 2 # minimum length for filtering CANONICAL_AA = set('ACDEFGHIKLMNPQRSTVWY') print(f"Using device: {DEVICE}") if torch.cuda.is_available(): print(f"GPU: {torch.cuda.get_device_name()}") print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB") # ---------------- Sequence Loading ---------------- def read_peptides_json(json_file): """ Read and filter sequences from the all_peptides_data.json file. Extracts sequences from both main peptides and their monomers. Filters: - Only canonical 20 AAs - Sequence length between MIN_SEQ_LEN and MAX_SEQ_LEN - Non-empty sequences Returns: List of (seq_id, sequence) tuples. """ print(f"Loading peptides from {json_file}...") with open(json_file, 'r') as f: data = json.load(f) seqs = [] processed_ids = set() for item in tqdm(data, desc="Processing peptides"): # Process main peptide sequence if 'sequence' in item and item['sequence']: seq = item['sequence'].upper().strip() if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and all(aa in CANONICAL_AA for aa in seq)): seq_id = f"main_{item.get('id', 'unk')}" if seq_id not in processed_ids: seqs.append((seq_id, seq)) processed_ids.add(seq_id) # Process monomer sequences if 'monomers' in item and item['monomers']: for monomer in item['monomers']: if 'sequence' in monomer and monomer['sequence']: seq = monomer['sequence'].upper().strip() if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and all(aa in CANONICAL_AA for aa in seq)): seq_id = f"monomer_{monomer.get('id', 'unk')}" if seq_id not in processed_ids: seqs.append((seq_id, seq)) processed_ids.add(seq_id) print(f"Found {len(seqs)} valid sequences") return seqs @torch.no_grad() def get_per_residue_embeddings(model, alphabet, sequences, batch_size=BATCH_SIZE): """ Compute per-residue ESM-2 embeddings for a list of (id, seq). Pads or truncates each embedding to shape [MAX_SEQ_LEN, D]. Returns a dict {seq_id: tensor[MAX_SEQ_LEN, D]} on CPU. """ model.eval() converter = alphabet.get_batch_converter() embeddings = {} print(f"Computing embeddings for {len(sequences)} sequences...") for i in tqdm(range(0, len(sequences), batch_size), desc="Computing embeddings"): batch = sequences[i:i+batch_size] labels, seqs = zip(*batch) _, _, tokens = converter(batch) tokens = tokens.to(DEVICE) out = model(tokens, repr_layers=[33], return_contacts=False) reps = out['representations'][33] # [B, L+2, D] for idx, sid in enumerate(labels): seq = seqs[idx] L = len(seq) # take per-residue embeddings and pad/truncate emb = reps[idx, 1:1+L, :] # Remove CLS and EOS tokens if L < MAX_SEQ_LEN: pad_len = MAX_SEQ_LEN - L emb = F.pad(emb, (0, 0, 0, pad_len)) elif L > MAX_SEQ_LEN: emb = emb[:MAX_SEQ_LEN, :] embeddings[sid] = emb.cpu() return embeddings def save_embeddings_for_compressor(embeddings, output_dir="/data2/edwardsun/flow_project/peptide_embeddings"): """ Save embeddings in a format compatible with the compressor. Creates both individual files and a combined tensor. """ os.makedirs(output_dir, exist_ok=True) # Save individual embeddings print(f"Saving individual embeddings to {output_dir}/...") for seq_id, emb in tqdm(embeddings.items(), desc="Saving individual files"): torch.save(emb, os.path.join(output_dir, f"{seq_id}.pt")) # Create and save combined tensor for compressor print("Creating combined tensor...") all_embeddings = [] seq_ids = [] for seq_id, emb in embeddings.items(): all_embeddings.append(emb) seq_ids.append(seq_id) # Stack all embeddings combined_embeddings = torch.stack(all_embeddings) # [N, MAX_SEQ_LEN, D] # Save combined tensor combined_path = os.path.join(output_dir, "all_peptide_embeddings.pt") torch.save(combined_embeddings, combined_path) # Save sequence IDs for reference seq_ids_path = os.path.join(output_dir, "sequence_ids.json") with open(seq_ids_path, 'w') as f: json.dump(seq_ids, f, indent=2) # Save metadata metadata = { "num_sequences": len(embeddings), "embedding_dim": combined_embeddings.shape[-1], "max_seq_len": MAX_SEQ_LEN, "device_used": str(DEVICE), "model_name": "esm2_t33_650M_UR50D" } metadata_path = os.path.join(output_dir, "metadata.json") with open(metadata_path, 'w') as f: json.dump(metadata, f, indent=2) print(f"Saved combined embeddings: {combined_path}") print(f"Combined tensor shape: {combined_embeddings.shape}") print(f"Memory usage: {combined_embeddings.element_size() * combined_embeddings.nelement() / 1e6:.1f} MB") return combined_path def create_compressor_dataset(embeddings, output_dir="/data2/edwardsun/flow_project/compressor_dataset"): """ Create a dataset format specifically for the compressor training. """ os.makedirs(output_dir, exist_ok=True) # Stack all embeddings all_embeddings = torch.stack(list(embeddings.values())) # Save as numpy array for easy loading np_path = os.path.join(output_dir, "peptide_embeddings.npy") np.save(np_path, all_embeddings.numpy()) # Save as torch tensor torch_path = os.path.join(output_dir, "peptide_embeddings.pt") torch.save(all_embeddings, torch_path) print(f"Created compressor dataset:") print(f" Shape: {all_embeddings.shape}") print(f" Numpy: {np_path}") print(f" Torch: {torch_path}") return torch_path # ---------------- Main Execution ---------------- if __name__ == '__main__': # 1. Load model & tokenizer print("Loading ESM-2 model...") model_name = 'esm2_t33_650M_UR50D' model, alphabet = esm.pretrained.load_model_and_alphabet(model_name) model = model.to(DEVICE) print(f"Loaded {model_name}") # 2. Read and filter sequences from peptides JSON json_file = 'all_peptides_data.json' sequences = read_peptides_json(json_file) print(f"Loaded {len(sequences)} valid sequences from {json_file}") if len(sequences) == 0: print("No valid sequences found. Exiting.") exit(1) # 3. Compute per-residue embeddings embeddings = get_per_residue_embeddings(model, alphabet, sequences) # 4. Save embeddings in multiple formats print("\nSaving embeddings...") # Save individual files and combined tensor combined_path = save_embeddings_for_compressor(embeddings) # Create compressor-specific dataset compressor_path = create_compressor_dataset(embeddings) print(f"\nāœ“ Successfully processed {len(embeddings)} peptide sequences") print(f"āœ“ Embeddings saved and ready for compressor training") print(f"āœ“ Use '{compressor_path}' in your compressor.py file") # Show some statistics sample_emb = next(iter(embeddings.values())) print(f"\nEmbedding statistics:") print(f" Individual embedding shape: {sample_emb.shape}") print(f" Embedding dimension: {sample_emb.shape[-1]}") print(f" Data type: {sample_emb.dtype}")