File size: 8,094 Bytes
370f342
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
import json
import os
import torch
import torch.nn.functional as F
import esm
from tqdm import tqdm
import numpy as np

# ---------------- Configuration ----------------
DEVICE       = torch.device("cuda" if torch.cuda.is_available() else "cpu")
BATCH_SIZE   = 32    # increased for GPU efficiency
MAX_SEQ_LEN  = 50   # max sequence length for AMPs
MIN_SEQ_LEN  = 2     # minimum length for filtering
CANONICAL_AA = set('ACDEFGHIKLMNPQRSTVWY')

print(f"Using device: {DEVICE}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name()}")
    print(f"GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

# ---------------- Sequence Loading ----------------
def read_peptides_json(json_file):
    """
    Read and filter sequences from the all_peptides_data.json file.
    Extracts sequences from both main peptides and their monomers.
    Filters:
      - Only canonical 20 AAs
      - Sequence length between MIN_SEQ_LEN and MAX_SEQ_LEN
      - Non-empty sequences
    Returns:
      List of (seq_id, sequence) tuples.
    """
    print(f"Loading peptides from {json_file}...")
    with open(json_file, 'r') as f:
        data = json.load(f)
    
    seqs = []
    processed_ids = set()
    
    for item in tqdm(data, desc="Processing peptides"):
        # Process main peptide sequence
        if 'sequence' in item and item['sequence']:
            seq = item['sequence'].upper().strip()
            if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and 
                all(aa in CANONICAL_AA for aa in seq)):
                seq_id = f"main_{item.get('id', 'unk')}"
                if seq_id not in processed_ids:
                    seqs.append((seq_id, seq))
                    processed_ids.add(seq_id)
        
        # Process monomer sequences
        if 'monomers' in item and item['monomers']:
            for monomer in item['monomers']:
                if 'sequence' in monomer and monomer['sequence']:
                    seq = monomer['sequence'].upper().strip()
                    if (MIN_SEQ_LEN <= len(seq) <= MAX_SEQ_LEN and 
                        all(aa in CANONICAL_AA for aa in seq)):
                        seq_id = f"monomer_{monomer.get('id', 'unk')}"
                        if seq_id not in processed_ids:
                            seqs.append((seq_id, seq))
                            processed_ids.add(seq_id)
    
    print(f"Found {len(seqs)} valid sequences")
    return seqs

@torch.no_grad()
def get_per_residue_embeddings(model, alphabet, sequences, batch_size=BATCH_SIZE):
    """
    Compute per-residue ESM-2 embeddings for a list of (id, seq).
    Pads or truncates each embedding to shape [MAX_SEQ_LEN, D].
    Returns a dict {seq_id: tensor[MAX_SEQ_LEN, D]} on CPU.
    """
    model.eval()
    converter = alphabet.get_batch_converter()
    embeddings = {}
    
    print(f"Computing embeddings for {len(sequences)} sequences...")
    for i in tqdm(range(0, len(sequences), batch_size), desc="Computing embeddings"):
        batch = sequences[i:i+batch_size]
        labels, seqs = zip(*batch)
        _, _, tokens = converter(batch)
        tokens = tokens.to(DEVICE)
        
        out = model(tokens, repr_layers=[33], return_contacts=False)
        reps = out['representations'][33]  # [B, L+2, D]
        
        for idx, sid in enumerate(labels):
            seq = seqs[idx]
            L = len(seq)
            # take per-residue embeddings and pad/truncate
            emb = reps[idx, 1:1+L, :]  # Remove CLS and EOS tokens
            if L < MAX_SEQ_LEN:
                pad_len = MAX_SEQ_LEN - L
                emb = F.pad(emb, (0, 0, 0, pad_len))
            elif L > MAX_SEQ_LEN:
                emb = emb[:MAX_SEQ_LEN, :]
            embeddings[sid] = emb.cpu()
    
    return embeddings

