Spaces:
Runtime error
Runtime error
Commit
·
9ce67d0
1
Parent(s):
7518be4
Update app.py
Browse files
app.py
CHANGED
|
@@ -18,8 +18,8 @@ blip_model_base = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-
|
|
| 18 |
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
| 19 |
blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
| 20 |
|
| 21 |
-
vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 22 |
-
vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 23 |
|
| 24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
|
|
@@ -27,7 +27,7 @@ git_model_base.to(device)
|
|
| 27 |
blip_model_base.to(device)
|
| 28 |
git_model_large.to(device)
|
| 29 |
blip_model_large.to(device)
|
| 30 |
-
vilt_model.to(device)
|
| 31 |
|
| 32 |
def generate_answer_git(processor, model, image, question):
|
| 33 |
# prepare image
|
|
@@ -41,7 +41,7 @@ def generate_answer_git(processor, model, image, question):
|
|
| 41 |
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=128)#50)
|
| 42 |
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 43 |
|
| 44 |
-
return generated_answer
|
| 45 |
|
| 46 |
|
| 47 |
def generate_answer_blip(processor, model, image, question):
|
|
|
|
| 18 |
blip_processor_large = AutoProcessor.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
| 19 |
blip_model_large = BlipForQuestionAnswering.from_pretrained("Salesforce/blip-vqa-capfilt-large")
|
| 20 |
|
| 21 |
+
# vilt_processor = AutoProcessor.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 22 |
+
# vilt_model = ViltForQuestionAnswering.from_pretrained("dandelin/vilt-b32-finetuned-vqa")
|
| 23 |
|
| 24 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 25 |
|
|
|
|
| 27 |
blip_model_base.to(device)
|
| 28 |
git_model_large.to(device)
|
| 29 |
blip_model_large.to(device)
|
| 30 |
+
# vilt_model.to(device)
|
| 31 |
|
| 32 |
def generate_answer_git(processor, model, image, question):
|
| 33 |
# prepare image
|
|
|
|
| 41 |
generated_ids = model.generate(pixel_values=pixel_values, input_ids=input_ids, max_length=128)#50)
|
| 42 |
generated_answer = processor.batch_decode(generated_ids, skip_special_tokens=True)
|
| 43 |
|
| 44 |
+
return generated_answer.replace(question, '').replace(question.lower(), '').strip()
|
| 45 |
|
| 46 |
|
| 47 |
def generate_answer_blip(processor, model, image, question):
|