Spaces:
Running
on
Zero
Running
on
Zero
Update models/unet.py
Browse files- models/unet.py +21 -1
models/unet.py
CHANGED
|
@@ -9,6 +9,25 @@ from einops.layers.torch import Rearrange
|
|
| 9 |
from einops import rearrange
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import os
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
|
| 14 |
MONITOR_ATTN = []
|
|
@@ -808,7 +827,8 @@ class MotionCLR(nn.Module):
|
|
| 808 |
# text encoder
|
| 809 |
self.embed_text = nn.Linear(clip_dim, text_latent_dim)
|
| 810 |
self.clip_version = clip_version
|
| 811 |
-
self.clip_model = self.load_and_freeze_clip(clip_version)
|
|
|
|
| 812 |
textTransEncoderLayer = nn.TransformerEncoderLayer(
|
| 813 |
d_model=text_latent_dim,
|
| 814 |
nhead=text_num_heads,
|
|
|
|
| 9 |
from einops import rearrange
|
| 10 |
import matplotlib.pyplot as plt
|
| 11 |
import os
|
| 12 |
+
import torch.nn as nn
|
| 13 |
+
|
| 14 |
+
# Custom LayerNorm class to handle fp16
|
| 15 |
+
class CustomLayerNorm(nn.LayerNorm):
|
| 16 |
+
def forward(self, x: torch.Tensor):
|
| 17 |
+
if self.weight.dtype == torch.float32:
|
| 18 |
+
orig_type = x.dtype
|
| 19 |
+
ret = super().forward(x.type(torch.float32))
|
| 20 |
+
return ret.type(orig_type)
|
| 21 |
+
else:
|
| 22 |
+
return super().forward(x)
|
| 23 |
+
|
| 24 |
+
# Function to replace LayerNorm in CLIP model with CustomLayerNorm
|
| 25 |
+
def replace_layer_norm(model):
|
| 26 |
+
for name, module in model.named_children():
|
| 27 |
+
if isinstance(module, nn.LayerNorm):
|
| 28 |
+
setattr(model, name, CustomLayerNorm(module.normalized_shape, elementwise_affine=module.elementwise_affine))
|
| 29 |
+
else:
|
| 30 |
+
replace_layer_norm(module) # Recursively apply to all submodules
|
| 31 |
|
| 32 |
|
| 33 |
MONITOR_ATTN = []
|
|
|
|
| 827 |
# text encoder
|
| 828 |
self.embed_text = nn.Linear(clip_dim, text_latent_dim)
|
| 829 |
self.clip_version = clip_version
|
| 830 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
| 831 |
+
replace_layer_norm(self.clip_model)
|
| 832 |
textTransEncoderLayer = nn.TransformerEncoderLayer(
|
| 833 |
d_model=text_latent_dim,
|
| 834 |
nhead=text_num_heads,
|