sanatan_ai / embeddings.py
vikramvasudevan's picture
Upload folder using huggingface_hub
2a96fbf verified
from typing import Literal
import numpy as np
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from dotenv import load_dotenv
import tiktoken
load_dotenv()
# Local HuggingFace model
hf_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
# OpenAI client
client = OpenAI()
# Choose tokenizer for embeddings model
tokenizer = tiktoken.encoding_for_model("text-embedding-3-large")
# -------------------------------
# Helpers
# -------------------------------
def _get_hf_embedding(texts: list[str]) -> list[list[float]]:
"""Get embeddings using HuggingFace SentenceTransformer."""
return hf_model.encode(texts).tolist()
def chunk_text(text: str, max_tokens: int = 1000) -> list[str]:
tokens = tokenizer.encode(text)
return [tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens)]
import numpy as np
EMBED_DIM = 3072 # dimension of text-embedding-3-large
def _get_openai_embedding(texts: list[str]) -> list[list[float]]:
"""Get embeddings for a list of texts. If a text is too long, chunk + average."""
final_embeddings = []
for text in texts:
if not text or not isinstance(text, str) or not text.strip():
# fallback: skip or append zero vector
final_embeddings.append([0.0] * EMBED_DIM)
continue
# Split into chunks if too long
if len(tokenizer.encode(text)) > 8192:
chunks = chunk_text(text)
else:
chunks = [text]
# Clean chunks
clean_chunks = [c.strip() for c in chunks if isinstance(c, str) and c.strip()]
if not clean_chunks:
final_embeddings.append([0.0] * EMBED_DIM)
continue
try:
response = client.embeddings.create(
model="text-embedding-3-large",
input=clean_chunks
)
chunk_embeddings = [np.array(d.embedding) for d in response.data]
avg_embedding = np.mean(chunk_embeddings, axis=0)
final_embeddings.append(avg_embedding.tolist())
except Exception as e:
print(f"Embedding failed for text[:100]={text[:100]!r}, error={e}")
final_embeddings.append([0.0] * EMBED_DIM) # fallback
return final_embeddings
embedding_cache = {}
def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]:
key = (backend, tuple(texts)) # tuple is hashable
if key in embedding_cache:
return embedding_cache[key]
if backend == "hf":
embedding_cache[key] = _get_hf_embedding(texts)
else:
embedding_cache[key] = _get_openai_embedding(texts)
return embedding_cache[key]
# -------------------------------
# Example
# -------------------------------
if __name__ == "__main__":
texts = [
"short text example",
"very long text " * 2000 # will get chunked
]
embs = get_embedding(texts, backend="openai")
print(len(embs), "embeddings returned")