def save_embeddings_for_compressor(embeddings, output_dir="/data2/edwardsun/flow_project/peptide_embeddings"):
    """
    Save embeddings in a format compatible with the compressor.
    Creates both individual files and a combined tensor.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Save individual embeddings
    print(f"Saving individual embeddings to {output_dir}/...")
    for seq_id, emb in tqdm(embeddings.items(), desc="Saving individual files"):
        torch.save(emb, os.path.join(output_dir, f"{seq_id}.pt"))
    
    # Create and save combined tensor for compressor
    print("Creating combined tensor...")
    all_embeddings = []
    seq_ids = []
    
    for seq_id, emb in embeddings.items():
        all_embeddings.append(emb)
        seq_ids.append(seq_id)
    
    # Stack all embeddings
    combined_embeddings = torch.stack(all_embeddings)  # [N, MAX_SEQ_LEN, D]
    
    # Save combined tensor
    combined_path = os.path.join(output_dir, "all_peptide_embeddings.pt")
    torch.save(combined_embeddings, combined_path)
    
    # Save sequence IDs for reference
    seq_ids_path = os.path.join(output_dir, "sequence_ids.json")
    with open(seq_ids_path, 'w') as f:
        json.dump(seq_ids, f, indent=2)
    
    # Save metadata
    metadata = {
        "num_sequences": len(embeddings),
        "embedding_dim": combined_embeddings.shape[-1],
        "max_seq_len": MAX_SEQ_LEN,
        "device_used": str(DEVICE),
        "model_name": "esm2_t33_650M_UR50D"
    }
    metadata_path = os.path.join(output_dir, "metadata.json")
    with open(metadata_path, 'w') as f:
        json.dump(metadata, f, indent=2)
    
    print(f"Saved combined embeddings: {combined_path}")
    print(f"Combined tensor shape: {combined_embeddings.shape}")
    print(f"Memory usage: {combined_embeddings.element_size() * combined_embeddings.nelement() / 1e6:.1f} MB")
    
    return combined_path

def create_compressor_dataset(embeddings, output_dir="/data2/edwardsun/flow_project/compressor_dataset"):
    """
    Create a dataset format specifically for the compressor training.
    """
    os.makedirs(output_dir, exist_ok=True)
    
    # Stack all embeddings
    all_embeddings = torch.stack(list(embeddings.values()))
    
    # Save as numpy array for easy loading
    np_path = os.path.join(output_dir, "peptide_embeddings.npy")
    np.save(np_path, all_embeddings.numpy())
    
    # Save as torch tensor
    torch_path = os.path.join(output_dir, "peptide_embeddings.pt")
    torch.save(all_embeddings, torch_path)
    
    print(f"Created compressor dataset:")
    print(f"  Shape: {all_embeddings.shape}")
    print(f"  Numpy: {np_path}")
    print(f"  Torch: {torch_path}")
    
    return torch_path

# ---------------- Main Execution ----------------
if __name__ == '__main__':
    # 1. Load model & tokenizer
    print("Loading ESM-2 model...")
    model_name = 'esm2_t33_650M_UR50D'
    model, alphabet = esm.pretrained.load_model_and_alphabet(model_name)
    model = model.to(DEVICE)
    print(f"Loaded {model_name}")

    # 2. Read and filter sequences from peptides JSON
    json_file = 'all_peptides_data.json'
    sequences = read_peptides_json(json_file)
    print(f"Loaded {len(sequences)} valid sequences from {json_file}")

    if len(sequences) == 0:
        print("No valid sequences found. Exiting.")
        exit(1)

    # 3. Compute per-residue embeddings
    embeddings = get_per_residue_embeddings(model, alphabet, sequences)

    # 4. Save embeddings in multiple formats
    print("\nSaving embeddings...")
    
    # Save individual files and combined tensor
    combined_path = save_embeddings_for_compressor(embeddings)
    
    # Create compressor-specific dataset
    compressor_path = create_compressor_dataset(embeddings)
    
    print(f"\n✓ Successfully processed {len(embeddings)} peptide sequences")
    print(f"✓ Embeddings saved and ready for compressor training")
    print(f"✓ Use '{compressor_path}' in your compressor.py file")
    
    # Show some statistics
    sample_emb = next(iter(embeddings.values()))
    print(f"\nEmbedding statistics:")
    print(f"  Individual embedding shape: {sample_emb.shape}")
    print(f"  Embedding dimension: {sample_emb.shape[-1]}")
    print(f"  Data type: {sample_emb.dtype}")