fix typo
Browse files- modeling_tunbert.py +1 -1
modeling_tunbert.py
CHANGED
|
@@ -42,7 +42,7 @@ class TunBERT(PreTrainedModel):
|
|
| 42 |
def process(self,**inputs):
|
| 43 |
with torch.no_grad():
|
| 44 |
out = self.forward(**inputs)
|
| 45 |
-
out = torch.argmax(
|
| 46 |
return ["positive" if index == 0 else "negative" for index in out.tolist()]
|
| 47 |
|
| 48 |
|
|
|
|
| 42 |
def process(self,**inputs):
|
| 43 |
with torch.no_grad():
|
| 44 |
out = self.forward(**inputs)
|
| 45 |
+
out = torch.argmax(out.logits,dim=1)
|
| 46 |
return ["positive" if index == 0 else "negative" for index in out.tolist()]
|
| 47 |
|
| 48 |
|