Spaces:
Running
Running
Update models/tag2text.py
Browse files- models/tag2text.py +3 -0
models/tag2text.py
CHANGED
|
@@ -25,6 +25,8 @@ import numpy as np
|
|
| 25 |
def read_json(rpath):
|
| 26 |
with open(rpath, 'r') as f:
|
| 27 |
return json.load(f)
|
|
|
|
|
|
|
| 28 |
|
| 29 |
class Tag2Text_Caption(nn.Module):
|
| 30 |
def __init__(self,
|
|
@@ -132,6 +134,7 @@ class Tag2Text_Caption(nn.Module):
|
|
| 132 |
targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
| 133 |
|
| 134 |
tag = targets.cpu().numpy()
|
|
|
|
| 135 |
bs = image.size(0)
|
| 136 |
tag_input = []
|
| 137 |
for b in range(bs):
|
|
|
|
| 25 |
def read_json(rpath):
|
| 26 |
with open(rpath, 'r') as f:
|
| 27 |
return json.load(f)
|
| 28 |
+
|
| 29 |
+
delete_tag_index = [135]
|
| 30 |
|
| 31 |
class Tag2Text_Caption(nn.Module):
|
| 32 |
def __init__(self,
|
|
|
|
| 134 |
targets = torch.where(torch.sigmoid(logits) > self.threshold , torch.tensor(1.0).to(image.device), torch.zeros(self.num_class).to(image.device))
|
| 135 |
|
| 136 |
tag = targets.cpu().numpy()
|
| 137 |
+
tag[:,delete_tag_index] = 0
|
| 138 |
bs = image.size(0)
|
| 139 |
tag_input = []
|
| 140 |
for b in range(bs):
|