Spaces:
Running
Running
TedYeh
commited on
Commit
·
a38dfb6
1
Parent(s):
6d2ffd2
update predictor
Browse files- predictor.py +3 -3
predictor.py
CHANGED
|
@@ -201,7 +201,7 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
|
| 201 |
device = torch.device("cuda")
|
| 202 |
translator= Translator(to_lang="zh-TW")
|
| 203 |
|
| 204 |
-
model = CUPredictor()
|
| 205 |
model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
|
| 206 |
# load image-to-text model
|
| 207 |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
@@ -218,13 +218,13 @@ def inference(inp_img, classes = ['big', 'small'], epoch = 6):
|
|
| 218 |
image_tensor = trans(inp_img)
|
| 219 |
image_tensor = image_tensor.unsqueeze(0)
|
| 220 |
with torch.no_grad():
|
| 221 |
-
inputs = image_tensor
|
| 222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
| 223 |
_, preds = torch.max(outputs_c, 1)
|
| 224 |
idx = preds.numpy()[0]
|
| 225 |
|
| 226 |
# unconditional image captioning
|
| 227 |
-
inputs = processor(inp_img, return_tensors="pt")
|
| 228 |
out = model_blip.generate(**inputs)
|
| 229 |
description = processor.decode(out[0], skip_special_tokens=True)
|
| 230 |
description_tw = translator.translate(description)
|
|
|
|
| 201 |
device = torch.device("cuda")
|
| 202 |
translator= Translator(to_lang="zh-TW")
|
| 203 |
|
| 204 |
+
model = CUPredictor()
|
| 205 |
model.load_state_dict(torch.load(f'models/model_{epoch}.pt'))
|
| 206 |
# load image-to-text model
|
| 207 |
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
|
|
|
|
| 218 |
image_tensor = trans(inp_img)
|
| 219 |
image_tensor = image_tensor.unsqueeze(0)
|
| 220 |
with torch.no_grad():
|
| 221 |
+
inputs = image_tensor
|
| 222 |
outputs_c, outputs_h, outputs_b, outputs_w, outputs_hi = model(inputs)
|
| 223 |
_, preds = torch.max(outputs_c, 1)
|
| 224 |
idx = preds.numpy()[0]
|
| 225 |
|
| 226 |
# unconditional image captioning
|
| 227 |
+
inputs = processor(inp_img, return_tensors="pt")
|
| 228 |
out = model_blip.generate(**inputs)
|
| 229 |
description = processor.decode(out[0], skip_special_tokens=True)
|
| 230 |
description_tw = translator.translate(description)
|