vikramvasudevan commited on
Commit
74c37c0
·
verified ·
1 Parent(s): cbc9372

Upload folder using huggingface_hub

Browse files
Files changed (7) hide show
  1. Dockerfile +0 -8
  2. README.md +3 -3
  3. app.py +1 -3
  4. config.py +11 -4
  5. copy_chromadb.py +41 -12
  6. db.py +7 -2
  7. embeddings.py +69 -6
Dockerfile CHANGED
@@ -1,9 +1,5 @@
1
  FROM python:3.12-slim
2
 
3
- # Add near the top of Dockerfile
4
- ENV HF_HOME=/app/hf_cache
5
- RUN mkdir -p $HF_HOME && chmod 777 $HF_HOME
6
-
7
  # Avoid interactive prompts during build
8
  ENV DEBIAN_FRONTEND=noninteractive
9
 
@@ -34,9 +30,5 @@ RUN pip install --no-cache-dir -r requirements.txt
34
  COPY . /app
35
  WORKDIR /app
36
 
37
- RUN useradd -m appuser
38
- RUN mkdir -p /app/chroma_db && chown -R appuser:appuser /app
39
- USER appuser
40
-
41
  # Default command (Gradio, Streamlit, or Python)
42
  CMD ["python", "app.py"]
 
1
  FROM python:3.12-slim
2
 
 
 
 
 
3
  # Avoid interactive prompts during build
4
  ENV DEBIAN_FRONTEND=noninteractive
5
 
 
30
  COPY . /app
31
  WORKDIR /app
32
 
 
 
 
 
33
  # Default command (Gradio, Streamlit, or Python)
34
  CMD ["python", "app.py"]
README.md CHANGED
@@ -1,7 +1,7 @@
1
  ---
2
  title: sanatan_ai
3
  app_file: app.py
4
- sdk: docker
 
5
  python_version: 3.12
6
- emoji: 👀
7
- ---
 
1
  ---
2
  title: sanatan_ai
3
  app_file: app.py
4
+ sdk: gradio
5
+ sdk_version: 5.38.0
6
  python_version: 3.12
7
+ ---
 
app.py CHANGED
@@ -468,6 +468,4 @@ with gr.Blocks(
468
  textbox=message_textbox,
469
  )
470
 
471
- port = int(os.environ.get("PORT", 7860))
472
- app.launch(server_name="0.0.0.0", server_port=port)
473
-
 
468
  textbox=message_textbox,
469
  )
470
 
471
+ app.launch()
 
 
config.py CHANGED
@@ -159,6 +159,7 @@ class SanatanConfig:
159
  "title": "4000 Divya Prabandham",
160
  "output_dir": "./output/divya_prabandham",
161
  "collection_name": "divya_prabandham",
 
162
  "metadata_fields": [
163
  {
164
  "name": "prabandham_code",
@@ -381,8 +382,7 @@ class SanatanConfig:
381
  "Show detailed commentary for sloka 2 from Chathusloki",
382
  "What is the role of Sri Devi in the universe according to the Chathusloki?",
383
  ],
384
- "llm_hints" : [
385
- ]
386
  },
387
  {
388
  "name": "sri_stavam",
@@ -420,9 +420,9 @@ class SanatanConfig:
420
  "Show detailed commentary for sloka 2 from Sri Stavam",
421
  "What is the role of Sri Devi in the universe according to the Sri Stavam?",
422
  ],
423
- "llm_hints" : [
424
  "if the user asks for nth sloka, do a metadata search on the `verse` field."
425
- ]
426
  },
427
  ]
428
 
@@ -445,3 +445,10 @@ class SanatanConfig:
445
  f"metadata_field: [{filter.metadata_field}] not allowed in collection [{collection_name}]. Here are the allowed fields with their descriptions: {scripture["metadata_fields"]}"
446
  )
447
  return True
 
 
 
 
 
 
 
 
159
  "title": "4000 Divya Prabandham",
160
  "output_dir": "./output/divya_prabandham",
161
  "collection_name": "divya_prabandham",
162
+ "collection_embedding_fn": "openai",
163
  "metadata_fields": [
164
  {
165
  "name": "prabandham_code",
 
382
  "Show detailed commentary for sloka 2 from Chathusloki",
383
  "What is the role of Sri Devi in the universe according to the Chathusloki?",
384
  ],
385
+ "llm_hints": [],
 
386
  },
387
  {
388
  "name": "sri_stavam",
 
420
  "Show detailed commentary for sloka 2 from Sri Stavam",
421
  "What is the role of Sri Devi in the universe according to the Sri Stavam?",
422
  ],
423
+ "llm_hints": [
424
  "if the user asks for nth sloka, do a metadata search on the `verse` field."
425
+ ],
426
  },
427
  ]
428
 
 
445
  f"metadata_field: [{filter.metadata_field}] not allowed in collection [{collection_name}]. Here are the allowed fields with their descriptions: {scripture["metadata_fields"]}"
