Torch detection bbox differs from JAX models?
I was using this model for detection and I noticed that the detection output tokens (the bbox) differs quite a bit between this model (when running locally) and the online HF space (that uses the JAX model).
The torch-version detections seem to be less accurate.
Often it is the second bbox coordinate (bottom-right) that is off, wherease the first coordinate is usually the same.
I'll give an example below:
torch local:<loc0379><loc0120><loc0761><loc0703> mug
jax hf space (from [here[(https://huggingface.co/spaces/big-vision/paligemma))
<loc0379><loc0120><loc0759><loc0731> mug
original image:
code to reproduce local coords:
import numpy as np
from PIL import Image
import requests
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
import torch
processor = AutoProcessor.from_pretrained("google/paligemma-3b-mix-448")
model = PaliGemmaForConditionalGeneration.from_pretrained("google/paligemma-3b-mix-448",device_map="cuda:0",revision="bfloat16",torch_dtype=torch.bfloat16).eval()
prompt = "detect mug"
url = "2024-06-10_17-46.png"
image = Image.open(url)
# url = "https://huggingface.co/spaces/big-vision/paligemma/resolve/main/examples/cc_fox.jpg?download=true"
# image = Image.open(requests.get(url, stream=True).raw)
image = np.array(image)[...,:3]
inputs = processor(text=prompt, images=np.array(image), return_tensors="pt")
inputs = {name: tensor.cuda() for name, tensor in inputs.items()}
# Generate
generate_ids = model.generate(**inputs, max_length=2000)
output = processor.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
print(output)
Are these differences expected? Because qualitatively it feels as if the results are a lot better for the JAX models.
I'd expect the bbox to be the same for all versions
question: is the image preprocessing exact the same?
If not, it would explain the difference
I noticed the same, why is this happening?
@emanuelevivoli , I don't have a clue so far. If you have any suggestions, be my guest :) I ended up using GroundingDINO for now, as I did not want to go through the hassle of installing the JAX support stack for PaliGemma. But Paligemma feels superior imo, sou would be happy to fix this issue.
@gusthema , images get resized etc by the HF tokenizer so I would expect it to be the same. Particular steps I should look into?
Hello! Thanks all for the super-detailed reports, I'll take a look to see if we can find the reason for the discrepancy.
We're still investigating. A temporary workaround, as noted here, is to disable key-value caching with use_cache=True when calling generate. This results in very similar tokens to the ones produced by the JAX pipeline. There are still minor differences, mostly due to numerical differences in the pre-processing algorithms.
thanks to everyone here for reporting! I had time to check this today. It is indeed due to a miscalculation on the attention mask in the generation step, causing it to miss a part of past context. Should be able to patch this at worst tomorrow.
With a quick fix of the attention mask I'm getting results still slightly different from jax, as said by @pcuenq it's mostly numerical fluctuations.
Thanks so much!
You're welcome! and @tlpss I tested with your originally reported example, seems to work as well now
once https://github.com/huggingface/transformers/pull/31587 is merged to main you'll be able to use Paligemma from transformer:main (and in the next release of transformers) and detection/segmentation tasks should be fine.







