import torch import torch.nn as nn import torch.optim as optim from torch.utils.data import DataLoader from torch.optim.lr_scheduler import CosineAnnealingLR, LinearLR, SequentialLR from torch.nn.parallel import DistributedDataParallel as DDP from torch.utils.data.distributed import DistributedSampler import torch.distributed as dist import numpy as np from tqdm import tqdm import json import os import argparse # Import your existing components from compressor_with_embeddings import Compressor, Decompressor, PrecomputedEmbeddingDataset from final_flow_model import AMPFlowMatcherCFGConcat, SinusoidalTimeEmbedding from cfg_dataset import CFGFlowDataset, create_cfg_dataloader # ---------------- Configuration ---------------- ESM_DIM = 1280 # ESM-2 hidden dim (esm2_t33_650M_UR50D) COMP_RATIO = 16 # compression factor COMP_DIM = ESM_DIM // COMP_RATIO MAX_SEQ_LEN = 50 # Actual sequence length from final_sequence_encoder.py BATCH_SIZE = 64 # Per GPU batch size (256 total across 4 GPUs) - increased for faster training EPOCHS = 5000 # Extended to 5K iterations for more comprehensive training (~50 minutes) BASE_LR = 1e-4 # initial learning rate LR_MIN = 2e-5 # minimum learning rate for cosine schedule WARMUP_STEPS = 100 # Reduced warmup for shorter training def setup_distributed(): """Setup distributed training.""" if 'RANK' in os.environ and 'WORLD_SIZE' in os.environ: rank = int(os.environ["RANK"]) world_size = int(os.environ['WORLD_SIZE']) local_rank = int(os.environ['LOCAL_RANK']) else: print('Not using distributed mode') return None, None, None torch.cuda.set_device(local_rank) dist.init_process_group(backend='nccl', init_method='env://') dist.barrier() return rank, world_size, local_rank class AMPFlowTrainerMultiGPU: """ Multi-GPU training pipeline for AMP generation using ProtFlow methodology. """ def __init__(self, embeddings_path, cfg_data_path, rank, world_size, local_rank): self.rank = rank self.world_size = world_size self.local_rank = local_rank self.device = torch.device(f'cuda:{local_rank}') self.embeddings_path = embeddings_path self.cfg_data_path = cfg_data_path # Load ALL pre-computed embeddings (only on main process) if self.rank == 0: print(f"Loading ALL AMP embeddings from {embeddings_path}...") # Try to load the combined embeddings file first (FULL DATA) combined_path = os.path.join(embeddings_path, "all_peptide_embeddings.pt") if os.path.exists(combined_path): print(f"Loading combined embeddings from {combined_path} (FULL DATA)...") self.embeddings = torch.load(combined_path, map_location=self.device) print(f"✓ Loaded ALL embeddings: {self.embeddings.shape}") else: print("Combined embeddings file not found, loading individual files...") # Fallback to individual files import glob embedding_files = glob.glob(os.path.join(embeddings_path, "*.pt")) embedding_files = [f for f in embedding_files if not f.endswith('metadata.json') and not f.endswith('sequence_ids.json') and not f.endswith('all_peptide_embeddings.pt')] print(f"Found {len(embedding_files)} individual embedding files") # Load and stack all embeddings embeddings_list = [] for file_path in embedding_files: try: embedding = torch.load(file_path) if embedding.dim() == 2: # (seq_len, hidden_dim) embeddings_list.append(embedding) else: print(f"Warning: Skipping {file_path} - unexpected shape {embedding.shape}") except Exception as e: print(f"Warning: Could not load {file_path}: {e}") if not embeddings_list: raise ValueError("No valid embeddings found!") self.embeddings = torch.stack(embeddings_list) print(f"Loaded {len(self.embeddings)} embeddings from individual files") # Compute normalization statistics print("Computing preprocessing statistics...") self._compute_preprocessing_stats() # Broadcast statistics to all processes if self.rank == 0: stats_tensor = torch.stack([ self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] ]).to(self.device) else: stats_tensor = torch.zeros(4, ESM_DIM, device=self.device) dist.broadcast(stats_tensor, src=0) if self.rank != 0: self.stats = { 'mean': stats_tensor[0], 'std': stats_tensor[1], 'min': stats_tensor[2], 'max': stats_tensor[3] } # Initialize models self._initialize_models() def _compute_preprocessing_stats(self): """Compute preprocessing statistics (only on main process).""" # Flatten all embeddings flat = self.embeddings.view(-1, ESM_DIM) # 1. Z-score normalization statistics feat_mean = flat.mean(0) feat_std = flat.std(0) + 1e-8 # 2. Truncation statistics (after z-score) z_score_normalized = (flat - feat_mean) / feat_std z_score_clamped = torch.clamp(z_score_normalized, -4, 4) # 3. Min-max normalization statistics (after truncation) feat_min = z_score_clamped.min(0)[0] feat_max = z_score_clamped.max(0)[0] # Store statistics self.stats = { 'mean': feat_mean, 'std': feat_std, 'min': feat_min, 'max': feat_max } # Save statistics for later use torch.save(self.stats, 'normalization_stats.pt') if self.rank == 0: print("✓ Preprocessing statistics computed and saved to normalization_stats.pt") def _initialize_models(self): """Initialize models for distributed training.""" # Load pre-trained compressor and decompressor self.compressor = Compressor().to(self.device) self.decompressor = Decompressor().to(self.device) # Load trained weights self.compressor.load_state_dict(torch.load('final_compressor_model.pth', map_location=self.device)) self.decompressor.load_state_dict(torch.load('final_decompressor_model.pth', map_location=self.device)) # Initialize flow matching model with CFG self.flow_model = AMPFlowMatcherCFGConcat( hidden_dim=480, compressed_dim=COMP_DIM, n_layers=12, n_heads=16, dim_ff=3072, max_seq_len=25, use_cfg=True ).to(self.device) # Wrap with DDP self.flow_model = DDP(self.flow_model, device_ids=[self.local_rank], find_unused_parameters=True) if self.rank == 0: print("✓ Initialized models for distributed training") print(f" - Flow model parameters: {sum(p.numel() for p in self.flow_model.parameters()):,}") print(f" - Using {self.world_size} GPUs") def _preprocess_batch(self, batch): """Apply preprocessing to a batch of embeddings.""" # 1. Z-score normalization h_norm = (batch - self.stats['mean'].to(batch.device)) / self.stats['std'].to(batch.device) # 2. Truncation (saturation) of outliers h_trunc = torch.clamp(h_norm, min=-4.0, max=4.0) # 3. Min-max normalization per dimension h_min = self.stats['min'].to(batch.device) h_max = self.stats['max'].to(batch.device) h_scaled = (h_trunc - h_min) / (h_max - h_min + 1e-8) h_scaled = torch.clamp(h_scaled, 0.0, 1.0) return h_scaled def train_flow_matching(self): """Train the flow matching model using distributed training.""" if self.rank == 0: print("Step 3: Training Flow Matching model (Multi-GPU)...") # Create CFG dataset and distributed data loader try: # Try to use CFG dataset with real labels dataset = CFGFlowDataset( embeddings_path=self.embeddings_path, cfg_data_path=self.cfg_data_path, use_masked_labels=True, max_seq_len=MAX_SEQ_LEN, device=self.device ) print("✓ Using CFG dataset with real labels") except Exception as e: print(f"Warning: Could not load CFG dataset: {e}") print("Falling back to random labels (not recommended for CFG)") # Fallback to original dataset with random labels dataset = PrecomputedEmbeddingDataset(self.embeddings_path) sampler = DistributedSampler(dataset, num_replicas=self.world_size, rank=self.rank) dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, sampler=sampler, num_workers=4) # Initialize optimizer optimizer = optim.AdamW( self.flow_model.parameters(), lr=BASE_LR, betas=(0.9, 0.98), weight_decay=0.01, eps=1e-6 ) # LR scheduling: warmup -> cosine warmup_sched = LinearLR(optimizer, start_factor=1e-8, end_factor=1.0, total_iters=WARMUP_STEPS) cosine_sched = CosineAnnealingLR(optimizer, T_max=EPOCHS, eta_min=LR_MIN) scheduler = SequentialLR(optimizer, [warmup_sched, cosine_sched], milestones=[WARMUP_STEPS]) # Training loop self.flow_model.train() total_steps = 0 if self.rank == 0: print(f"Starting training for {EPOCHS} iterations with FULL DATA...") print(f"Total batch size: {BATCH_SIZE * self.world_size}") print(f"Steps per epoch: {len(dataloader)}") print(f"Total samples: {len(dataset):,}") print(f"Estimated time: ~30-45 minutes (using ALL data)") for epoch in range(EPOCHS): sampler.set_epoch(epoch) # Ensure different shuffling per epoch for batch_idx, batch_data in enumerate(dataloader): # Handle different data formats if isinstance(batch_data, dict) and 'embeddings' in batch_data: # CFG dataset format x = batch_data['embeddings'].to(self.device) labels = batch_data['labels'].to(self.device) else: # Original dataset format - use random labels x = batch_data.to(self.device) labels = torch.randint(0, 3, (x.shape[0],), device=self.device) batch_size = x.shape[0] # Apply preprocessing x_processed = self._preprocess_batch(x) # Compress to latent space with torch.no_grad(): z = self.compressor(x_processed, self.stats) # Sample random noise eps = torch.randn_like(z) # Sample random time t = torch.rand(batch_size, device=self.device) # Interpolate between data and noise xt = t.view(batch_size, 1, 1) * eps + (1 - t.view(batch_size, 1, 1)) * z # Target vector field for rectified flow ut = eps - z # Use real labels from CFG dataset or random labels as fallback # labels are already defined above based on dataset type # Predict vector field with CFG vt_pred = self.flow_model(xt, t, labels=labels) # CFM loss loss = ((vt_pred - ut) ** 2).mean() # Backward pass optimizer.zero_grad() loss.backward() # Gradient clipping torch.nn.utils.clip_grad_norm_(self.flow_model.parameters(), 1.0) optimizer.step() scheduler.step() total_steps += 1 # Logging (only on main process) - more frequent for short training if self.rank == 0 and total_steps % 10 == 0: progress = (total_steps / EPOCHS) * 100 label_dist = torch.bincount(labels, minlength=3) print(f"Step {total_steps}/{EPOCHS} ({progress:.1f}%): Loss = {loss.item():.6f}, LR = {scheduler.get_last_lr()[0]:.2e}, Labels: AMP={label_dist[0]}, Non-AMP={label_dist[1]}, Mask={label_dist[2]}") # Save checkpoint (only on main process) - more frequent for short training if self.rank == 0 and total_steps % 100 == 0: self._save_checkpoint(total_steps, loss.item()) # Validation (only on main process) - more frequent for short training if self.rank == 0 and total_steps % 200 == 0: self._validate() # Save final model (only on main process) if self.rank == 0: self._save_checkpoint(total_steps, loss.item(), is_final=True) print("✓ Flow matching training completed!") def _save_checkpoint(self, step, loss, is_final=False): """Save training checkpoint (only on main process).""" # Get the underlying model from DDP model_state_dict = self.flow_model.module.state_dict() checkpoint = { 'step': step, 'flow_model_state_dict': model_state_dict, 'loss': loss, } if is_final: torch.save(checkpoint, 'amp_flow_model_final_full_data.pth') print(f"✓ Final model saved: amp_flow_model_final_full_data.pth") else: torch.save(checkpoint, f'amp_flow_checkpoint_full_data_step_{step}.pth') print(f"✓ Checkpoint saved: amp_flow_checkpoint_full_data_step_{step}.pth") def _validate(self): """Validate the model by generating a few samples.""" print("Generating validation samples...") self.flow_model.eval() with torch.no_grad(): # Generate a few samples eps = torch.randn(4, 25, COMP_DIM, device=self.device) xt = eps.clone() # 25-step generation with CFG (using AMP label) labels = torch.full((4,), 0, device=self.device) # 0 = AMP for step in range(25): t = torch.ones(4, device=self.device) * (1.0 - step/25) vt = self.flow_model(xt, t, labels=labels) dt = 1.0 / 25 xt = xt + vt * dt # Decompress decompressed = self.decompressor(xt) # Apply reverse preprocessing m, s, mn, mx = self.stats['mean'].to(self.device), self.stats['std'].to(self.device), self.stats['min'].to(self.device), self.stats['max'].to(self.device) decompressed = decompressed * (mx - mn + 1e-8) + mn decompressed = decompressed * s + m print(f" Generated samples shape: {decompressed.shape}") print(f" Sample stats - Mean: {decompressed.mean():.4f}, Std: {decompressed.std():.4f}") self.flow_model.train() def main(): """Main training function with distributed setup.""" parser = argparse.ArgumentParser() parser.add_argument('--local_rank', type=int, default=0) parser.add_argument('--cfg_data_path', type=str, default='/data2/edwardsun/flow_project/test_uniprot_processed/uniprot_processed_data.json', help='Path to FULL CFG training data with real labels') args = parser.parse_args() # Setup distributed training rank, world_size, local_rank = setup_distributed() if rank == 0: print("=== Multi-GPU AMP Flow Matching Training Pipeline with FULL DATA ===") print("This implements the complete ProtFlow methodology for AMP generation.") print("Training for 5,000 iterations (~30-45 minutes) using ALL available data.") print() # Check if required files exist required_files = [ 'final_compressor_model.pth', 'final_decompressor_model.pth', '/data2/edwardsun/flow_project/peptide_embeddings/' ] for file in required_files: if not os.path.exists(file): print(f"❌ Missing required file: {file}") print("Please ensure you have:") print("1. Run final_sequence_encoder.py to generate embeddings") print("2. Run compressor_with_embeddings.py to train compressor/decompressor") return # Check if CFG data exists if not os.path.exists(args.cfg_data_path): print(f"⚠️ CFG data not found: {args.cfg_data_path}") print("Training will use random labels (not recommended for CFG)") print("To use real labels, run uniprot_data_processor.py first") else: print(f"✓ CFG data found: {args.cfg_data_path}") print("✓ All required files found!") print() # Initialize trainer trainer = AMPFlowTrainerMultiGPU( embeddings_path='/data2/edwardsun/flow_project/peptide_embeddings/', cfg_data_path=args.cfg_data_path, rank=rank, world_size=world_size, local_rank=local_rank ) # Train flow matching model trainer.train_flow_matching() if rank == 0: print("\n=== Multi-GPU Training Complete with FULL DATA ===") print("Your AMP flow matching model trained on ALL available data!") print("Next steps:") print("1. Test the model: python generate_amps.py") print("2. Compare performance with previous model") print("3. Implement reflow for 1-step generation") print("4. Add conditioning for toxicity (future project)") if __name__ == "__main__": main()