446
  )
447
  return True
448
+
449
+ def get_embedding_for_collection(self, collection_name: str):
450
+ scripture = self.get_scripture_by_collection(collection_name)
451
+ embedding_fn = "hf" # default is huggingface sentence transformaers
452
+ if "collection_embedding_fn" in scripture:
453
+ embedding_fn = scripture["collection_embedding_fn"] # overridden in config
454
+ return embedding_fn
copy_chromadb.py CHANGED
@@ -1,22 +1,51 @@
 
1
  import chromadb
2
  from tqdm import tqdm # Optional: For progress bar
3
 
4
- # Connect to source and destination local persistent clients
5
- source_client = chromadb.PersistentClient(
6
- path="../vedam_ai/chromadb-store"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7
  )
 
 
 
 
 
 
 
 
 
 
 
 
8
  destination_client = chromadb.PersistentClient(path="./chromadb-store")
9
 
10
- source_collection_name = "sri_stavam"
11
- destination_collection_name = "sri_stavam"
12
 
13
  # Get the source collection
14
  source_collection = source_client.get_collection(source_collection_name)
15
 
16
  # Retrieve all data from the source collection
17
- source_data = source_collection.get(
18
- include=["documents", "metadatas", "embeddings"]
19
- )
20
 
21
  # Create or get the destination collection
22
  if destination_client.get_or_create_collection(destination_collection_name):
@@ -35,11 +64,11 @@ total_records = len(source_data["ids"])
35
  print(f"Copying {total_records} records in batches of {BATCH_SIZE}...")
36
 
37
  for i in tqdm(range(0, total_records, BATCH_SIZE)):
38
- batch_ids = source_data["ids"][i:i + BATCH_SIZE]
39
- batch_docs = source_data["documents"][i:i + BATCH_SIZE]
40
- batch_metas = source_data["metadatas"][i:i + BATCH_SIZE]
41
  batch_embeds = (
42
- source_data["embeddings"][i:i + BATCH_SIZE]
43
  if "embeddings" in source_data and source_data["embeddings"] is not None
44
  else None
45
  )
 
1
+ import argparse
2
  import chromadb
3
  from tqdm import tqdm # Optional: For progress bar
4
 
5
+ db_config = {
6
+ "youtube_db": {
7
+ "source_db_path": "../youtube_surfer_ai_agent/youtube_db",
8
+ "source_collection_name": "yt_metadata",
9
+ "destination_collection_name": "yt_metadata",
10
+ },
11
+ "divya_prabandham": {
12
+ "source_db_path": "../uveda_analyzer/chromadb_store",
13
+ "source_collection_name": "divya_prabandham",
14
+ "destination_collection_name": "divya_prabandham",
15
+ },
16
+ }
17
+
18
+ parser = argparse.ArgumentParser(description="My app with database parameter")
19
+
20
+ parser.add_argument(
21
+ "--db",
22
+ type=str,
23
+ required=True,
24
+ choices=list(db_config.keys()),
25
+ help=f"Id of the database to use. allowed_values : {', '.join(db_config.keys())}",
26
  )
27
+
28
+ args = parser.parse_args()
29
+
30
+ db_id = args.db
31
+
32
+ if db_id is None:
33
+ raise Exception(f"No db provided!")
34
+ if db_id not in db_config:
35
+ raise Exception(f"db with id {db_id} not found!")
36
+
37
+ # Connect to source and destination local persistent clients
38
+ source_client = chromadb.PersistentClient(path=db_config[db_id]["source_db_path"])
39
  destination_client = chromadb.PersistentClient(path="./chromadb-store")
40
 
41
+ source_collection_name = db_config[db_id]["source_collection_name"]
42
+ destination_collection_name = db_config[db_id]["destination_collection_name"]
43
 
44
  # Get the source collection
45
  source_collection = source_client.get_collection(source_collection_name)
46
 
47
  # Retrieve all data from the source collection
48
+ source_data = source_collection.get(include=["documents", "metadatas", "embeddings"])
 
 
49
 
50
  # Create or get the destination collection
51
  if destination_client.get_or_create_collection(destination_collection_name):
 
64
  print(f"Copying {total_records} records in batches of {BATCH_SIZE}...")
65
 
66
  for i in tqdm(range(0, total_records, BATCH_SIZE)):
67
+ batch_ids = source_data["ids"][i : i + BATCH_SIZE]
68
+ batch_docs = source_data["documents"][i : i + BATCH_SIZE]
69
+ batch_metas = source_data["metadatas"][i : i + BATCH_SIZE]
70
  batch_embeds = (
71
+ source_data["embeddings"][i : i + BATCH_SIZE]
72
  if "embeddings" in source_data and source_data["embeddings"] is not None
73
  else None
74
  )
db.py CHANGED
@@ -34,10 +34,13 @@ class SanatanDatabase:
34
  logger.info("Vector Semantic Search for [%s] in [%s]", query, collection_name)
