Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -439,7 +439,7 @@ class UniMCPredict:
|
|
| 439 |
batch = [self.data_model.train_data.encode(
|
| 440 |
sample) for sample in batch_data]
|
| 441 |
batch = self.data_model.collate_fn(batch)
|
| 442 |
-
batch = {k: v.cuda() for k, v in batch.items()}
|
| 443 |
_, _, logits = self.model.model(**batch)
|
| 444 |
soft_logits = torch.nn.functional.softmax(logits, dim=-1)
|
| 445 |
logits = torch.argmax(soft_logits, dim=-1).detach().cpu().numpy()
|
|
|
|
| 439 |
batch = [self.data_model.train_data.encode(
|
| 440 |
sample) for sample in batch_data]
|
| 441 |
batch = self.data_model.collate_fn(batch)
|
| 442 |
+
# batch = {k: v.cuda() for k, v in batch.items()}
|
| 443 |
_, _, logits = self.model.model(**batch)
|
| 444 |
soft_logits = torch.nn.functional.softmax(logits, dim=-1)
|
| 445 |
logits = torch.argmax(soft_logits, dim=-1).detach().cpu().numpy()
|