Embedding is different depending on batch contents (when sequence lengths are different)

#28
by praateek-nv - opened

Maybe this is known, but I ran into this today where produces different embeddings for sequences with length exactly 129 tokens (and similarly 193, 257; if you notice they all are (2**i)+1) when batched with shorter sequences. The embeddings differ by ~1% from the ground truth (individual encoding), despite proper attention masking.

It might be due to RoPE / Sliding Attention but I'm curious about it since it's always at 2**i+1 boundary. This doesn't happen with sentence-transformers/all-MiniLM-L6-v2 which might be due to architectural reasons (but it'll be nice to understand which part of architecture causes this).

Detailed Test Results

Test Case embeddinggemma-300m all-MiniLM-L6-v2
129 + 129 tokens (same length) ✅ OK (0.00000013) ✅ OK (0.00000002)
113 + 128 tokens (below boundary) ✅ OK (0.00000015) ✅ OK (0.00000004)
128 + 129 tokens (crosses boundary) BUG (0.01030827) ✅ OK (0.00000002)
113 + 129 tokens (crosses boundary) BUG (0.01030827) ✅ OK (0.00000002)
113 + 130 tokens (skips 129) ✅ OK (0.00000017) ✅ OK (0.00000002)

Quick Test

from sentence_transformers import SentenceTransformer
import numpy as np

model = SentenceTransformer("google/embeddinggemma-300m").to("cuda")

# Create texts with exactly 113 and 129 tokens
TEXT_113 = ("Hello " * 111).strip()  # 113 tokens
TEXT_129 = ("Hello " * 127).strip()  # 129 tokens

# Get baseline: encode TEXT_129 alone
baseline = model.encode([TEXT_129], batch_size=1)

# Batch with shorter text
batched = model.encode([TEXT_113, TEXT_129], batch_size=2)

# Compare
diff = np.abs(baseline - batched[1:2]).max()
print(f"Difference: {diff:.8f}")  # Shows ~0.01 (1% error!)

Here is a github gist too https://gist.github.com/praateekmahajan/17abaf7bfe435cd6cbb98ac6d0650377

Hi @praateek-nv Apologies for late response . I was unable to reproduce your reported results after running the provided Quick Test code. My results, specifically a Difference of 0.00000024, aligned more closely with the all-MiniLM-L6-v2 model's expected output. Could you please advise if I'm missing any steps ?

Sign up or log in to comment