import torch import torch.nn as nn import torch.nn.functional as F import math class SinusoidalTimeEmbedding(nn.Module): """Sinusoidal time embedding as used in ProtFlow paper.""" def __init__(self, dim): super().__init__() self.dim = dim def forward(self, time): device = time.device half_dim = self.dim // 2 embeddings = math.log(10000) / (half_dim - 1) embeddings = torch.exp(torch.arange(half_dim, device=device) * -embeddings) # Ensure time is 2D: [B, 1] and embeddings is 1D: [half_dim] if time.dim() > 2: time = time.squeeze() # Remove extra dimensions embeddings = time.unsqueeze(-1) * embeddings.unsqueeze(0) # [B, half_dim] embeddings = torch.cat((embeddings.sin(), embeddings.cos()), dim=-1) # [B, dim] # Ensure output is exactly 2D if embeddings.dim() > 2: embeddings = embeddings.squeeze() return embeddings class LabelMLP(nn.Module): """ MLP for processing class labels into embeddings. This approach processes labels separately from time embeddings. """ def __init__(self, num_classes=3, hidden_dim=480, mlp_dim=256): super().__init__() self.num_classes = num_classes # MLP to process labels self.label_mlp = nn.Sequential( nn.Embedding(num_classes, mlp_dim), nn.Linear(mlp_dim, mlp_dim), nn.GELU(), nn.Linear(mlp_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim) ) # Initialize embeddings nn.init.normal_(self.label_mlp[0].weight, std=0.02) def forward(self, labels): """ Args: labels: (B,) tensor of class labels - 0: AMP (MIC < 100) - 1: Non-AMP (MIC >= 100) - 2: Mask (Unknown MIC) Returns: embeddings: (B, hidden_dim) tensor of processed label embeddings """ return self.label_mlp(labels) class AMPFlowMatcherCFGConcat(nn.Module): """ Flow Matching model with Classifier-Free Guidance using concatenation approach. - 12-layer transformer with long skip connections - Time embedding + MLP-processed label embedding (concatenated then projected) - Optimized for peptide sequences (max length 50) """ def __init__(self, hidden_dim=480, compressed_dim=30, n_layers=12, n_heads=16, dim_ff=3072, dropout=0.1, max_seq_len=25, use_cfg=True): super().__init__() self.hidden_dim = hidden_dim self.compressed_dim = compressed_dim self.n_layers = n_layers self.max_seq_len = max_seq_len self.use_cfg = use_cfg # Time embedding self.time_embed = nn.Sequential( SinusoidalTimeEmbedding(hidden_dim), nn.Linear(hidden_dim, hidden_dim), nn.GELU(), nn.Linear(hidden_dim, hidden_dim) ) # CFG components using concatenation approach if use_cfg: self.label_mlp = LabelMLP(num_classes=3, hidden_dim=hidden_dim) # Projection layer for concatenated time + label embeddings self.condition_proj = nn.Sequential( nn.Linear(hidden_dim * 2, hidden_dim), # 2 for time + label nn.GELU(), nn.Linear(hidden_dim, hidden_dim) ) # Projection layers for compressed space self.compress_proj = nn.Linear(compressed_dim, hidden_dim) self.decompress_proj = nn.Linear(hidden_dim, compressed_dim) # Positional encoding for peptide sequences self.pos_embed = nn.Parameter(torch.randn(1, max_seq_len, hidden_dim)) # Transformer layers with long skip connections self.layers = nn.ModuleList([ nn.TransformerEncoderLayer( d_model=hidden_dim, nhead=n_heads, dim_feedforward=dim_ff, dropout=dropout, activation='gelu', batch_first=True ) for _ in range(n_layers) ]) # Long skip connections (U-ViT style) self.skip_projections = nn.ModuleList([ nn.Linear(hidden_dim, hidden_dim) for _ in range(n_layers - 1) ]) # Output projection self.output_proj = nn.Linear(hidden_dim, compressed_dim) def forward(self, x, t, labels=None, mask=None): """ Args: x: compressed latent (B, L, compressed_dim) - AMP embeddings t: time scalar (B,) or (B, 1) labels: class labels (B,) for CFG - 0=AMP, 1=Non-AMP, 2=Mask mask: attention mask (B, L) if needed """ B, L, D = x.shape # Project to hidden dimension x = self.compress_proj(x) # (B, L, hidden_dim) # Add positional encoding if L <= self.max_seq_len: x = x + self.pos_embed[:, :L, :] # Time embedding - ensure t is 2D (B, 1) if t.dim() == 1: t = t.unsqueeze(-1) # (B, 1) elif t.dim() > 2: t = t.squeeze() # Remove extra dimensions if t.dim() == 1: t = t.unsqueeze(-1) # (B, 1) t_emb = self.time_embed(t) # (B, hidden_dim) # Ensure t_emb is 2D before expanding if t_emb.dim() > 2: t_emb = t_emb.squeeze() # Remove extra dimensions t_emb = t_emb.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim) # CFG: Process label embedding if enabled if self.use_cfg and labels is not None: # Process labels through MLP label_emb = self.label_mlp(labels) # (B, hidden_dim) label_emb = label_emb.unsqueeze(1).expand(-1, L, -1) # (B, L, hidden_dim) # Professor's approach: Concatenate time and label embeddings combined_emb = torch.cat([t_emb, label_emb], dim=-1) # (B, L, hidden_dim*2) projected_emb = self.condition_proj(combined_emb) # (B, L, hidden_dim) else: projected_emb = t_emb # Just use time embedding if no CFG # Store intermediate representations for skip connections skip_features = [] # Pass through transformer layers with skip connections for i, layer in enumerate(self.layers): # Add skip connection from earlier layers if i > 0 and i < len(self.layers) - 1: skip_feat = skip_features[i-1] skip_feat = self.skip_projections[i-1](skip_feat) x = x + skip_feat # Store current features for future skip connections if i < len(self.layers) - 1: skip_features.append(x.clone()) # Add projected condition embedding to EACH layer x = x + projected_emb # Apply transformer layer x = layer(x, src_key_padding_mask=mask) # Project back to compressed dimension x = self.output_proj(x) # (B, L, compressed_dim) return x class AMPProtFlowPipelineCFG: """ Complete ProtFlow pipeline for AMP generation with CFG. """ def __init__(self, compressor, decompressor, flow_model, device='cuda'): self.compressor = compressor self.decompressor = decompressor self.flow_model = flow_model self.device = device # Load normalization stats self.stats = torch.load('normalization_stats.pt', map_location=device) def generate_amps_cfg(self, num_samples=100, num_steps=25, cfg_scale=7.5, condition_label=0): """ Generate AMP samples using CFG. Args: num_samples: Number of samples to generate num_steps: Number of ODE solving steps cfg_scale: CFG guidance scale (higher = stronger conditioning) condition_label: 0=AMP, 1=Non-AMP, 2=Mask """ print(f"Generating {num_samples} samples with CFG (label={condition_label}, scale={cfg_scale})...") # Sample random noise batch_size = min(num_samples, 32) # Process in batches all_samples = [] for i in range(0, num_samples, batch_size): current_batch = min(batch_size, num_samples - i) # Initialize with noise eps = torch.randn(current_batch, self.flow_model.max_seq_len, self.flow_model.compressed_dim, device=self.device) # ODE solving steps with CFG xt = eps.clone() for step in range(num_steps): t = torch.ones(current_batch, device=self.device) * (1.0 - step/num_steps) # CFG: Generate with condition and without condition if cfg_scale > 0: # With condition vt_cond = self.flow_model(xt, t, labels=torch.full((current_batch,), condition_label, device=self.device)) # Without condition (mask) vt_uncond = self.flow_model(xt, t, labels=torch.full((current_batch,), 2, device=self.device)) # CFG interpolation vt = vt_uncond + cfg_scale * (vt_cond - vt_uncond) else: # No CFG, use mask label vt = self.flow_model(xt, t, labels=torch.full((current_batch,), 2, device=self.device)) # Euler step for backward integration (t: 1 -> 0) # Use negative dt to integrate backward from noise to data dt = -1.0 / num_steps xt = xt + vt * dt all_samples.append(xt) # Concatenate all batches generated = torch.cat(all_samples, dim=0) # Decompress and decode with torch.no_grad(): # Decompress decompressed = self.decompressor(generated) # Apply reverse normalization m, s, mn, mx = self.stats['mean'], self.stats['std'], self.stats['min'], self.stats['max'] decompressed = decompressed * (mx - mn + 1e-8) + mn decompressed = decompressed * s + m return generated, decompressed # Example usage if __name__ == "__main__": # Initialize FINAL AMP flow model with CFG using concatenation approach flow_model = AMPFlowMatcherCFGConcat( hidden_dim=480, compressed_dim=30, # 16x compression of 480 n_layers=12, n_heads=16, dim_ff=3072, max_seq_len=25, # For AMP sequences (max 50, halved by pooling) use_cfg=True ) print(f"FINAL AMP Flow Model with CFG (Concat+Proj) parameters: {sum(p.numel() for p in flow_model.parameters()):,}") # Test forward pass batch_size = 4 seq_len = 20 compressed_dim = 30 x = torch.randn(batch_size, seq_len, compressed_dim) t = torch.rand(batch_size) labels = torch.randint(0, 3, (batch_size,)) # Random labels with torch.no_grad(): output = flow_model(x, t, labels=labels) print(f"Input shape: {x.shape}") print(f"Output shape: {output.shape}") print(f"Time embedding shape: {t.shape}") print(f"Labels: {labels}") print("🎯 FINAL AMP Flow Model with CFG (Concat+Proj) ready for training!")