| \section{Flow Matching Architecture with Classifier-Free Guidance} | |
| \label{sec:flow_model} | |
| Our flow matching model employs a transformer-based architecture with classifier-free guidance (CFG) for controllable antimicrobial peptide generation. The model operates in the compressed embedding space (80 dimensions) and uses continuous normalizing flows to transform noise into biologically meaningful protein representations. | |
| \subsection{Flow Matching Framework} | |
| Flow matching provides a simulation-free approach to training continuous normalizing flows by directly regressing the vector field. Given a source distribution $p_0$ (Gaussian noise) and target distribution $p_1$ (compressed AMP embeddings), flow matching learns a vector field $v_\theta(x, t)$ that transports samples along optimal transport paths. | |
| \subsubsection{Flow Matching Objective} | |
| \label{sec:flow_objective} | |
| The flow matching loss minimizes the difference between the predicted and true vector fields: | |
| \begin{align} | |
| \mathcal{L}_{\text{FM}}(\theta) &= \mathbb{E}_{t \sim U[0,1], x_1 \sim p_1, x_0 \sim p_0} \left[ \|v_\theta(x_t, t) - u_t(x_t)\|_2^2 \right] \label{eq:flow_matching_loss} | |
| \end{align} | |
| where $x_t = (1-t)x_0 + tx_1$ is the linear interpolation path and $u_t(x_t) = x_1 - x_0$ is the true vector field along this path. | |
| For conditional generation with classifier-free guidance, we extend this to: | |
| \begin{align} | |
| \mathcal{L}_{\text{CFG}}(\theta) &= \mathbb{E}_{t, x_1, x_0, c} \left[ \|v_\theta(x_t, t, c) - u_t(x_t)\|_2^2 \right] \label{eq:cfg_flow_loss} | |
| \end{align} | |
| where $c$ represents the conditioning information (AMP/non-AMP labels). | |
| \subsubsection{Conditional Vector Field with CFG} | |
| \label{sec:cfg_vector_field} | |
| During inference, classifier-free guidance combines conditional and unconditional predictions: | |
| \begin{align} | |
| \tilde{v}_\theta(x_t, t, c) &= v_\theta(x_t, t, \emptyset) + w \cdot (v_\theta(x_t, t, c) - v_\theta(x_t, t, \emptyset)) \label{eq:cfg_combination} | |
| \end{align} | |
| where $w$ is the guidance scale, $v_\theta(x_t, t, c)$ is the conditional prediction, and $v_\theta(x_t, t, \emptyset)$ is the unconditional prediction. | |
| \subsection{Transformer-Based Architecture} | |
| The flow matching model employs a 12-layer transformer with long skip connections, sinusoidal time embeddings, and learned positional encodings optimized for protein sequences. | |
| \subsubsection{Model Architecture Specifications} | |
| \label{sec:architecture_specs} | |
| \begin{itemize} | |
| \item \textbf{Input Dimension}: 80 (compressed embedding space) | |
| \item \textbf{Hidden Dimension}: 480 (model dimension) | |
| \item \textbf{Transformer Layers}: 12 with long skip connections | |
| \item \textbf{Attention Heads}: 16 multi-head attention heads | |
| \item \textbf{Feedforward Dimension}: 3072 (6.4× hidden dimension) | |
| \item \textbf{Maximum Sequence Length}: 25 (after hourglass pooling) | |
| \item \textbf{Activation Function}: GELU throughout the network | |
| \item \textbf{Dropout Rate}: 0.1 during training | |
| \end{itemize} | |
| \subsubsection{Time Embedding Architecture} | |
| \label{sec:time_embedding} | |
| Time information is encoded using sinusoidal embeddings following the ProtFlow methodology: | |
| \begin{align} | |
| \text{PE}(t, 2i) &= \sin\left(\frac{t}{10000^{2i/d}}\right) \label{eq:sin_time_embed}\\ | |
| \text{PE}(t, 2i+1) &= \cos\left(\frac{t}{10000^{2i/d}}\right) \label{eq:cos_time_embed}\\ | |
| \mathbf{t}_{\text{emb}} &= \text{MLP}(\text{PE}(t)) \in \mathbb{R}^{480} \label{eq:time_mlp} | |
| \end{align} | |
| where $d = 480$ is the hidden dimension and the MLP consists of two linear layers with GELU activation. | |
| \subsubsection{Long Skip Connections} | |
| \label{sec:skip_connections} | |
| The model incorporates U-ViT style long skip connections to preserve information flow: | |
| \begin{align} | |
| \mathbf{h}^{(i)} &= \text{TransformerLayer}^{(i)}(\mathbf{h}^{(i-1)} + \mathbf{t}_{\text{emb}}) \label{eq:transformer_layer}\\ | |
| \mathbf{h}^{(i)} &= \mathbf{h}^{(i)} + \text{SkipProj}^{(i-1)}(\mathbf{h}^{(i-2)}) \quad \text{for } i > 1 \label{eq:skip_connection} | |
| \end{align} | |
| where $\text{SkipProj}^{(i)}$ are learned linear projections for each skip connection. | |
| \subsection{Classifier-Free Guidance Implementation} | |
| CFG enables controllable generation by training a single model to handle both conditional and unconditional generation, then combining predictions during inference. | |
| \subsubsection{Label Processing Architecture} | |
| \label{sec:label_processing} | |
| The model processes three types of labels: | |
| \begin{itemize} | |
| \item \textbf{AMP (0)}: Sequences with MIC $< 100$ μg/mL | |
| \item \textbf{Non-AMP (1)}: Sequences with MIC $\geq 100$ μg/mL | |
| \item \textbf{Mask (2)}: Unknown/unconditional generation | |
| \end{itemize} | |
| Label embeddings are processed through a dedicated MLP: | |
| \begin{align} | |
| \mathbf{l}_{\text{raw}} &= \text{Embedding}(c) \in \mathbb{R}^{256} \label{eq:label_embedding}\\ | |
| \mathbf{l}_{\text{hidden}} &= \text{GELU}(\mathbf{l}_{\text{raw}} \mathbf{W}_1 + \mathbf{b}_1) \label{eq:label_hidden}\\ | |
| \mathbf{l}_{\text{emb}} &= \text{GELU}(\mathbf{l}_{\text{hidden}} \mathbf{W}_2 + \mathbf{b}_2) \in \mathbb{R}^{480} \label{eq:label_final} | |
| \end{align} | |
| \subsubsection{Condition Integration Strategy} | |
| \label{sec:condition_integration} | |
| We employ a concatenation-based approach for integrating time and label information: | |
| \begin{align} | |
| \mathbf{c}_{\text{concat}} &= \text{Concat}(\mathbf{t}_{\text{emb}}, \mathbf{l}_{\text{emb}}) \in \mathbb{R}^{960} \label{eq:concat_conditions}\\ | |
| \mathbf{c}_{\text{proj}} &= \text{MLP}_{\text{proj}}(\mathbf{c}_{\text{concat}}) \in \mathbb{R}^{480} \label{eq:condition_projection} | |
| \end{align} | |
| The projected conditioning is added to each transformer layer: | |
| \begin{align} | |
| \mathbf{h}^{(i)} &= \text{TransformerLayer}^{(i)}(\mathbf{h}^{(i-1)} + \mathbf{c}_{\text{proj}}) \label{eq:conditioned_transformer} | |
| \end{align} | |
| \subsubsection{CFG Training Strategy} | |
| \label{sec:cfg_training} | |
| During training, 15\% of samples are randomly masked (set to label 2) to enable unconditional generation: | |
| \begin{align} | |
| c_{\text{train}} = \begin{cases} | |
| c & \text{with probability } 0.85 \\ | |
| 2 & \text{with probability } 0.15 | |
| \end{cases} \label{eq:cfg_masking} | |
| \end{align} | |
| This masking strategy ensures the model learns both conditional and unconditional generation capabilities. | |
| \subsection{Training Methodology and Optimization} | |
| The model is trained using advanced optimization techniques specifically tuned for H100 GPU architecture with mixed precision training. | |
| \subsubsection{Training Hyperparameters} | |
| \label{sec:training_hyperparams} | |
| \begin{itemize} | |
| \item \textbf{Batch Size}: 512 (maximizing H100 utilization) | |
| \item \textbf{Training Epochs}: 2000 epochs | |
| \item \textbf{Base Learning Rate}: $8 \times 10^{-4}$ | |
| \item \textbf{Minimum Learning Rate}: $4 \times 10^{-4}$ | |
| \item \textbf{Warmup Steps}: 4000 steps | |
| \item \textbf{Weight Decay}: 0.01 | |
| \item \textbf{Gradient Clipping}: 0.5 (tight clipping for stability) | |
| \item \textbf{Mixed Precision}: BF16 for H100 optimization | |
| \end{itemize} | |
| \subsubsection{Advanced Learning Rate Scheduling} | |
| \label{sec:advanced_lr_scheduling} | |
| The training employs a sophisticated three-phase learning rate schedule: | |
| \begin{align} | |
| \text{lr}_{\text{warmup}}(t) &= \text{lr}_{\text{base}} \cdot \frac{t}{T_{\text{warmup}}} \quad \text{for } t \leq T_{\text{warmup}} \label{eq:flow_warmup}\\ | |
| \text{lr}_{\text{cosine}}(t) &= \text{lr}_{\text{min}} + \frac{1}{2}(\text{lr}_{\text{base}} - \text{lr}_{\text{min}})\left(1 + \cos\left(\frac{\pi(t - T_{\text{warmup}})}{T_{\text{total}} - T_{\text{warmup}}}\right)\right) \label{eq:flow_cosine} | |
| \end{align} | |
| \subsubsection{H100 GPU Optimizations} | |
| \label{sec:h100_optimizations} | |
| Training is optimized for H100 architecture with several performance enhancements: | |
| \begin{itemize} | |
| \item \textbf{TensorFloat-32 (TF32)}: Enabled for matrix operations | |
| \item \textbf{Mixed Precision Training}: BF16 with automatic loss scaling | |
| \item \textbf{Torch Compilation}: JIT compilation for 20-30\% speedup | |
| \item \textbf{Data Loading}: 32 parallel workers for optimal throughput | |
| \item \textbf{Memory Management}: Gradient checkpointing for large batches | |
| \end{itemize} | |
| \subsubsection{Training Dataset and Statistics} | |
| \label{sec:training_data} | |
| The model is trained on a comprehensive dataset of antimicrobial peptides: | |
| \begin{itemize} | |
| \item \textbf{Total Samples}: 6,983 validated sequences | |
| \item \textbf{AMP Sequences}: 3,306 (47.3\%) | |
| \item \textbf{Non-AMP Sequences}: 3,677 (52.7\%) | |
| \item \textbf{CFG Masked}: 698 samples (10\%) for unconditional training | |
| \item \textbf{Sequence Length}: Fixed at 50 amino acids (25 after compression) | |
| \item \textbf{Training Steps}: 28,000 total steps (14 batches × 2000 epochs) | |
| \end{itemize} | |
| \subsection{Training Results and Performance} | |
| The model achieved excellent convergence and stability during the 2.3-hour training session on H100 GPU. | |
| \subsubsection{Training Convergence} | |
| \label{sec:training_convergence} | |
| \begin{itemize} | |
| \item \textbf{Final Loss}: 1.318 (mean squared error) | |
| \item \textbf{Best Validation Loss}: 0.021476 | |
| \item \textbf{Training Time}: 2.3 hours on H100 | |
| \item \textbf{GPU Utilization}: ~70GB memory usage (91\% of H100) | |
| \item \textbf{Training Speed}: 0.1-3.4 steps/second (increasing with warmup) | |
| \item \textbf{Convergence}: Stable convergence without overfitting | |
| \end{itemize} | |
| \subsubsection{Model Performance Metrics} | |
| \label{sec:model_performance} | |
| \begin{itemize} | |
| \item \textbf{Parameter Count}: 50,779,584 parameters | |
| \item \textbf{Model Size}: ~607MB checkpoint file | |
| \item \textbf{Inference Speed}: ~1000 sequences/second | |
| \item \textbf{Memory Requirements}: ~12GB for inference | |
| \item \textbf{CFG Effectiveness}: Clear differentiation between conditional/unconditional generation | |
| \end{itemize} | |
| \subsubsection{CFG Scale Analysis} | |
| \label{sec:cfg_scale_analysis} | |
| Different CFG scales produce distinct generation characteristics: | |
| \begin{itemize} | |
| \item \textbf{CFG Scale 0.0}: Unconditional generation, maximum diversity | |
| \item \textbf{CFG Scale 3.0}: Weak conditioning, balanced diversity/control | |
| \item \textbf{CFG Scale 7.5}: Strong conditioning, optimal for AMP generation | |
| \item \textbf{CFG Scale 15.0}: Very strong conditioning, may reduce diversity | |
| \end{itemize} | |
| HMD-AMP validation results show CFG scale 7.5 achieves optimal performance with 20\% AMP classification rate. | |
| \begin{algorithm}[h] | |
| \caption{Flow Matching Model Forward Pass} | |
| \label{alg:flow_forward} | |
| \begin{algorithmic}[1] | |
| \REQUIRE Compressed embeddings $\mathbf{x} \in \mathbb{R}^{B \times L \times 80}$ | |
| \REQUIRE Time steps $\mathbf{t} \in \mathbb{R}^{B}$ | |
| \REQUIRE Condition labels $\mathbf{c} \in \mathbb{Z}^{B}$ (optional) | |
| \ENSURE Vector field prediction $\mathbf{v} \in \mathbb{R}^{B \times L \times 80}$ | |
| \STATE \textbf{// Stage 1: Input Processing} | |
| \STATE $\mathbf{h} \leftarrow \text{LinearProj}_{80 \rightarrow 480}(\mathbf{x})$ \COMMENT{Project to hidden dimension} | |
| \STATE $\mathbf{h} \leftarrow \mathbf{h} + \mathbf{P}[:, :L, :]$ \COMMENT{Add positional embeddings} | |
| \STATE \textbf{// Stage 2: Time Embedding} | |
| \STATE $\mathbf{t} \leftarrow \mathbf{t}.\text{unsqueeze}(-1)$ if $\mathbf{t}.\text{dim}() = 1$ \COMMENT{Ensure 2D} | |
| \FOR{$i = 0$ to $d/2 - 1$} | |
| \STATE $\text{emb}[:, 2i] \leftarrow \sin(\mathbf{t} / 10000^{2i/d})$ | |
| \STATE $\text{emb}[:, 2i+1] \leftarrow \cos(\mathbf{t} / 10000^{2i/d})$ | |
| \ENDFOR | |
| \STATE $\mathbf{t}_{\text{emb}} \leftarrow \text{MLP}_{\text{time}}(\text{emb})$ \COMMENT{Process through time MLP} | |
| \STATE $\mathbf{t}_{\text{emb}} \leftarrow \mathbf{t}_{\text{emb}}.\text{unsqueeze}(1).\text{expand}(-1, L, -1)$ | |
| \STATE \textbf{// Stage 3: Conditional Processing (if CFG enabled)} | |
| \IF{$\text{use\_cfg}$ and $\mathbf{c}$ is not None} | |
| \STATE $\mathbf{l}_{\text{emb}} \leftarrow \text{Embedding}(\mathbf{c})$ \COMMENT{Embed labels} | |
| \STATE $\mathbf{l}_{\text{emb}} \leftarrow \text{MLP}_{\text{label}}(\mathbf{l}_{\text{emb}})$ \COMMENT{Process labels} | |
| \STATE $\mathbf{l}_{\text{emb}} \leftarrow \mathbf{l}_{\text{emb}}.\text{unsqueeze}(1).\text{expand}(-1, L, -1)$ | |
| \STATE $\mathbf{c}_{\text{concat}} \leftarrow \text{Concat}(\mathbf{t}_{\text{emb}}, \mathbf{l}_{\text{emb}})$ \COMMENT{Concatenate conditions} | |
| \STATE $\mathbf{c}_{\text{proj}} \leftarrow \text{MLP}_{\text{proj}}(\mathbf{c}_{\text{concat}})$ \COMMENT{Project to hidden dim} | |
| \ELSE | |
| \STATE $\mathbf{c}_{\text{proj}} \leftarrow \mathbf{t}_{\text{emb}}$ \COMMENT{Use only time embedding} | |
| \ENDIF | |
| \STATE \textbf{// Stage 4: Transformer Processing with Skip Connections} | |
| \STATE $\text{skip\_features} \leftarrow []$ \COMMENT{Initialize skip connection storage} | |
| \FOR{$i = 0$ to $11$} \COMMENT{12 transformer layers} | |
| \IF{$i > 0$ and $i < 11$} \COMMENT{Add skip connections} | |
| \STATE $\mathbf{s} \leftarrow \text{skip\_features}[i-1]$ | |
| \STATE $\mathbf{s} \leftarrow \text{SkipProj}^{(i-1)}(\mathbf{s})$ | |
| \STATE $\mathbf{h} \leftarrow \mathbf{h} + \mathbf{s}$ | |
| \ENDIF | |
| \IF{$i < 11$} \COMMENT{Store for future skip connections} | |
| \STATE $\text{skip\_features}.\text{append}(\mathbf{h}.\text{clone}())$ | |
| \ENDIF | |
| \STATE $\mathbf{h} \leftarrow \mathbf{h} + \mathbf{c}_{\text{proj}}$ \COMMENT{Add conditioning} | |
| \STATE $\mathbf{h} \leftarrow \text{TransformerLayer}^{(i)}(\mathbf{h})$ \COMMENT{Apply transformer} | |
| \ENDFOR | |
| \STATE \textbf{// Stage 5: Output Projection} | |
| \STATE $\mathbf{v} \leftarrow \text{LinearProj}_{480 \rightarrow 80}(\mathbf{h})$ \COMMENT{Project to output dimension} | |
| \RETURN $\mathbf{v}$ | |
| \end{algorithmic} | |
| \end{algorithm} | |
| \begin{algorithm}[h] | |
| \caption{Classifier-Free Guidance Training} | |
| \label{alg:cfg_training} | |
| \begin{algorithmic}[1] | |
| \REQUIRE Training dataset $\mathcal{D} = \{(\mathbf{x}_i, c_i)\}_{i=1}^N$ | |
| \REQUIRE CFG dropout rate $p_{\text{drop}} = 0.15$ | |
| \REQUIRE Flow matching model $f_\theta$ | |
| \ENSURE Trained CFG-enabled flow model $f_{\theta^*}$ | |
| \FOR{$\text{epoch} = 1$ to $2000$} | |
| \FOR{$\text{batch} \in \text{DataLoader}(\mathcal{D}, \text{batch\_size}=512)$} | |
| \STATE $\{\mathbf{x}_{\text{batch}}, \mathbf{c}_{\text{batch}}\} \leftarrow \text{batch}$ | |
| \STATE \textbf{// Apply CFG masking} | |
| \STATE $\text{mask} \leftarrow \text{Bernoulli}(p_{\text{drop}})$ \COMMENT{Random masking} | |
| \STATE $\mathbf{c}_{\text{masked}} \leftarrow \text{where}(\text{mask}, 2, \mathbf{c}_{\text{batch}})$ \COMMENT{2 = unconditional} | |
| \STATE \textbf{// Sample time and create interpolation path} | |
| \STATE $\mathbf{t} \leftarrow \text{Uniform}(0, 1, \text{size}=(B,))$ | |
| \STATE $\mathbf{x}_0 \leftarrow \mathcal{N}(0, \mathbf{I})$ \COMMENT{Gaussian noise} | |
| \STATE $\mathbf{x}_1 \leftarrow \mathbf{x}_{\text{batch}}$ \COMMENT{Target embeddings} | |
| \STATE $\mathbf{x}_t \leftarrow (1 - \mathbf{t}) \mathbf{x}_0 + \mathbf{t} \mathbf{x}_1$ \COMMENT{Linear interpolation} | |
| \STATE $\mathbf{u}_t \leftarrow \mathbf{x}_1 - \mathbf{x}_0$ \COMMENT{True vector field} | |
| \STATE \textbf{// Forward pass} | |
| \STATE $\mathbf{v}_{\text{pred}} \leftarrow f_\theta(\mathbf{x}_t, \mathbf{t}, \mathbf{c}_{\text{masked}})$ | |
| \STATE \textbf{// Compute flow matching loss} | |
| \STATE $\mathcal{L} \leftarrow \|\mathbf{v}_{\text{pred}} - \mathbf{u}_t\|_2^2$ \COMMENT{MSE loss} | |
| \STATE \textbf{// Backward pass with mixed precision} | |
| \STATE $\text{scaler.scale}(\mathcal{L}).\text{backward}()$ | |
| \STATE $\text{scaler.unscale\_}(\text{optimizer})$ | |
| \STATE $\text{clip\_grad\_norm\_}(\theta, 0.5)$ \COMMENT{Gradient clipping} | |
| \STATE $\text{scaler.step}(\text{optimizer})$ | |
| \STATE $\text{scaler.update}()$ | |
| \STATE $\text{scheduler.step}()$ | |
| \ENDFOR | |
| \ENDFOR | |
| \RETURN $\theta^*$ | |
| \end{algorithmic} | |
| \end{algorithm} | |
| \begin{algorithm}[h] | |
| \caption{CFG-Enhanced Generation Process} | |
| \label{alg:cfg_generation} | |
| \begin{algorithmic}[1] | |
| \REQUIRE Trained flow model $f_\theta$ | |
| \REQUIRE CFG scale $w \in \mathbb{R}^+$ | |
| \REQUIRE Condition label $c \in \{0, 1\}$ (0=AMP, 1=Non-AMP) | |
| \REQUIRE Number of integration steps $N = 25$ | |
| \ENSURE Generated sequence embeddings $\mathbf{x}_1$ | |
| \STATE \textbf{// Initialize with Gaussian noise} | |
| \STATE $\mathbf{x}_0 \leftarrow \mathcal{N}(0, \mathbf{I})$ \COMMENT{Sample initial noise} | |
| \STATE \textbf{// Numerical integration with CFG} | |
| \FOR{$i = 0$ to $N-1$} | |
| \STATE $t \leftarrow i / N$ \COMMENT{Current time step} | |
| \STATE \textbf{// Conditional prediction} | |
| \STATE $\mathbf{v}_{\text{cond}} \leftarrow f_\theta(\mathbf{x}_t, t, c)$ | |
| \STATE \textbf{// Unconditional prediction} | |
| \STATE $\mathbf{v}_{\text{uncond}} \leftarrow f_\theta(\mathbf{x}_t, t, 2)$ \COMMENT{2 = mask/unconditional} | |
| \STATE \textbf{// Apply classifier-free guidance} | |
| \STATE $\mathbf{v}_{\text{guided}} \leftarrow \mathbf{v}_{\text{uncond}} + w \cdot (\mathbf{v}_{\text{cond}} - \mathbf{v}_{\text{uncond}})$ | |
| \STATE \textbf{// Euler integration step} | |
| \STATE $dt \leftarrow 1.0 / N$ | |
| \STATE $\mathbf{x}_{t+dt} \leftarrow \mathbf{x}_t + dt \cdot \mathbf{v}_{\text{guided}}$ | |
| \STATE $\mathbf{x}_t \leftarrow \mathbf{x}_{t+dt}$ | |
| \ENDFOR | |
| \STATE $\mathbf{x}_1 \leftarrow \mathbf{x}_t$ \COMMENT{Final generated embedding} | |
| \RETURN $\mathbf{x}_1$ | |
| \end{algorithmic} | |
| \end{algorithm} | |
| \begin{algorithm}[h] | |
| \caption{H100-Optimized Training Pipeline} | |
| \label{alg:h100_training} | |
| \begin{algorithmic}[1] | |
| \REQUIRE Dataset $\mathcal{D}$, Model $f_\theta$, H100 GPU | |
| \ENSURE Optimally trained model $f_{\theta^*}$ | |
| \STATE \textbf{// H100 Optimizations Setup} | |
| \STATE $\text{torch.backends.cuda.matmul.allow\_tf32} \leftarrow \text{True}$ | |
| \STATE $\text{torch.backends.cudnn.allow\_tf32} \leftarrow \text{True}$ | |
| \STATE $\text{model} \leftarrow \text{torch.compile}(f_\theta)$ \COMMENT{JIT compilation} | |
| \STATE $\text{scaler} \leftarrow \text{GradScaler}()$ \COMMENT{Mixed precision} | |
| \STATE \textbf{// Optimizer Setup} | |
| \STATE $\text{optimizer} \leftarrow \text{AdamW}(\theta, \text{lr}=8e-4, \text{weight\_decay}=0.01)$ | |
| \STATE $\text{warmup\_sched} \leftarrow \text{LinearLR}(\text{start\_factor}=1e-8, \text{total\_iters}=4000)$ | |
| \STATE $\text{cosine\_sched} \leftarrow \text{CosineAnnealingLR}(\text{eta\_min}=4e-4)$ | |
| \STATE $\text{scheduler} \leftarrow \text{SequentialLR}([\text{warmup\_sched}, \text{cosine\_sched}])$ | |
| \STATE \textbf{// Data Loading Optimization} | |
| \STATE $\text{dataloader} \leftarrow \text{DataLoader}(\mathcal{D}, \text{batch\_size}=512, \text{num\_workers}=32)$ | |
| \FOR{$\text{epoch} = 1$ to $2000$} | |
| \STATE $\text{epoch\_loss} \leftarrow 0$ | |
| \FOR{$\text{batch} \in \text{dataloader}$} | |
| \STATE \textbf{// Mixed precision forward pass} | |
| \WITH{$\text{autocast}()$} | |
| \STATE $\mathcal{L} \leftarrow \text{CFGFlowMatchingLoss}(\text{model}, \text{batch})$ | |
| \ENDWITH | |
| \STATE \textbf{// Scaled backward pass} | |
| \STATE $\text{scaler.scale}(\mathcal{L}).\text{backward}()$ | |
| \STATE $\text{scaler.unscale\_}(\text{optimizer})$ | |
| \STATE $\text{clip\_grad\_norm\_}(\theta, 0.5)$ | |
| \STATE $\text{scaler.step}(\text{optimizer})$ | |
| \STATE $\text{scaler.update}()$ | |
| \STATE $\text{scheduler.step}()$ | |
| \STATE $\text{epoch\_loss} \leftarrow \text{epoch\_loss} + \mathcal{L}.\text{item}()$ | |
| \ENDFOR | |
| \IF{$\text{epoch} \bmod 300 = 0$} \COMMENT{Checkpoint every 300 epochs} | |
| \STATE $\text{SaveCheckpoint}(\theta, \text{optimizer}, \text{scheduler}, \text{epoch})$ | |
| \ENDIF | |
| \ENDFOR | |
| \RETURN $\theta^*$ | |
| \end{algorithmic} | |
| \end{algorithm} | |