jpkarthikeyan commited on
Commit
ebc6cc2
·
verified ·
1 Parent(s): a2219d5

Upload folder using huggingface_hub

Browse files
Files changed (37) hide show
  1. .gitattributes +1 -0
  2. .gitignore +29 -0
  3. .idea/.gitignore +8 -0
  4. .idea/Fly-Kite-AirLine-RAG-Project.iml +8 -0
  5. .idea/inspectionProfiles/profiles_settings.xml +6 -0
  6. .idea/misc.xml +7 -0
  7. .idea/modules.xml +8 -0
  8. .idea/workspace.xml +160 -0
  9. DOCUMENTS/Flykite_Airlines_HR_Policy.pdf +3 -0
  10. Dockerfile +18 -0
  11. HostingIntoHuggingFace.py +73 -0
  12. README.md +13 -10
  13. app.py +272 -0
  14. configuration.py +62 -0
  15. create_user_db.py +47 -0
  16. eval_data.json +33 -0
  17. main.py +541 -0
  18. rag_scripts/__pycache__/interfaces.cpython-312.pyc +0 -0
  19. rag_scripts/__pycache__/rag_pipeline.cpython-312.pyc +0 -0
  20. rag_scripts/documents_processing/__pycache__/chunking.cpython-312.pyc +0 -0
  21. rag_scripts/documents_processing/chunking.py +134 -0
  22. rag_scripts/embedding/__pycache__/embedder.cpython-312.pyc +0 -0
  23. rag_scripts/embedding/embedder.py +53 -0
  24. rag_scripts/embedding/vector_db/__pycache__/chroma_db.cpython-312.pyc +0 -0
  25. rag_scripts/embedding/vector_db/__pycache__/faiss_db.cpython-312.pyc +0 -0
  26. rag_scripts/embedding/vector_db/__pycache__/pinecone_db.cpython-312.pyc +0 -0
  27. rag_scripts/embedding/vector_db/chroma_db.py +103 -0
  28. rag_scripts/embedding/vector_db/faiss_db.py +145 -0
  29. rag_scripts/embedding/vector_db/pinecone_db.py +202 -0
  30. rag_scripts/evaluation/__pycache__/evaluator.cpython-312.pyc +0 -0
  31. rag_scripts/evaluation/evaluator.py +525 -0
  32. rag_scripts/interfaces.py +129 -0
  33. rag_scripts/llm/__pycache__/llmResponse.cpython-312.pyc +0 -0
  34. rag_scripts/llm/llmResponse.py +63 -0
  35. rag_scripts/rag_pipeline.py +206 -0
  36. requirements.txt +11 -0
  37. users.db +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ DOCUMENTS/Flykite_Airlines_HR_Policy.pdf filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .env
