vikramvasudevan commited on
Commit
d2bda67
·
verified ·
1 Parent(s): b2bbee4

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. copy_chromadb.py +5 -2
  2. db.py +20 -1
copy_chromadb.py CHANGED
@@ -10,10 +10,11 @@ destination_collection_name = "vishnu_sahasranamam"
10
  source_collection = source_client.get_collection(source_collection_name)
11
 
12
  # Retrieve all data from the source collection
13
- source_data = source_collection.get(ids=source_collection.peek()["ids"]) # Efficiently get all IDs
14
 
15
  # Create or get the destination collection
16
  if destination_client.get_or_create_collection(destination_collection_name):
 
17
  destination_client.delete_collection(destination_collection_name)
18
 
19
  destination_collection = destination_client.get_or_create_collection(
@@ -29,4 +30,6 @@ destination_collection.add(
29
  embeddings=source_data.get("embeddings") # Include embeddings if they exist in source
30
  )
31
 
32
- print("Collection copied successfully!")
 
 
 
10
  source_collection = source_client.get_collection(source_collection_name)
11
 
12
  # Retrieve all data from the source collection
13
+ source_data = source_collection.get()
14
 
15
  # Create or get the destination collection
16
  if destination_client.get_or_create_collection(destination_collection_name):
17
+ print("Deleting existing collection", destination_collection_name)
18
  destination_client.delete_collection(destination_collection_name)
19
 
20
  destination_collection = destination_client.get_or_create_collection(
 
30
  embeddings=source_data.get("embeddings") # Include embeddings if they exist in source
31
  )
32
 
33
+ print("Collection copied successfully!")
34
+ print("Total records in source collection = ", source_collection.count())
35
+ print("Total records in destination collection = ", destination_collection.count())
db.py CHANGED
@@ -28,8 +28,27 @@ class SanatanDatabase:
28
  )
29
 
30
  def search(self, collection_name: str, query: str, n_results=2):
 
31
  collection = self.chroma_client.get_or_create_collection(name=collection_name)
32
  response = collection.query(
33
- query_embeddings=[get_embedding(query)], n_results=n_results
 
 
34
  )
35
  return response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  )
29
 
30
  def search(self, collection_name: str, query: str, n_results=2):
31
+ logger.info("Searching for [%s] in [%s]", query, collection_name)
32
  collection = self.chroma_client.get_or_create_collection(name=collection_name)
33
  response = collection.query(
34
+ query_embeddings=[get_embedding(query)],
35
+ # query_texts=[query],
36
+ n_results=n_results
37
  )
38
  return response
39
+
40
+ def count(self, collection_name: str):
41
+ logger.info("Getting total records in [%s]", collection_name)
42
+ collection = self.chroma_client.get_or_create_collection(name=collection_name)
43
+ return collection.count()
44
+
45
+ if __name__ == "__main__":
46
+ collection_name="vishnu_sahasranamam"
47
+ database = SanatanDatabase()
48
+ print("count = ", database.count(collection_name))
49
+ while True:
50
+ query = input("Search for: ")
51
+ if(query.strip() == ""):
52
+ break
53
+ response = database.search(collection_name=collection_name, query=query, n_results=1)
54
+ print(response["metadatas"][0][0]["translation"])