Kandinsky / kandinsky /magcache_utils.py
rahul7star's picture
Migrated from GitHub
0084610 verified
# This is an adaptation of Magcache from https://github.com/Zehong-Ma/MagCache/
import numpy as np
import torch
def nearest_interp(src_array, target_length):
src_length = len(src_array)
if target_length == 1:
return np.array([src_array[-1]])
scale = (src_length - 1) / (target_length - 1)
mapped_indices = np.round(np.arange(target_length) * scale).astype(int)
return src_array[mapped_indices]
def set_magcache_params(dit, mag_ratios, num_steps, no_cfg):
print(f'using Magcache')
dit.__class__.forward = magcache_forward
dit.cnt = 0
dit.num_steps = num_steps * 2
dit.magcache_thresh = 0.12
dit.K = 2
dit.accumulated_err = [0.0, 0.0]
dit.accumulated_steps = [0, 0]
dit.accumulated_ratio = [1.0, 1.0]
dit.retention_ratio = 0.2
dit.residual_cache = [None, None]
dit.mag_ratios = np.array([1.0]*2 + mag_ratios)
dit.no_cfg = no_cfg
if len(dit.mag_ratios) != num_steps * 2:
print(f'interpolate MAG RATIOS: curr len {len(dit.mag_ratios)}')
mag_ratio_con = nearest_interp(dit.mag_ratios[0::2], num_steps)
mag_ratio_ucon = nearest_interp(dit.mag_ratios[1::2], num_steps)
interpolated_mag_ratios = np.concatenate(
[mag_ratio_con.reshape(-1, 1), mag_ratio_ucon.reshape(-1, 1)], axis=1).reshape(-1)
dit.mag_ratios = interpolated_mag_ratios
@torch.compile(mode="max-autotune-no-cudagraphs")
def magcache_forward(
self,
x,
text_embed,
pooled_text_embed,
time,
visual_rope_pos,
text_rope_pos,
scale_factor=(1.0, 1.0, 1.0),
sparse_params=None
):
text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks(
text_embed, time, pooled_text_embed, x, text_rope_pos)
for text_transformer_block in self.text_transformer_blocks:
text_embed = text_transformer_block(text_embed, time_embed, text_rope)
visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks(
visual_embed, visual_rope_pos, scale_factor, sparse_params)
skip_forward = False
ori_visual_embed = visual_embed
if self.cnt>=int(self.num_steps*self.retention_ratio):
cur_mag_ratio = self.mag_ratios[self.cnt]
self.accumulated_ratio[self.cnt%2] = self.accumulated_ratio[self.cnt%2]*cur_mag_ratio
self.accumulated_steps[self.cnt%2] += 1
cur_skip_err = np.abs(1-self.accumulated_ratio[self.cnt%2])
self.accumulated_err[self.cnt%2] += cur_skip_err
if self.accumulated_err[self.cnt%2]<self.magcache_thresh and self.accumulated_steps[self.cnt%2]<=self.K:
skip_forward = True
residual_visual_embed = self.residual_cache[self.cnt%2]
else:
self.accumulated_err[self.cnt%2] = 0
self.accumulated_steps[self.cnt%2] = 0
self.accumulated_ratio[self.cnt%2] = 1.0
if skip_forward:
visual_embed = visual_embed + residual_visual_embed
else:
for visual_transformer_block in self.visual_transformer_blocks:
visual_embed = visual_transformer_block(visual_embed, text_embed, time_embed,
visual_rope, sparse_params)
residual_visual_embed = visual_embed - ori_visual_embed
self.residual_cache[self.cnt%2] = residual_visual_embed
x = self.after_blocks(visual_embed, visual_shape, to_fractal, text_embed, time_embed)
if self.no_cfg:
self.cnt += 2
else:
self.cnt += 1
if self.cnt >= self.num_steps:
self.cnt = 0
self.accumulated_ratio = [1.0, 1.0]
self.accumulated_err = [0.0, 0.0]
self.accumulated_steps = [0, 0]
return x