sasha HF Staff commited on
Commit
e6fd470
·
1 Parent(s): 8fe1df6
Files changed (1) hide show
  1. app.py +1 -2
app.py CHANGED
@@ -19,14 +19,13 @@ ds = dataset["train"]
19
 
20
 
21
  def query(image, top_k=4):
22
- inputs = feature_extractor(image.convert("RGB"), return_tensors="pt")
23
  model_output = model(**inputs)
24
  embedding = model_output.pooler_output.detach()
25
  results = index.query(embedding, k=top_k)
26
  inx = results[0][0].tolist()
27
  logits = results[1][0].tolist()
28
  images = ds.select(inx)["image"]
29
-
30
  captions = ds.select(inx)["label"]
31
  images_with_captions = [(i, c) for i, c in zip(images, captions)]
32
  labels_with_probs = dict(zip(captions, logits))
 
19
 
20
 
21
  def query(image, top_k=4):
22
+ inputs = feature_extractor(image, return_tensors="pt")
23
  model_output = model(**inputs)
24
  embedding = model_output.pooler_output.detach()
25
  results = index.query(embedding, k=top_k)
26
  inx = results[0][0].tolist()
27
  logits = results[1][0].tolist()
28
  images = ds.select(inx)["image"]
 
29
  captions = ds.select(inx)["label"]
30
  images_with_captions = [(i, c) for i, c in zip(images, captions)]
31
  labels_with_probs = dict(zip(captions, logits))