Update inference.py
Browse files- inference.py +2 -2
inference.py
CHANGED
|
@@ -128,10 +128,10 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|
| 128 |
half = dim // 2
|
| 129 |
|
| 130 |
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
|
| 131 |
-
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
| 132 |
|
| 133 |
# Block CUDA steam, but consistent with official codes:
|
| 134 |
-
|
| 135 |
|
| 136 |
args = t[:, None].float() * freqs[None]
|
| 137 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
|
|
|
| 128 |
half = dim // 2
|
| 129 |
|
| 130 |
# Do not block CUDA steam, but having about 1e-4 differences with Flux official codes:
|
| 131 |
+
# freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32, device=t.device) / half)
|
| 132 |
|
| 133 |
# Block CUDA steam, but consistent with official codes:
|
| 134 |
+
freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half).to(t.device)
|
| 135 |
|
| 136 |
args = t[:, None].float() * freqs[None]
|
| 137 |
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|