|
|
from typing import Optional, Tuple |
|
|
|
|
|
import torch |
|
|
from diffusers.models.embeddings import get_3d_rotary_pos_embed |
|
|
from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid |
|
|
|
|
|
|
|
|
def prepare_rotary_positional_embeddings( |
|
|
height: int, |
|
|
width: int, |
|
|
num_frames: int, |
|
|
vae_scale_factor_spatial: int = 8, |
|
|
patch_size: int = 2, |
|
|
patch_size_t: int = None, |
|
|
attention_head_dim: int = 64, |
|
|
device: Optional[torch.device] = None, |
|
|
base_height: int = 480, |
|
|
base_width: int = 720, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
grid_height = height // (vae_scale_factor_spatial * patch_size) |
|
|
grid_width = width // (vae_scale_factor_spatial * patch_size) |
|
|
base_size_width = base_width // (vae_scale_factor_spatial * patch_size) |
|
|
base_size_height = base_height // (vae_scale_factor_spatial * patch_size) |
|
|
|
|
|
if patch_size_t is None: |
|
|
|
|
|
grid_crops_coords = get_resize_crop_region_for_grid( |
|
|
(grid_height, grid_width), base_size_width, base_size_height |
|
|
) |
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( |
|
|
embed_dim=attention_head_dim, |
|
|
crops_coords=grid_crops_coords, |
|
|
grid_size=(grid_height, grid_width), |
|
|
temporal_size=num_frames, |
|
|
) |
|
|
else: |
|
|
|
|
|
base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t |
|
|
|
|
|
freqs_cos, freqs_sin = get_3d_rotary_pos_embed( |
|
|
embed_dim=attention_head_dim, |
|
|
crops_coords=None, |
|
|
grid_size=(grid_height, grid_width), |
|
|
temporal_size=base_num_frames, |
|
|
grid_type="slice", |
|
|
max_size=(base_size_height, base_size_width), |
|
|
) |
|
|
|
|
|
freqs_cos = freqs_cos.to(device=device) |
|
|
freqs_sin = freqs_sin.to(device=device) |
|
|
return freqs_cos, freqs_sin |
|
|
|