35
  collection = self.chroma_client.get_or_create_collection(name=collection_name)
36
  response = collection.query(
37
- query_embeddings=[get_embedding(query)],
 
 
38
  # query_texts=[query],
39
  n_results=n_results,
40
  )
 
41
  return response
42
 
43
  def search_for_literal(
@@ -137,7 +140,9 @@ class SanatanDatabase:
137
  )
138
  collection = self.chroma_client.get_or_create_collection(name=collection_name)
139
  response = collection.query(
140
- query_embeddings=[get_embedding(query)],
 
 
141
  where=metadata_where_clause.to_chroma_where(),
142
  # query_texts=[query],
143
  n_results=n_results,
 
34
  logger.info("Vector Semantic Search for [%s] in [%s]", query, collection_name)
35
  collection = self.chroma_client.get_or_create_collection(name=collection_name)
36
  response = collection.query(
37
+ query_embeddings=get_embedding(
38
+ [query], SanatanConfig().get_embedding_for_collection(collection_name)
39
+ ),
40
  # query_texts=[query],
41
  n_results=n_results,
42
  )
43
+ # logger.info("number of matches = %d", len(response["metadatas"]))
44
  return response
45
 
46
  def search_for_literal(
 
140
  )
141
  collection = self.chroma_client.get_or_create_collection(name=collection_name)
142
  response = collection.query(
143
+ query_embeddings=get_embedding(
144
+ [query], SanatanConfig().get_embedding_for_collection(collection_name)
145
+ ),
146
  where=metadata_where_clause.to_chroma_where(),
147
  # query_texts=[query],
148
  n_results=n_results,
embeddings.py CHANGED
@@ -1,9 +1,72 @@
1
-
 
2
  from sentence_transformers import SentenceTransformer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- # Step 1: Load SentenceTransformer model
5
- # model = SentenceTransformer("all-MiniLM-L6-v2")
6
- model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
 
 
 
 
 
7
 
8
- def get_embedding(text: str) -> list:
9
- return model.encode(text).tolist()
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ import numpy as np
3
  from sentence_transformers import SentenceTransformer
4
+ from openai import OpenAI
5
+ from dotenv import load_dotenv
6
+ import tiktoken
7
+
8
+ load_dotenv()
9
+
10
+ # Local HuggingFace model
11
+ hf_model = SentenceTransformer("sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2")
12
+
13
+ # OpenAI client
14
+ client = OpenAI()
15
+
16
+ # Choose tokenizer for embeddings model
17
+ tokenizer = tiktoken.encoding_for_model("text-embedding-3-large")
18
+
19
+ # -------------------------------
20
+ # Helpers
21
+ # -------------------------------
22
+ def _get_hf_embedding(texts: list[str]) -> list[list[float]]:
23
+ """Get embeddings using HuggingFace SentenceTransformer."""
24
+ return hf_model.encode(texts).tolist()
25
+
26
+ def chunk_text(text: str, max_tokens: int = 1000) -> list[str]:
27
+ tokens = tokenizer.encode(text)
28
+ return [tokenizer.decode(tokens[i:i+max_tokens]) for i in range(0, len(tokens), max_tokens)]
29
+
30
+ def _get_openai_embedding(texts: list[str]) -> list[list[float]]:
31
+ """Get embeddings for a list of texts. If a text is too long, chunk + average."""
32
+ final_embeddings = []
33
+
34
+ for text in texts:
35
+ # Split into chunks if too long
36
+ if len(tokenizer.encode(text)) > 8192:
37
+ chunks = chunk_text(text)
38
+ else:
39
+ chunks = [text]
40
+
41
+ # Call API on all chunks at once
42
+ response = client.embeddings.create(
43
+ model="text-embedding-3-large",
44
+ input=chunks
45
+ )
46
+ chunk_embeddings = [np.array(d.embedding) for d in response.data]
47
+
48
+ # Average embeddings if multiple chunks
49
+ avg_embedding = np.mean(chunk_embeddings, axis=0)
50
+ final_embeddings.append(avg_embedding.tolist())
51
+
52
+ return final_embeddings
53
 
54
+ def get_embedding(texts: list[str], backend: Literal["hf","openai"] = "hf") -> list[list[float]]:
55
+ """
56
+ Get embeddings for a list of texts.
57
+ backend = "openai" or "hf"
58
+ """
59
+ if backend == "hf":
60
+ return _get_hf_embedding(texts)
61
+ return _get_openai_embedding(texts)
62
 
63
+ # -------------------------------
64
+ # Example
65
+ # -------------------------------
66
+ if __name__ == "__main__":
67
+ texts = [
68
+ "short text example",
69
+ "very long text " * 2000 # will get chunked
70
+ ]
71
+ embs = get_embedding(texts, backend="openai")
72
+ print(len(embs), "embeddings returned")