VideoModelStudio
/
docs
/finetrainers-src-codebase
/finetrainers
/patches
/dependencies
/diffusers
/rms_norm.py
| import torch | |
| import torch.nn as nn | |
| from diffusers.utils import is_torch_npu_available, is_torch_version | |
| def _patched_rms_norm_forward(self, hidden_states: torch.Tensor) -> torch.Tensor: | |
| if is_torch_npu_available(): | |
| import torch_npu | |
| if self.weight is not None: | |
| # convert into half-precision if necessary | |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
| hidden_states = hidden_states.to(self.weight.dtype) | |
| hidden_states = torch_npu.npu_rms_norm(hidden_states, self.weight, epsilon=self.eps)[0] | |
| if self.bias is not None: | |
| hidden_states = hidden_states + self.bias | |
| elif is_torch_version(">=", "2.4"): | |
| ### ===== <Modified> ======= | |
| input_dtype = hidden_states.dtype | |
| if self.weight is not None: | |
| # convert into half-precision if necessary | |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
| hidden_states = hidden_states.to(self.weight.dtype) | |
| hidden_states = nn.functional.rms_norm( | |
| hidden_states, normalized_shape=(hidden_states.shape[-1],), weight=self.weight, eps=self.eps | |
| ) | |
| if self.bias is not None: | |
| hidden_states = hidden_states + self.bias | |
| hidden_states = hidden_states.to(input_dtype) | |
| ### ===== </Modified> ===== | |
| else: | |
| input_dtype = hidden_states.dtype | |
| variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) | |
| hidden_states = hidden_states * torch.rsqrt(variance + self.eps) | |
| if self.weight is not None: | |
| # convert into half-precision if necessary | |
| if self.weight.dtype in [torch.float16, torch.bfloat16]: | |
| hidden_states = hidden_states.to(self.weight.dtype) | |
| hidden_states = hidden_states * self.weight | |
| if self.bias is not None: | |
| hidden_states = hidden_states + self.bias | |
| else: | |
| hidden_states = hidden_states.to(input_dtype) | |
| return hidden_states | |