File size: 3,056 Bytes
74c37c0
 
fd1b271
74c37c0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a96fbf
 
 
 
74c37c0
 
 
 
 
2a96fbf
 
 
 
 
74c37c0
 
 
 
 
 
2a96fbf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74c37c0
 
fd1b271
2a96fbf
a88708b
 
74c37c0
a88708b
 
 
 
74c37c0
a88708b
 
 
 
 
fd1b271
74c37c0
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
from typing import Literal
import numpy as np
from sentence_transformers import SentenceTransformer
from openai import OpenAI
from dotenv import load_dotenv
import tiktoken

load_dotenv()

# Local HuggingFace model
hf_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")

# OpenAI client
client = OpenAI()

# Choose tokenizer for embeddings model
tokenizer = tiktoken.encoding_for_model("text-embedding-3-large")

# -------------------------------
# Helpers
# -------------------------------
def _get_hf_embedding(texts: list[str]) -> list[list[float]]:
    """Get embeddings using HuggingFace SentenceTransformer."""
    return hf_model.encode(texts).tolist()

def chunk_text(text: str, max_tokens: int = 1000) -> list[str]:
    tokens = tokenizer.encode(text)
    return [tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens)]

import numpy as np

EMBED_DIM = 3072  # dimension of text-embedding-3-large

def _get_openai_embedding(texts: list[str]) -> list[list[float]]:
    """Get embeddings for a list of texts. If a text is too long, chunk + average."""
    final_embeddings = []

    for text in texts:
        if not text or not isinstance(text, str) or not text.strip():
            # fallback: skip or append zero vector
            final_embeddings.append([0.0] * EMBED_DIM)
            continue

        # Split into chunks if too long
        if len(tokenizer.encode(text)) > 8192:
            chunks = chunk_text(text)
        else:
            chunks = [text]

        # Clean chunks
        clean_chunks = [c.strip() for c in chunks if isinstance(c, str) and c.strip()]
        if not clean_chunks:
            final_embeddings.append([0.0] * EMBED_DIM)
            continue

        try:
            response = client.embeddings.create(
                model="text-embedding-3-large",
                input=clean_chunks
            )
            chunk_embeddings = [np.array(d.embedding) for d in response.data]
            avg_embedding = np.mean(chunk_embeddings, axis=0)
            final_embeddings.append(avg_embedding.tolist())
        except Exception as e:
            print(f"Embedding failed for text[:100]={text[:100]!r}, error={e}")
            final_embeddings.append([0.0] * EMBED_DIM)  # fallback

    return final_embeddings


embedding_cache = {}

def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]:
    key = (backend, tuple(texts))  # tuple is hashable
    if key in embedding_cache:
        return embedding_cache[key]

    if backend == "hf":
        embedding_cache[key] = _get_hf_embedding(texts)
    else:
        embedding_cache[key] = _get_openai_embedding(texts)

    return embedding_cache[key]

# -------------------------------
# Example
# -------------------------------
if __name__ == "__main__":
    texts = [
        "short text example",
        "very long text " * 2000  # will get chunked
    ]
    embs = get_embedding(texts, backend="openai")
    print(len(embs), "embeddings returned")