medchat / embedding_service.py
vihashini-18
i
0a5c991
"""
Module for handling embeddings and Pinecone operations
"""
from pinecone import Pinecone, ServerlessSpec
from sentence_transformers import SentenceTransformer
import numpy as np
import time
from typing import List, Dict, Any
from config import (
PINECONE_API_KEY,
INDEX_NAME,
NAMESPACE,
EMBEDDING_MODEL
)
class EmbeddingService:
def __init__(self):
"""Initialize embedding model and Pinecone connection"""
print(f"Loading embedding model: {EMBEDDING_MODEL}")
self.model = SentenceTransformer(EMBEDDING_MODEL)
# Initialize Pinecone
self.pc = Pinecone(api_key=PINECONE_API_KEY)
# Check if index exists
if INDEX_NAME not in [idx.name for idx in self.pc.list_indexes()]:
print(f"Creating index: {INDEX_NAME}")
self.pc.create_index(
name=INDEX_NAME,
dimension=384, # Dimension for all-MiniLM-L6-v2
metric='cosine',
spec=ServerlessSpec(
cloud='aws',
region='us-east-1'
)
)
time.sleep(2)
self.index = self.pc.Index(INDEX_NAME)
print("Pinecone connection established")
def create_embeddings(self, texts: List[str]) -> List[List[float]]:
"""Create embeddings for a list of texts"""
embeddings = self.model.encode(texts, show_progress_bar=True)
return embeddings.tolist()
def upsert_documents(self, documents: List[Dict[str, Any]]):
"""Upload documents to Pinecone"""
print(f"Preparing to upload {len(documents)} documents...")
vectors = []
texts = [doc['text'] for doc in documents]
embeddings = self.create_embeddings(texts)
for idx, (doc, embedding) in enumerate(zip(documents, embeddings)):
vector_id = f"doc_{idx}_{int(time.time())}"
vectors.append({
'id': vector_id,
'values': embedding,
'metadata': {
'text': doc['text'],
'source': doc['source'],
'question': doc['metadata'].get('question', ''),
'answer': doc['metadata'].get('answer', ''),
'type': doc['metadata'].get('type', ''),
}
})
# Upload in batches
batch_size = 100
for i in range(0, len(vectors), batch_size):
batch = vectors[i:i + batch_size]
self.index.upsert(batch, namespace=NAMESPACE)
print(f"Uploaded batch {i//batch_size + 1}/{(len(vectors) + batch_size - 1)//batch_size}")
print(f"Successfully uploaded {len(documents)} documents to Pinecone")
def search(self, query: str, top_k: int = 5) -> List[Dict[str, Any]]:
"""Search for similar documents"""
query_embedding = self.model.encode(query).tolist()
results = self.index.query(
vector=query_embedding,
top_k=top_k,
namespace=NAMESPACE,
include_metadata=True
)
return results