2
+ __pycache__/
3
+ .python
4
+ env/
5
+ venv/
6
+ *.egg-info
7
+ dist/
8
+ build/
9
+ *.pyd
10
+ *.pyo
11
+
12
+
13
+ .vscode/
14
+ *.code-workspace
15
+
16
+ DATA/chroma_temp_db_eval/
17
+ DATA/faiss_temp_db_eval/
18
+ DATA/pinecone_temp_db_eval/
19
+ DATA/
20
+
21
+ *.log
22
+ *.key
23
+ *.pem
24
+
25
+ #LLM
26
+ *.bin
27
+ *.h5
28
+ *.pt
29
+
.idea/.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ # Default ignored files
2
+ /shelf/
3
+ /workspace.xml
4
+ # Editor-based HTTP Client requests
5
+ /httpRequests/
6
+ # Datasource local storage ignored files
7
+ /dataSources/
8
+ /dataSources.local.xml
.idea/Fly-Kite-AirLine-RAG-Project.iml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <module type="PYTHON_MODULE" version="4">
3
+ <component name="NewModuleRootManager">
4
+ <content url="file://$MODULE_DIR$" />
5
+ <orderEntry type="jdk" jdkName="Python 3.12" jdkType="Python SDK" />
6
+ <orderEntry type="sourceFolder" forTests="false" />
7
+ </component>
8
+ </module>
.idea/inspectionProfiles/profiles_settings.xml ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ <component name="InspectionProjectProfileManager">
2
+ <settings>
3
+ <option name="USE_PROJECT_PROFILE" value="false" />
4
+ <version value="1.0" />
5
+ </settings>
6
+ </component>
.idea/misc.xml ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="Black">
4
+ <option name="sdkName" value="Python 3.12" />
5
+ </component>
6
+ <component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
7
+ </project>
.idea/modules.xml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="ProjectModuleManager">
4
+ <modules>
5
+ <module fileurl="file://$PROJECT_DIR$/.idea/Fly-Kite-AirLine-RAG-Project.iml" filepath="$PROJECT_DIR$/.idea/Fly-Kite-AirLine-RAG-Project.iml" />
6
+ </modules>
7
+ </component>
8
+ </project>
.idea/workspace.xml ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <?xml version="1.0" encoding="UTF-8"?>
2
+ <project version="4">
3
+ <component name="AutoImportSettings">
4
+ <option name="autoReloadType" value="SELECTIVE" />
5
+ </component>
6
+ <component name="ChangeListManager">
7
+ <list default="true" id="43d0b45c-4ac8-44fd-8129-96b2dd008826" name="Changes" comment="" />
8
+ <option name="SHOW_DIALOG" value="false" />
9
+ <option name="HIGHLIGHT_CONFLICTS" value="true" />
10
+ <option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
11
+ <option name="LAST_RESOLUTION" value="IGNORE" />
12
+ </component>
13
+ <component name="FileTemplateManagerImpl">
14
+ <option name="RECENT_TEMPLATES">
15
+ <list>
16
+ <option value="Python Script" />
17
+ </list>
18
+ </option>
19
+ </component>
20
+ <component name="FlaskConsoleOptions" custom-start-script="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS])&#10;from flask.cli import ScriptInfo, NoAppException&#10;for module in [&quot;main.py&quot;, &quot;wsgi.py&quot;, &quot;app.py&quot;]:&#10; try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print(&quot;\nFlask App: %s&quot; % app.import_name); break&#10; except NoAppException: pass">
21
+ <envs>
22
+ <env key="FLASK_APP" value="app" />
23
+ </envs>
24
+ <option name="myCustomStartScript" value="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS])&#10;from flask.cli import ScriptInfo, NoAppException&#10;for module in [&quot;main.py&quot;, &quot;wsgi.py&quot;, &quot;app.py&quot;]:&#10; try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print(&quot;\nFlask App: %s&quot; % app.import_name); break&#10; except NoAppException: pass" />
25
+ <option name="myEnvs">
26
+ <map>
27
+ <entry key="FLASK_APP" value="app" />
28
+ </map>
29
+ </option>
30
+ </component>
31
+ <component name="ProjectColorInfo">{
32
+ &quot;associatedIndex&quot;: 1
33
+ }</component>
34
+ <component name="ProjectId" id="33s38NvoyHdVLzaauzWfPb9TnxM" />
35
+ <component name="ProjectViewState">
36
+ <option name="hideEmptyMiddlePackages" value="true" />
37
+ <option name="showLibraryContents" value="true" />
38
+ </component>
39
+ <component name="PropertiesComponent">{
40
+ &quot;keyToString&quot;: {
41
+ &quot;ModuleVcsDetector.initialDetectionPerformed&quot;: &quot;true&quot;,
42
+ &quot;Python.HRPolicyApp.executor&quot;: &quot;Run&quot;,
43
+ &quot;Python.create_user_db.executor&quot;: &quot;Run&quot;,
44
+ &quot;RunOnceActivity.ShowReadmeOnStart&quot;: &quot;true&quot;,
45
+ &quot;RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252&quot;: &quot;true&quot;,
46
+ &quot;node.js.detected.package.eslint&quot;: &quot;true&quot;,
47
+ &quot;node.js.detected.package.tslint&quot;: &quot;true&quot;,
48
+ &quot;node.js.selected.package.eslint&quot;: &quot;(autodetect)&quot;,
49
+ &quot;node.js.selected.package.tslint&quot;: &quot;(autodetect)&quot;,
50
+ &quot;nodejs_package_manager_path&quot;: &quot;npm&quot;,
51
+ &quot;settings.editor.selected.configurable&quot;: &quot;preferences.pluginManager&quot;,
52
+ &quot;vue.rearranger.settings.migration&quot;: &quot;true&quot;
53
+ }
54
+ }</component>
55
+ <component name="SharedIndexes">
56
+ <attachedChunks>
57
+ <set>
58
+ <option value="bundled-js-predefined-d6986cc7102b-3aa1da707db6-JavaScript-PY-252.26830.99" />
59
+ <option value="bundled-python-sdk-164cda30dcd9-0af03a5fa574-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-252.26830.99" />
60
+ </set>
61
+ </attachedChunks>
62
+ </component>
63
+ <component name="TaskManager">
64
+ <task active="true" id="Default" summary="Default task">
65
+ <changelist id="43d0b45c-4ac8-44fd-8129-96b2dd008826" name="Changes" comment="" />
66
+ <created>1760091804577</created>
67
+ <option name="number" value="Default" />
68
+ <option name="presentableId" value="Default" />
69
+ <updated>1760091804577</updated>
70
+ <workItem from="1760091890212" duration="859000" />
71
+ <workItem from="1760093438967" duration="2452000" />
72
+ <workItem from="1760096239520" duration="21456000" />
73
+ <workItem from="1760158047614" duration="33908000" />
74
+ <workItem from="1760276025638" duration="8385000" />
75
+ <workItem from="1760361998945" duration="294000" />
76
+ <workItem from="1760696201916" duration="16790000" />
77
+ <workItem from="1760716047458" duration="13878000" />
78
+ <workItem from="1760764472535" duration="4802000" />
79
+ <workItem from="1760802622597" duration="4499000" />
80
+ <workItem from="1760844577587" duration="42402000" />
81
+ </task>
82
+ <servers />
83
+ </component>
84
+ <component name="TypeScriptGeneratedFilesManager">
85
+ <option name="version" value="3" />
86
+ </component>
87
+ <component name="XDebuggerManager">
88
+ <breakpoint-manager>
89
+ <breakpoints>
90
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
91
+ <url>file://$PROJECT_DIR$/rag_scripts/rag_pipeline.py</url>
92
+ <option name="timeStamp" value="2" />
93
+ </line-breakpoint>
94
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
95
+ <url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
96
+ <line>245</line>
97
+ <option name="timeStamp" value="3" />
98
+ </line-breakpoint>
99
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
100
+ <url>file://$PROJECT_DIR$/rag_scripts/documents_processing/chunking.py</url>
101
+ <line>49</line>
102
+ <option name="timeStamp" value="4" />
103
+ </line-breakpoint>
104
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
105
+ <url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
106
+ <line>396</line>
107
+ <option name="timeStamp" value="6" />
108
+ </line-breakpoint>
109
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
110
+ <url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
111
+ <line>163</line>
112
+ <option name="timeStamp" value="10" />
113
+ </line-breakpoint>
114
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
115
+ <url>file://$PROJECT_DIR$/rag_scripts/rag_pipeline.py</url>
116
+ <line>151</line>
117
+ <option name="timeStamp" value="11" />
118
+ </line-breakpoint>
119
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
120
+ <url>file://$PROJECT_DIR$/rag_scripts/llm/llmResponse.py</url>
121
+ <line>28</line>
122
+ <option name="timeStamp" value="12" />
123
+ </line-breakpoint>
124
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
125
+ <url>file://$PROJECT_DIR$/main.py</url>
126
+ <line>237</line>
127
+ <option name="timeStamp" value="13" />
128
+ </line-breakpoint>
129
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
130
+ <url>file://$PROJECT_DIR$/main.py</url>
131
+ <line>296</line>
132
+ <option name="timeStamp" value="14" />
133
+ </line-breakpoint>
134
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
135
+ <url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
136
+ <line>451</line>
137
+ <option name="timeStamp" value="15" />
138
+ </line-breakpoint>
139
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
140
+ <url>file://$PROJECT_DIR$/rag_scripts/llm/llmResponse.py</url>
141
+ <line>27</line>
142
+ <option name="timeStamp" value="16" />
143
+ </line-breakpoint>
144
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
145
+ <url>file://$PROJECT_DIR$/rag_scripts/llm/llmResponse.py</url>
146
+ <option name="timeStamp" value="17" />
147
+ </line-breakpoint>
148
+ <line-breakpoint enabled="true" suspend="THREAD" type="python-line">
149
+ <url>file://$PROJECT_DIR$/main.py</url>
150
+ <line>248</line>
151
+ <option name="timeStamp" value="18" />
152
+ </line-breakpoint>
153
+ </breakpoints>
154
+ </breakpoint-manager>
155
+ </component>
156
+ <component name="com.intellij.coverage.CoverageDataManagerImpl">
157
+ <SUITE FILE_PATH="coverage/Fly_Kite_AirLine_RAG_Project$HRPolicyApp.coverage" NAME="HRPolicyApp Coverage Results" MODIFIED="1760718510975" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
158
+ <SUITE FILE_PATH="coverage/Fly_Kite_AirLine_RAG_Project$create_user_db.coverage" NAME="create_user_db Coverage Results" MODIFIED="1760724265707" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
159
+ </component>
160
+ </project>
DOCUMENTS/Flykite_Airlines_HR_Policy.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:1a5cdeec566a672375800654a4d004e4c3b7907578bd8c177ac5755e1cdb6cd0
3
+ size 262937
Dockerfile ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.12-slim-bookworm
2
+ LABEL authors="karthikeyan"
3
+
4
+ RUN apt-get update && apt-get install -y \
5
+ build-essential \
6
+ && rm -rf /var/lib/apt/lists/*
7
+
8
+ WORKDIR /app
9
+
10
+ COPY . .
11
+
12
+ #COPY requirements.txt .
13
+ RUN pip install --no-cache-dir -r requirements.txt
14
+ RUN mkdir -p data/TunedDB data/ChromaDB data/FIASS_DB
15
+
16
+ EXPOSE 7860
17
+
18
+ CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableXsrfProtection=false"]
HostingIntoHuggingFace.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import inspect
3
+ import traceback
4
+ from huggingface_hub import HfApi, create_repo, login,hf_hub_download
5
+ from huggingface_hub.utils import RepositoryNotFoundError
6
+ from configuration import Configuration
7
+ class HostingInHuggingFace:
8
+ def __init__(self):
9
+ self.base_path = Configuration.PROJECT_ROOT
10
+ self.hf_token = Configuration.HF_TOKEN
11
+ self.repo_id = 'jpkarthikeyan/FlyKiteAirlines'
12
+
13
+ def CreatingSpaceInHF(self):
14
+ print(f"Function Name {inspect.currentframe().f_code.co_name}")
15
+ api = HfApi()
16
+ try:
17
+ print(f"Checking for {self.repo_id} is correct or not")
18
+ api.repo_info(repo_id = self.repo_id,
19
+ repo_type='space',
20
+ token = self.hf_token)
21
+ print(f"Space {self.repo_id} already exists")
22
+ except RepositoryNotFoundError:
23
+ create_repo(repo_id=self.repo_id,
24
+ repo_type='space',
25
+ space_sdk='docker',
26
+ private=False,
27
+ token=self.hf_token)
28
+ print(f"Space created in {self.repo_id}")
29
+ except Exception as ex:
30
+ print(f"Exception in creating space {ex}")
31
+ traceback.print_exc()
32
+ finally:
33
+ print('-'*50)
34
+
35
+
36
+ def UploadDeploymentFile(self):
37
+ print(f"Function Name {inspect.currentframe().f_code.co_name}")
38
+ try:
39
+ api = HfApi(token=self.hf_token)
40
+ #directory_to_upload = os.path.join(self.base_path,'Deployment')
41
+ directory_to_upload = self.base_path
42
+ print(f"Directory to upload {directory_to_upload} into HF Space {self.repo_id}")
43
+ api.upload_folder(repo_id=self.repo_id, folder_path=directory_to_upload,
44
+ repo_type='space')
45
+ print(f"Successfully upload {directory_to_upload} into {self.repo_id}")
46
+ return True
47
+ except Exception as ex:
48
+ print(f"Exception occured {ex}")
49
+ print(traceback.print_exc())
50
+ return False
51
+ finally:
52
+ print('-'*50)
53
+
54
+ def ToRunPipeline(self):
55
+ try:
56
+ self.CreatingSpaceInHF()
57
+ if self.UploadDeploymentFile():
58
+ print('Deployment pipeline completed')
59
+ return True
60
+ else:
61
+ print('Deployment pipeline failed')
62
+ return False
63
+ except Exception as ex:
64
+ print(f"Exception occured {ex}")
65
+ print(traceback.print_exc())
66
+ return False
67
+ finally:
68
+ print('-'*50)
69
+
70
+
71
+ if __name__ == '__main__':
72
+ hosting = HostingInHuggingFace()
73
+ hosting.ToRunPipeline()
README.md CHANGED
@@ -1,10 +1,13 @@
1
- ---
2
- title: FlyKiteAirlines
3
- emoji: 🌍
4
- colorFrom: gray
5
- colorTo: yellow
6
- sdk: docker
7
- pinned: false
8
- ---
9
-
10
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
1
+ ---
2
+ title: FlyKite Airlines HR Policy
3
+ emoji: ✈️ 🤗
4
+ colorFrom: blue
5
+ colorTo: green
6
+ sdk: docker
7
+ sdk_version: 3.9
8
+ app_file: app.py
9
+ app_type: streamlit
10
+ pinned: false
11
+ license: mit
12
+ ---
13
+ The streamlit app is frontend for RAG App for FlyKite Airlines HR Policy
app.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import subprocess
3
+ import os
4
+ import json
5
+ import sqlite3
6
+ import hashlib
7
+ from typing import Optional, Dict, Any
8
+
9
+ # Assuming Configuration.py and other rag_scripts are in the same directory or accessible via sys.path
10
+ # You might need to adjust sys.path if your project structure is different
11
+ import sys
12
+ # Make sure this path is correct for your setup
13
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
14
+ from configuration import Configuration
15
+
16
+ # --- Helper functions (from your main.py or adapted) ---
17
+
18
+ # Removed create_user_db() as per your request
19
+
20
+ def authenticate_user(username, password) -> Optional[Dict[str,str]]:
21
+ hashed_password = hashlib.sha256(password.encode()).hexdigest()
22
+ conn = sqlite3.connect('users.db') # Ensure 'users.db' is accessible
23
+ cursor = conn.cursor()
24
+ cursor.execute(
25
+ "SELECT username, jobrole, department, location FROM users WHERE username = ? AND password = ?",
26
+ (username, hashed_password)
27
+ )
28
+ user = cursor.fetchone()
29
+ conn.close()
30
+ if user:
31
+ return {"username": user[0], "role": user[1], "department": user[2], "location": user[3]}
32
+ return None
33
+
34
+ def run_rag_job(job_type: str, **kwargs) -> Dict[str, Any]:
35
+ """
36
+ Executes a RAG job using the main.py script as a subprocess.
37
+ Returns the output as a dictionary (assuming main.py prints JSON).
38
+ """
39
+ cmd = [sys.executable, 'main.py', '--job', job_type]
40
+ for key, value in kwargs.items():
41
+ if value is not None and value != '':
42
+ if isinstance(value, bool):
43
+ if value:
44
+ cmd.append(f'--{key.replace("_", "-")}')
45
+ else:
46
+ cmd.extend([f'--{key.replace("_", "-")}', str(value)])
47
+
48
+ # Inject user context if available
49
+ if 'user_context' in st.session_state and st.session_state.user_context:
50
+ user_ctx_str = json.dumps(st.session_state.user_context)
51
+ cmd.extend(['--user-context', user_ctx_str])
52
+
53
+ st.write(f"Running command: `{' '.join(cmd)}`") # For debugging
54
+ try:
55
+ process = subprocess.run(cmd, capture_output=True, text=True, check=True)
56
+ # st.success(f"Job '{job_type}' completed successfully.")
57
+ # Attempt to parse JSON output. main.py currently prints to stdout directly.
58
+ # We need to capture the last JSON block.
59
+ output_lines = process.stdout.strip().split('\n')
60
+ json_output = {}
61
+ for line in reversed(output_lines):
62
+ try:
63
+ json_output = json.loads(line)
64
+ break
65
+ except json.JSONDecodeError:
66
+ continue
67
+ return json_output
68
+ except subprocess.CalledProcessError as e:
69
+ st.error(f"Error running RAG job '{job_type}': {e.stderr}")
70
+ return {"error": e.stderr}
71
+ except Exception as e:
72
+ st.error(f"An unexpected error occurred: {e}")
73
+ return {"error": str(e)}
74
+
75
+ # --- Streamlit UI Components ---
76
+
77
+ def home_page():
78
+ st.title("Welcome to Flykite RAG System")
79
+
80
+ if 'logged_in' not in st.session_state:
81
+ st.session_state.logged_in = False
82
+ if 'user_info' not in st.session_state:
83
+ st.session_state.user_info = None
84
+
85
+ if not st.session_state.logged_in:
86
+ st.subheader("Login")
87
+ with st.form("login_form"):
88
+ username = st.text_input("Username")
89
+ password = st.text_input("Password", type="password")
90
+ login_button = st.form_submit_button("Login")
91
+
92
+ if login_button:
93
+ user_data = authenticate_user(username, password)
94
+ if user_data:
95
+ st.session_state.logged_in = True
96
+ st.session_state.user_info = user_data
97
+ st.session_state.user_context = {
98
+ "role": user_data['role'],
99
+ "department": user_data['department'],
100
+ "location": user_data['location']
101
+ }
102
+ st.success(f"Logged in as {user_data['username']} ({user_data['role']})")
103
+ st.rerun() # Changed from st.experimental_rerun()
104
+ else:
105
+ st.error("Invalid username or password.")
106
+ else:
107
+ st.write(f"You are logged in as **{st.session_state.user_info['username']}** (Role: **{st.session_state.user_info['role']}**)")
108
+ if st.button("Logout"):
109
+ st.session_state.logged_in = False
110
+ st.session_state.user_info = None
111
+ st.session_state.user_context = None
112
+ st.rerun() # Changed from st.experimental_rerun()
113
+
114
+ def admin_page():
115
+ st.title("Admin Dashboard")
116
+ st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})")
117
+
118
+ if st.session_state.user_info and st.session_state.user_info['role'] == 'admin':
119
+ st.header("RAG Hypertuning")
120
+ st.info("Run hyperparameter tuning to find the best RAG configuration and build a tuned index.")
121
+
122
+ with st.form("hypertune_form"):
123
+ st.write("Hypertuning parameters (default values from main.py if not overridden):")
124
+ llm_model_ht = st.selectbox("LLM Model for Hypertuning Evaluation",
125
+ options=["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"],
126
+ key="llm_model_ht_select")
127
+ n_iter = st.number_input("Number of Hyper-tuning Iterations (n_iter)", min_value=1, value=1, step=1, key="n_iter_ht")
128
+
129
+ hypertune_button = st.form_submit_button("Run Hypertune Job")
130
+
131
+ if hypertune_button:
132
+ st.write("Starting RAG Hypertuning. This may take a while...")
133
+ with st.spinner("Running hypertuning..."):
134
+ result = run_rag_job('eval-hypertune', llm_model=llm_model_ht, n_iter=n_iter)
135
+ if result and "error" not in result:
136
+ st.success("Hypertuning completed and tuned index built!")
137
+ st.json(result)
138
+ else:
139
+ st.error("Hypertuning failed.")
140
+
141
+ st.header("RAG Testing")
142
+ st.info("Test the RAG pipeline with a specific query, optionally using the tuned database.")
143
+
144
+ with st.form("rag_test_form"):
145
+ test_query = st.text_area("Enter a test query for the RAG system:",
146
+ value="What is the policy on annual leave?",
147
+ key="test_query_input")
148
+ use_tuned_db = st.checkbox("Use Tuned RAG Database (if hypertuned previously)", value=True, key="use_tuned_db_checkbox")
149
+ display_raw = st.checkbox("Display Raw Retrieved Documents only (no LLM)", key="display_raw_docs_checkbox")
150
+ k_value = st.slider("Number of documents to retrieve (k)", min_value=1, max_value=10, value=5, key="k_value_slider")
151
+
152
+ test_rag_button = st.form_submit_button("Run RAG Test Query")
153
+
154
+ if test_rag_button:
155
+ st.write("Running RAG test query...")
156
+ with st.spinner("Getting RAG response..."):
157
+ # Pass the LLM model configured in the backend's Configuration if not specifically chosen for testing
158
+ # Or, make a selectbox for it in the admin testing section too.
159
+ llm_for_search = st.session_state.get('llm_model_ht_select', Configuration.DEFAULT_GROQ_LLM_MODEL)
160
+ result = run_rag_job('search',
161
+ query=test_query,
162
+ use_tuned=use_tuned_db,
163
+ raw=display_raw,
164
+ k=k_value,
165
+ llm_model=llm_for_search) # Ensure LLM model is passed
166
+
167
+ if result and "error" not in result:
168
+ st.success("RAG Test Query Completed!")
169
+ st.subheader("RAG Response:")
170
+ if display_raw:
171
+ st.json(result)
172
+ else:
173
+ if 'response' in result and 'summary' in result['response']:
174
+ st.write(result['response']['summary'])
175
+ if 'sources' in result['response'] and result['response']['sources']:
176
+ st.subheader("Sources:")
177
+ for source in result['response']['sources']:
178
+ st.markdown(f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}")
179
+ else:
180
+ st.json(result)
181
+ if 'evaluation' in result:
182
+ st.subheader("Evaluation Results:")
183
+ st.json(result['evaluation'])
184
+ else:
185
+ st.error("RAG test query failed.")
186
+ else:
187
+ st.warning("You do not have administrative privileges to view this page.")
188
+ if st.button("Go to User Page"):
189
+ st.session_state.page = "User"
190
+ st.rerun() # Changed from st.experimental_rerun()
191
+
192
+ def user_page():
193
+ st.title("Flykite HR Policy Query")
194
+ st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})")
195
+
196
+ st.info("Ask any question about the Flykite Airlines HR policy document.")
197
+
198
+ with st.form("user_query_form"):
199
+ user_query = st.text_area("Your Query:", height=100, key="user_query_input")
200
+ response_type = st.radio("Choose Response Type:",
201
+ options=["LLM Tuned Response (RAG + LLM)", "RAG Raw Response (Retrieved Docs Only)"],
202
+ index=0, key="response_type_radio")
203
+ k_value_user = st.slider("Number of documents to consider (k)", min_value=1, max_value=10, value=5, key="k_value_user_slider")
204
+
205
+ submit_query_button = st.form_submit_button("Get Answer")
206
+
207
+ if submit_query_button and user_query:
208
+ st.subheader("Response:")
209
+ with st.spinner("Fetching answer..."):
210
+ display_raw = (response_type == "RAG Raw Response (Retrieved Docs Only)")
211
+ # Assuming the user page should always use the best available LLM model from config or tuned params
212
+ llm_for_user_query = Configuration.DEFAULT_GROQ_LLM_MODEL # Default from config
213
+ # If a tuned model was selected during hypertuning, we might want to use that for users
214
+ # This requires main.py to save and provide which LLM was part of the best_params
215
+ # For now, we'll stick to a default or what's explicitly passed.
216
+ result = run_rag_job('search',
217
+ query=user_query,
218
+ raw=display_raw,
219
+ k=k_value_user,
220
+ use_tuned=True, # User page always uses tuned if available
221
+ llm_model=llm_for_user_query) # Pass LLM model for search
222
+
223
+ if result and "error" not in result:
224
+ if display_raw:
225
+ st.json(result)
226
+ else:
227
+ if 'response' in result and 'summary' in result['response']:
228
+ st.markdown(result['response']['summary'])
229
+ if 'sources' in result['response'] and result['response']['sources']:
230
+ st.subheader("Sources:")
231
+ for source in result['response']['sources']:
232
+ st.markdown(f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}")
233
+ else:
234
+ st.json(result)
235
+ else:
236
+ st.error("Failed to get a response. Please try again.")
237
+ elif submit_query_button and not user_query:
238
+ st.warning("Please enter a query.")
239
+
240
+ # --- Main Application Logic ---
241
+ def main_app():
242
+ # create_user_db() # Removed this call as per your request
243
+
244
+ st.sidebar.title("Navigation")
245
+ if 'logged_in' not in st.session_state:
246
+ st.session_state.logged_in = False
247
+ if 'page' not in st.session_state:
248
+ st.session_state.page = "Home"
249
+
250
+ if not st.session_state.logged_in:
251
+ st.session_state.page = "Home" # Force to home/login if not logged in
252
+ home_page()
253
+ else:
254
+ # Navigation for logged-in users
255
+ if st.session_state.user_info and st.session_state.user_info['role'] == 'admin':
256
+ st.sidebar.button("Home", on_click=lambda: st.session_state.update(page="Home"))
257
+ st.sidebar.button("Admin Dashboard", on_click=lambda: st.session_state.update(page="Admin"))
258
+ st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User"))
259
+ else: # Regular user
260
+ st.sidebar.button("Home", on_click=lambda: st.session_state.update(page="Home"))
261
+ st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User"))
262
+
263
+ # Display the selected page
264
+ if st.session_state.page == "Home":
265
+ home_page()
266
+ elif st.session_state.page == "Admin":
267
+ admin_page()
268
+ elif st.session_state.page == "User":
269
+ user_page()
270
+
271
+ if __name__ == "__main__":
272
+ main_app()
configuration.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from dotenv import load_dotenv
3
+
4
+ load_dotenv()
5
+
6
+ class Configuration:
7
+ GROQ_API_KEY = os.getenv('GROQ_API_KEY','default_groq_key')
8
+ OPEN_API_KEY = os.getenv('OPENAI_API_KEY','default_openai_key')
9
+ HF_TOKEN = os.getenv('HF_TOKEN','default_hf_token')
10
+
11
+ DEFAULT_CHUNK_SIZE = int(os.getenv('DEFAULT_CHUNK_SIZE'))
12
+ DEFAULT_CHUNK_OVERLAP = int(os.getenv('DEFAULT_CHUNK_OVERLAP'))
13
+ DEFAULT_SENTENCE_TRANSFORMER_MODEL = os.getenv('DEFAULT_SENTENCE_TRANSFORMER_MODEL','all-MiniLM-L6-v2')
14
+ DEFAULT_GROQ_LLM_MODEL = os.getenv('DEFAULT_GROQ_LLM_MODEL','llama-3.3-70b-versatile')
15
+ DEFAULT_RERANKER = os.getenv('DEFAULT_RERANKER')
16
+
17
+ PROJECT_ROOT_BASE = os.path.abspath(os.path.dirname(__file__))
18
+ print(f"PROJECT_ROOT {PROJECT_ROOT_BASE}")
19
+
20
+ if os.path.basename(PROJECT_ROOT_BASE) == "src":
21
+ ##PROJECT_ROOT = os.path.abspath(os.path.join(PROJECT_ROOT_BASE,os.pardir))
22
+ PROJECT_ROOT = os.path.abspath(os.path.join(PROJECT_ROOT_BASE, '..'))
23
+ else:
24
+ PROJECT_ROOT = PROJECT_ROOT_BASE
25
+
26
+
27
+ DOCUMENTS_DIR = os.path.join(PROJECT_ROOT,'DOCUMENTS')
28
+ print(f"DOCUMENTS_DIR {DOCUMENTS_DIR}")
29
+ DATA_DIR = os.path.join(PROJECT_ROOT,'DATA')
30
+
31
+
32
+ PDF_FILE_NAME = os.getenv('PDF_FILE_NAME')
33
+ FULL_PDF_PATH = os.path.join(DOCUMENTS_DIR,PDF_FILE_NAME)
34
+ print(f"FULL_PDF_PATH: {FULL_PDF_PATH} ")
35
+
36
+
37
+ CHROMA_DB_PATH = os.path.join(DATA_DIR,os.getenv('CHROMA_DB_PATH'))
38
+ FAISS_DB_PATH = os.path.join(DATA_DIR,os.getenv('FAISS_DB_PATH'))
39
+
40
+ COLLECTION_NAME = os.getenv('COLLECTION_NAME')
41
+ EVAL_DATA_PATH = os.path.join(PROJECT_ROOT, os.getenv('EVAL_DATA_PATH'))
42
+
43
+ PINECONE_API_KEY=os.getenv('PINECONE_API_KEY')
44
+ PINECONE_CLOUD=os.getenv('PINECONE_CLOUD','aws')
45
+ PINECONE_REGION=os.getenv('PINECONE_REGION','us-east-1')
46
+
47
+ os.makedirs(DATA_DIR, exist_ok=True)
48
+ os.makedirs(DOCUMENTS_DIR,exist_ok=True)
49
+ os.makedirs(CHROMA_DB_PATH, exist_ok=True)
50
+ os.makedirs(FAISS_DB_PATH,exist_ok=True)
51
+
52
+ if not os.path.exists(FULL_PDF_PATH):
53
+ print(f"PDF not found in {FULL_PDF_PATH}")
54
+ else:
55
+ print(f"PDF file found in {FULL_PDF_PATH}")
56
+
57
+ if not os.path.exists(EVAL_DATA_PATH):
58
+ print(f"eval json file not found in {EVAL_DATA_PATH}")
59
+ else:
60
+ print(f"eval file found in {EVAL_DATA_PATH}")
61
+
62
+
create_user_db.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sqlite3
2
+ import hashlib
3
+ import traceback
4
+
5
+ from argon2 import hash_password
6
+
7
+
8
+ class CreateUserDB():
9
+ def encrypttion_pwd(self,password):
10
+ return hashlib.sha256(password.encode()).hexdigest()
11
+
12
+ def setup_database(self):
13
+ try:
14
+ conn = sqlite3.connect('users.db')
15
+ cursor = conn.cursor()
16
+ cursor.execute('''
17
+ CREATE TABLE IF NOT EXISTS users (
18
+ username TEXT PRIMARY KEY,
19
+ password TEXT,
20
+ jobrole TEXT,
21
+ department TEXT,
22
+ location TEXT)
23
+ ''')
24
+ cursor.execute("SELECT COUNT(*) FROM users")
25
+ if cursor.fetchone()[0] == 0:
26
+ users = [('admin', self.encrypttion_pwd('admin'), 'admin', 'admin', 'chennai'),
27
+ ('user1', self.encrypttion_pwd('user'), 'user', 'manager', 'chennai'),
28
+ ('user2', self.encrypttion_pwd('user'), 'user', 'pilot', 'chennai'),
29
+ ('user3', self.encrypttion_pwd('user'), 'user', 'engineer', 'chennai'),
30
+ ('user4', self.encrypttion_pwd('user'), 'user', 'cabin crew', 'chennai'),
31
+ ('user5', self.encrypttion_pwd('user'), 'user', 'ground staff', 'chennai')]
32
+ cursor.executemany("INSERT INTO users VALUES(?,?,?,?,?)",users)
33
+ conn.commit()
34
+ else:
35
+ print("Databases alerady contains users")
36
+
37
+ except Exception as ex:
38
+ traceback.print_exc()
39
+ print(f"Exception occurred: {ex}")
40
+ finally:
41
+ if conn:
42
+ conn.close()
43
+
44
+
45
+ if __name__ == "__main__":
46
+ db_obj = CreateUserDB()
47
+ db_obj.setup_database()
eval_data.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "query":"what are the Criteria for Probation Extension?",
4
+ "expected_keywords":["60%", "PIP", "5 working days", "7 calendar days","measurable target"],
5
+ "expected_answer_snippet":"Extensions are granted only if: a. The employee has achieved at least 60% of their probationary objectives. b. A written PIP with measurable targets is issued within 5 working days of the original probation end date. c. The Department Head and HR Manager both sign off on the extension request. Employees will be notified of extensions in writing at least 7 calendar days before the probation end date."
6
+ },
7
+ {
8
+ "query":"Explain Exit Procedure & Timeline of FlyKite airline",
9
+ "expected_keywords":["3 working days", "within 7 working days", "Exit interview","no later than 5 working days"],
10
+ "expected_answer_snippet":"Company property (ID badge, uniforms, devices) must be returned on or before the last working day. Clearance forms must be signed by Finance, IT, and Admin within 3 working days post-termination. Final salary settlement is processed within 7 working days of clearance completion. Exit interview: Mandatory and to be scheduled no later than 5 working days after last working day. Failure to attend delays issuance of relieving letter."
11
+
12
+ },
13
+ {
14
+ "query":"Elaborate Allowable Expenses and Reimbursement Procedures",
15
+ "expected_keywords":["Eligibility","Exclusions", "Submission Deadlines", "Appeals","directly work-related","itemized","within 15 calendar days","7 working days"],
16
+ "expected_answer_snippet":"1. Eligibility: Expenses must be directly work-related and supported by itemized receipts. Per diem limits: ₹1,200/day domestic, ₹4,000/day international. Exclusions: Alcohol, entertainment unrelated to work, and non-economy travel (unless pre-approved) are not reimbursable. Submission Deadlines: Claims must be fi led within 15 calendar days of incurring expense. Appeals: Appeal must be submitted within 7 working days of claim denial with"
17
+ },
18
+ {
19
+ "query": "What are the documents required for special leave approval ?",
20
+ "expected_keywords": ["Death certificate", "funeral", "notice","obituary","court summons", "jury duty letter"],
21
+ "expected_answer_snippet": "Special leave is approved only upon submission of: 1. Death certificate, funeral notice, or obituary (bereavement). 2. Official court summons or jury duty letter (legal obligations). 3. Medical certificate from a registered practitioner (emergency care). 4. Government-issued disaster report or evacuation notice (natural disasters). 5. All documents must be submitted within 5 working days of returning to duty."
22
+ },
23
+ {
24
+ "query":"explain in detail about the Attendance and Absence Management of the flykite airlne hr policy",
25
+ "expected_keywords":["Core Hours","Notification","Consequences"],
26
+ "expected_answer_snippet": "1. Core Hours: 9:30 AM – 6:00 PM IST. 2. Notification: Planned absence: Email supervisor at least 1 day before. Emergency absence: Call within 1 hour of shift start. 3. Consequences: 3 unreported absences in 60 days → Written warning. 5 unreported absences in 90 days → Termination review."
27
+ },
28
+ {
29
+ "query":"Explain the compensation and termination policy of the flykite airline hr policy",
30
+ "expected_keywords":["Salary review cycle", "Mid-year adjustments", "Benefits eligiblity","Termination Impact"],
31
+ "expected_answer_snippet":"1. Salary Review Cycle: April annually. 2. Mid-Year Adjustments: Only on written approval from CEO & CFO. 3. Benefits Eligibility Health insurance starts after 30 days service (full-time). Retirement plan enrollment: Within 60 days of confirmation. 4. Termination Impact: Benefits end on last working day unless law mandates otherwise."
32
+ }
33
+ ]
main.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import json
4
+ import argparse
5
+ import warnings
6
+ import traceback
7
+ import logging
8
+ import chromadb
9
+ import hashlib
10
+ import sqlite3
11
+ import regex as re
12
+ from pinecone import Pinecone
13
+ from typing import Optional, Dict, Any
14
+ from sentence_transformers import SentenceTransformer, util
15
+
16
+
17
+ os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
18
+ warnings.filterwarnings("ignore")
19
+
20
+ sys.path.insert(0,os.path.abspath(os.path.join(os.path.dirname(__file__),'src')))
21
+ from sentence_transformers import SentenceTransformer
22
+ from configuration import Configuration
23
+ from rag_scripts.rag_pipeline import RAGPipeline
24
+ from rag_scripts.documents_processing.chunking import PyMuPDFChunker
25
+ from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
26
+ from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
27
+ from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
28
+ from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
29
+ from rag_scripts.llm.llmResponse import GROQLLM
30
+ from rag_scripts.evaluation.evaluator import RAGEvaluator
31
+
32
+ class RAGOperations:
33
+ VALID_VECTOR_DB = {'chroma','faiss','pinecone'}
34
+
35
+ @staticmethod
36
+ def check_db(vector_db_type: str, db_path: str, collection_name: str) -> bool:
37
+ try:
38
+ if vector_db_type not in RAGOperations.VALID_VECTOR_DB:
39
+ print(f"Invalid Vector DB: {vector_db_type}")
40
+ raise
41
+ if vector_db_type.lower() == 'pinecone':
42
+ pc = Pinecone(api_key=Configuration.PINECONE_API_KEY)
43
+ return collection_name in pc.list_indexes().names()
44
+ elif vector_db_type.lower() == 'chroma':
45
+ return os.path.exists(db_path) and os.listdir(db_path)
46
+ elif vector_db_type.lower() == "faiss":
47
+ faiss_index_file = os.path.join(db_path,f"{collection_name}.faiss")
48
+ faiss_doc_store_file = os.path.join(db_path,f"{collection_name}_docs.pkl")
49
+ return os.path.exists(faiss_index_file) and os.path.exists(faiss_doc_store_file)
50
+ except Exception as ex:
51
+ print(f"Exception in checking {vector_db_type} existence")
52
+ traceback.print_exc()
53
+ return False
54
+
55
+ @staticmethod
56
+ def get_pipeline_params(args: argparse.Namespace, use_tuned: bool = False) -> Dict[str,Any]:
57
+ try:
58
+ best_param_path = os.path.join(Configuration.DATA_DIR,'best_params.json')
59
+ params = {
60
+ 'document_path':Configuration.FULL_PDF_PATH,
61
+ 'chunk_size':args.chunk_size,
62
+ 'chunk_overlap':args.chunk_overlap,
63
+ 'embedding_model_name':args.embedding_model,
64
+ 'vector_db_type':args.vector_db_type,
65
+ 'llm_model_name':args.llm_model,
66
+ 'db_path': None,
67
+ 'collection_name': Configuration.COLLECTION_NAME,
68
+ 'vector_db': None,
69
+ 'temperature': args.temperature,
70
+ 'top_p':args.top_p,
71
+ 'max_tokens':args.max_tokens,
72
+ 're_ranker_model':args.re_ranker_model
73
+ }
74
+
75
+ if os.path.exists(best_param_path):
76
+ with open(best_param_path,'rb') as f:
77
+ best_params = json.load(f)
78
+ print(f"Best params: {best_params} from the file {best_param_path}")
79
+
80
+ params.update({
81
+ 'vector_db_type': best_params.get('vector_db_type',params['vector_db_type']),
82
+ 'embedding_model_name': best_params.get('embedding_model',params['embedding_model_name']),
83
+ 'chunk_overlap': best_params.get('chunk_overlap',params['chunk_overlap']),
84
+ 'chunk_size': best_params.get('chunk_size',params['chunk_size']) ,
85
+ 're_ranker_model': best_params.get('re_ranker_model',params['re_ranker_model']) })
86
+ use_tuned = True
87
+
88
+ if use_tuned:
89
+ tuned_db_type = params['vector_db_type']
90
+ params['db_path'] = os.path.join(Configuration.DATA_DIR,'TunedDB',tuned_db_type) if tuned_db_type != 'pinecone' else ""
91
+ params['collection_name'] = 'tuned-'+Configuration.COLLECTION_NAME
92
+ if tuned_db_type in ['chroma','faiss']:
93
+ os.makedirs(params['db_path'],exist_ok=True)
94
+ print(f"Tuned db path: {params['db_path']}")
95
+ else:
96
+ params['db_path'] = ( Configuration.CHROMA_DB_PATH if params['vector_db_type'] == 'chroma'
97
+ else Configuration.FAISS_DB_PATH if params['vector_db_type'] == 'faiss'
98
+ else "")
99
+ if params['vector_db_type'] in ['chroma', 'faiss']:
100
+ os.makedirs(params['db_path'],exist_ok=True)
101
+ print(f"Created directory for {params['vector_db_type']} at {params['db_path']}")
102
+
103
+ return params
104
+ except Exception as ex:
105
+ print(f"Exception in get_pipeline_params: {ex}")
106
+ traceback.print_exc()
107
+ sys.exit(1)
108
+
109
+
110
+ @staticmethod
111
+ def check_embedding_dimension(vector_db_type: str,db_path: str,
112
+ collection_name: str, embedding_model: str) -> bool:
113
+ if vector_db_type !='chroma':
114
+ return True
115
+ try:
116
+ client = chromadb.PersistentClient(path=db_path)
117
+ collection = client.get_collection(collection_name)
118
+ model = SentenceTransformer(embedding_model)
119
+ sample_embedding = model.encode(["test"])[0]
120
+ try:
121
+ expected_dim = collection._embedding_function.dim
122
+ except AttributeError:
123
+ peek_result = collection.peek(limit=1)
124
+ if 'embedding' in peek_result and peek_result['embedding']:
125
+ expected_dim = len(peek_result['embedding'][0])
126
+ else:
127
+ return False
128
+ actual_dim = len(sample_embedding)
129
+ print(f"Expected dimension: {expected_dim} Actual dimension: {actual_dim}")
130
+ return expected_dim == actual_dim
131
+ except Exception as ex:
132
+ print(f"Error checking embedding dimension: {ex}")
133
+ return False
134
+
135
+
136
+ @staticmethod
137
+ def initialize_pipeline(params: dict[str,Any]) -> RAGPipeline:
138
+ try:
139
+ embedder = SentenceTransformerEmbedder(model_name=params['embedding_model_name'])
140
+ chunkerObj = PyMuPDFChunker(
141
+ pdf_path=params['document_path'],
142
+ chunk_size=params['chunk_size'],
143
+ chunk_overlap=params['chunk_overlap'])
144
+ llm_model = params['llm_model_name']
145
+ vector_db = None
146
+ if params['vector_db_type'] == 'chroma':
147
+ vector_db = chromaDBVectorDB(embedder=embedder,
148
+ db_path=params['db_path'],
149
+ collection_name=params['collection_name'])
150
+ elif params['vector_db_type'] == 'faiss':
151
+ vector_db = FAISSVectorDB(embedder=embedder,
152
+ db_path=params['db_path'],
153
+ collection_name=params['collection_name'] )
154
+ elif params['vector_db_type'] == 'pinecone':
155
+ vector_db = PineconeVectorDB(embedder=embedder,
156
+ db_path=params['db_path'],
157
+ collection_name=params['collection_name'])
158
+ else:
159
+ raise ValueError(f"Unknown vector_db_type: {params['vector_db_type']}")
160
+
161
+ return RAGPipeline( document_path=params['document_path'],
162
+ chunker=chunkerObj, embedder=embedder,
163
+ vector_db=vector_db,
164
+ llm=GROQLLM(model_name= llm_model),
165
+ re_ranker_model_name=params['re_ranker_model'] if params['re_ranker_model'] else Configuration.DEFAULT_RERANKER,)
166
+ except Exception as ex:
167
+ print(f"Exception in pipeline initialize: {ex}")
168
+ traceback.print_exc()
169
+ sys.exit(1)
170
+
171
+ @staticmethod
172
+ def run_build_job(args: argparse.Namespace) -> None:
173
+ try:
174
+ params = RAGOperations.get_pipeline_params(args)
175
+ pipeline = RAGOperations.initialize_pipeline(params)
176
+ pipeline.build_index()
177
+ print(f"RAG Build JOB completed")
178
+ except Exception as ex:
179
+ print(f"Exception in run build job: {ex}")
180
+ traceback.print_exc()
181
+ sys.exit(1)
182
+
183
+
184
+ @staticmethod
185
+ def run_search_job(args: argparse.Namespace,user_info: Dict[str,str]) -> None:
186
+ try:
187
+ params = RAGOperations.get_pipeline_params(args, use_tuned=args.use_tuned)
188
+ vector_db_type = params['vector_db_type']
189
+ db_path = params['db_path']
190
+ collection_name = params['collection_name']
191
+
192
+ pipeline = RAGOperations.initialize_pipeline(params)
193
+ db_exists = RAGOperations.check_db(vector_db_type,db_path,collection_name)
194
+
195
+ if args.use_rag:
196
+ if not db_exists:
197
+ pipeline.build_index()
198
+ elif pipeline.vector_db.count_documents() == 0:
199
+ pipeline.build_index()
200
+ elif not RAGOperations.check_embedding_dimension(vector_db_type,db_path,
201
+ collection_name,params['embedding_model_name'] ):
202
+ print(f"Embedding dimension mismatch. rebuilding the index")
203
+ pipeline.vector_db.delete_collection(collection_name)
204
+ pipeline.build_index()
205
+
206
+ else:
207
+ print(f"Using existing {vector_db_type} database with collection: {collection_name}")
208
+
209
+ if pipeline.vector_db.count_documents() == 0:
210
+ print(f"No Documents found in vector database after re-build")
211
+ sys.exit(1)
212
+
213
+ evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
214
+ pdf_path=Configuration.FULL_PDF_PATH)
215
+
216
+ user_query = args.query if args.query else (
217
+ input("Enter your Query: "))
218
+ if user_query.lower() == 'exit':
219
+ return
220
+ user_context = {"role": user_info['role'],
221
+ "location": user_info['location'],
222
+ "department":user_info['department'] }
223
+
224
+ expected_answers = None
225
+ expected_keywords = []
226
+ query_found = False
227
+ try:
228
+ with open(Configuration.EVAL_DATA_PATH, 'r') as f:
229
+ eval_data = json.load(f)
230
+ for item in eval_data:
231
+ if item.get('query').strip().lower() == user_query.strip().lower():
232
+ expected_keywords = item.get('expected_keywords',[])
233
+ expected_answers = item.get('expected_answer_snippet',"")
234
+ query_found = True
235
+ break
236
+ if not expected_keywords and not expected_answers:
237
+ print(f"No evaluation data found for query in json")
238
+ except Exception as ex:
239
+ print(f"No json file : {ex}")
240
+ retrieved_documents = []
241
+ if args.raw:
242
+ retrieved_documents = pipeline.retrieve_raw_documents(
243
+ user_query, k=args.k*2)
244
+ print("Raw documents retrieved")
245
+ print(json.dumps(retrieved_documents, indent=4))
246
+ if not retrieved_documents:
247
+ response ={"summary":"No relevant documents found",
248
+ "sources":[]}
249
+ else:
250
+ #print("Similar documents retrieved")
251
+ query_embedding = evaluator.embedder.encode(user_query,
252
+ convert_to_tensor=True,normalize_embeddings=True)
253
+ similarities = [(doc, util.cos_sim(query_embedding,
254
+ evaluator.embedder.encode(doc['content'],
255
+ convert_to_tensor=True,
256
+ normalize_embeddings=True)).item())
257
+ for doc in retrieved_documents]
258
+ similarities.sort(key=lambda x: x[1], reverse=True)
259
+ #print(f"Similar documents retrieved: {similarities}")
260
+ top_docs = similarities[:min(3, len(similarities))]
261
+ #print("Top documents retrieved")
262
+ #print(json.dumps(top_docs, indent=4))
263
+ truncated_content = []
264
+ for doc, sim in top_docs:
265
+ content_paragraphs = re.split(r'\n\s*\n', doc['content'].strip())
266
+ para_sims = [(para, util.cos_sim(query_embedding,
267
+ evaluator.embedder.encode(para.strip(), convert_to_tensor=True,
268
+ normalize_embeddings=True)).item())
269
+ for para in content_paragraphs if para.strip()]
270
+ para_sims.sort(key=lambda x: x[1], reverse=True)
271
+
272
+ top_paras = [para for para, para_sim in para_sims[:2] if para_sim >= 0.3]
273
+ if len(top_paras) < 1: # Fallback to at least one paragraph
274
+ top_paras = [para for para, _ in para_sims[:1]]
275
+ truncated_content.append('\n\n'.join(top_paras))
276
+ #print(f"Truncated content: {truncated_content}")
277
+ response = {
278
+ "summary": "\n".join(truncated_content),
279
+ "sources":[{ "document_id":f"DOC {idx+1}",
280
+ "page": str(doc['metadata'].get("page_number","NA")),
281
+ "section": doc['metadata'].get("section","NA"),
282
+ "clause": doc['metadata'].get("clause","NA")}
283
+ for idx,(doc,_) in enumerate(top_docs)] }
284
+
285
+ else:
286
+ print("LLM+RAG")
287
+ response = pipeline.query(user_query, k=args.k,
288
+ include_metadata=True,
289
+ user_context=user_context
290
+ )
291
+ retrieved_documents = pipeline.retrieve_raw_documents(
292
+ user_query, k=args.k)
293
+
294
+
295
+ final_expected_answer = expected_answers if expected_answers is not None else ""
296
+ additional_eval_metrices = {}
297
+ if not query_found:
298
+ print(f"No query found in eval_Data.json: {user_query}")
299
+ raw_reference_for_score = evaluator._syntesize_raw_reference(retrieved_documents)
300
+ if not final_expected_answer.strip():
301
+ final_expected_answer = raw_reference_for_score
302
+
303
+ retrieved_documents_content = [doc.get('content','') for doc in retrieved_documents]
304
+ llm_as_judge = evaluator._evaluate_with_llm(user_query, response.get('summary',''),retrieved_documents_content)
305
+ if llm_as_judge:
306
+ additional_eval_metrices.update(llm_as_judge)
307
+ output = {"query": user_query, "response": response, "evaluation": additional_eval_metrices}
308
+ print(json.dumps(output, indent=4))
309
+ return json.dumps(output)
310
+ else:
311
+ output = { "query": user_query, "response":response, "evaluation":llm_as_judge }
312
+ print(json.dumps(output, indent=4))
313
+ return json.dumps(output)
314
+
315
+ else:
316
+
317
+ eval_result = evaluator.evaluate_response(user_query, response, retrieved_documents,
318
+ expected_keywords, expected_answers)
319
+ output = { "query": user_query, "response":response, "evaluation":eval_result }
320
+ print(json.dumps(output,indent=2,ensure_ascii=False))
321
+
322
+ return json.dumps(output)
323
+
324
+
325
+ except Exception as ex:
326
+ print(f"Exception in run search job {ex}")
327
+ traceback.print_exc()
328
+
329
+ @staticmethod
330
+ def run_hypertune_job(args: argparse.Namespace) -> None:
331
+ try:
332
+ evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
333
+ pdf_path=Configuration.FULL_PDF_PATH)
334
+
335
+ result = evaluator.evaluate_combined_params_grid(
336
+ chunk_size_to_test=[512,1024,2048],
337
+ chunk_overlap_to_test=[100,200,400],
338
+ embedding_models_to_test=["all-MiniLM-L6-v2",
339
+ "all-mpnet-base-v2",
340
+ "paraphrase-MiniLM-L3-v2",
341
+ "multi-qa-mpnet-base-dot-v1" ],
342
+ vector_db_types_to_test=['pinecone'],
343
+ llm_model_name=args.llm_model,
344
+ re_ranker_model = [ "cross-encoder/ms-marco-MiniLM-L-6-v2",
345
+ "cross-encoder/ms-marco-TinyBERT-L-2"],
346
+ search_type='random', n_iter=1 )
347
+ # embedding_models_to_test = ["all-MiniLM-L6-v2",
348
+ # "all-mpnet-base-v2",
349
+ # "paraphrase-MiniLM-L3-v2",
350
+ # "multi-qa-mpnet-base-dot-v1"]
351
+ best_parameter = result['best_params']
352
+ best_score = result['best_score']
353
+ pkl_file = result['pkl_file']
354
+ best_metrics = result['best_metrics']
355
+
356
+ best_param_path = os.path.join(Configuration.DATA_DIR,'best_params.json')
357
+
358
+ with open(best_param_path, 'w') as f:
359
+ json.dump(best_parameter, f, indent=4)
360
+
361
+ tuned_db = best_parameter['vector_db_type']
362
+ tuned_path = os.path.join(Configuration.DATA_DIR,'TunedDB',tuned_db)
363
+ if tuned_db != 'pinecone':
364
+ os.makedirs(tuned_path, exist_ok=True)
365
+ tuned_collection_name = "tuned-"+Configuration.COLLECTION_NAME
366
+
367
+ tuned_params = {
368
+ 'document_path': Configuration.FULL_PDF_PATH,
369
+ 'chunk_size': best_parameter.get('chunk_size', Configuration.DEFAULT_CHUNK_SIZE),
370
+ 'chunk_overlap': best_parameter.get('chunk_overlap',Configuration.DEFAULT_CHUNK_OVERLAP),
371
+ 'embedding_model_name': best_parameter.get('embedding_model',Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL),
372
+ 'vector_db_type': tuned_db,
373
+ 'llm_model_name':args.llm_model,
374
+ 'db_path':tuned_path if tuned_db !='pinecone' else "",
375
+ 'collection_name':tuned_collection_name,
376
+ 'vector_db': None,
377
+ 're_ranker_model':best_parameter.get('re_ranker', Configuration.DEFAULT_RERANKER)
378
+ }
379
+
380
+ if 're_ranker_model' in best_parameter:
381
+ tuned_params['re_ranker_model'] = best_parameter['re_ranker_model']
382
+ else:
383
+ tuned_params['re_ranker_model'] = Configuration.DEFAULT_RERANKER
384
+
385
+ tuned_pipeline = RAGOperations.initialize_pipeline(tuned_params)
386
+ tuned_pipeline.build_index()
387
+
388
+ except Exception as ex:
389
+ print(f"Exception in hypertune: {ex} ")
390
+ traceback.print_exc()
391
+
392
+
393
+ @staticmethod
394
+ def run_llm_with_prompt(args: argparse.Namespace,run_type: str) -> None:
395
+ try:
396
+ params = RAGOperations.get_pipeline_params(args,
397
+ use_tuned=args.use_tuned)
398
+ pipeline = RAGOperations.initialize_pipeline(params)
399
+
400
+
401
+ evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
402
+ pdf_path=Configuration.FULL_PDF_PATH)
403
+
404
+ system_message = (
405
+ "You are an expert assistant for Flykite Airlines HR Policy Queries."
406
+ "Provide concise, accurate and policy-specific answers based solely on the the provided context."
407
+ "Structured your response clearly, using bullet points, newlines if applicable. "
408
+ "If the context lacks information, state that clearly and speculation."
409
+ ) if run_type == 'prompting' else None
410
+
411
+ user_query = input("Enter your query: ")
412
+ expected_answer = None
413
+ expected_keywords = []
414
+ try:
415
+ with open(Configuration.EVAL_DATA_PATH, 'r') as f:
416
+ eval_data= json.load(f)
417
+ for item in eval_data:
418
+ expected_answer = item.get('expected_answer_snippet',"")
419
+ expected_keywords = item.get('expected_keywords',[])
420
+ break
421
+ except Exception as ex:
422
+ print(f"Error loading eval_data.json for query {user_query}: {ex}")
423
+
424
+ if run_type == 'prompting':
425
+ prompt = (
426
+ f"You are an expert assistant for Flykite Airlines HR Policy Queries."
427
+ f"Answer the following question with a structured response, using bullet points or sections where applicable"
428
+ f"Base your answer solely on the query and avoid hallucination"
429
+ f"Question: \n {user_query} \n"
430
+ f"Answer: ")
431
+
432
+ else:
433
+ prompt = user_query
434
+
435
+ response = pipeline.llm.generate_response(
436
+ prompt=prompt,
437
+ system_message=system_message,
438
+ temperature = args.temperature,
439
+ top_p = args.top_p,
440
+ max_tokens = args.max_tokens
441
+ )
442
+ retreived_documents = []
443
+
444
+ eval_result = evaluator.evaluate_response(user_query,
445
+ response,
446
+ retreived_documents,
447
+ expected_keywords,
448
+ expected_answer)
449
+
450
+ output = { "query":user_query,
451
+ "response": {
452
+ "summary: ":response.strip(),
453
+ "source: ":["LLM Response Not RAG loaded"]},
454
+ "evaluation": eval_result }
455
+
456
+
457
+ print(json.dumps(output, indent=2))
458
+
459
+ except Exception as ex:
460
+ print(f"Exception in LLm_prompting response: {ex}")
461
+ traceback.print_exc()
462
+ sys.exit(1)
463
+
464
+ @staticmethod
465
+ def login() -> Dict[str,str]:
466
+ username = input("Enter your username: ")
467
+ password = input("Enter your password: ")
468
+
469
+ hashed_password = hashlib.sha256(password.encode()).hexdigest()
470
+ try:
471
+ conn = sqlite3.connect('users.db')
472
+ cursor = conn.cursor()
473
+ cursor.execute(
474
+ "SELECT username,jobrole,department,location FROM users WHERE username = ? AND password = ?",
475
+ (username, hashed_password)
476
+ )
477
+ user = cursor.fetchone()
478
+ print(f"{user}")
479
+ conn.close()
480
+ if user:
481
+ return {"username": user[0], "role": user[1],"department": user[2],"location": user[3]}
482
+ else:
483
+ print("Invalid username or password")
484
+ sys.exit(1)
485
+
486
+ except sqlite3.Error as ex:
487
+
488
+ return False
489
+
490
+
491
+
492
+
493
+
494
+
495
+ def main():
496
+
497
+ user_info = RAGOperations.login()
498
+ print(f"Logged in as {user_info['username']}"
499
+ f"{user_info['department']} at {user_info['location']}")
500
+
501
+ parser = argparse.ArgumentParser(description='RAG PIPELINE OPERATIONS')
502
+
503
+ parser.add_argument('--job',type=str, required=True,
504
+ choices=['rag-build','search','eval-hypertune','llm','prompting'],
505
+ help='job to execute build RAG index, search and hypertune')
506
+
507
+ parser.add_argument('--raw', action='store_true',
508
+ help='dispaly raw retrieved documents in json format')
509
+ parser.add_argument('--chunk_size', type=int, default=Configuration.DEFAULT_CHUNK_SIZE,help="Document chunking")
510
+ parser.add_argument('--chunk_overlap', type=int, default=Configuration.DEFAULT_CHUNK_OVERLAP,help="Document overlap")
511
+ parser.add_argument('--embedding_model', type=str, default=Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL,help="Embedding model")
512
+ parser.add_argument('--vector_db_type', type=str,default="chroma",choices=['chroma','faiss','pinecone'],help='vector database type')
513
+ parser.add_argument('--llm_model',type=str, default=Configuration.DEFAULT_GROQ_LLM_MODEL)
514
+ parser.add_argument('--use-tuned',action='store_true',help='Use the tuned DB for search')
515
+ parser.add_argument('--k',type=int, default=5,help='number of doc to retrieve')
516
+ parser.add_argument('--query',type=str, default=None,help='query from user')
517
+ parser.add_argument('--temperature', type=float, default=0.1,help='LLM temperature for response generation')
518
+ parser.add_argument('--top_p', type=float, default=0.95, help='LLM top_p for response generation')
519
+ parser.add_argument('--max_tokens', type=int, default = 1000, help='LLM max tokens for response generation')
520
+ parser.add_argument('--use-rag', type=lambda x: x.lower() =='true',default=True, help ='llm or prompting')
521
+ parser.add_argument('--user-context', type=str, default=None, help='user context as JSON e.g., {"role": "admin", "location":"chennai", "designation":"engineer"}')
522
+ parser.add_argument('--fine-tune', action='store_true', help='use fine-tuned model')
523
+ parser.add_argument('--n-iter',type=int, default=10,help='number of iterations')
524
+ parser.add_argument('--re_ranker_model',type=str,default=Configuration.DEFAULT_RERANKER,help="ReRanking the retrieval documents")
525
+ args = parser.parse_args()
526
+
527
+ if args.job == 'eval-hypertune' and user_info['role']!= 'admin':
528
+ print(f"Access denied: Hypertune execution is forbidden")
529
+
530
+ if args.job == 'rag-build':
531
+ RAGOperations.run_build_job(args)
532
+ elif args.job == 'search':
533
+ RAGOperations.run_search_job(args,user_info)
534
+ elif args.job == 'eval-hypertune':
535
+ RAGOperations.run_hypertune_job(args)
536
+ elif args.job in ['llm','prompting']:
537
+ RAGOperations.run_llm_with_prompt(args, args.job)
538
+
539
+
540
+ if __name__ == '__main__':
541
+ main()
rag_scripts/__pycache__/interfaces.cpython-312.pyc ADDED
Binary file (5.93 kB). View file
 
rag_scripts/__pycache__/rag_pipeline.cpython-312.pyc ADDED
Binary file (10.4 kB). View file
 
rag_scripts/documents_processing/__pycache__/chunking.cpython-312.pyc ADDED
Binary file (8.42 kB). View file
 
rag_scripts/documents_processing/chunking.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import hashlib
3
+ import traceback
4
+ import pymupdf as fitz
5
+ import regex as re
6
+ from typing import List, Dict, Tuple, Any
7
+ from langchain_text_splitters import RecursiveCharacterTextSplitter
8
+
9
+ from configuration import Configuration
10
+ from rag_scripts.interfaces import IDocumentChunker
11
+
12
+
13
+ class PyMuPDFChunker(IDocumentChunker):
14
+
15
+ def __init__(self, pdf_path: str, chunk_size: int = Configuration.DEFAULT_CHUNK_SIZE,
16
+ chunk_overlap: int = Configuration.DEFAULT_CHUNK_OVERLAP):
17
+
18
+ if not os.path.exists(pdf_path):
19
+ raise FileNotFoundError(f"PDF file not found at: {pdf_path}")
20
+
21
+ self.pdf_path = pdf_path
22
+ self.chunk_size = chunk_size
23
+ self.chunk_overlap = chunk_overlap
24
+ self.text_splitter = RecursiveCharacterTextSplitter(
25
+ chunk_size=self.chunk_size,
26
+ chunk_overlap=self.chunk_overlap,
27
+ length_function =len,
28
+ separators=["\n\n","\n","(?<=\. )\n",
29
+ "(?<=[a-z0-9]\.)","(?<=\? )", "(?<=\! )",
30
+ " ","" ] )
31
+
32
+ print(f"Initialized PyMuPDFChunker for {os.path.basename(pdf_path)} with chunk_size = {chunk_size} and chunk overlap = {chunk_overlap}")
33
+
34
+ def _clean_text(self, text: str) -> str:
35
+ try:
36
+ text = text.replace('\u25cf','').replace('\u2022','')
37
+ text = text.replace('\u201c','').replace('\u201d','')
38
+ text = text.replace('\u2013','-').replace('\u2014','-').replace('\u2015','-')
39
+ text = re.sub(r'\n\s*\n',' ',text)
40
+ text = re.sub(r' {2,}', ' ',text)
41
+ text = text.replace('\\nb','')
42
+ text = '\n'.join([line.strip() for line in text.split('\n')])
43
+ text = re.sub(r'[^\x20-\x7E\t\n\r]', '', text)
44
+ return text.strip()
45
+ except Exception as ex:
46
+ traceback.print_exc()
47
+ return text.strip()
48
+
49
+
50
+ def hash_document(self):
51
+ try:
52
+ hasher = hashlib.sha256()
53
+ with open(self.pdf_path, 'rb') as fl:
54
+ while chunk:=fl.read(8192):
55
+ hasher.update(chunk)
56
+ return hasher.hexdigest()
57
+ except Exception as ex:
58
+ print(f"Exception hashing PDF {self.pdf_path}: {ex}")
59
+ traceback.print_exc()
60
+ raise ValueError(f"Failed to hash PDF: {ex}")
61
+
62
+ def chunk_documents(self) -> Tuple[str, List[Dict[str, Any]]]:
63
+ try:
64
+ doc_hash = self.hash_document()
65
+ doc_chunks =[]
66
+
67
+ with fitz.open(self.pdf_path) as document:
68
+ for idx, page in enumerate(document):
69
+ page_text = page.get_text()
70
+ if page_text.strip():
71
+ section = self._extract_section(page_text)
72
+ clause = self._extract_clause(page_text)
73
+ page_text = self._clean_text(page_text)
74
+ chunked_page = self.text_splitter.split_text(page_text)
75
+ for chunk_idx, chunk_content in enumerate(chunked_page):
76
+ if not chunk_content.strip():
77
+ continue
78
+ doc_chunks.append({
79
+ "content":chunk_content.strip(),
80
+ "metadata":{
81
+ "document_id":doc_hash,
82
+ "source_file":os.path.basename(self.pdf_path),
83
+ "page_number": idx+1,
84
+ "chunk_id":f"{doc_hash} - {idx + 1} -{chunk_idx}",
85
+ "section": section or "Unknown Section",
86
+ "clause": clause or "Unknown Clause",
87
+ "chunk_index_on_page": chunk_idx
88
+ }
89
+ })
90
+ if not doc_chunks:
91
+ print(f"No text or chunks extracted from {self.pdf_path} after cleaning")
92
+ return doc_hash,[]
93
+ else:
94
+ print(f"success Document chunked {self.pdf_path} into {len(doc_hash)} chunks")
95
+ return doc_hash,doc_chunks
96
+
97
+
98
+
99
+ except Exception as ex:
100
+ print(f"Exception in document chunking {ex}")
101
+ traceback.print_exc()
102
+ return self.hash_document(), []
103
+
104
+
105
+ def _extract_section(self, text: str) -> str:
106
+ match_major = re.search(r'^(?:[IVX]+\.?\s+|[A-Z]\.?\s+|[0-9]+\.?\s+)(.+)', text, re.MULTILINE)
107
+ if match_major:
108
+ return match_major.group(0)
109
+
110
+ match_firstline = re.search(r'^\s*([A-Za-z0-9][\w\s,&\'-]+?)\s*$', text, re.MULTILINE)
111
+ if match_firstline:
112
+ return match_firstline.group(1).strip()
113
+
114
+ return None
115
+
116
+
117
+ def _extract_clause(self,text) -> str:
118
+ match = re.search(r'^(?:(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))\s*)(.+?)(?=\n(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))|\n\n|\Z)', text, re.MULTILINE | re.DOTALL)
119
+ if match:
120
+ clause = match.group(1).strip()
121
+ if len(clause.split()) < 10:
122
+ next_match = re.search(
123
+ r'^(?:(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))\s*)(.+?)(?=\n(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))|\n\n|\Z)',
124
+ text[match.end():], re.MULTILINE | re.DOTALL)
125
+ if next_match:
126
+ clause += " " + next_match.group(0).strip()
127
+ return clause
128
+
129
+ match_para = re.search(r'^(?!#|\s*$).*?\n\n',text,re.DOTALL)
130
+ if match_para:
131
+ return match_para.group(0).strip()
132
+ return None
133
+
134
+
rag_scripts/embedding/__pycache__/embedder.cpython-312.pyc ADDED
Binary file (3.5 kB). View file
 
rag_scripts/embedding/embedder.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ from typing import List
4
+ from sentence_transformers import SentenceTransformer
5
+ from traits.trait_types import self
6
+
7
+ from configuration import Configuration
8
+ from rag_scripts.interfaces import IEmbedder
9
+
10
+ class SentenceTransformerEmbedder(IEmbedder):
11
+ def __init__(self, model_name: str = Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL):
12
+ self.model_name = model_name
13
+ try:
14
+ self.model = SentenceTransformer(self.model_name)
15
+ print(f'Sentence Transformer loaded {self.model_name}')
16
+ except Exception as ex:
17
+ print(f"Exception in loading Sentence Transformer {self.model_name}")
18
+ traceback.print_exc()
19
+ raise
20
+
21
+
22
+ def embed_texts(self, texts: List[str]) -> List[List[float]]:
23
+ try:
24
+ embeddings = self.model.encode(texts).tolist()
25
+ return embeddings
26
+ except Exception as ex:
27
+ print(f"Exception in embedding: {ex}")
28
+ traceback.print_exc()
29
+ return [[] for _ in texts]
30
+
31
+ def embed_query(self,query: str) -> List[float]:
32
+ try:
33
+ print(f"Embedding query: {query}")
34
+ embedding = self.model.encode(query).tolist()
35
+ return embedding
36
+ except Exception as ex:
37
+ print(f"Exception in query embedding: {ex}")
38
+ traceback.print_exc()
39
+ return []
40
+
41
+ def rank(selfself,Query: str, documents:List[str]) -> List[float]:
42
+ if not documents:
43
+ return []
44
+ try:
45
+ sentence_paris = [[Query, doc] for doc in documents]
46
+ scores = self.model.predict(sentence_paris)
47
+ return scores.tolist()
48
+ except Exception as ex:
49
+ print(f"Exception in ranking: {ex}")
50
+ traceback.print_exc()
51
+ return []
52
+
53
+
rag_scripts/embedding/vector_db/__pycache__/chroma_db.cpython-312.pyc ADDED
Binary file (6.96 kB). View file
 
rag_scripts/embedding/vector_db/__pycache__/faiss_db.cpython-312.pyc ADDED
Binary file (9.43 kB). View file
 
rag_scripts/embedding/vector_db/__pycache__/pinecone_db.cpython-312.pyc ADDED
Binary file (10.7 kB). View file
 
rag_scripts/embedding/vector_db/chroma_db.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chromadb
2
+ import traceback
3
+ from typing import List, Dict, Any
4
+ from chromadb.api.types import EmbeddingFunction
5
+ from configuration import Configuration
6
+ from rag_scripts.interfaces import IVector
7
+ from rag_scripts.interfaces import IEmbedder
8
+ from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
9
+
10
+ class chromaDBEmbeddingFunction(EmbeddingFunction):
11
+
12
+ def __init__(self, embedder: IEmbedder):
13
+ self.embedder = embedder
14
+
15
+ def __call__(self, texts: List[str]) -> List[List[float]]:
16
+ return self.embedder.embed_texts(texts)
17
+
18
+ class chromaDBVectorDB(IVector):
19
+ def __init__(self, embedder:IEmbedder, db_path: str = Configuration.CHROMA_DB_PATH, collection_name: str = Configuration.COLLECTION_NAME):
20
+ super().__init__(embedder, db_path, collection_name)
21
+
22
+ self.client = chromadb.PersistentClient(path=self.db_path)
23
+ self.chroma_embed_function = chromaDBEmbeddingFunction(self.embedder)
24
+ self.collection = self._get_or_create_collection()
25
+ print(f"Chroma DB intialized path = {self.db_path}, collection = {self.collection_name}")
26
+
27
+ def _get_or_create_collection(self):
28
+ return self.client.get_or_create_collection(
29
+ name=self.collection_name,
30
+ embedding_function=self.chroma_embed_function)
31
+
32
+
33
+ def add_chunks(self,documents: List[Dict[str,Any]]) -> List[str]:
34
+ try:
35
+ if not documents:
36
+ return []
37
+ ids = [doc['metadata']['chunk_id'] for doc in documents]
38
+ contents = [doc['content'] for doc in documents]
39
+ metadatas = [doc['metadata'] for doc in documents]
40
+
41
+ self.collection.add(documents=contents, metadatas=metadatas,ids=ids)
42
+
43
+ print(f"Added {len(ids)} chunks to chroma db")
44
+ return ids
45
+ except Exception as ex:
46
+ print(f"Exception in adding chunks to chromaDB: {ex}")
47
+ traceback.print_exc
48
+ return[]
49
+
50
+ def get_document_hash_ids(self, document_hash: str) -> List[str]:
51
+
52
+ try:
53
+ result = self.collection.get(where={"document_id":document_hash},limit=1)
54
+ return result['ids'] if result and result['ids'] else []
55
+ except Exception as ex:
56
+ print(f"Exception getting document hash {document_hash} from chromadb {ex}")
57
+ traceback.print_exc
58
+ return[]
59
+
60
+ def search(self, query: str, k: int=3) -> List[Dict[str, Any]]:
61
+ try:
62
+ search_result = self.collection.query(
63
+ query_texts = [query],
64
+ n_results=k,
65
+ include=['documents','metadatas','distances']
66
+ )
67
+ retrieved_documents =[]
68
+ if search_result and search_result['documents'] and search_result['metadatas']:
69
+ for indx in range(len(search_result['documents'][0])):
70
+ document_content = search_result['documents'][0][indx]
71
+ document_metadata = search_result['metadatas'][0][indx]
72
+ document_distance = search_result['distances'][0][indx]
73
+
74
+ retrieved_documents.append({
75
+ "content": document_content,
76
+ "metadata": document_metadata,
77
+ "distance": document_distance })
78
+ print(f"Rettrieved {len(retrieved_documents)} documents from chroma DB for query: '{query}'")
79
+
80
+ return retrieved_documents
81
+ except Exception as ex:
82
+ print(f"Exception in Chroma db search: {ex}")
83
+ traceback.print_exc()
84
+ return []
85
+
86
+ def delete_collection(self, collection_name: str):
87
+ try:
88
+ self.client.delete_collection(name = collection_name)
89
+ print(f"Chroma DB collection {collection_name} deleted")
90
+ self.collection = self._get_or_create_collection()
91
+ except Exception as ex:
92
+ print(f"Exception in deleting the chroma db collection: {collection_name}")
93
+ traceback.print_exc()
94
+
95
+ def count_documents(self) -> int:
96
+ try:
97
+ return self.collection.count()
98
+ except Exception as ex:
99
+ print(f"Exception in counting documents in chroma db: {ex}")
100
+ traceback.print_exc()
101
+ return 0
102
+
103
+
rag_scripts/embedding/vector_db/faiss_db.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import faiss
3
+ import pickle
4
+ import traceback
5
+ import numpy as np
6
+ from typing import List, Dict, Any
7
+
8
+ from configuration import Configuration
9
+ from rag_scripts.interfaces import IVector, IEmbedder
10
+ from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
11
+
12
+ class FAISSVectorDB(IVector):
13
+
14
+ def __init__(self, embedder: IEmbedder, db_path: str= Configuration.FAISS_DB_PATH,
15
+ collection_name: str = Configuration.COLLECTION_NAME):
16
+ super().__init__(embedder,db_path,collection_name)
17
+ self.index = None
18
+ self.doc_store: List[Dict[str,Any]] =[]
19
+
20
+ self.collection_file = os.path.join(self.db_path,f"{self.collection_name}.faiss")
21
+ self.doc_store_file = os.path.join(self.db_path, f"{self.collection_name}_docs.pkl")
22
+
23
+ os.makedirs(self.db_path, exist_ok=True)
24
+
25
+ self._load_index()
26
+
27
+ print(f"FAISSDB Initalized: path ='{self.db_path}, collection ='{self.collection_name}")
28
+
29
+
30
+ def _load_index(self):
31
+ try:
32
+ if os.path.exists(self.collection_file) and os.path.exists(self.doc_store_file):
33
+ self.index = faiss.read_index(self.collection_file)
34
+ with open(self.doc_store_file,'rb') as f:
35
+ self.doc_store = pickle.load(f)
36
+ print(f"Loaded Existing FAISS index and doc store from {self.collection_file}")
37
+ else:
38
+ print(f"No Existing FAISS index found, need new collection")
39
+ except Exception as ex:
40
+ print(f"Exception in loading FAISS index: {ex}")
41
+ traceback.print_exc()
42
+ self.index = None
43
+ self.doc_store = []
44
+
45
+ def _save_index(self):
46
+ try:
47
+ if self.index is not None:
48
+ faiss.write_index(self.index, self.collection_file)
49
+ with open(self.doc_store_file, "wb") as f:
50
+ pickle.dump(self.doc_store,f)
51
+ print(f"FAISS index and doc store saved to {self.collection_file}")
52
+
53
+ except Exception as ex:
54
+ print(f"Exception in saving FAISS index: {ex}")
55
+ traceback.print_exc()
56
+
57
+ def add_chunks(self, documents: List[Dict[str, Any]]) -> List[str]:
58
+ try:
59
+ if not documents:
60
+ return[]
61
+
62
+ contents = [doc['content'] for doc in documents]
63
+ embeddings = np.array(self.embedder.embed_texts(contents),dtype='float32')
64
+
65
+ if self.index is None:
66
+ dimension = embeddings.shape[1]
67
+ self.index = faiss.IndexFlatL2(dimension)
68
+ print(f"Created New FAISS index with dimension {dimension}")
69
+
70
+ self.index.add(embeddings)
71
+
72
+ for doc in documents:
73
+ self.doc_store.append(doc)
74
+
75
+ self._save_index()
76
+ chunk_ids = [doc['metadata']['chunk_id'] for doc in documents]
77
+ print(f"Added {len(chunk_ids)} chunks to FAISS index")
78
+ return chunk_ids
79
+ except Exception as ex:
80
+ print(f"Exception in adding chunks to FAISS: {ex}")
81
+ traceback.print_exc()
82
+ return []
83
+
84
+ def get_document_hash_ids(self, document_hash: str) ->List[str]:
85
+ try:
86
+ found_ids =[]
87
+ for doc in self.doc_store:
88
+ if doc['metadata'].get('document_id') == document_hash:
89
+ found_ids.append(doc['metadata']['chunk_id'])
90
+
91
+ return found_ids
92
+ except Exception as ex:
93
+ print(f"Exception in get document hash {ex}")
94
+
95
+ def search(self, query: str, k:int=3) -> List[Dict[str,Any]]:
96
+ try:
97
+ if self.index is None or self.index.ntotal ==0:
98
+ print(f"FAISS index is empty or not intialized")
99
+ return []
100
+ query_embedding = np.array(self.embedder.embed_query(query),dtype='float32').reshape(1,-1)
101
+ distances, indices = self.index.search(query_embedding,k)
102
+
103
+ retrieved_documents = []
104
+ for dist, idx in zip(distances[0], indices[0]):
105
+ if 0 <= idx < len(self.doc_store):
106
+ doc = self.doc_store[idx]
107
+ retrieved_documents.append({
108
+ "content":doc['content'],
109
+ "metadata":doc['metadata'],
110
+ "distance":float(dist) })
111
+ print(f"Retrieved {len(retrieved_documents)} documents from FAISS for query: {query}")
112
+
113
+ return retrieved_documents
114
+
115
+
116
+ except Exception as ex:
117
+ print(f"Exception in FAISS db search {ex}")
118
+ traceback.print_exc()
119
+ return []
120
+
121
+ def delete_collection(self,collection_name: str):
122
+ try:
123
+ if os.path.exists(self.collection_file):
124
+ os.remove(self.collection_file)
125
+ if os.path.exists(self.doc_store_file):
126
+ os.remove(self.doc_store_file)
127
+ self.index = None
128
+ self.doc_store = []
129
+ print(f"FAISS collection files for {collection_name} deleted")
130
+
131
+ except Exception as ex:
132
+ print(f"Error deleting FAISS collection files for {collection_name}: ex")
133
+ traceback.print_exc
134
+
135
+ def count_documents(self) -> int:
136
+ try:
137
+ return self.index.ntotal if self.index is not None else 0
138
+ except Exception as ex:
139
+ print(f"Exception in counting document in FAISS: {ex}")
140
+ traceback.print_exc()
141
+ return 0
142
+
143
+
144
+
145
+
rag_scripts/embedding/vector_db/pinecone_db.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import time
2
+ import traceback
3
+ from typing import List, Dict,Any
4
+ import re
5
+ from pinecone import Pinecone, ServerlessSpec
6
+ from pinecone.exceptions import PineconeException
7
+
8
+ from configuration import Configuration
9
+ from rag_scripts.interfaces import IVector, IEmbedder
10
+
11
+ class PineconeVectorDB(IVector):
12
+ def __init__(self, embedder:IEmbedder, db_path: str, collection_name: str="flykite"):
13
+ super().__init__(embedder,db_path,collection_name)
14
+
15
+ self.api_key = Configuration.PINECONE_API_KEY
16
+ self.cloud = Configuration.PINECONE_CLOUD
17
+ self.region = Configuration.PINECONE_REGION
18
+
19
+ if not self.api_key:
20
+ raise ValueError("Pinecone API KEY not provided in configuration")
21
+
22
+ self.pc = Pinecone(api_key=self.api_key)
23
+ collection_name = collection_name.replace('_','-')
24
+ self.index_name=collection_name
25
+
26
+ print(f"Collection name: {collection_name}")
27
+
28
+ self.dim=len(self.embedder.embed_texts(["test"])[0])
29
+ self._create_index()
30
+ self.index=self.pc.Index(self.index_name)
31
+
32
+ print(f"Pinecone DB Initialized index= {self.index_name}, cloud= {self.cloud}, region = {self.region}")
33
+
34
+ def _create_index(self):
35
+ try:
36
+ existing_indexes = self.pc.list_indexes().names()
37
+ if self.index_name in existing_indexes:
38
+ index = self.pc.Index((self.index_name))
39
+ index_info = index.describe_index_stats()
40
+ if index_info.dimension!=self.dim:
41
+ print(f"Pinecone DB Index already exists {self.index_name} "
42
+ f"with dimension {index_info.dimension}"
43
+ f"but the expected dimension is {self.dim}"
44
+ " so deleting it and recreating again"
45
+ )
46
+ self.pc.delete_index(self.index_name)
47
+
48
+ # while self.index_name in self.pc.list_indexes().names():
49
+ # time.sleep(2)
50
+ for _ in range(30):
51
+ if self.index_name not in self.pc.list_indexes().names():
52
+ break
53
+ else:
54
+ time.sleep(2)
55
+
56
+
57
+ elif index_info.metric!='cosine':
58
+ self.pc.delete_index(self.index_name)
59
+
60
+ for _ in range(30):
61
+ if self.index_name in self.pc.list_indexes().names():
62
+ time.sleep(2)
63
+ else:
64
+ break
65
+ else:
66
+ print(f"Pinecone index already exists {self.index_name} ")
67
+ return
68
+
69
+ if self.index_name not in self.pc.list_indexes().names():
70
+ self.pc.create_index(name=self.index_name,
71
+ dimension=self.dim,
72
+ metric='cosine',
73
+ spec=ServerlessSpec(cloud=self.cloud,region=self.region)
74
+ )
75
+ max_attempts = 90
76
+ for attempt in range(max_attempts):
77
+ try:
78
+ index = self.pc.Index(self.index_name)
79
+ index_info = index.describe_index_stats()
80
+ if index_info.dimension == self.dim:
81
+ print(f"Pinecone index {self.index_name} already exists ")
82
+ return
83
+ except PineconeException:
84
+ pass
85
+ time.sleep(2)
86
+ raise TimeoutError(f"Pinecone index {self.index_name} does not exist")
87
+
88
+
89
+ except PineconeException as ex:
90
+
91
+ print(f"Exception creating Pinecone index: {ex}")
92
+ traceback.print_exc()
93
+ raise
94
+ except Exception as ex:
95
+ print(f"Exception creating Pinecone index: {ex}")
96
+ traceback.print_exc()
97
+ raise
98
+
99
+
100
+
101
+ def add_chunks(self, documents: List[Dict[str,Any]]) -> List[str]:
102
+ try:
103
+ if not documents:
104
+ return[]
105
+ vectors_info = []
106
+ for doc in documents:
107
+ content = doc['content']
108
+ embedding = self.embedder.embed_texts([content])[0]
109
+ chunk_id = doc['metadata']['chunk_id']
110
+ metadata= doc['metadata'].copy()
111
+ metadata['content']=content
112
+
113
+ vectors_info.append({
114
+ "id":chunk_id, "values":embedding, "metadata":metadata
115
+ })
116
+
117
+ print(f"Upserting {len(vectors_info)} to pincone index")
118
+ self.index.upsert(vectors=vectors_info)
119
+ chunk_ids = [vec['id'] for vec in vectors_info]
120
+ print(f"Added {len(chunk_ids)} chunks to pinecone index {self.index_name}")
121
+ print(f"Added chunk ids {chunk_ids[:5]}")
122
+ return chunk_ids
123
+
124
+ except PineconeException as ex:
125
+ print(f"Exception in adding chunks to pinecone db {ex}")
126
+ traceback.print_exc()
127
+ return []
128
+
129
+ def get_document_hash_ids(self, document_hash: str) ->List[str]:
130
+ try:
131
+ vector_sample = [0.0]*self.dim
132
+ result = self.index.query(
133
+ vector = vector_sample,
134
+ filter = {"document_id": {"$eq": document_hash}},
135
+ top_k = 10000,
136
+ include_metadata=False )
137
+
138
+ return [match['id'] for match in result['matches']]
139
+
140
+ except PineconeException as ex:
141
+ print(f"Exception getting document hash IDs from pinecone: {ex}")
142
+ traceback.print_exc()
143
+ return []
144
+
145
+ def search(self,query: str, k:int=3) -> List[Dict[str,Any]]:
146
+ try:
147
+ print(f"searching pinecode index {self.index_name} for query {query}")
148
+ query_embedding = self.embedder.embed_query(query)
149
+ result = self.index.query(
150
+ vector=query_embedding,
151
+ top_k=k,
152
+ include_metadata=True
153
+ )
154
+
155
+ retrieved_documents = []
156
+ for match in result['matches']:
157
+ md= match['metadata']
158
+ content = md.pop('content','')
159
+
160
+ retrieved_documents.append({
161
+ "content":content,
162
+ "metadata":md,
163
+ "distance":match['score']
164
+ })
165
+
166
+ print(f"Retrieved {len(retrieved_documents)} documents from pinecone for query: {query}")
167
+
168
+ return retrieved_documents
169
+
170
+
171
+ except PineconeException as ex:
172
+ print(f"Exception in Pinecone search: {ex}")
173
+ traceback.print_exc()
174
+ return[]
175
+
176
+ def delete_collection(self, collection_name: str):
177
+ try:
178
+
179
+ self.pc.delete_index(collection_name)
180
+ print(f"Pinecone index {collection_name} deleted")
181
+ except PineconeException as ex:
182
+ print(f"Exception in deleting pinecone index: {collection_name} {ex}")
183
+ traceback.print_exc()
184
+
185
+ def count_documents(self) -> int:
186
+ try:
187
+ for attempt in range(19):
188
+ status = self.index.describe_index_stats()
189
+ count = status.get('total_vector_count',0)
190
+ if count > 0:
191
+ return count
192
+ time.sleep(7)
193
+ print(f"No vectors found in index: {self.index_name}")
194
+ return 0
195
+
196
+ except PineconeException as ex:
197
+ print(f"EXception in getting document counts: {ex}")
198
+ traceback.print_exc()
199
+ return 0
200
+
201
+
202
+
rag_scripts/evaluation/__pycache__/evaluator.cpython-312.pyc ADDED
Binary file (30.5 kB). View file
 
rag_scripts/evaluation/evaluator.py ADDED
@@ -0,0 +1,525 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import re
4
+ import time
5
+ import json
6
+ import random
7
+ import joblib
8
+ import traceback
9
+ import logging
10
+
11
+ import torch
12
+ import transformers
13
+ from bert_score import score
14
+ from itertools import product
15
+ from datetime import datetime
16
+ from typing import List, Dict, Any
17
+ from configuration import Configuration
18
+ from rag_scripts.llm.llmResponse import GROQLLM
19
+ from pinecone import Pinecone, PineconeException
20
+ from rag_scripts.rag_pipeline import RAGPipeline
21
+ from sentence_transformers import SentenceTransformer, util
22
+ from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
23
+ from rag_scripts.documents_processing.chunking import PyMuPDFChunker
24
+ from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
25
+ from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
26
+ from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
27
+ from rag_scripts.interfaces import IDocumentChunker, ILLM, IEmbedder, IRAGPipeline, IVector
28
+
29
+ logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
30
+ logger = logging.getLogger(__name__)
31
+
32
+ VECTOR_DB_CONSTRUCTORS = {
33
+ "chroma": chromaDBVectorDB,
34
+ "faiss": FAISSVectorDB,
35
+ "pinecone": PineconeVectorDB
36
+ }
37
+
38
+ class RAGEvaluator:
39
+ def __init__(self, eval_data_path: str = Configuration.EVAL_DATA_PATH,
40
+ pdf_path: str = Configuration.FULL_PDF_PATH):
41
+ if not os.path.exists(eval_data_path):
42
+ raise FileNotFoundError(f"Evaluation data not found at: {eval_data_path}")
43
+
44
+ if not os.path.exists(pdf_path):
45
+ raise FileNotFoundError(f"PDF document not found at: {pdf_path}")
46
+
47
+ with open(eval_data_path, 'r') as f:
48
+ self.eval_queries = json.load(f)
49
+
50
+ self.pdf_path = pdf_path
51
+ self.embedder = SentenceTransformer(Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL)
52
+ logger.info(f"RAG Evaluator initialized with {len(self.eval_queries)} evaluation queries")
53
+
54
+ def _sanitize_collection_name(self, name: str) -> str:
55
+ sanitized = re.sub(r'[^a-z0-9]', '-', name.lower())
56
+ sanitized = re.sub(r'-+', '-', sanitized).strip('-')
57
+ return sanitized[:45].rstrip('-') if len(sanitized) > 45 else sanitized
58
+
59
+ def _calculate_retrieval_relevance(self, query: str, retrieved_docs: List[Dict[str, Any]],
60
+ expected_answer: str =None) -> float:
61
+ if not retrieved_docs:
62
+ logger.warning(f"No retrieved documents for query: {query}")
63
+ return 0.0
64
+
65
+ query_embedding = self.embedder.encode(query, convert_to_tensor=True, normalize_embeddings=True)
66
+ reference_embedding = [query_embedding]
67
+ # if expected_answer.strip():
68
+ # expected_answer_embedding = self.embedder.encode(expected_answer,convert_to_tensor=True,
69
+ # normalize_embeddings=True,
70
+ # normalize_retrieved=True)
71
+
72
+
73
+ score_cosine = []
74
+ for doc in retrieved_docs:
75
+ doc_content = doc.get('content','')
76
+ if not doc_content.strip():
77
+ continue
78
+ doc_embedding = self.embedder.encode(doc_content, convert_to_tensor=True, normalize_embeddings=True)
79
+ max_sim_to_reference = 0.0
80
+ for ref_embedding in reference_embedding:
81
+ sim = util.cos_sim(ref_embedding, doc_embedding)
82
+ if isinstance(sim,torch.Tensor):
83
+ sim_value = sim.item()
84
+ else:
85
+ sim_value = sim
86
+ if sim_value > max_sim_to_reference:
87
+ max_sim_to_reference = sim_value
88
+ score_cosine.append(max_sim_to_reference)
89
+ cosine_score = sum(score_cosine) / len(score_cosine) if score_cosine else 0.0
90
+ #cosine_score = max(score_cosine) if score_cosine else 0.0
91
+ logger.info(f"Max cosine score for query '{query}': {cosine_score}")
92
+ return cosine_score
93
+
94
+ def _calculate_response_groundedness(self, response: Dict[str,Any], retrieved_docs: List[Dict[str, Any]]) -> float:
95
+ try:
96
+ if not retrieved_docs:
97
+ logger.warning(f"No retrieved documents for groundedness evaluation")
98
+ return 0.0
99
+ response_text = response.get("summary","")
100
+ if not isinstance(response_text,str):
101
+ response_text = str(response_text)
102
+
103
+ response_segments = re.split(r'[\u25cf\u25cb]\s*|\n\s*\n', response_text.strip())
104
+ response_segments = [seg.strip() for seg in response_segments if seg.strip()]
105
+
106
+ groundedness_scores = []
107
+ for doc in retrieved_docs:
108
+ doc_content = doc.get('content','')
109
+ if not doc_content.strip():
110
+ continue
111
+ doc_embedding = self.embedder.encode(doc_content,
112
+ convert_to_tensor=True,
113
+ normalize_embeddings=True)
114
+ for segment in response_segments:
115
+ if segment:
116
+ segment_embedding = self.embedder.encode(segment,
117
+ convert_to_tensor=True,
118
+ normalize_embeddings=True)
119
+
120
+ segment_similarity = util.cos_sim(segment_embedding,
121
+ doc_embedding).item()
122
+ groundedness_scores.append(segment_similarity)
123
+
124
+ full_response_embedding = self.embedder.encode(response_text,
125
+ convert_to_tensor=True,
126
+ normalize_embeddings=True)
127
+ full_response_similarity = util.cos_sim(full_response_embedding,doc_embedding).item()
128
+ groundedness_scores.append(full_response_similarity)
129
+
130
+ groundedness = max(
131
+ groundedness_scores) if groundedness_scores else 0.0 # Use max to handle structured responses
132
+ logger.info(f"Max groundedness score for response: {groundedness:.2f}")
133
+ return groundedness
134
+ except Exception as ex:
135
+ logger.error(f"Exception in calculating groundedness: {ex}")
136
+ traceback.print_exc()
137
+ return 0.0
138
+
139
+ def _calculate_response_relevance(self, query: str, response: Dict[str,Any],expected_answer: str=None) -> float:
140
+ try:
141
+ response_text = response.get("summary","")
142
+ response_text = response_text.replace('\n','')
143
+ if not isinstance(response_text,str) or not response_text.strip():
144
+ return 0.0
145
+
146
+ reference_text = expected_answer if expected_answer is not None and expected_answer.strip() else query
147
+
148
+ refernce_embedding = self.embedder.encode(reference_text,convert_to_tensor=True,normalize_embeddings=True)
149
+ #query_embedding = self.embedder.encode(query, convert_to_tensor=True, normalize_embeddings=True)
150
+ response_embedding = self.embedder.encode(response_text, convert_to_tensor=True, normalize_embeddings=True)
151
+ similarity = util.cos_sim(refernce_embedding, response_embedding).item()
152
+
153
+ response_length = len(response_text.split())
154
+ #query_length = len(query.split())
155
+ reference_length = len(reference_text.split())
156
+ length_penalty_factor = 1.0
157
+ if response_length > (reference_length*1.5) and response_length >50:
158
+ length_penalty_factor = max(0.8,1.0-(response_length-(reference_length*2))/100.0)
159
+ elif response_length < (reference_length*0.5) and response_length >20:
160
+ length_penalty_factor = max(0.5,response_length/(reference_length*0.5))
161
+
162
+ adjusted_relevance = similarity * length_penalty_factor
163
+ logger.info(
164
+ f"Relevance score for query '{query}': {adjusted_relevance:.2f} (similarity: {similarity:.2f}, penalty: {length_penalty_factor:.2f})")
165
+ return adjusted_relevance
166
+ except Exception as ex:
167
+ logger.error(f"Exception in relevance score calculation: {ex}")
168
+ traceback.print_exc()
169
+ return 0.0
170
+
171
+ def _calculate_bert_score(self, response: Dict[str,Any], reference: str) -> float:
172
+ try:
173
+ response_text = response.get("summary", "")
174
+ if not isinstance(response_text, str):
175
+ logger.warning("Response text is not a string")
176
+ return 0.0
177
+ if not response_text:
178
+ logger.error(f"No retrieved documents for BERT score evaluation")
179
+ return 0.0
180
+ if not reference.strip():
181
+ logger.warning(f"No reference answer provided for BERTScore calculation")
182
+ return 0.0
183
+ transformers.utils.logging.set_verbosity_error()
184
+ _, _, f1 = score([response_text], [reference], lang="en", model_type="roberta-large")
185
+ transformers.utils.logging.set_verbosity_warning()
186
+ bert_score_value = f1.item()
187
+ logger.info(f"BERT score: {bert_score_value:.2f}")
188
+ return bert_score_value
189
+ except Exception as ex:
190
+ logger.error(f"Exception in BERT score calculation: {ex}")
191
+ traceback.print_exc()
192
+ return 0.0
193
+
194
+ def evaluate_response(self, query: str, response: Dict[str,Any], retrieved_docs: List[Dict[str, Any]],
195
+ expected_keywords: List[str] = None, expected_answer: str = None) -> Dict[str, Any]:
196
+ try:
197
+ cosine_score = self._calculate_retrieval_relevance(query, retrieved_docs,expected_answer=expected_answer)
198
+ groundedness = self._calculate_response_groundedness(response, retrieved_docs)
199
+ relevance = self._calculate_response_relevance(query, response,expected_answer=expected_answer)
200
+ bert_score = self._calculate_bert_score(response, expected_answer) if expected_answer else 0.0
201
+
202
+ observations = []
203
+ if groundedness < 0.6:
204
+ observations.append(f"Response may contain ungrounded information (groundedness score: {groundedness:.2f})")
205
+ else:
206
+ observations.append(f"Response is well grounded (score: {groundedness:.2f})")
207
+
208
+ if relevance < 0.7:
209
+ observations.append(f"Response may not fully address the question (relevance score: {relevance:.2f})")
210
+ else:
211
+ observations.append(f"Response is highly relevant (score: {relevance:.2f})")
212
+
213
+ if cosine_score < 0.5:
214
+ observations.append(f"Low similarity between query and retrieved documents")
215
+
216
+ if bert_score < 0.7 and expected_answer:
217
+ observations.append(f"Low BERT score, semantic mismatch with reference answer")
218
+
219
+ return {
220
+ "cosine_score": round(cosine_score, 2),
221
+ "groundedness": round(groundedness, 2),
222
+ "relevance": round(relevance, 2),
223
+ "bert_score": round(bert_score, 2),
224
+ "observations": "; ".join(observations)
225
+ }
226
+
227
+ except Exception as ex:
228
+ logger.error(f"Exception in calculating response scores: {ex}")
229
+ traceback.print_exc()
230
+ return {
231
+ "cosine_score": 0.0,
232
+ "groundedness": 0.0,
233
+ "relevance": 0.0,
234
+ "bert_score": 0.0,
235
+ "observations": f"Error in evaluation: {ex}"
236
+ }
237
+
238
+ def evaluate_combined_params_grid(self,
239
+ chunk_size_to_test: List[int],
240
+ chunk_overlap_to_test: List[int],
241
+ embedding_models_to_test: List[str],
242
+ vector_db_types_to_test: List[str],
243
+ re_ranker_model: List[str],
244
+ llm_model_name: str = Configuration.DEFAULT_GROQ_LLM_MODEL,
245
+ search_type: str = "grid",
246
+ n_iter: int = 50) -> Dict[str, Any]:
247
+ logger.info("\n--- Starting the evaluation of best parameters ---")
248
+ best_score = -1.0
249
+ best_params = {}
250
+ best_metrics = {}
251
+ results = []
252
+
253
+ param_combination = [(c_size, c_overlap, embed_model, db_type,re_ranker)
254
+ for c_size, c_overlap, embed_model, db_type,re_ranker in product(
255
+ chunk_size_to_test, chunk_overlap_to_test,
256
+ embedding_models_to_test, vector_db_types_to_test,re_ranker_model)
257
+ if c_overlap < c_size]
258
+
259
+ param_to_test = (random.sample(param_combination, min(n_iter, len(param_combination)))
260
+ if search_type.lower() == 'random' else param_combination)
261
+ logger.info(f"Testing {len(param_to_test)} {'random' if search_type.lower() == 'random' else 'all'} "
262
+ f"combinations out of {len(param_combination)}")
263
+
264
+ for idx, (c_size, c_overlap, embed_model, db_type,re_ranker) in enumerate(param_to_test, 1):
265
+ logger.info(f"\nIteration {idx}/{len(param_to_test)} \nchunk_size: {c_size} \nchunk_overlap: {c_overlap} "
266
+ f"\nembed_model: {embed_model} \nvector_db: {db_type}")
267
+
268
+ current_params_str = f"Chunk: {c_size}-{c_overlap}- Embed- {embed_model}- DB-{db_type}-{re_ranker}"
269
+ embed_model = embed_model.replace('_', '-')
270
+ temp_collection_name = self._sanitize_collection_name(
271
+ f"{Configuration.COLLECTION_NAME}-{search_type}-{c_size}-{c_overlap}-{embed_model}-{db_type}")
272
+ temp_db_path = os.path.join(Configuration.DATA_DIR, f"{db_type}_temp_{search_type}_{c_size}_{c_overlap}_{embed_model}_{embed_model}")
273
+ os.makedirs(temp_db_path, exist_ok=True)
274
+
275
+ vector_db_instance = None
276
+ try:
277
+ embedder = SentenceTransformerEmbedder(model_name=embed_model)
278
+ db_constructor = VECTOR_DB_CONSTRUCTORS.get(db_type.lower())
279
+ if not db_constructor:
280
+ logger.error(f"Unsupported vector DB type: {db_type}")
281
+ continue
282
+ vector_db_instance = db_constructor(embedder=embedder,
283
+ db_path=temp_db_path,
284
+ collection_name=temp_collection_name)
285
+
286
+
287
+ chunker_instance = PyMuPDFChunker(pdf_path=self.pdf_path,
288
+ chunk_size=c_size,
289
+ chunk_overlap=c_overlap)
290
+
291
+ llm_instance = GROQLLM(model_name=llm_model_name)
292
+
293
+ pipeline = RAGPipeline(
294
+ document_path=self.pdf_path,
295
+ chunker=chunker_instance,
296
+ embedder=embedder,
297
+ vector_db=vector_db_instance,
298
+ llm=llm_instance,
299
+ chunk_size=c_size,
300
+ chunk_overlap=c_overlap,
301
+ embedding_model_name=embed_model,
302
+ llm_model_name=llm_model_name,
303
+ db_path=temp_db_path,
304
+ collection_name=temp_collection_name,
305
+ re_ranker_model_name=re_ranker
306
+ )
307
+
308
+ pipeline.build_index()
309
+ if vector_db_instance.count_documents() == 0:
310
+ logger.warning(f"No documents foundin vector DB after build for "
311
+ f"{current_params_str}. skipping evaluation for this combination.")
312
+ continue
313
+
314
+ cosine_scores = []
315
+ groundedness_scores = []
316
+ relevance_scores = []
317
+ bert_scores = []
318
+
319
+ for eval_item in self.eval_queries:
320
+ query = eval_item['query']
321
+ expected_answer = eval_item.get('expected_answer_snippet', '')
322
+
323
+ expected_keywords = eval_item.get('expected_keywords', [])
324
+ retrieved_docs = pipeline.retrieve_raw_documents(query, k=3)
325
+ response = pipeline.query(query, k=3)
326
+ logger.debug(f"Query: {query} \n Response: {json.dumps(response,indent=2, ensure_ascii=False)}")
327
+ if not expected_answer.strip():
328
+ expected_answer = self._syntesize_raw_reference(retrieved_docs)
329
+
330
+
331
+ eval_result = self.evaluate_response(query, response,
332
+ retrieved_docs,
333
+ expected_answer=expected_answer,
334
+ expected_keywords=expected_keywords)
335
+
336
+ cosine_scores.append(eval_result['cosine_score'])
337
+ groundedness_scores.append(eval_result['groundedness'])
338
+ relevance_scores.append(eval_result['relevance'])
339
+ bert_scores.append(eval_result['bert_score'])
340
+
341
+ average_cosine_score = round(sum(cosine_scores) / len(cosine_scores) if cosine_scores else 0.0, 2)
342
+ average_groundedness = round(sum(groundedness_scores) / len(groundedness_scores) if groundedness_scores else 0.0, 2)
343
+ average_relevance = round(sum(relevance_scores) / len(relevance_scores) if relevance_scores else 0.0, 2)
344
+ average_bert_score = round(sum(bert_scores) / len(bert_scores) if bert_scores and any(bert_scores) else 0.0, 2)
345
+
346
+ average_score = round((0.2* average_cosine_score +
347
+ 0.35* average_groundedness +
348
+ 0.35* average_relevance +
349
+ 0.15*average_bert_score), 2)
350
+
351
+ results.append({
352
+ "iteration": idx,
353
+ "chunk_size": c_size,
354
+ "chunk_overlap": c_overlap,
355
+ "embedding_model": embed_model,
356
+ "vector_db_type": db_type,
357
+ "average_retrieval_relevance_score": average_score,
358
+ "average_cosine_score": average_cosine_score,
359
+ "average_groundedness": average_groundedness,
360
+ "average_relevance": average_relevance,
361
+ "average_bert_score": average_bert_score
362
+ })
363
+
364
+ logger.info(f"Average retrieval relevance score: {average_score}")
365
+ logger.info(f"Average cosine score: {average_cosine_score}")
366
+ logger.info(f"Average groundedness: {average_groundedness}")
367
+ logger.info(f"Average relevance: {average_relevance}")
368
+ logger.info(f"Average BERT score: {average_bert_score}")
369
+
370
+ logger.info(f"Comparing the average score {average_score} with best score {best_score}")
371
+ if average_score > best_score:
372
+ best_score = average_score
373
+ best_params = {
374
+ "iteration":idx,
375
+ "chunk_size": c_size,
376
+ "chunk_overlap": c_overlap,
377
+ "embedding_model": embed_model,
378
+ "vector_db_type": db_type,
379
+ "re_ranker_model":re_ranker
380
+ }
381
+ best_metrics = {
382
+ "average_retrieval_relevance_score": average_score,
383
+ "average_cosine_score": average_cosine_score,
384
+ "average_groundedness": average_groundedness,
385
+ "average_relevance": average_relevance,
386
+ "average_bert_score": average_bert_score,
387
+ "re_ranker_model":re_ranker
388
+ }
389
+
390
+ logger.info(f"Best score: {best_score}")
391
+ logger.info(f"Best params: {best_params}")
392
+
393
+ except (PineconeException, ValueError) as ex:
394
+ logger.error(f"Exception in grid search for chunk_size={c_size}, "
395
+ f"chunk_overlap={c_overlap}, embed_model={embed_model}, vector_db={db_type}: {ex}")
396
+ traceback.print_exc()
397
+ finally:
398
+ if 'vector_db_instance' in locals() and vector_db_instance is not None:
399
+ try:
400
+ vector_db_instance.delete_collection(temp_collection_name)
401
+ if isinstance(vector_db_instance, chromaDBVectorDB):
402
+ del vector_db_instance.client
403
+ time.sleep(5)
404
+ if isinstance(vector_db_instance, PineconeVectorDB):
405
+ pc = Pinecone(api_key=Configuration.PINECONE_API_KEY)
406
+ if temp_collection_name in pc.list_indexes().names():
407
+ pc.delete_index(temp_collection_name)
408
+ logger.info(f"Pinecone index: {temp_collection_name} deleted")
409
+ except Exception as cleanup_ex:
410
+ logger.error(f"Error during cleanup: {cleanup_ex}")
411
+ traceback.print_exc()
412
+
413
+ logger.info("\n---- Evaluation completed ----")
414
+ logger.info(f"Best parameters: {best_params}")
415
+ logger.info(f"Best score: {best_score:.2f}")
416
+ logger.info(f"Total iterations evaluated: {len(results)}")
417
+
418
+ pkl_directory = os.path.join(Configuration.DATA_DIR, "eval_results")
419
+ os.makedirs(pkl_directory, exist_ok=True)
420
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M")
421
+ pkl_file = os.path.join(pkl_directory, f"eval_results_{search_type}_{timestamp}.pkl")
422
+
423
+ try:
424
+ joblib.dump({
425
+ "best_params": best_params,
426
+ "best_score": best_score,
427
+ "results": results,
428
+ "best_metrics": best_metrics
429
+ }, pkl_file)
430
+ logger.info(f"Results saved to {pkl_file}")
431
+ except Exception as ex:
432
+ logger.error(f"Exception in saving results to {pkl_file}: {ex}")
433
+
434
+ return {
435
+ "best_params": best_params,
436
+ "best_score": best_score,
437
+ "results": results,
438
+ "best_metrics": best_metrics,
439
+ "pkl_file": pkl_file
440
+ }
441
+
442
+
443
+ def _evaluate_with_llm(self, query:str, response_summary: str,retrieved_docs_contents: List[str]) -> Dict[str,Any]:
444
+ try:
445
+ context = "\n".join(retrieved_docs_contents)
446
+ system_message = f"""
447
+ You are an expert, Impartial judge for evaluating Retrieval Augmented Generation (RAG) system message.
448
+ Your ONLY task is to output a JSON object containing the evaluation score and reasoning.
449
+ DO NOT include any other text, explanations, conversational remarks or markdown code blocks(```json).
450
+ Strictly adhere to the requested JSON format.
451
+ """
452
+ prompt = f"""
453
+ You are evaluating a RAG system's response.
454
+
455
+ Query: "{query}"
456
+ Retrieved context: --- {context} ---
457
+ RAG Systems Response: "{response_summary}"
458
+
459
+ please provide a score for Groundedness and Relevance on a scale of 1 to 5,
460
+ where 5 is excellent and 1 is very poor.
461
+
462
+ Groundedness: is how the response is supported *only* to the Retrieved context with no hallucinations ?
463
+ 1. Contains significant information not supported by the context or contradicts it.
464
+ 2. Contains some unsupported information.
465
+ 3. Mostly grounded, but might have minor deviations or additions.
466
+ 4. Almost entirely grounded in the context.
467
+ 5. Fully and accurately, using only information from the context.
468
+
469
+ Relevance: how well the response is directly and comprehensively answers the query based on the context ?
470
+ 1. Does not answer the query at all, or answers a different question.
471
+ 2. Addresses the query partially but misses significant part or is off-topic.
472
+ 3. Answer the query reasonably well, but could be more complete or focused.
473
+ 4. Answer the query well, covering most relevant aspects.
474
+ 5. Answer the query completely, accurately and concisely, directly addressing all aspects.
475
+
476
+ output your assessment ONLY in the following JSON format. no other text.
477
+ {{
478
+ "groundedness_score": <int> out of 5,
479
+ "relevance_score": <int> out of 5,
480
+ "reasoning": "Brief explanation for the scores."
481
+ }}
482
+ """
483
+
484
+ #llm_response = GROQLLM.generate_response(prompt=prompt,temperature=0.1,top_p=0.95,max_tokens=1500)
485
+ eval_with_llm = GROQLLM(model_name="llama-3.3-70b-versatile")
486
+ llm_response = eval_with_llm.generate_response(
487
+ prompt=prompt,
488
+ system_message=system_message,
489
+ temperature=0.1,
490
+ top_p=0.95,
491
+ max_tokens=1500)
492
+
493
+ print("\n -- LLM Judge Raw Response --")
494
+ print(llm_response)
495
+ print('-'*50)
496
+
497
+ eval_scores = json.loads(llm_response)
498
+ return {
499
+ "Groundedness score": eval_scores.get("groundedness_score",0),
500
+ "Relevance score": eval_scores.get("relevance_score",0),
501
+ "Reasoning": eval_scores.get("reasoning","")
502
+ }
503
+ except Exception as ex:
504
+ traceback.print_exc()
505
+ #return ""
506
+ return {
507
+ "Groundedness score": 0,
508
+ "Relevance score": 0,
509
+ "Reasoning": f"Exception in LLM evaluation: {ex}"
510
+ }
511
+
512
+
513
+ def _syntesize_raw_reference(self, retrieved_docs: List[Dict[str, Any]]) -> str:
514
+ try:
515
+ if not retrieved_docs:
516
+ return ""
517
+ raw_content_snippets = [doc.get('content','') for doc in retrieved_docs if doc.get('content')]
518
+ raw_answer = " ".join(raw_content_snippets)
519
+
520
+ raw_answer = " ".join(raw_answer.split()).join(sorted(list(set(raw_answer.split()))))
521
+
522
+ return raw_answer
523
+ except Exception as ex:
524
+ traceback.print_exc()
525
+ return ""
rag_scripts/interfaces.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from typing import List, Dict, Tuple, Any
3
+
4
+ class IDocumentChunker(ABC):
5
+ @abstractmethod
6
+ def hash_document(self) -> str:
7
+ """This function will generate a unique hash for the document"""
8
+ pass
9
+
10
+ @abstractmethod
11
+ def chunk_documents(self) -> Tuple[str,List[Dict[str,Any]]]:
12
+ '''
13
+ This function will be chunking the document
14
+ return: Tuple[str, List[Dict[str,Any]]]
15
+ this will return the hashed Document ID
16
+ and the dict which contains the content and metadata
17
+ '''
18
+ pass
19
+
20
+
21
+
22
+ class IEmbedder(ABC):
23
+ @abstractmethod
24
+ def embed_texts(self,texts: List[str]) -> List[List[float]]:
25
+ '''
26
+ This fucntion generates the embeddings for a list of text strings
27
+ Return: List[List[float]]: A list of embedding vectors
28
+ '''
29
+ pass
30
+
31
+ @abstractmethod
32
+ def embed_query(self, query: str) -> List[float]:
33
+ '''
34
+ This function generates a embedding for a single query string
35
+ returns: List[float]: A single embedding vector
36
+ '''
37
+ pass
38
+
39
+ class IVector(ABC):
40
+ def __init__(self, embedder: IEmbedder, db_path: str, collection_name: str):
41
+ self.embedder = embedder
42
+ self.db_path = db_path
43
+ self.collection_name = collection_name
44
+
45
+ @abstractmethod
46
+ def add_chunks(self, documents: List[Dict[str,Any]]) -> List[str]:
47
+ '''
48
+ This function add the list of documents chunk to the vector database
49
+ Argument:
50
+ documents: List of chunk dictionaries which contains content and metadata
51
+ Returns:
52
+ List[str]: List of Chunk id
53
+ '''
54
+ pass
55
+
56
+ @abstractmethod
57
+ def get_document_hash_ids(self, document_hash: str) -> List[str]:
58
+
59
+ '''
60
+ This function check the document is already exist based ont he hash id
61
+ retunr: List[str]: List of chunk IDs associated with the document hash or empty list if not found
62
+ '''
63
+ pass
64
+
65
+ @abstractmethod
66
+ def search(self, query: str, k: int=3) -> List[Dict[str,Any]]:
67
+ '''
68
+ This function searches the vector database for relevant documents based on a query.
69
+ Argument:
70
+ query: user query string
71
+ k: The number of top-k results to retrieve
72
+ returns:
73
+ List[Dict]: A List of retrieved document chunk, including content, metadata and distance
74
+
75
+ '''
76
+ pass
77
+
78
+ @abstractmethod
79
+ def delete_collection(self, collection_name: str):
80
+ '''
81
+ Deletes a sepcified colelction or index from the vector database.
82
+ useful for clean up during evaluations
83
+ '''
84
+ pass
85
+
86
+ @abstractmethod
87
+ def count_documents(self) -> int:
88
+ '''
89
+ Returns the number of documents/chunks in the collection
90
+ '''
91
+ pass
92
+
93
+
94
+ class ILLM(ABC):
95
+ @abstractmethod
96
+ def generate_response(self, prompt: str, system_message: str = None) -> str:
97
+ '''
98
+ Generates a response from the LLM based on a prompt
99
+
100
+ '''
101
+ pass
102
+
103
+ class IRAGPipeline(ABC):
104
+ @abstractmethod
105
+ def build_index(self):
106
+ '''
107
+ Builds the Rag Pipeline for documents index
108
+ '''
109
+ pass
110
+
111
+ @abstractmethod
112
+ def query(self, user_query: str) -> str:
113
+ '''
114
+ Quering the RAG pipeline
115
+
116
+ '''
117
+ pass
118
+
119
+ @abstractmethod
120
+ def retrieve_raw_documents(self, user_query: str) -> List[Dict]:
121
+ '''
122
+ Retrives the raw documents without LLM generation
123
+ '''
124
+ pass
125
+
126
+
127
+
128
+
129
+
rag_scripts/llm/__pycache__/llmResponse.cpython-312.pyc ADDED
Binary file (3.64 kB). View file
 
rag_scripts/llm/llmResponse.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import traceback
3
+ from typing import List, Dict, Any, Optional
4
+ from groq import Groq
5
+
6
+ from rag_scripts.interfaces import ILLM
7
+ from configuration import Configuration
8
+
9
+ class GROQLLM(ILLM):
10
+ def __init__(self, api_key: str = Configuration.GROQ_API_KEY,
11
+ model_name: str = Configuration.DEFAULT_GROQ_LLM_MODEL):
12
+ if not api_key:
13
+ raise ValueError("Groq API not provided in env file")
14
+
15
+ self.api_key = api_key
16
+ self.model_name = model_name
17
+ self.client = Groq(api_key=self.api_key)
18
+ self.default_system_message = ("You are a helpful assistant for Flykite Airline HT Policy document queries."
19
+ "provide a concise and accurate answers based strictly on the provided context."
20
+ "Do not hallucinate or add ungrounded details.")
21
+
22
+ print(f"Initialized Groq LLM with model: {self.model_name}")
23
+
24
+ def generate_response(self, prompt, system_message: Optional[str] = None,
25
+ context: Optional[List[Dict]] = None,
26
+ temperature: float = 0.1,
27
+ top_p: float = 0.95,
28
+ max_tokens: int = 1000) -> str:
29
+ try:
30
+ complete_prompt = prompt
31
+ if context:
32
+ context_text ="\n".join([doc['content'] for doc in context])
33
+ complete_prompt = f"Context:\n{context_text}\n\nQuestion: {prompt}"
34
+
35
+ message = [{"role": "system", "content":system_message or self.default_system_message},
36
+ {"role":"user","content":complete_prompt} ]
37
+
38
+ completion = self.client.chat.completions.create(
39
+ model = self.model_name,
40
+ messages= message,
41
+ temperature=temperature if temperature is not None else 0.1,
42
+ max_tokens=max_tokens if max_tokens is not None else 1000,
43
+ top_p=top_p if top_p is not None else 0.95,
44
+ stream=True, stop=None
45
+ )
46
+
47
+ response_content = ""
48
+ for chunk in completion:
49
+ if chunk.choices and chunk.choices[0].delta.content:
50
+ response_content+=chunk.choices[0].delta.content
51
+ return response_content.strip()
52
+
53
+ except Exception as ex:
54
+ print(f"Exception in LLM response: {ex}")
55
+ traceback.print_exc()
56
+
57
+
58
+ def llm_judge(selfself, prompt:str, **kwargs) -> Dict[str,Any]:
59
+ try:
60
+ pass
61
+ except Exception as ex:
62
+ traceback.print_exc()
63
+
rag_scripts/rag_pipeline.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import traceback
4
+ import math
5
+ from typing import List, Dict, Any
6
+
7
+ from configuration import Configuration
8
+ from rag_scripts.interfaces import IDocumentChunker, IEmbedder, IVector, ILLM, IRAGPipeline
9
+ from sentence_transformers import CrossEncoder
10
+ from rag_scripts.documents_processing.chunking import PyMuPDFChunker
11
+ from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
12
+ from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
13
+ from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
14
+ from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
15
+ from rag_scripts.llm.llmResponse import GROQLLM
16
+
17
+ class RAGPipeline(IRAGPipeline):
18
+
19
+ def __init__(self,
20
+ document_path: str = Configuration.FULL_PDF_PATH,
21
+ chunker: IDocumentChunker = None,
22
+ embedder: IEmbedder = None,
23
+ vector_db: IVector = None,
24
+ llm: ILLM = None,
25
+ chunk_size: int = Configuration.DEFAULT_CHUNK_SIZE,
26
+ chunk_overlap: int = Configuration.DEFAULT_CHUNK_OVERLAP,
27
+ embedding_model_name: str = Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL,
28
+ llm_model_name: str = Configuration.DEFAULT_GROQ_LLM_MODEL,
29
+ vector_db_type: str = "chroma",
30
+ db_path: str = None,
31
+ collection_name: str = None,
32
+ re_ranker_model_name: str = Configuration.DEFAULT_RERANKER
33
+ ):
34
+ self.document_path = document_path
35
+
36
+ self.chunker = chunker if chunker else PyMuPDFChunker(
37
+ pdf_path = self.document_path,
38
+ chunk_size=chunk_size,
39
+ chunk_overlap=chunk_overlap
40
+ )
41
+
42
+ self.embedder = embedder if embedder else SentenceTransformerEmbedder(model_name = embedding_model_name)
43
+
44
+ if vector_db:
45
+ self.vector_db = vector_db
46
+
47
+ else:
48
+ if not isinstance(vector_db_type, str):
49
+ raise ValueError("vector db type must be string")
50
+ db_path = db_path or (
51
+ Configuration.CHROMA_DB_PATH if vector_db_type.lower() == 'chroma'
52
+ else
53
+ Configuration.FAISS_DB_PATH if vector_db_type.lower() == 'faiss'
54
+ else "" )
55
+ collection_name = collection_name or Configuration.COLLECTION_NAME
56
+
57
+ if vector_db_type.lower() == "chroma":
58
+ self.vector_db = chromaDBVectorDB(
59
+ embedder = self.embedder,
60
+ db_path=db_path,
61
+ collection_name=collection_name
62
+ )
63
+ elif vector_db_type.lower() == "faiss":
64
+ self.vector_db = FAISSVectorDB(
65
+ embedder=self.embedder,
66
+ db_path=db_path,
67
+ collection_name=collection_name
68
+ )
69
+ elif vector_db_type.lower() == "pinecone":
70
+ self.vector_db = PineconeVectorDB(
71
+ embedder=self.embedder,
72
+ db_path=db_path,
73
+ collection_name=collection_name
74
+ )
75
+ else:
76
+ raise ValueError("RAG application suppots chroma or faiss db")
77
+
78
+ self.llm = llm if llm else GROQLLM(
79
+ api_key= Configuration.GROQ_API_KEY,
80
+ model_name=llm_model_name )
81
+
82
+ self.re_ranker = None
83
+ if re_ranker_model_name:
84
+ try:
85
+ self.re_ranker = CrossEncoder(re_ranker_model_name=re_ranker_model_name)
86
+ print(f"ReRanker model loaded: {re_ranker_model_name}")
87
+ except Exception as ex:
88
+ print(f"ReRanker model could not be loaded: {re_ranker_model_name}")
89
+ self.re_ranker = None
90
+
91
+ print("RAG pipeline initialized")
92
+
93
+ def build_index(self):
94
+ print(f"building index for {self.document_path}")
95
+ try:
96
+ document_hash = self.chunker.hash_document()
97
+ print(f"{document_hash}")
98
+ existing_chunk_ids = self.vector_db.get_document_hash_ids(document_hash)
99
+
100
+ if existing_chunk_ids:
101
+ print(f"Documents {self.document_path} hash: {document_hash[:8]} already present in the vector DB with {len(existing_chunk_ids)} chunks.")
102
+ return
103
+
104
+ print(f"Chunking starting")
105
+ document_hash, chunks = self.chunker.chunk_documents()
106
+
107
+ if not chunks:
108
+ print(f"No chunks generated for {self.document_path} index not built")
109
+ return
110
+
111
+ self.vector_db.add_chunks(chunks)
112
+
113
+ print(f"Index built successfully for the {self.document_path} with {len(chunks)}")
114
+
115
+
116
+ except FileNotFoundError:
117
+ print(f"Exception Document not at {self.document_path}")
118
+ traceback.print_exc()
119
+ except Exception as ex:
120
+ print(f"Exception in build index {ex}")
121
+ traceback.print_exc()
122
+
123
+ def retrieve_raw_documents(self, user_query: str, k: int =5) -> List[Dict[str,Any]]:
124
+ if self.vector_db.count_documents() == 0:
125
+ print("vector database is empty please build index first")
126
+ return []
127
+ #query_embedding = self.embedder.embed_query(query=user_query)
128
+ query_embedding = user_query
129
+ initial_retrieval_k = k*3 if self.re_ranker else k
130
+ retrieved_docs_with_score = self.vector_db.search(query_embedding, k=initial_retrieval_k)
131
+ if not retrieved_docs_with_score:
132
+ print(f"Retrieval failed for {user_query}")
133
+ return []
134
+ retrieved_docs = [doc for doc in retrieved_docs_with_score]
135
+ if self.re_ranker:
136
+ document_content = [doc['content'] for doc in retrieved_docs]
137
+ rerank_score = self.re_ranker.rerank(document_content)
138
+ doc_with_reRank = []
139
+ for idx, doc in enumerate(retrieved_docs):
140
+ doc_with_reRank.append({
141
+ 'doc': doc,
142
+ 'rerank_score': rerank_score[idx]
143
+ })
144
+ ranked_docs = sorted(doc_with_reRank, key=lambda x: x['rerank_score'], reverse=True)
145
+ final_retrieved_docs = [item['doc'] for item in ranked_docs[:k]]
146
+ else:
147
+ final_retrieved_docs = retrieved_docs
148
+ return final_retrieved_docs
149
+
150
+ def query(self, user_query: str, k: int=3,
151
+ include_metadata: bool = True,
152
+ user_context: Dict[str,Any]=None) -> Dict[str,Any]:
153
+ if not user_query.strip():
154
+ return {"summary": "Enter the query", "sources": []}
155
+
156
+ retrieved_docs = self.retrieve_raw_documents(user_query,k)
157
+
158
+ if not retrieved_docs:
159
+ return {"summary": "Unable to find relevant information in the documents for the query asked. "
160
+ "Please refer contact HR department directly or refer HR policy document",
161
+ "sources": []}
162
+
163
+
164
+
165
+ context_info =[]
166
+ metadata_info = []
167
+
168
+ for indx, doc in enumerate(retrieved_docs):
169
+ context_info.append(f"Document {indx+1} content: {doc['content']}")
170
+ if include_metadata:
171
+ metadata = doc['metadata']
172
+ metadata_info.append({
173
+ "document_id": f"DOC {indx+1}",
174
+ "page": str(metadata.get("page_number","NA")),
175
+ "section": metadata.get("section","NA"),
176
+ "clause": metadata.get("clause","NA") })
177
+
178
+ context_string = "\n".join(context_info)
179
+ user_context = user_context or {"role":"general", "location":"chennai","department":"unknown"}
180
+ context_description = (f" for a {user_context['role']} in"
181
+ f"{user_context['location']} and {user_context['department']}")
182
+
183
+
184
+
185
+ prompt = (
186
+ f"You are an expert assistant for Flykite Airlines HR Policy queries. "
187
+ f"Answer the question '{user_query}' based solely on the provided context from the Flykite Airlines HR Policy, "
188
+ f"Tailor the answer for a {user_context['role']} in {user_context['location']} and {user_context['department']}. "
189
+ f"Include only the criteria and details that directly address the question, ensuring all relevant points from the context are covered without adding unrelated information or assumptions. "
190
+ f"Present the answer in a concise format using bullet points, a table, or sections for readability, and cite specific sections and clauses from the sources where applicable. "
191
+ f"Cite specific sections and clauses from the sources using the format (soruce: DOC X, Page: Y, Section: Z, Clause: A) at the end of each relevant point or page"
192
+ f"If the query is ambiguous, ask for clarification. If the context does not fully address the question, state what is known and suggest consulting the full HR Policy or HR department. "
193
+ f"Context: \n{context_string}\n\n"
194
+ f"Answer: " )
195
+
196
+ llm_response = self.llm.generate_response(prompt)
197
+ final_response = {"summary": llm_response.strip(),
198
+ "sources":metadata_info if include_metadata else []
199
+
200
+ }
201
+
202
+
203
+
204
+ return final_response
205
+
206
+
requirements.txt ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ langchain-text-splitters
2
+ sentence-transformers
3
+ chromadb
4
+ pymupdf
5
+ regex
6
+ groq
7
+ streamlit
8
+ huggingface_hub
9
+ faiss-cpu
10
+ pinecone
11
+ bert-score
users.db ADDED
Binary file (12.3 kB). View file