Spaces:
Runtime error
Runtime error
bugfix
Browse files
app.py
CHANGED
|
@@ -15,7 +15,7 @@ model = BlipForConditionalGeneration.from_pretrained(
|
|
| 15 |
|
| 16 |
def inference(raw_image, model_n, question, strategy):
|
| 17 |
if model_n == 'Image Captioning':
|
| 18 |
-
|
| 19 |
with torch.no_grad():
|
| 20 |
if strategy == "Beam search":
|
| 21 |
config = GenerationConfig(
|
|
@@ -24,7 +24,7 @@ def inference(raw_image, model_n, question, strategy):
|
|
| 24 |
max_length=20,
|
| 25 |
min_length=5,
|
| 26 |
)
|
| 27 |
-
captions = model.generate(
|
| 28 |
else:
|
| 29 |
config = GenerationConfig(
|
| 30 |
do_sample=True,
|
|
@@ -32,7 +32,7 @@ def inference(raw_image, model_n, question, strategy):
|
|
| 32 |
max_length=20,
|
| 33 |
min_length=5,
|
| 34 |
)
|
| 35 |
-
captions = model.generate(
|
| 36 |
caption = processor.decode(captions[0], skip_special_tokens=True)
|
| 37 |
caption = caption.replace(' ', '')
|
| 38 |
return 'caption: '+caption
|
|
|
|
| 15 |
|
| 16 |
def inference(raw_image, model_n, question, strategy):
|
| 17 |
if model_n == 'Image Captioning':
|
| 18 |
+
input = processor(raw_image).to(device, torch.float16)
|
| 19 |
with torch.no_grad():
|
| 20 |
if strategy == "Beam search":
|
| 21 |
config = GenerationConfig(
|
|
|
|
| 24 |
max_length=20,
|
| 25 |
min_length=5,
|
| 26 |
)
|
| 27 |
+
captions = model.generate(**input, generation_config=config)
|
| 28 |
else:
|
| 29 |
config = GenerationConfig(
|
| 30 |
do_sample=True,
|
|
|
|
| 32 |
max_length=20,
|
| 33 |
min_length=5,
|
| 34 |
)
|
| 35 |
+
captions = model.generate(**input, generation_config=config)
|
| 36 |
caption = processor.decode(captions[0], skip_special_tokens=True)
|
| 37 |
caption = caption.replace(' ', '')
|
| 38 |
return 'caption: '+caption
|