Update app.py
Browse files
app.py
CHANGED
|
@@ -82,7 +82,7 @@ class AttentionBlock(nn.Module):
|
|
| 82 |
|
| 83 |
self.do = nn.Dropout(dropout)
|
| 84 |
|
| 85 |
-
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).
|
| 86 |
|
| 87 |
def forward(self, query, key, value, mask=None):
|
| 88 |
batch_size = query.shape[0]
|
|
|
|
| 82 |
|
| 83 |
self.do = nn.Dropout(dropout)
|
| 84 |
|
| 85 |
+
self.scale = torch.sqrt(torch.FloatTensor([hid_dim // n_heads])).to(device)
|
| 86 |
|
| 87 |
def forward(self, query, key, value, mask=None):
|
| 88 |
batch_size = query.shape[0]
|