Spaces:
Runtime error
Runtime error
Update autoregressive/models/generate.py
Browse files
autoregressive/models/generate.py
CHANGED
|
@@ -60,7 +60,9 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
|
|
| 60 |
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
| 61 |
if top_k > 0 or top_p < 1.0:
|
| 62 |
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 63 |
-
probs = F.softmax(logits
|
|
|
|
|
|
|
| 64 |
# values, indices = torch.max(probs, dim=1, keepdim=True)
|
| 65 |
# mask = (probs == values).float()
|
| 66 |
# probs = probs * (1 - mask)
|
|
@@ -71,8 +73,6 @@ def sample(logits, temperature: float=1.0, top_k: int=2000, top_p: float=1.0, sa
|
|
| 71 |
# add to fix 'nan' and 'inf'
|
| 72 |
probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
|
| 73 |
probs = torch.clamp(probs, min=0, max=None)
|
| 74 |
-
print(probs.sum())
|
| 75 |
-
print(probs)
|
| 76 |
print(f'inf:{torch.any(torch.isinf(probs))}')
|
| 77 |
print(f'nan: {torch.any(torch.isnan(probs))}')
|
| 78 |
|
|
|
|
| 60 |
logits = logits[:, -1, :] / max(temperature, 1e-5)
|
| 61 |
if top_k > 0 or top_p < 1.0:
|
| 62 |
logits = top_k_top_p_filtering(logits, top_k=top_k, top_p=top_p)
|
| 63 |
+
probs = F.softmax(logits, dim=-1)
|
| 64 |
+
print(probs.sum())
|
| 65 |
+
print(probs)
|
| 66 |
# values, indices = torch.max(probs, dim=1, keepdim=True)
|
| 67 |
# mask = (probs == values).float()
|
| 68 |
# probs = probs * (1 - mask)
|
|
|
|
| 73 |
# add to fix 'nan' and 'inf'
|
| 74 |
probs = torch.where(torch.isnan(probs), torch.tensor(0.0), probs)
|
| 75 |
probs = torch.clamp(probs, min=0, max=None)
|
|
|
|
|
|
|
| 76 |
print(f'inf:{torch.any(torch.isinf(probs))}')
|
| 77 |
print(f'nan: {torch.any(torch.isnan(probs))}')
|
| 78 |
|