Commit
·
548d485
1
Parent(s):
9578f22
Update ced_model/modeling_ced.py
Browse files
ced_model/modeling_ced.py
CHANGED
|
@@ -457,9 +457,7 @@ class CedModel(CedPreTrainedModel):
|
|
| 457 |
n_splits = 1
|
| 458 |
|
| 459 |
x = self.forward_features(x)
|
| 460 |
-
|
| 461 |
-
x = torch.flatten(x, 0, 1)
|
| 462 |
-
x = torch.unsqueeze(x, 0)
|
| 463 |
|
| 464 |
return SequenceClassifierOutput(logits=x)
|
| 465 |
|
|
|
|
| 457 |
n_splits = 1
|
| 458 |
|
| 459 |
x = self.forward_features(x)
|
| 460 |
+
x = torch.reshape(x, (x.shape[0] // n_splits, -1, x.shape[-1]))
|
|
|
|
|
|
|
| 461 |
|
| 462 |
return SequenceClassifierOutput(logits=x)
|
| 463 |
|