# This is an adaptation of Magcache from https://github.com/Zehong-Ma/MagCache/ import numpy as np import torch def nearest_interp(src_array, target_length): src_length = len(src_array) if target_length == 1: return np.array([src_array[-1]]) scale = (src_length - 1) / (target_length - 1) mapped_indices = np.round(np.arange(target_length) * scale).astype(int) return src_array[mapped_indices] def set_magcache_params(dit, mag_ratios, num_steps, no_cfg): print(f'using Magcache') dit.__class__.forward = magcache_forward dit.cnt = 0 dit.num_steps = num_steps * 2 dit.magcache_thresh = 0.12 dit.K = 2 dit.accumulated_err = [0.0, 0.0] dit.accumulated_steps = [0, 0] dit.accumulated_ratio = [1.0, 1.0] dit.retention_ratio = 0.2 dit.residual_cache = [None, None] dit.mag_ratios = np.array([1.0]*2 + mag_ratios) dit.no_cfg = no_cfg if len(dit.mag_ratios) != num_steps * 2: print(f'interpolate MAG RATIOS: curr len {len(dit.mag_ratios)}') mag_ratio_con = nearest_interp(dit.mag_ratios[0::2], num_steps) mag_ratio_ucon = nearest_interp(dit.mag_ratios[1::2], num_steps) interpolated_mag_ratios = np.concatenate( [mag_ratio_con.reshape(-1, 1), mag_ratio_ucon.reshape(-1, 1)], axis=1).reshape(-1) dit.mag_ratios = interpolated_mag_ratios @torch.compile(mode="max-autotune-no-cudagraphs") def magcache_forward( self, x, text_embed, pooled_text_embed, time, visual_rope_pos, text_rope_pos, scale_factor=(1.0, 1.0, 1.0), sparse_params=None ): text_embed, time_embed, text_rope, visual_embed = self.before_text_transformer_blocks( text_embed, time, pooled_text_embed, x, text_rope_pos) for text_transformer_block in self.text_transformer_blocks: text_embed = text_transformer_block(text_embed, time_embed, text_rope) visual_embed, visual_shape, to_fractal, visual_rope = self.before_visual_transformer_blocks( visual_embed, visual_rope_pos, scale_factor, sparse_params) skip_forward = False ori_visual_embed = visual_embed if self.cnt>=int(self.num_steps*self.retention_ratio): cur_mag_ratio = self.mag_ratios[self.cnt] self.accumulated_ratio[self.cnt%2] = self.accumulated_ratio[self.cnt%2]*cur_mag_ratio self.accumulated_steps[self.cnt%2] += 1 cur_skip_err = np.abs(1-self.accumulated_ratio[self.cnt%2]) self.accumulated_err[self.cnt%2] += cur_skip_err if self.accumulated_err[self.cnt%2]= self.num_steps: self.cnt = 0 self.accumulated_ratio = [1.0, 1.0] self.accumulated_err = [0.0, 0.0] self.accumulated_steps = [0, 0] return x