Spaces:
Running
Running
Upload folder using huggingface_hub
Browse files- .gitattributes +1 -0
- .gitignore +29 -0
- .idea/.gitignore +8 -0
- .idea/Fly-Kite-AirLine-RAG-Project.iml +8 -0
- .idea/inspectionProfiles/profiles_settings.xml +6 -0
- .idea/misc.xml +7 -0
- .idea/modules.xml +8 -0
- .idea/workspace.xml +160 -0
- DOCUMENTS/Flykite_Airlines_HR_Policy.pdf +3 -0
- Dockerfile +18 -0
- HostingIntoHuggingFace.py +73 -0
- README.md +13 -10
- app.py +272 -0
- configuration.py +62 -0
- create_user_db.py +47 -0
- eval_data.json +33 -0
- main.py +541 -0
- rag_scripts/__pycache__/interfaces.cpython-312.pyc +0 -0
- rag_scripts/__pycache__/rag_pipeline.cpython-312.pyc +0 -0
- rag_scripts/documents_processing/__pycache__/chunking.cpython-312.pyc +0 -0
- rag_scripts/documents_processing/chunking.py +134 -0
- rag_scripts/embedding/__pycache__/embedder.cpython-312.pyc +0 -0
- rag_scripts/embedding/embedder.py +53 -0
- rag_scripts/embedding/vector_db/__pycache__/chroma_db.cpython-312.pyc +0 -0
- rag_scripts/embedding/vector_db/__pycache__/faiss_db.cpython-312.pyc +0 -0
- rag_scripts/embedding/vector_db/__pycache__/pinecone_db.cpython-312.pyc +0 -0
- rag_scripts/embedding/vector_db/chroma_db.py +103 -0
- rag_scripts/embedding/vector_db/faiss_db.py +145 -0
- rag_scripts/embedding/vector_db/pinecone_db.py +202 -0
- rag_scripts/evaluation/__pycache__/evaluator.cpython-312.pyc +0 -0
- rag_scripts/evaluation/evaluator.py +525 -0
- rag_scripts/interfaces.py +129 -0
- rag_scripts/llm/__pycache__/llmResponse.cpython-312.pyc +0 -0
- rag_scripts/llm/llmResponse.py +63 -0
- rag_scripts/rag_pipeline.py +206 -0
- requirements.txt +11 -0
- users.db +0 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
DOCUMENTS/Flykite_Airlines_HR_Policy.pdf filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.env
|
| 2 |
+
__pycache__/
|
| 3 |
+
.python
|
| 4 |
+
env/
|
| 5 |
+
venv/
|
| 6 |
+
*.egg-info
|
| 7 |
+
dist/
|
| 8 |
+
build/
|
| 9 |
+
*.pyd
|
| 10 |
+
*.pyo
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
.vscode/
|
| 14 |
+
*.code-workspace
|
| 15 |
+
|
| 16 |
+
DATA/chroma_temp_db_eval/
|
| 17 |
+
DATA/faiss_temp_db_eval/
|
| 18 |
+
DATA/pinecone_temp_db_eval/
|
| 19 |
+
DATA/
|
| 20 |
+
|
| 21 |
+
*.log
|
| 22 |
+
*.key
|
| 23 |
+
*.pem
|
| 24 |
+
|
| 25 |
+
#LLM
|
| 26 |
+
*.bin
|
| 27 |
+
*.h5
|
| 28 |
+
*.pt
|
| 29 |
+
|
.idea/.gitignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Default ignored files
|
| 2 |
+
/shelf/
|
| 3 |
+
/workspace.xml
|
| 4 |
+
# Editor-based HTTP Client requests
|
| 5 |
+
/httpRequests/
|
| 6 |
+
# Datasource local storage ignored files
|
| 7 |
+
/dataSources/
|
| 8 |
+
/dataSources.local.xml
|
.idea/Fly-Kite-AirLine-RAG-Project.iml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<module type="PYTHON_MODULE" version="4">
|
| 3 |
+
<component name="NewModuleRootManager">
|
| 4 |
+
<content url="file://$MODULE_DIR$" />
|
| 5 |
+
<orderEntry type="jdk" jdkName="Python 3.12" jdkType="Python SDK" />
|
| 6 |
+
<orderEntry type="sourceFolder" forTests="false" />
|
| 7 |
+
</component>
|
| 8 |
+
</module>
|
.idea/inspectionProfiles/profiles_settings.xml
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<component name="InspectionProjectProfileManager">
|
| 2 |
+
<settings>
|
| 3 |
+
<option name="USE_PROJECT_PROFILE" value="false" />
|
| 4 |
+
<version value="1.0" />
|
| 5 |
+
</settings>
|
| 6 |
+
</component>
|
.idea/misc.xml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="Black">
|
| 4 |
+
<option name="sdkName" value="Python 3.12" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ProjectRootManager" version="2" project-jdk-name="Python 3.12" project-jdk-type="Python SDK" />
|
| 7 |
+
</project>
|
.idea/modules.xml
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="ProjectModuleManager">
|
| 4 |
+
<modules>
|
| 5 |
+
<module fileurl="file://$PROJECT_DIR$/.idea/Fly-Kite-AirLine-RAG-Project.iml" filepath="$PROJECT_DIR$/.idea/Fly-Kite-AirLine-RAG-Project.iml" />
|
| 6 |
+
</modules>
|
| 7 |
+
</component>
|
| 8 |
+
</project>
|
.idea/workspace.xml
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<?xml version="1.0" encoding="UTF-8"?>
|
| 2 |
+
<project version="4">
|
| 3 |
+
<component name="AutoImportSettings">
|
| 4 |
+
<option name="autoReloadType" value="SELECTIVE" />
|
| 5 |
+
</component>
|
| 6 |
+
<component name="ChangeListManager">
|
| 7 |
+
<list default="true" id="43d0b45c-4ac8-44fd-8129-96b2dd008826" name="Changes" comment="" />
|
| 8 |
+
<option name="SHOW_DIALOG" value="false" />
|
| 9 |
+
<option name="HIGHLIGHT_CONFLICTS" value="true" />
|
| 10 |
+
<option name="HIGHLIGHT_NON_ACTIVE_CHANGELIST" value="false" />
|
| 11 |
+
<option name="LAST_RESOLUTION" value="IGNORE" />
|
| 12 |
+
</component>
|
| 13 |
+
<component name="FileTemplateManagerImpl">
|
| 14 |
+
<option name="RECENT_TEMPLATES">
|
| 15 |
+
<list>
|
| 16 |
+
<option value="Python Script" />
|
| 17 |
+
</list>
|
| 18 |
+
</option>
|
| 19 |
+
</component>
|
| 20 |
+
<component name="FlaskConsoleOptions" custom-start-script="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo, NoAppException for module in ["main.py", "wsgi.py", "app.py"]: try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print("\nFlask App: %s" % app.import_name); break except NoAppException: pass">
|
| 21 |
+
<envs>
|
| 22 |
+
<env key="FLASK_APP" value="app" />
|
| 23 |
+
</envs>
|
| 24 |
+
<option name="myCustomStartScript" value="import sys; print('Python %s on %s' % (sys.version, sys.platform)); sys.path.extend([WORKING_DIR_AND_PYTHON_PATHS]) from flask.cli import ScriptInfo, NoAppException for module in ["main.py", "wsgi.py", "app.py"]: try: locals().update(ScriptInfo(app_import_path=module, create_app=None).load_app().make_shell_context()); print("\nFlask App: %s" % app.import_name); break except NoAppException: pass" />
|
| 25 |
+
<option name="myEnvs">
|
| 26 |
+
<map>
|
| 27 |
+
<entry key="FLASK_APP" value="app" />
|
| 28 |
+
</map>
|
| 29 |
+
</option>
|
| 30 |
+
</component>
|
| 31 |
+
<component name="ProjectColorInfo">{
|
| 32 |
+
"associatedIndex": 1
|
| 33 |
+
}</component>
|
| 34 |
+
<component name="ProjectId" id="33s38NvoyHdVLzaauzWfPb9TnxM" />
|
| 35 |
+
<component name="ProjectViewState">
|
| 36 |
+
<option name="hideEmptyMiddlePackages" value="true" />
|
| 37 |
+
<option name="showLibraryContents" value="true" />
|
| 38 |
+
</component>
|
| 39 |
+
<component name="PropertiesComponent">{
|
| 40 |
+
"keyToString": {
|
| 41 |
+
"ModuleVcsDetector.initialDetectionPerformed": "true",
|
| 42 |
+
"Python.HRPolicyApp.executor": "Run",
|
| 43 |
+
"Python.create_user_db.executor": "Run",
|
| 44 |
+
"RunOnceActivity.ShowReadmeOnStart": "true",
|
| 45 |
+
"RunOnceActivity.TerminalTabsStorage.copyFrom.TerminalArrangementManager.252": "true",
|
| 46 |
+
"node.js.detected.package.eslint": "true",
|
| 47 |
+
"node.js.detected.package.tslint": "true",
|
| 48 |
+
"node.js.selected.package.eslint": "(autodetect)",
|
| 49 |
+
"node.js.selected.package.tslint": "(autodetect)",
|
| 50 |
+
"nodejs_package_manager_path": "npm",
|
| 51 |
+
"settings.editor.selected.configurable": "preferences.pluginManager",
|
| 52 |
+
"vue.rearranger.settings.migration": "true"
|
| 53 |
+
}
|
| 54 |
+
}</component>
|
| 55 |
+
<component name="SharedIndexes">
|
| 56 |
+
<attachedChunks>
|
| 57 |
+
<set>
|
| 58 |
+
<option value="bundled-js-predefined-d6986cc7102b-3aa1da707db6-JavaScript-PY-252.26830.99" />
|
| 59 |
+
<option value="bundled-python-sdk-164cda30dcd9-0af03a5fa574-com.jetbrains.pycharm.pro.sharedIndexes.bundled-PY-252.26830.99" />
|
| 60 |
+
</set>
|
| 61 |
+
</attachedChunks>
|
| 62 |
+
</component>
|
| 63 |
+
<component name="TaskManager">
|
| 64 |
+
<task active="true" id="Default" summary="Default task">
|
| 65 |
+
<changelist id="43d0b45c-4ac8-44fd-8129-96b2dd008826" name="Changes" comment="" />
|
| 66 |
+
<created>1760091804577</created>
|
| 67 |
+
<option name="number" value="Default" />
|
| 68 |
+
<option name="presentableId" value="Default" />
|
| 69 |
+
<updated>1760091804577</updated>
|
| 70 |
+
<workItem from="1760091890212" duration="859000" />
|
| 71 |
+
<workItem from="1760093438967" duration="2452000" />
|
| 72 |
+
<workItem from="1760096239520" duration="21456000" />
|
| 73 |
+
<workItem from="1760158047614" duration="33908000" />
|
| 74 |
+
<workItem from="1760276025638" duration="8385000" />
|
| 75 |
+
<workItem from="1760361998945" duration="294000" />
|
| 76 |
+
<workItem from="1760696201916" duration="16790000" />
|
| 77 |
+
<workItem from="1760716047458" duration="13878000" />
|
| 78 |
+
<workItem from="1760764472535" duration="4802000" />
|
| 79 |
+
<workItem from="1760802622597" duration="4499000" />
|
| 80 |
+
<workItem from="1760844577587" duration="42402000" />
|
| 81 |
+
</task>
|
| 82 |
+
<servers />
|
| 83 |
+
</component>
|
| 84 |
+
<component name="TypeScriptGeneratedFilesManager">
|
| 85 |
+
<option name="version" value="3" />
|
| 86 |
+
</component>
|
| 87 |
+
<component name="XDebuggerManager">
|
| 88 |
+
<breakpoint-manager>
|
| 89 |
+
<breakpoints>
|
| 90 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 91 |
+
<url>file://$PROJECT_DIR$/rag_scripts/rag_pipeline.py</url>
|
| 92 |
+
<option name="timeStamp" value="2" />
|
| 93 |
+
</line-breakpoint>
|
| 94 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 95 |
+
<url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
|
| 96 |
+
<line>245</line>
|
| 97 |
+
<option name="timeStamp" value="3" />
|
| 98 |
+
</line-breakpoint>
|
| 99 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 100 |
+
<url>file://$PROJECT_DIR$/rag_scripts/documents_processing/chunking.py</url>
|
| 101 |
+
<line>49</line>
|
| 102 |
+
<option name="timeStamp" value="4" />
|
| 103 |
+
</line-breakpoint>
|
| 104 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 105 |
+
<url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
|
| 106 |
+
<line>396</line>
|
| 107 |
+
<option name="timeStamp" value="6" />
|
| 108 |
+
</line-breakpoint>
|
| 109 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 110 |
+
<url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
|
| 111 |
+
<line>163</line>
|
| 112 |
+
<option name="timeStamp" value="10" />
|
| 113 |
+
</line-breakpoint>
|
| 114 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 115 |
+
<url>file://$PROJECT_DIR$/rag_scripts/rag_pipeline.py</url>
|
| 116 |
+
<line>151</line>
|
| 117 |
+
<option name="timeStamp" value="11" />
|
| 118 |
+
</line-breakpoint>
|
| 119 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 120 |
+
<url>file://$PROJECT_DIR$/rag_scripts/llm/llmResponse.py</url>
|
| 121 |
+
<line>28</line>
|
| 122 |
+
<option name="timeStamp" value="12" />
|
| 123 |
+
</line-breakpoint>
|
| 124 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 125 |
+
<url>file://$PROJECT_DIR$/main.py</url>
|
| 126 |
+
<line>237</line>
|
| 127 |
+
<option name="timeStamp" value="13" />
|
| 128 |
+
</line-breakpoint>
|
| 129 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 130 |
+
<url>file://$PROJECT_DIR$/main.py</url>
|
| 131 |
+
<line>296</line>
|
| 132 |
+
<option name="timeStamp" value="14" />
|
| 133 |
+
</line-breakpoint>
|
| 134 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 135 |
+
<url>file://$PROJECT_DIR$/rag_scripts/evaluation/evaluator.py</url>
|
| 136 |
+
<line>451</line>
|
| 137 |
+
<option name="timeStamp" value="15" />
|
| 138 |
+
</line-breakpoint>
|
| 139 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 140 |
+
<url>file://$PROJECT_DIR$/rag_scripts/llm/llmResponse.py</url>
|
| 141 |
+
<line>27</line>
|
| 142 |
+
<option name="timeStamp" value="16" />
|
| 143 |
+
</line-breakpoint>
|
| 144 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 145 |
+
<url>file://$PROJECT_DIR$/rag_scripts/llm/llmResponse.py</url>
|
| 146 |
+
<option name="timeStamp" value="17" />
|
| 147 |
+
</line-breakpoint>
|
| 148 |
+
<line-breakpoint enabled="true" suspend="THREAD" type="python-line">
|
| 149 |
+
<url>file://$PROJECT_DIR$/main.py</url>
|
| 150 |
+
<line>248</line>
|
| 151 |
+
<option name="timeStamp" value="18" />
|
| 152 |
+
</line-breakpoint>
|
| 153 |
+
</breakpoints>
|
| 154 |
+
</breakpoint-manager>
|
| 155 |
+
</component>
|
| 156 |
+
<component name="com.intellij.coverage.CoverageDataManagerImpl">
|
| 157 |
+
<SUITE FILE_PATH="coverage/Fly_Kite_AirLine_RAG_Project$HRPolicyApp.coverage" NAME="HRPolicyApp Coverage Results" MODIFIED="1760718510975" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
| 158 |
+
<SUITE FILE_PATH="coverage/Fly_Kite_AirLine_RAG_Project$create_user_db.coverage" NAME="create_user_db Coverage Results" MODIFIED="1760724265707" SOURCE_PROVIDER="com.intellij.coverage.DefaultCoverageFileProvider" RUNNER="coverage.py" COVERAGE_BY_TEST_ENABLED="false" COVERAGE_TRACING_ENABLED="false" WORKING_DIRECTORY="$PROJECT_DIR$" />
|
| 159 |
+
</component>
|
| 160 |
+
</project>
|
DOCUMENTS/Flykite_Airlines_HR_Policy.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:1a5cdeec566a672375800654a4d004e4c3b7907578bd8c177ac5755e1cdb6cd0
|
| 3 |
+
size 262937
|
Dockerfile
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM python:3.12-slim-bookworm
|
| 2 |
+
LABEL authors="karthikeyan"
|
| 3 |
+
|
| 4 |
+
RUN apt-get update && apt-get install -y \
|
| 5 |
+
build-essential \
|
| 6 |
+
&& rm -rf /var/lib/apt/lists/*
|
| 7 |
+
|
| 8 |
+
WORKDIR /app
|
| 9 |
+
|
| 10 |
+
COPY . .
|
| 11 |
+
|
| 12 |
+
#COPY requirements.txt .
|
| 13 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
+
RUN mkdir -p data/TunedDB data/ChromaDB data/FIASS_DB
|
| 15 |
+
|
| 16 |
+
EXPOSE 7860
|
| 17 |
+
|
| 18 |
+
CMD ["streamlit", "run", "app.py", "--server.port=7860", "--server.address=0.0.0.0", "--server.enableXsrfProtection=false"]
|
HostingIntoHuggingFace.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import inspect
|
| 3 |
+
import traceback
|
| 4 |
+
from huggingface_hub import HfApi, create_repo, login,hf_hub_download
|
| 5 |
+
from huggingface_hub.utils import RepositoryNotFoundError
|
| 6 |
+
from configuration import Configuration
|
| 7 |
+
class HostingInHuggingFace:
|
| 8 |
+
def __init__(self):
|
| 9 |
+
self.base_path = Configuration.PROJECT_ROOT
|
| 10 |
+
self.hf_token = Configuration.HF_TOKEN
|
| 11 |
+
self.repo_id = 'jpkarthikeyan/FlyKiteAirlines'
|
| 12 |
+
|
| 13 |
+
def CreatingSpaceInHF(self):
|
| 14 |
+
print(f"Function Name {inspect.currentframe().f_code.co_name}")
|
| 15 |
+
api = HfApi()
|
| 16 |
+
try:
|
| 17 |
+
print(f"Checking for {self.repo_id} is correct or not")
|
| 18 |
+
api.repo_info(repo_id = self.repo_id,
|
| 19 |
+
repo_type='space',
|
| 20 |
+
token = self.hf_token)
|
| 21 |
+
print(f"Space {self.repo_id} already exists")
|
| 22 |
+
except RepositoryNotFoundError:
|
| 23 |
+
create_repo(repo_id=self.repo_id,
|
| 24 |
+
repo_type='space',
|
| 25 |
+
space_sdk='docker',
|
| 26 |
+
private=False,
|
| 27 |
+
token=self.hf_token)
|
| 28 |
+
print(f"Space created in {self.repo_id}")
|
| 29 |
+
except Exception as ex:
|
| 30 |
+
print(f"Exception in creating space {ex}")
|
| 31 |
+
traceback.print_exc()
|
| 32 |
+
finally:
|
| 33 |
+
print('-'*50)
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def UploadDeploymentFile(self):
|
| 37 |
+
print(f"Function Name {inspect.currentframe().f_code.co_name}")
|
| 38 |
+
try:
|
| 39 |
+
api = HfApi(token=self.hf_token)
|
| 40 |
+
#directory_to_upload = os.path.join(self.base_path,'Deployment')
|
| 41 |
+
directory_to_upload = self.base_path
|
| 42 |
+
print(f"Directory to upload {directory_to_upload} into HF Space {self.repo_id}")
|
| 43 |
+
api.upload_folder(repo_id=self.repo_id, folder_path=directory_to_upload,
|
| 44 |
+
repo_type='space')
|
| 45 |
+
print(f"Successfully upload {directory_to_upload} into {self.repo_id}")
|
| 46 |
+
return True
|
| 47 |
+
except Exception as ex:
|
| 48 |
+
print(f"Exception occured {ex}")
|
| 49 |
+
print(traceback.print_exc())
|
| 50 |
+
return False
|
| 51 |
+
finally:
|
| 52 |
+
print('-'*50)
|
| 53 |
+
|
| 54 |
+
def ToRunPipeline(self):
|
| 55 |
+
try:
|
| 56 |
+
self.CreatingSpaceInHF()
|
| 57 |
+
if self.UploadDeploymentFile():
|
| 58 |
+
print('Deployment pipeline completed')
|
| 59 |
+
return True
|
| 60 |
+
else:
|
| 61 |
+
print('Deployment pipeline failed')
|
| 62 |
+
return False
|
| 63 |
+
except Exception as ex:
|
| 64 |
+
print(f"Exception occured {ex}")
|
| 65 |
+
print(traceback.print_exc())
|
| 66 |
+
return False
|
| 67 |
+
finally:
|
| 68 |
+
print('-'*50)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
if __name__ == '__main__':
|
| 72 |
+
hosting = HostingInHuggingFace()
|
| 73 |
+
hosting.ToRunPipeline()
|
README.md
CHANGED
|
@@ -1,10 +1,13 @@
|
|
| 1 |
-
---
|
| 2 |
-
title:
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
-
sdk: docker
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: FlyKite Airlines HR Policy
|
| 3 |
+
emoji: ✈️ 🤗
|
| 4 |
+
colorFrom: blue
|
| 5 |
+
colorTo: green
|
| 6 |
+
sdk: docker
|
| 7 |
+
sdk_version: 3.9
|
| 8 |
+
app_file: app.py
|
| 9 |
+
app_type: streamlit
|
| 10 |
+
pinned: false
|
| 11 |
+
license: mit
|
| 12 |
+
---
|
| 13 |
+
The streamlit app is frontend for RAG App for FlyKite Airlines HR Policy
|
app.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
import subprocess
|
| 3 |
+
import os
|
| 4 |
+
import json
|
| 5 |
+
import sqlite3
|
| 6 |
+
import hashlib
|
| 7 |
+
from typing import Optional, Dict, Any
|
| 8 |
+
|
| 9 |
+
# Assuming Configuration.py and other rag_scripts are in the same directory or accessible via sys.path
|
| 10 |
+
# You might need to adjust sys.path if your project structure is different
|
| 11 |
+
import sys
|
| 12 |
+
# Make sure this path is correct for your setup
|
| 13 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
|
| 14 |
+
from configuration import Configuration
|
| 15 |
+
|
| 16 |
+
# --- Helper functions (from your main.py or adapted) ---
|
| 17 |
+
|
| 18 |
+
# Removed create_user_db() as per your request
|
| 19 |
+
|
| 20 |
+
def authenticate_user(username, password) -> Optional[Dict[str,str]]:
|
| 21 |
+
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
| 22 |
+
conn = sqlite3.connect('users.db') # Ensure 'users.db' is accessible
|
| 23 |
+
cursor = conn.cursor()
|
| 24 |
+
cursor.execute(
|
| 25 |
+
"SELECT username, jobrole, department, location FROM users WHERE username = ? AND password = ?",
|
| 26 |
+
(username, hashed_password)
|
| 27 |
+
)
|
| 28 |
+
user = cursor.fetchone()
|
| 29 |
+
conn.close()
|
| 30 |
+
if user:
|
| 31 |
+
return {"username": user[0], "role": user[1], "department": user[2], "location": user[3]}
|
| 32 |
+
return None
|
| 33 |
+
|
| 34 |
+
def run_rag_job(job_type: str, **kwargs) -> Dict[str, Any]:
|
| 35 |
+
"""
|
| 36 |
+
Executes a RAG job using the main.py script as a subprocess.
|
| 37 |
+
Returns the output as a dictionary (assuming main.py prints JSON).
|
| 38 |
+
"""
|
| 39 |
+
cmd = [sys.executable, 'main.py', '--job', job_type]
|
| 40 |
+
for key, value in kwargs.items():
|
| 41 |
+
if value is not None and value != '':
|
| 42 |
+
if isinstance(value, bool):
|
| 43 |
+
if value:
|
| 44 |
+
cmd.append(f'--{key.replace("_", "-")}')
|
| 45 |
+
else:
|
| 46 |
+
cmd.extend([f'--{key.replace("_", "-")}', str(value)])
|
| 47 |
+
|
| 48 |
+
# Inject user context if available
|
| 49 |
+
if 'user_context' in st.session_state and st.session_state.user_context:
|
| 50 |
+
user_ctx_str = json.dumps(st.session_state.user_context)
|
| 51 |
+
cmd.extend(['--user-context', user_ctx_str])
|
| 52 |
+
|
| 53 |
+
st.write(f"Running command: `{' '.join(cmd)}`") # For debugging
|
| 54 |
+
try:
|
| 55 |
+
process = subprocess.run(cmd, capture_output=True, text=True, check=True)
|
| 56 |
+
# st.success(f"Job '{job_type}' completed successfully.")
|
| 57 |
+
# Attempt to parse JSON output. main.py currently prints to stdout directly.
|
| 58 |
+
# We need to capture the last JSON block.
|
| 59 |
+
output_lines = process.stdout.strip().split('\n')
|
| 60 |
+
json_output = {}
|
| 61 |
+
for line in reversed(output_lines):
|
| 62 |
+
try:
|
| 63 |
+
json_output = json.loads(line)
|
| 64 |
+
break
|
| 65 |
+
except json.JSONDecodeError:
|
| 66 |
+
continue
|
| 67 |
+
return json_output
|
| 68 |
+
except subprocess.CalledProcessError as e:
|
| 69 |
+
st.error(f"Error running RAG job '{job_type}': {e.stderr}")
|
| 70 |
+
return {"error": e.stderr}
|
| 71 |
+
except Exception as e:
|
| 72 |
+
st.error(f"An unexpected error occurred: {e}")
|
| 73 |
+
return {"error": str(e)}
|
| 74 |
+
|
| 75 |
+
# --- Streamlit UI Components ---
|
| 76 |
+
|
| 77 |
+
def home_page():
|
| 78 |
+
st.title("Welcome to Flykite RAG System")
|
| 79 |
+
|
| 80 |
+
if 'logged_in' not in st.session_state:
|
| 81 |
+
st.session_state.logged_in = False
|
| 82 |
+
if 'user_info' not in st.session_state:
|
| 83 |
+
st.session_state.user_info = None
|
| 84 |
+
|
| 85 |
+
if not st.session_state.logged_in:
|
| 86 |
+
st.subheader("Login")
|
| 87 |
+
with st.form("login_form"):
|
| 88 |
+
username = st.text_input("Username")
|
| 89 |
+
password = st.text_input("Password", type="password")
|
| 90 |
+
login_button = st.form_submit_button("Login")
|
| 91 |
+
|
| 92 |
+
if login_button:
|
| 93 |
+
user_data = authenticate_user(username, password)
|
| 94 |
+
if user_data:
|
| 95 |
+
st.session_state.logged_in = True
|
| 96 |
+
st.session_state.user_info = user_data
|
| 97 |
+
st.session_state.user_context = {
|
| 98 |
+
"role": user_data['role'],
|
| 99 |
+
"department": user_data['department'],
|
| 100 |
+
"location": user_data['location']
|
| 101 |
+
}
|
| 102 |
+
st.success(f"Logged in as {user_data['username']} ({user_data['role']})")
|
| 103 |
+
st.rerun() # Changed from st.experimental_rerun()
|
| 104 |
+
else:
|
| 105 |
+
st.error("Invalid username or password.")
|
| 106 |
+
else:
|
| 107 |
+
st.write(f"You are logged in as **{st.session_state.user_info['username']}** (Role: **{st.session_state.user_info['role']}**)")
|
| 108 |
+
if st.button("Logout"):
|
| 109 |
+
st.session_state.logged_in = False
|
| 110 |
+
st.session_state.user_info = None
|
| 111 |
+
st.session_state.user_context = None
|
| 112 |
+
st.rerun() # Changed from st.experimental_rerun()
|
| 113 |
+
|
| 114 |
+
def admin_page():
|
| 115 |
+
st.title("Admin Dashboard")
|
| 116 |
+
st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})")
|
| 117 |
+
|
| 118 |
+
if st.session_state.user_info and st.session_state.user_info['role'] == 'admin':
|
| 119 |
+
st.header("RAG Hypertuning")
|
| 120 |
+
st.info("Run hyperparameter tuning to find the best RAG configuration and build a tuned index.")
|
| 121 |
+
|
| 122 |
+
with st.form("hypertune_form"):
|
| 123 |
+
st.write("Hypertuning parameters (default values from main.py if not overridden):")
|
| 124 |
+
llm_model_ht = st.selectbox("LLM Model for Hypertuning Evaluation",
|
| 125 |
+
options=["mixtral-8x7b-32768", "llama2-70b-4096", "gemma-7b-it"],
|
| 126 |
+
key="llm_model_ht_select")
|
| 127 |
+
n_iter = st.number_input("Number of Hyper-tuning Iterations (n_iter)", min_value=1, value=1, step=1, key="n_iter_ht")
|
| 128 |
+
|
| 129 |
+
hypertune_button = st.form_submit_button("Run Hypertune Job")
|
| 130 |
+
|
| 131 |
+
if hypertune_button:
|
| 132 |
+
st.write("Starting RAG Hypertuning. This may take a while...")
|
| 133 |
+
with st.spinner("Running hypertuning..."):
|
| 134 |
+
result = run_rag_job('eval-hypertune', llm_model=llm_model_ht, n_iter=n_iter)
|
| 135 |
+
if result and "error" not in result:
|
| 136 |
+
st.success("Hypertuning completed and tuned index built!")
|
| 137 |
+
st.json(result)
|
| 138 |
+
else:
|
| 139 |
+
st.error("Hypertuning failed.")
|
| 140 |
+
|
| 141 |
+
st.header("RAG Testing")
|
| 142 |
+
st.info("Test the RAG pipeline with a specific query, optionally using the tuned database.")
|
| 143 |
+
|
| 144 |
+
with st.form("rag_test_form"):
|
| 145 |
+
test_query = st.text_area("Enter a test query for the RAG system:",
|
| 146 |
+
value="What is the policy on annual leave?",
|
| 147 |
+
key="test_query_input")
|
| 148 |
+
use_tuned_db = st.checkbox("Use Tuned RAG Database (if hypertuned previously)", value=True, key="use_tuned_db_checkbox")
|
| 149 |
+
display_raw = st.checkbox("Display Raw Retrieved Documents only (no LLM)", key="display_raw_docs_checkbox")
|
| 150 |
+
k_value = st.slider("Number of documents to retrieve (k)", min_value=1, max_value=10, value=5, key="k_value_slider")
|
| 151 |
+
|
| 152 |
+
test_rag_button = st.form_submit_button("Run RAG Test Query")
|
| 153 |
+
|
| 154 |
+
if test_rag_button:
|
| 155 |
+
st.write("Running RAG test query...")
|
| 156 |
+
with st.spinner("Getting RAG response..."):
|
| 157 |
+
# Pass the LLM model configured in the backend's Configuration if not specifically chosen for testing
|
| 158 |
+
# Or, make a selectbox for it in the admin testing section too.
|
| 159 |
+
llm_for_search = st.session_state.get('llm_model_ht_select', Configuration.DEFAULT_GROQ_LLM_MODEL)
|
| 160 |
+
result = run_rag_job('search',
|
| 161 |
+
query=test_query,
|
| 162 |
+
use_tuned=use_tuned_db,
|
| 163 |
+
raw=display_raw,
|
| 164 |
+
k=k_value,
|
| 165 |
+
llm_model=llm_for_search) # Ensure LLM model is passed
|
| 166 |
+
|
| 167 |
+
if result and "error" not in result:
|
| 168 |
+
st.success("RAG Test Query Completed!")
|
| 169 |
+
st.subheader("RAG Response:")
|
| 170 |
+
if display_raw:
|
| 171 |
+
st.json(result)
|
| 172 |
+
else:
|
| 173 |
+
if 'response' in result and 'summary' in result['response']:
|
| 174 |
+
st.write(result['response']['summary'])
|
| 175 |
+
if 'sources' in result['response'] and result['response']['sources']:
|
| 176 |
+
st.subheader("Sources:")
|
| 177 |
+
for source in result['response']['sources']:
|
| 178 |
+
st.markdown(f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}")
|
| 179 |
+
else:
|
| 180 |
+
st.json(result)
|
| 181 |
+
if 'evaluation' in result:
|
| 182 |
+
st.subheader("Evaluation Results:")
|
| 183 |
+
st.json(result['evaluation'])
|
| 184 |
+
else:
|
| 185 |
+
st.error("RAG test query failed.")
|
| 186 |
+
else:
|
| 187 |
+
st.warning("You do not have administrative privileges to view this page.")
|
| 188 |
+
if st.button("Go to User Page"):
|
| 189 |
+
st.session_state.page = "User"
|
| 190 |
+
st.rerun() # Changed from st.experimental_rerun()
|
| 191 |
+
|
| 192 |
+
def user_page():
|
| 193 |
+
st.title("Flykite HR Policy Query")
|
| 194 |
+
st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})")
|
| 195 |
+
|
| 196 |
+
st.info("Ask any question about the Flykite Airlines HR policy document.")
|
| 197 |
+
|
| 198 |
+
with st.form("user_query_form"):
|
| 199 |
+
user_query = st.text_area("Your Query:", height=100, key="user_query_input")
|
| 200 |
+
response_type = st.radio("Choose Response Type:",
|
| 201 |
+
options=["LLM Tuned Response (RAG + LLM)", "RAG Raw Response (Retrieved Docs Only)"],
|
| 202 |
+
index=0, key="response_type_radio")
|
| 203 |
+
k_value_user = st.slider("Number of documents to consider (k)", min_value=1, max_value=10, value=5, key="k_value_user_slider")
|
| 204 |
+
|
| 205 |
+
submit_query_button = st.form_submit_button("Get Answer")
|
| 206 |
+
|
| 207 |
+
if submit_query_button and user_query:
|
| 208 |
+
st.subheader("Response:")
|
| 209 |
+
with st.spinner("Fetching answer..."):
|
| 210 |
+
display_raw = (response_type == "RAG Raw Response (Retrieved Docs Only)")
|
| 211 |
+
# Assuming the user page should always use the best available LLM model from config or tuned params
|
| 212 |
+
llm_for_user_query = Configuration.DEFAULT_GROQ_LLM_MODEL # Default from config
|
| 213 |
+
# If a tuned model was selected during hypertuning, we might want to use that for users
|
| 214 |
+
# This requires main.py to save and provide which LLM was part of the best_params
|
| 215 |
+
# For now, we'll stick to a default or what's explicitly passed.
|
| 216 |
+
result = run_rag_job('search',
|
| 217 |
+
query=user_query,
|
| 218 |
+
raw=display_raw,
|
| 219 |
+
k=k_value_user,
|
| 220 |
+
use_tuned=True, # User page always uses tuned if available
|
| 221 |
+
llm_model=llm_for_user_query) # Pass LLM model for search
|
| 222 |
+
|
| 223 |
+
if result and "error" not in result:
|
| 224 |
+
if display_raw:
|
| 225 |
+
st.json(result)
|
| 226 |
+
else:
|
| 227 |
+
if 'response' in result and 'summary' in result['response']:
|
| 228 |
+
st.markdown(result['response']['summary'])
|
| 229 |
+
if 'sources' in result['response'] and result['response']['sources']:
|
| 230 |
+
st.subheader("Sources:")
|
| 231 |
+
for source in result['response']['sources']:
|
| 232 |
+
st.markdown(f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}")
|
| 233 |
+
else:
|
| 234 |
+
st.json(result)
|
| 235 |
+
else:
|
| 236 |
+
st.error("Failed to get a response. Please try again.")
|
| 237 |
+
elif submit_query_button and not user_query:
|
| 238 |
+
st.warning("Please enter a query.")
|
| 239 |
+
|
| 240 |
+
# --- Main Application Logic ---
|
| 241 |
+
def main_app():
|
| 242 |
+
# create_user_db() # Removed this call as per your request
|
| 243 |
+
|
| 244 |
+
st.sidebar.title("Navigation")
|
| 245 |
+
if 'logged_in' not in st.session_state:
|
| 246 |
+
st.session_state.logged_in = False
|
| 247 |
+
if 'page' not in st.session_state:
|
| 248 |
+
st.session_state.page = "Home"
|
| 249 |
+
|
| 250 |
+
if not st.session_state.logged_in:
|
| 251 |
+
st.session_state.page = "Home" # Force to home/login if not logged in
|
| 252 |
+
home_page()
|
| 253 |
+
else:
|
| 254 |
+
# Navigation for logged-in users
|
| 255 |
+
if st.session_state.user_info and st.session_state.user_info['role'] == 'admin':
|
| 256 |
+
st.sidebar.button("Home", on_click=lambda: st.session_state.update(page="Home"))
|
| 257 |
+
st.sidebar.button("Admin Dashboard", on_click=lambda: st.session_state.update(page="Admin"))
|
| 258 |
+
st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User"))
|
| 259 |
+
else: # Regular user
|
| 260 |
+
st.sidebar.button("Home", on_click=lambda: st.session_state.update(page="Home"))
|
| 261 |
+
st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User"))
|
| 262 |
+
|
| 263 |
+
# Display the selected page
|
| 264 |
+
if st.session_state.page == "Home":
|
| 265 |
+
home_page()
|
| 266 |
+
elif st.session_state.page == "Admin":
|
| 267 |
+
admin_page()
|
| 268 |
+
elif st.session_state.page == "User":
|
| 269 |
+
user_page()
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main_app()
|
configuration.py
ADDED
|
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from dotenv import load_dotenv
|
| 3 |
+
|
| 4 |
+
load_dotenv()
|
| 5 |
+
|
| 6 |
+
class Configuration:
|
| 7 |
+
GROQ_API_KEY = os.getenv('GROQ_API_KEY','default_groq_key')
|
| 8 |
+
OPEN_API_KEY = os.getenv('OPENAI_API_KEY','default_openai_key')
|
| 9 |
+
HF_TOKEN = os.getenv('HF_TOKEN','default_hf_token')
|
| 10 |
+
|
| 11 |
+
DEFAULT_CHUNK_SIZE = int(os.getenv('DEFAULT_CHUNK_SIZE'))
|
| 12 |
+
DEFAULT_CHUNK_OVERLAP = int(os.getenv('DEFAULT_CHUNK_OVERLAP'))
|
| 13 |
+
DEFAULT_SENTENCE_TRANSFORMER_MODEL = os.getenv('DEFAULT_SENTENCE_TRANSFORMER_MODEL','all-MiniLM-L6-v2')
|
| 14 |
+
DEFAULT_GROQ_LLM_MODEL = os.getenv('DEFAULT_GROQ_LLM_MODEL','llama-3.3-70b-versatile')
|
| 15 |
+
DEFAULT_RERANKER = os.getenv('DEFAULT_RERANKER')
|
| 16 |
+
|
| 17 |
+
PROJECT_ROOT_BASE = os.path.abspath(os.path.dirname(__file__))
|
| 18 |
+
print(f"PROJECT_ROOT {PROJECT_ROOT_BASE}")
|
| 19 |
+
|
| 20 |
+
if os.path.basename(PROJECT_ROOT_BASE) == "src":
|
| 21 |
+
##PROJECT_ROOT = os.path.abspath(os.path.join(PROJECT_ROOT_BASE,os.pardir))
|
| 22 |
+
PROJECT_ROOT = os.path.abspath(os.path.join(PROJECT_ROOT_BASE, '..'))
|
| 23 |
+
else:
|
| 24 |
+
PROJECT_ROOT = PROJECT_ROOT_BASE
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
DOCUMENTS_DIR = os.path.join(PROJECT_ROOT,'DOCUMENTS')
|
| 28 |
+
print(f"DOCUMENTS_DIR {DOCUMENTS_DIR}")
|
| 29 |
+
DATA_DIR = os.path.join(PROJECT_ROOT,'DATA')
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
PDF_FILE_NAME = os.getenv('PDF_FILE_NAME')
|
| 33 |
+
FULL_PDF_PATH = os.path.join(DOCUMENTS_DIR,PDF_FILE_NAME)
|
| 34 |
+
print(f"FULL_PDF_PATH: {FULL_PDF_PATH} ")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
CHROMA_DB_PATH = os.path.join(DATA_DIR,os.getenv('CHROMA_DB_PATH'))
|
| 38 |
+
FAISS_DB_PATH = os.path.join(DATA_DIR,os.getenv('FAISS_DB_PATH'))
|
| 39 |
+
|
| 40 |
+
COLLECTION_NAME = os.getenv('COLLECTION_NAME')
|
| 41 |
+
EVAL_DATA_PATH = os.path.join(PROJECT_ROOT, os.getenv('EVAL_DATA_PATH'))
|
| 42 |
+
|
| 43 |
+
PINECONE_API_KEY=os.getenv('PINECONE_API_KEY')
|
| 44 |
+
PINECONE_CLOUD=os.getenv('PINECONE_CLOUD','aws')
|
| 45 |
+
PINECONE_REGION=os.getenv('PINECONE_REGION','us-east-1')
|
| 46 |
+
|
| 47 |
+
os.makedirs(DATA_DIR, exist_ok=True)
|
| 48 |
+
os.makedirs(DOCUMENTS_DIR,exist_ok=True)
|
| 49 |
+
os.makedirs(CHROMA_DB_PATH, exist_ok=True)
|
| 50 |
+
os.makedirs(FAISS_DB_PATH,exist_ok=True)
|
| 51 |
+
|
| 52 |
+
if not os.path.exists(FULL_PDF_PATH):
|
| 53 |
+
print(f"PDF not found in {FULL_PDF_PATH}")
|
| 54 |
+
else:
|
| 55 |
+
print(f"PDF file found in {FULL_PDF_PATH}")
|
| 56 |
+
|
| 57 |
+
if not os.path.exists(EVAL_DATA_PATH):
|
| 58 |
+
print(f"eval json file not found in {EVAL_DATA_PATH}")
|
| 59 |
+
else:
|
| 60 |
+
print(f"eval file found in {EVAL_DATA_PATH}")
|
| 61 |
+
|
| 62 |
+
|
create_user_db.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import sqlite3
|
| 2 |
+
import hashlib
|
| 3 |
+
import traceback
|
| 4 |
+
|
| 5 |
+
from argon2 import hash_password
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class CreateUserDB():
|
| 9 |
+
def encrypttion_pwd(self,password):
|
| 10 |
+
return hashlib.sha256(password.encode()).hexdigest()
|
| 11 |
+
|
| 12 |
+
def setup_database(self):
|
| 13 |
+
try:
|
| 14 |
+
conn = sqlite3.connect('users.db')
|
| 15 |
+
cursor = conn.cursor()
|
| 16 |
+
cursor.execute('''
|
| 17 |
+
CREATE TABLE IF NOT EXISTS users (
|
| 18 |
+
username TEXT PRIMARY KEY,
|
| 19 |
+
password TEXT,
|
| 20 |
+
jobrole TEXT,
|
| 21 |
+
department TEXT,
|
| 22 |
+
location TEXT)
|
| 23 |
+
''')
|
| 24 |
+
cursor.execute("SELECT COUNT(*) FROM users")
|
| 25 |
+
if cursor.fetchone()[0] == 0:
|
| 26 |
+
users = [('admin', self.encrypttion_pwd('admin'), 'admin', 'admin', 'chennai'),
|
| 27 |
+
('user1', self.encrypttion_pwd('user'), 'user', 'manager', 'chennai'),
|
| 28 |
+
('user2', self.encrypttion_pwd('user'), 'user', 'pilot', 'chennai'),
|
| 29 |
+
('user3', self.encrypttion_pwd('user'), 'user', 'engineer', 'chennai'),
|
| 30 |
+
('user4', self.encrypttion_pwd('user'), 'user', 'cabin crew', 'chennai'),
|
| 31 |
+
('user5', self.encrypttion_pwd('user'), 'user', 'ground staff', 'chennai')]
|
| 32 |
+
cursor.executemany("INSERT INTO users VALUES(?,?,?,?,?)",users)
|
| 33 |
+
conn.commit()
|
| 34 |
+
else:
|
| 35 |
+
print("Databases alerady contains users")
|
| 36 |
+
|
| 37 |
+
except Exception as ex:
|
| 38 |
+
traceback.print_exc()
|
| 39 |
+
print(f"Exception occurred: {ex}")
|
| 40 |
+
finally:
|
| 41 |
+
if conn:
|
| 42 |
+
conn.close()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if __name__ == "__main__":
|
| 46 |
+
db_obj = CreateUserDB()
|
| 47 |
+
db_obj.setup_database()
|
eval_data.json
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[
|
| 2 |
+
{
|
| 3 |
+
"query":"what are the Criteria for Probation Extension?",
|
| 4 |
+
"expected_keywords":["60%", "PIP", "5 working days", "7 calendar days","measurable target"],
|
| 5 |
+
"expected_answer_snippet":"Extensions are granted only if: a. The employee has achieved at least 60% of their probationary objectives. b. A written PIP with measurable targets is issued within 5 working days of the original probation end date. c. The Department Head and HR Manager both sign off on the extension request. Employees will be notified of extensions in writing at least 7 calendar days before the probation end date."
|
| 6 |
+
},
|
| 7 |
+
{
|
| 8 |
+
"query":"Explain Exit Procedure & Timeline of FlyKite airline",
|
| 9 |
+
"expected_keywords":["3 working days", "within 7 working days", "Exit interview","no later than 5 working days"],
|
| 10 |
+
"expected_answer_snippet":"Company property (ID badge, uniforms, devices) must be returned on or before the last working day. Clearance forms must be signed by Finance, IT, and Admin within 3 working days post-termination. Final salary settlement is processed within 7 working days of clearance completion. Exit interview: Mandatory and to be scheduled no later than 5 working days after last working day. Failure to attend delays issuance of relieving letter."
|
| 11 |
+
|
| 12 |
+
},
|
| 13 |
+
{
|
| 14 |
+
"query":"Elaborate Allowable Expenses and Reimbursement Procedures",
|
| 15 |
+
"expected_keywords":["Eligibility","Exclusions", "Submission Deadlines", "Appeals","directly work-related","itemized","within 15 calendar days","7 working days"],
|
| 16 |
+
"expected_answer_snippet":"1. Eligibility: Expenses must be directly work-related and supported by itemized receipts. Per diem limits: ₹1,200/day domestic, ₹4,000/day international. Exclusions: Alcohol, entertainment unrelated to work, and non-economy travel (unless pre-approved) are not reimbursable. Submission Deadlines: Claims must be fi led within 15 calendar days of incurring expense. Appeals: Appeal must be submitted within 7 working days of claim denial with"
|
| 17 |
+
},
|
| 18 |
+
{
|
| 19 |
+
"query": "What are the documents required for special leave approval ?",
|
| 20 |
+
"expected_keywords": ["Death certificate", "funeral", "notice","obituary","court summons", "jury duty letter"],
|
| 21 |
+
"expected_answer_snippet": "Special leave is approved only upon submission of: 1. Death certificate, funeral notice, or obituary (bereavement). 2. Official court summons or jury duty letter (legal obligations). 3. Medical certificate from a registered practitioner (emergency care). 4. Government-issued disaster report or evacuation notice (natural disasters). 5. All documents must be submitted within 5 working days of returning to duty."
|
| 22 |
+
},
|
| 23 |
+
{
|
| 24 |
+
"query":"explain in detail about the Attendance and Absence Management of the flykite airlne hr policy",
|
| 25 |
+
"expected_keywords":["Core Hours","Notification","Consequences"],
|
| 26 |
+
"expected_answer_snippet": "1. Core Hours: 9:30 AM – 6:00 PM IST. 2. Notification: Planned absence: Email supervisor at least 1 day before. Emergency absence: Call within 1 hour of shift start. 3. Consequences: 3 unreported absences in 60 days → Written warning. 5 unreported absences in 90 days → Termination review."
|
| 27 |
+
},
|
| 28 |
+
{
|
| 29 |
+
"query":"Explain the compensation and termination policy of the flykite airline hr policy",
|
| 30 |
+
"expected_keywords":["Salary review cycle", "Mid-year adjustments", "Benefits eligiblity","Termination Impact"],
|
| 31 |
+
"expected_answer_snippet":"1. Salary Review Cycle: April annually. 2. Mid-Year Adjustments: Only on written approval from CEO & CFO. 3. Benefits Eligibility Health insurance starts after 30 days service (full-time). Retirement plan enrollment: Within 60 days of confirmation. 4. Termination Impact: Benefits end on last working day unless law mandates otherwise."
|
| 32 |
+
}
|
| 33 |
+
]
|
main.py
ADDED
|
@@ -0,0 +1,541 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import json
|
| 4 |
+
import argparse
|
| 5 |
+
import warnings
|
| 6 |
+
import traceback
|
| 7 |
+
import logging
|
| 8 |
+
import chromadb
|
| 9 |
+
import hashlib
|
| 10 |
+
import sqlite3
|
| 11 |
+
import regex as re
|
| 12 |
+
from pinecone import Pinecone
|
| 13 |
+
from typing import Optional, Dict, Any
|
| 14 |
+
from sentence_transformers import SentenceTransformer, util
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
os.environ["TF_CPP_MIN_LOG_LEVEL"]="3"
|
| 18 |
+
warnings.filterwarnings("ignore")
|
| 19 |
+
|
| 20 |
+
sys.path.insert(0,os.path.abspath(os.path.join(os.path.dirname(__file__),'src')))
|
| 21 |
+
from sentence_transformers import SentenceTransformer
|
| 22 |
+
from configuration import Configuration
|
| 23 |
+
from rag_scripts.rag_pipeline import RAGPipeline
|
| 24 |
+
from rag_scripts.documents_processing.chunking import PyMuPDFChunker
|
| 25 |
+
from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
|
| 26 |
+
from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
|
| 27 |
+
from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
|
| 28 |
+
from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
|
| 29 |
+
from rag_scripts.llm.llmResponse import GROQLLM
|
| 30 |
+
from rag_scripts.evaluation.evaluator import RAGEvaluator
|
| 31 |
+
|
| 32 |
+
class RAGOperations:
|
| 33 |
+
VALID_VECTOR_DB = {'chroma','faiss','pinecone'}
|
| 34 |
+
|
| 35 |
+
@staticmethod
|
| 36 |
+
def check_db(vector_db_type: str, db_path: str, collection_name: str) -> bool:
|
| 37 |
+
try:
|
| 38 |
+
if vector_db_type not in RAGOperations.VALID_VECTOR_DB:
|
| 39 |
+
print(f"Invalid Vector DB: {vector_db_type}")
|
| 40 |
+
raise
|
| 41 |
+
if vector_db_type.lower() == 'pinecone':
|
| 42 |
+
pc = Pinecone(api_key=Configuration.PINECONE_API_KEY)
|
| 43 |
+
return collection_name in pc.list_indexes().names()
|
| 44 |
+
elif vector_db_type.lower() == 'chroma':
|
| 45 |
+
return os.path.exists(db_path) and os.listdir(db_path)
|
| 46 |
+
elif vector_db_type.lower() == "faiss":
|
| 47 |
+
faiss_index_file = os.path.join(db_path,f"{collection_name}.faiss")
|
| 48 |
+
faiss_doc_store_file = os.path.join(db_path,f"{collection_name}_docs.pkl")
|
| 49 |
+
return os.path.exists(faiss_index_file) and os.path.exists(faiss_doc_store_file)
|
| 50 |
+
except Exception as ex:
|
| 51 |
+
print(f"Exception in checking {vector_db_type} existence")
|
| 52 |
+
traceback.print_exc()
|
| 53 |
+
return False
|
| 54 |
+
|
| 55 |
+
@staticmethod
|
| 56 |
+
def get_pipeline_params(args: argparse.Namespace, use_tuned: bool = False) -> Dict[str,Any]:
|
| 57 |
+
try:
|
| 58 |
+
best_param_path = os.path.join(Configuration.DATA_DIR,'best_params.json')
|
| 59 |
+
params = {
|
| 60 |
+
'document_path':Configuration.FULL_PDF_PATH,
|
| 61 |
+
'chunk_size':args.chunk_size,
|
| 62 |
+
'chunk_overlap':args.chunk_overlap,
|
| 63 |
+
'embedding_model_name':args.embedding_model,
|
| 64 |
+
'vector_db_type':args.vector_db_type,
|
| 65 |
+
'llm_model_name':args.llm_model,
|
| 66 |
+
'db_path': None,
|
| 67 |
+
'collection_name': Configuration.COLLECTION_NAME,
|
| 68 |
+
'vector_db': None,
|
| 69 |
+
'temperature': args.temperature,
|
| 70 |
+
'top_p':args.top_p,
|
| 71 |
+
'max_tokens':args.max_tokens,
|
| 72 |
+
're_ranker_model':args.re_ranker_model
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
if os.path.exists(best_param_path):
|
| 76 |
+
with open(best_param_path,'rb') as f:
|
| 77 |
+
best_params = json.load(f)
|
| 78 |
+
print(f"Best params: {best_params} from the file {best_param_path}")
|
| 79 |
+
|
| 80 |
+
params.update({
|
| 81 |
+
'vector_db_type': best_params.get('vector_db_type',params['vector_db_type']),
|
| 82 |
+
'embedding_model_name': best_params.get('embedding_model',params['embedding_model_name']),
|
| 83 |
+
'chunk_overlap': best_params.get('chunk_overlap',params['chunk_overlap']),
|
| 84 |
+
'chunk_size': best_params.get('chunk_size',params['chunk_size']) ,
|
| 85 |
+
're_ranker_model': best_params.get('re_ranker_model',params['re_ranker_model']) })
|
| 86 |
+
use_tuned = True
|
| 87 |
+
|
| 88 |
+
if use_tuned:
|
| 89 |
+
tuned_db_type = params['vector_db_type']
|
| 90 |
+
params['db_path'] = os.path.join(Configuration.DATA_DIR,'TunedDB',tuned_db_type) if tuned_db_type != 'pinecone' else ""
|
| 91 |
+
params['collection_name'] = 'tuned-'+Configuration.COLLECTION_NAME
|
| 92 |
+
if tuned_db_type in ['chroma','faiss']:
|
| 93 |
+
os.makedirs(params['db_path'],exist_ok=True)
|
| 94 |
+
print(f"Tuned db path: {params['db_path']}")
|
| 95 |
+
else:
|
| 96 |
+
params['db_path'] = ( Configuration.CHROMA_DB_PATH if params['vector_db_type'] == 'chroma'
|
| 97 |
+
else Configuration.FAISS_DB_PATH if params['vector_db_type'] == 'faiss'
|
| 98 |
+
else "")
|
| 99 |
+
if params['vector_db_type'] in ['chroma', 'faiss']:
|
| 100 |
+
os.makedirs(params['db_path'],exist_ok=True)
|
| 101 |
+
print(f"Created directory for {params['vector_db_type']} at {params['db_path']}")
|
| 102 |
+
|
| 103 |
+
return params
|
| 104 |
+
except Exception as ex:
|
| 105 |
+
print(f"Exception in get_pipeline_params: {ex}")
|
| 106 |
+
traceback.print_exc()
|
| 107 |
+
sys.exit(1)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
@staticmethod
|
| 111 |
+
def check_embedding_dimension(vector_db_type: str,db_path: str,
|
| 112 |
+
collection_name: str, embedding_model: str) -> bool:
|
| 113 |
+
if vector_db_type !='chroma':
|
| 114 |
+
return True
|
| 115 |
+
try:
|
| 116 |
+
client = chromadb.PersistentClient(path=db_path)
|
| 117 |
+
collection = client.get_collection(collection_name)
|
| 118 |
+
model = SentenceTransformer(embedding_model)
|
| 119 |
+
sample_embedding = model.encode(["test"])[0]
|
| 120 |
+
try:
|
| 121 |
+
expected_dim = collection._embedding_function.dim
|
| 122 |
+
except AttributeError:
|
| 123 |
+
peek_result = collection.peek(limit=1)
|
| 124 |
+
if 'embedding' in peek_result and peek_result['embedding']:
|
| 125 |
+
expected_dim = len(peek_result['embedding'][0])
|
| 126 |
+
else:
|
| 127 |
+
return False
|
| 128 |
+
actual_dim = len(sample_embedding)
|
| 129 |
+
print(f"Expected dimension: {expected_dim} Actual dimension: {actual_dim}")
|
| 130 |
+
return expected_dim == actual_dim
|
| 131 |
+
except Exception as ex:
|
| 132 |
+
print(f"Error checking embedding dimension: {ex}")
|
| 133 |
+
return False
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
@staticmethod
|
| 137 |
+
def initialize_pipeline(params: dict[str,Any]) -> RAGPipeline:
|
| 138 |
+
try:
|
| 139 |
+
embedder = SentenceTransformerEmbedder(model_name=params['embedding_model_name'])
|
| 140 |
+
chunkerObj = PyMuPDFChunker(
|
| 141 |
+
pdf_path=params['document_path'],
|
| 142 |
+
chunk_size=params['chunk_size'],
|
| 143 |
+
chunk_overlap=params['chunk_overlap'])
|
| 144 |
+
llm_model = params['llm_model_name']
|
| 145 |
+
vector_db = None
|
| 146 |
+
if params['vector_db_type'] == 'chroma':
|
| 147 |
+
vector_db = chromaDBVectorDB(embedder=embedder,
|
| 148 |
+
db_path=params['db_path'],
|
| 149 |
+
collection_name=params['collection_name'])
|
| 150 |
+
elif params['vector_db_type'] == 'faiss':
|
| 151 |
+
vector_db = FAISSVectorDB(embedder=embedder,
|
| 152 |
+
db_path=params['db_path'],
|
| 153 |
+
collection_name=params['collection_name'] )
|
| 154 |
+
elif params['vector_db_type'] == 'pinecone':
|
| 155 |
+
vector_db = PineconeVectorDB(embedder=embedder,
|
| 156 |
+
db_path=params['db_path'],
|
| 157 |
+
collection_name=params['collection_name'])
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"Unknown vector_db_type: {params['vector_db_type']}")
|
| 160 |
+
|
| 161 |
+
return RAGPipeline( document_path=params['document_path'],
|
| 162 |
+
chunker=chunkerObj, embedder=embedder,
|
| 163 |
+
vector_db=vector_db,
|
| 164 |
+
llm=GROQLLM(model_name= llm_model),
|
| 165 |
+
re_ranker_model_name=params['re_ranker_model'] if params['re_ranker_model'] else Configuration.DEFAULT_RERANKER,)
|
| 166 |
+
except Exception as ex:
|
| 167 |
+
print(f"Exception in pipeline initialize: {ex}")
|
| 168 |
+
traceback.print_exc()
|
| 169 |
+
sys.exit(1)
|
| 170 |
+
|
| 171 |
+
@staticmethod
|
| 172 |
+
def run_build_job(args: argparse.Namespace) -> None:
|
| 173 |
+
try:
|
| 174 |
+
params = RAGOperations.get_pipeline_params(args)
|
| 175 |
+
pipeline = RAGOperations.initialize_pipeline(params)
|
| 176 |
+
pipeline.build_index()
|
| 177 |
+
print(f"RAG Build JOB completed")
|
| 178 |
+
except Exception as ex:
|
| 179 |
+
print(f"Exception in run build job: {ex}")
|
| 180 |
+
traceback.print_exc()
|
| 181 |
+
sys.exit(1)
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
@staticmethod
|
| 185 |
+
def run_search_job(args: argparse.Namespace,user_info: Dict[str,str]) -> None:
|
| 186 |
+
try:
|
| 187 |
+
params = RAGOperations.get_pipeline_params(args, use_tuned=args.use_tuned)
|
| 188 |
+
vector_db_type = params['vector_db_type']
|
| 189 |
+
db_path = params['db_path']
|
| 190 |
+
collection_name = params['collection_name']
|
| 191 |
+
|
| 192 |
+
pipeline = RAGOperations.initialize_pipeline(params)
|
| 193 |
+
db_exists = RAGOperations.check_db(vector_db_type,db_path,collection_name)
|
| 194 |
+
|
| 195 |
+
if args.use_rag:
|
| 196 |
+
if not db_exists:
|
| 197 |
+
pipeline.build_index()
|
| 198 |
+
elif pipeline.vector_db.count_documents() == 0:
|
| 199 |
+
pipeline.build_index()
|
| 200 |
+
elif not RAGOperations.check_embedding_dimension(vector_db_type,db_path,
|
| 201 |
+
collection_name,params['embedding_model_name'] ):
|
| 202 |
+
print(f"Embedding dimension mismatch. rebuilding the index")
|
| 203 |
+
pipeline.vector_db.delete_collection(collection_name)
|
| 204 |
+
pipeline.build_index()
|
| 205 |
+
|
| 206 |
+
else:
|
| 207 |
+
print(f"Using existing {vector_db_type} database with collection: {collection_name}")
|
| 208 |
+
|
| 209 |
+
if pipeline.vector_db.count_documents() == 0:
|
| 210 |
+
print(f"No Documents found in vector database after re-build")
|
| 211 |
+
sys.exit(1)
|
| 212 |
+
|
| 213 |
+
evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
|
| 214 |
+
pdf_path=Configuration.FULL_PDF_PATH)
|
| 215 |
+
|
| 216 |
+
user_query = args.query if args.query else (
|
| 217 |
+
input("Enter your Query: "))
|
| 218 |
+
if user_query.lower() == 'exit':
|
| 219 |
+
return
|
| 220 |
+
user_context = {"role": user_info['role'],
|
| 221 |
+
"location": user_info['location'],
|
| 222 |
+
"department":user_info['department'] }
|
| 223 |
+
|
| 224 |
+
expected_answers = None
|
| 225 |
+
expected_keywords = []
|
| 226 |
+
query_found = False
|
| 227 |
+
try:
|
| 228 |
+
with open(Configuration.EVAL_DATA_PATH, 'r') as f:
|
| 229 |
+
eval_data = json.load(f)
|
| 230 |
+
for item in eval_data:
|
| 231 |
+
if item.get('query').strip().lower() == user_query.strip().lower():
|
| 232 |
+
expected_keywords = item.get('expected_keywords',[])
|
| 233 |
+
expected_answers = item.get('expected_answer_snippet',"")
|
| 234 |
+
query_found = True
|
| 235 |
+
break
|
| 236 |
+
if not expected_keywords and not expected_answers:
|
| 237 |
+
print(f"No evaluation data found for query in json")
|
| 238 |
+
except Exception as ex:
|
| 239 |
+
print(f"No json file : {ex}")
|
| 240 |
+
retrieved_documents = []
|
| 241 |
+
if args.raw:
|
| 242 |
+
retrieved_documents = pipeline.retrieve_raw_documents(
|
| 243 |
+
user_query, k=args.k*2)
|
| 244 |
+
print("Raw documents retrieved")
|
| 245 |
+
print(json.dumps(retrieved_documents, indent=4))
|
| 246 |
+
if not retrieved_documents:
|
| 247 |
+
response ={"summary":"No relevant documents found",
|
| 248 |
+
"sources":[]}
|
| 249 |
+
else:
|
| 250 |
+
#print("Similar documents retrieved")
|
| 251 |
+
query_embedding = evaluator.embedder.encode(user_query,
|
| 252 |
+
convert_to_tensor=True,normalize_embeddings=True)
|
| 253 |
+
similarities = [(doc, util.cos_sim(query_embedding,
|
| 254 |
+
evaluator.embedder.encode(doc['content'],
|
| 255 |
+
convert_to_tensor=True,
|
| 256 |
+
normalize_embeddings=True)).item())
|
| 257 |
+
for doc in retrieved_documents]
|
| 258 |
+
similarities.sort(key=lambda x: x[1], reverse=True)
|
| 259 |
+
#print(f"Similar documents retrieved: {similarities}")
|
| 260 |
+
top_docs = similarities[:min(3, len(similarities))]
|
| 261 |
+
#print("Top documents retrieved")
|
| 262 |
+
#print(json.dumps(top_docs, indent=4))
|
| 263 |
+
truncated_content = []
|
| 264 |
+
for doc, sim in top_docs:
|
| 265 |
+
content_paragraphs = re.split(r'\n\s*\n', doc['content'].strip())
|
| 266 |
+
para_sims = [(para, util.cos_sim(query_embedding,
|
| 267 |
+
evaluator.embedder.encode(para.strip(), convert_to_tensor=True,
|
| 268 |
+
normalize_embeddings=True)).item())
|
| 269 |
+
for para in content_paragraphs if para.strip()]
|
| 270 |
+
para_sims.sort(key=lambda x: x[1], reverse=True)
|
| 271 |
+
|
| 272 |
+
top_paras = [para for para, para_sim in para_sims[:2] if para_sim >= 0.3]
|
| 273 |
+
if len(top_paras) < 1: # Fallback to at least one paragraph
|
| 274 |
+
top_paras = [para for para, _ in para_sims[:1]]
|
| 275 |
+
truncated_content.append('\n\n'.join(top_paras))
|
| 276 |
+
#print(f"Truncated content: {truncated_content}")
|
| 277 |
+
response = {
|
| 278 |
+
"summary": "\n".join(truncated_content),
|
| 279 |
+
"sources":[{ "document_id":f"DOC {idx+1}",
|
| 280 |
+
"page": str(doc['metadata'].get("page_number","NA")),
|
| 281 |
+
"section": doc['metadata'].get("section","NA"),
|
| 282 |
+
"clause": doc['metadata'].get("clause","NA")}
|
| 283 |
+
for idx,(doc,_) in enumerate(top_docs)] }
|
| 284 |
+
|
| 285 |
+
else:
|
| 286 |
+
print("LLM+RAG")
|
| 287 |
+
response = pipeline.query(user_query, k=args.k,
|
| 288 |
+
include_metadata=True,
|
| 289 |
+
user_context=user_context
|
| 290 |
+
)
|
| 291 |
+
retrieved_documents = pipeline.retrieve_raw_documents(
|
| 292 |
+
user_query, k=args.k)
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
final_expected_answer = expected_answers if expected_answers is not None else ""
|
| 296 |
+
additional_eval_metrices = {}
|
| 297 |
+
if not query_found:
|
| 298 |
+
print(f"No query found in eval_Data.json: {user_query}")
|
| 299 |
+
raw_reference_for_score = evaluator._syntesize_raw_reference(retrieved_documents)
|
| 300 |
+
if not final_expected_answer.strip():
|
| 301 |
+
final_expected_answer = raw_reference_for_score
|
| 302 |
+
|
| 303 |
+
retrieved_documents_content = [doc.get('content','') for doc in retrieved_documents]
|
| 304 |
+
llm_as_judge = evaluator._evaluate_with_llm(user_query, response.get('summary',''),retrieved_documents_content)
|
| 305 |
+
if llm_as_judge:
|
| 306 |
+
additional_eval_metrices.update(llm_as_judge)
|
| 307 |
+
output = {"query": user_query, "response": response, "evaluation": additional_eval_metrices}
|
| 308 |
+
print(json.dumps(output, indent=4))
|
| 309 |
+
return json.dumps(output)
|
| 310 |
+
else:
|
| 311 |
+
output = { "query": user_query, "response":response, "evaluation":llm_as_judge }
|
| 312 |
+
print(json.dumps(output, indent=4))
|
| 313 |
+
return json.dumps(output)
|
| 314 |
+
|
| 315 |
+
else:
|
| 316 |
+
|
| 317 |
+
eval_result = evaluator.evaluate_response(user_query, response, retrieved_documents,
|
| 318 |
+
expected_keywords, expected_answers)
|
| 319 |
+
output = { "query": user_query, "response":response, "evaluation":eval_result }
|
| 320 |
+
print(json.dumps(output,indent=2,ensure_ascii=False))
|
| 321 |
+
|
| 322 |
+
return json.dumps(output)
|
| 323 |
+
|
| 324 |
+
|
| 325 |
+
except Exception as ex:
|
| 326 |
+
print(f"Exception in run search job {ex}")
|
| 327 |
+
traceback.print_exc()
|
| 328 |
+
|
| 329 |
+
@staticmethod
|
| 330 |
+
def run_hypertune_job(args: argparse.Namespace) -> None:
|
| 331 |
+
try:
|
| 332 |
+
evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
|
| 333 |
+
pdf_path=Configuration.FULL_PDF_PATH)
|
| 334 |
+
|
| 335 |
+
result = evaluator.evaluate_combined_params_grid(
|
| 336 |
+
chunk_size_to_test=[512,1024,2048],
|
| 337 |
+
chunk_overlap_to_test=[100,200,400],
|
| 338 |
+
embedding_models_to_test=["all-MiniLM-L6-v2",
|
| 339 |
+
"all-mpnet-base-v2",
|
| 340 |
+
"paraphrase-MiniLM-L3-v2",
|
| 341 |
+
"multi-qa-mpnet-base-dot-v1" ],
|
| 342 |
+
vector_db_types_to_test=['pinecone'],
|
| 343 |
+
llm_model_name=args.llm_model,
|
| 344 |
+
re_ranker_model = [ "cross-encoder/ms-marco-MiniLM-L-6-v2",
|
| 345 |
+
"cross-encoder/ms-marco-TinyBERT-L-2"],
|
| 346 |
+
search_type='random', n_iter=1 )
|
| 347 |
+
# embedding_models_to_test = ["all-MiniLM-L6-v2",
|
| 348 |
+
# "all-mpnet-base-v2",
|
| 349 |
+
# "paraphrase-MiniLM-L3-v2",
|
| 350 |
+
# "multi-qa-mpnet-base-dot-v1"]
|
| 351 |
+
best_parameter = result['best_params']
|
| 352 |
+
best_score = result['best_score']
|
| 353 |
+
pkl_file = result['pkl_file']
|
| 354 |
+
best_metrics = result['best_metrics']
|
| 355 |
+
|
| 356 |
+
best_param_path = os.path.join(Configuration.DATA_DIR,'best_params.json')
|
| 357 |
+
|
| 358 |
+
with open(best_param_path, 'w') as f:
|
| 359 |
+
json.dump(best_parameter, f, indent=4)
|
| 360 |
+
|
| 361 |
+
tuned_db = best_parameter['vector_db_type']
|
| 362 |
+
tuned_path = os.path.join(Configuration.DATA_DIR,'TunedDB',tuned_db)
|
| 363 |
+
if tuned_db != 'pinecone':
|
| 364 |
+
os.makedirs(tuned_path, exist_ok=True)
|
| 365 |
+
tuned_collection_name = "tuned-"+Configuration.COLLECTION_NAME
|
| 366 |
+
|
| 367 |
+
tuned_params = {
|
| 368 |
+
'document_path': Configuration.FULL_PDF_PATH,
|
| 369 |
+
'chunk_size': best_parameter.get('chunk_size', Configuration.DEFAULT_CHUNK_SIZE),
|
| 370 |
+
'chunk_overlap': best_parameter.get('chunk_overlap',Configuration.DEFAULT_CHUNK_OVERLAP),
|
| 371 |
+
'embedding_model_name': best_parameter.get('embedding_model',Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL),
|
| 372 |
+
'vector_db_type': tuned_db,
|
| 373 |
+
'llm_model_name':args.llm_model,
|
| 374 |
+
'db_path':tuned_path if tuned_db !='pinecone' else "",
|
| 375 |
+
'collection_name':tuned_collection_name,
|
| 376 |
+
'vector_db': None,
|
| 377 |
+
're_ranker_model':best_parameter.get('re_ranker', Configuration.DEFAULT_RERANKER)
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
if 're_ranker_model' in best_parameter:
|
| 381 |
+
tuned_params['re_ranker_model'] = best_parameter['re_ranker_model']
|
| 382 |
+
else:
|
| 383 |
+
tuned_params['re_ranker_model'] = Configuration.DEFAULT_RERANKER
|
| 384 |
+
|
| 385 |
+
tuned_pipeline = RAGOperations.initialize_pipeline(tuned_params)
|
| 386 |
+
tuned_pipeline.build_index()
|
| 387 |
+
|
| 388 |
+
except Exception as ex:
|
| 389 |
+
print(f"Exception in hypertune: {ex} ")
|
| 390 |
+
traceback.print_exc()
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
@staticmethod
|
| 394 |
+
def run_llm_with_prompt(args: argparse.Namespace,run_type: str) -> None:
|
| 395 |
+
try:
|
| 396 |
+
params = RAGOperations.get_pipeline_params(args,
|
| 397 |
+
use_tuned=args.use_tuned)
|
| 398 |
+
pipeline = RAGOperations.initialize_pipeline(params)
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
|
| 402 |
+
pdf_path=Configuration.FULL_PDF_PATH)
|
| 403 |
+
|
| 404 |
+
system_message = (
|
| 405 |
+
"You are an expert assistant for Flykite Airlines HR Policy Queries."
|
| 406 |
+
"Provide concise, accurate and policy-specific answers based solely on the the provided context."
|
| 407 |
+
"Structured your response clearly, using bullet points, newlines if applicable. "
|
| 408 |
+
"If the context lacks information, state that clearly and speculation."
|
| 409 |
+
) if run_type == 'prompting' else None
|
| 410 |
+
|
| 411 |
+
user_query = input("Enter your query: ")
|
| 412 |
+
expected_answer = None
|
| 413 |
+
expected_keywords = []
|
| 414 |
+
try:
|
| 415 |
+
with open(Configuration.EVAL_DATA_PATH, 'r') as f:
|
| 416 |
+
eval_data= json.load(f)
|
| 417 |
+
for item in eval_data:
|
| 418 |
+
expected_answer = item.get('expected_answer_snippet',"")
|
| 419 |
+
expected_keywords = item.get('expected_keywords',[])
|
| 420 |
+
break
|
| 421 |
+
except Exception as ex:
|
| 422 |
+
print(f"Error loading eval_data.json for query {user_query}: {ex}")
|
| 423 |
+
|
| 424 |
+
if run_type == 'prompting':
|
| 425 |
+
prompt = (
|
| 426 |
+
f"You are an expert assistant for Flykite Airlines HR Policy Queries."
|
| 427 |
+
f"Answer the following question with a structured response, using bullet points or sections where applicable"
|
| 428 |
+
f"Base your answer solely on the query and avoid hallucination"
|
| 429 |
+
f"Question: \n {user_query} \n"
|
| 430 |
+
f"Answer: ")
|
| 431 |
+
|
| 432 |
+
else:
|
| 433 |
+
prompt = user_query
|
| 434 |
+
|
| 435 |
+
response = pipeline.llm.generate_response(
|
| 436 |
+
prompt=prompt,
|
| 437 |
+
system_message=system_message,
|
| 438 |
+
temperature = args.temperature,
|
| 439 |
+
top_p = args.top_p,
|
| 440 |
+
max_tokens = args.max_tokens
|
| 441 |
+
)
|
| 442 |
+
retreived_documents = []
|
| 443 |
+
|
| 444 |
+
eval_result = evaluator.evaluate_response(user_query,
|
| 445 |
+
response,
|
| 446 |
+
retreived_documents,
|
| 447 |
+
expected_keywords,
|
| 448 |
+
expected_answer)
|
| 449 |
+
|
| 450 |
+
output = { "query":user_query,
|
| 451 |
+
"response": {
|
| 452 |
+
"summary: ":response.strip(),
|
| 453 |
+
"source: ":["LLM Response Not RAG loaded"]},
|
| 454 |
+
"evaluation": eval_result }
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
print(json.dumps(output, indent=2))
|
| 458 |
+
|
| 459 |
+
except Exception as ex:
|
| 460 |
+
print(f"Exception in LLm_prompting response: {ex}")
|
| 461 |
+
traceback.print_exc()
|
| 462 |
+
sys.exit(1)
|
| 463 |
+
|
| 464 |
+
@staticmethod
|
| 465 |
+
def login() -> Dict[str,str]:
|
| 466 |
+
username = input("Enter your username: ")
|
| 467 |
+
password = input("Enter your password: ")
|
| 468 |
+
|
| 469 |
+
hashed_password = hashlib.sha256(password.encode()).hexdigest()
|
| 470 |
+
try:
|
| 471 |
+
conn = sqlite3.connect('users.db')
|
| 472 |
+
cursor = conn.cursor()
|
| 473 |
+
cursor.execute(
|
| 474 |
+
"SELECT username,jobrole,department,location FROM users WHERE username = ? AND password = ?",
|
| 475 |
+
(username, hashed_password)
|
| 476 |
+
)
|
| 477 |
+
user = cursor.fetchone()
|
| 478 |
+
print(f"{user}")
|
| 479 |
+
conn.close()
|
| 480 |
+
if user:
|
| 481 |
+
return {"username": user[0], "role": user[1],"department": user[2],"location": user[3]}
|
| 482 |
+
else:
|
| 483 |
+
print("Invalid username or password")
|
| 484 |
+
sys.exit(1)
|
| 485 |
+
|
| 486 |
+
except sqlite3.Error as ex:
|
| 487 |
+
|
| 488 |
+
return False
|
| 489 |
+
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
def main():
|
| 496 |
+
|
| 497 |
+
user_info = RAGOperations.login()
|
| 498 |
+
print(f"Logged in as {user_info['username']}"
|
| 499 |
+
f"{user_info['department']} at {user_info['location']}")
|
| 500 |
+
|
| 501 |
+
parser = argparse.ArgumentParser(description='RAG PIPELINE OPERATIONS')
|
| 502 |
+
|
| 503 |
+
parser.add_argument('--job',type=str, required=True,
|
| 504 |
+
choices=['rag-build','search','eval-hypertune','llm','prompting'],
|
| 505 |
+
help='job to execute build RAG index, search and hypertune')
|
| 506 |
+
|
| 507 |
+
parser.add_argument('--raw', action='store_true',
|
| 508 |
+
help='dispaly raw retrieved documents in json format')
|
| 509 |
+
parser.add_argument('--chunk_size', type=int, default=Configuration.DEFAULT_CHUNK_SIZE,help="Document chunking")
|
| 510 |
+
parser.add_argument('--chunk_overlap', type=int, default=Configuration.DEFAULT_CHUNK_OVERLAP,help="Document overlap")
|
| 511 |
+
parser.add_argument('--embedding_model', type=str, default=Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL,help="Embedding model")
|
| 512 |
+
parser.add_argument('--vector_db_type', type=str,default="chroma",choices=['chroma','faiss','pinecone'],help='vector database type')
|
| 513 |
+
parser.add_argument('--llm_model',type=str, default=Configuration.DEFAULT_GROQ_LLM_MODEL)
|
| 514 |
+
parser.add_argument('--use-tuned',action='store_true',help='Use the tuned DB for search')
|
| 515 |
+
parser.add_argument('--k',type=int, default=5,help='number of doc to retrieve')
|
| 516 |
+
parser.add_argument('--query',type=str, default=None,help='query from user')
|
| 517 |
+
parser.add_argument('--temperature', type=float, default=0.1,help='LLM temperature for response generation')
|
| 518 |
+
parser.add_argument('--top_p', type=float, default=0.95, help='LLM top_p for response generation')
|
| 519 |
+
parser.add_argument('--max_tokens', type=int, default = 1000, help='LLM max tokens for response generation')
|
| 520 |
+
parser.add_argument('--use-rag', type=lambda x: x.lower() =='true',default=True, help ='llm or prompting')
|
| 521 |
+
parser.add_argument('--user-context', type=str, default=None, help='user context as JSON e.g., {"role": "admin", "location":"chennai", "designation":"engineer"}')
|
| 522 |
+
parser.add_argument('--fine-tune', action='store_true', help='use fine-tuned model')
|
| 523 |
+
parser.add_argument('--n-iter',type=int, default=10,help='number of iterations')
|
| 524 |
+
parser.add_argument('--re_ranker_model',type=str,default=Configuration.DEFAULT_RERANKER,help="ReRanking the retrieval documents")
|
| 525 |
+
args = parser.parse_args()
|
| 526 |
+
|
| 527 |
+
if args.job == 'eval-hypertune' and user_info['role']!= 'admin':
|
| 528 |
+
print(f"Access denied: Hypertune execution is forbidden")
|
| 529 |
+
|
| 530 |
+
if args.job == 'rag-build':
|
| 531 |
+
RAGOperations.run_build_job(args)
|
| 532 |
+
elif args.job == 'search':
|
| 533 |
+
RAGOperations.run_search_job(args,user_info)
|
| 534 |
+
elif args.job == 'eval-hypertune':
|
| 535 |
+
RAGOperations.run_hypertune_job(args)
|
| 536 |
+
elif args.job in ['llm','prompting']:
|
| 537 |
+
RAGOperations.run_llm_with_prompt(args, args.job)
|
| 538 |
+
|
| 539 |
+
|
| 540 |
+
if __name__ == '__main__':
|
| 541 |
+
main()
|
rag_scripts/__pycache__/interfaces.cpython-312.pyc
ADDED
|
Binary file (5.93 kB). View file
|
|
|
rag_scripts/__pycache__/rag_pipeline.cpython-312.pyc
ADDED
|
Binary file (10.4 kB). View file
|
|
|
rag_scripts/documents_processing/__pycache__/chunking.cpython-312.pyc
ADDED
|
Binary file (8.42 kB). View file
|
|
|
rag_scripts/documents_processing/chunking.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import hashlib
|
| 3 |
+
import traceback
|
| 4 |
+
import pymupdf as fitz
|
| 5 |
+
import regex as re
|
| 6 |
+
from typing import List, Dict, Tuple, Any
|
| 7 |
+
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
| 8 |
+
|
| 9 |
+
from configuration import Configuration
|
| 10 |
+
from rag_scripts.interfaces import IDocumentChunker
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class PyMuPDFChunker(IDocumentChunker):
|
| 14 |
+
|
| 15 |
+
def __init__(self, pdf_path: str, chunk_size: int = Configuration.DEFAULT_CHUNK_SIZE,
|
| 16 |
+
chunk_overlap: int = Configuration.DEFAULT_CHUNK_OVERLAP):
|
| 17 |
+
|
| 18 |
+
if not os.path.exists(pdf_path):
|
| 19 |
+
raise FileNotFoundError(f"PDF file not found at: {pdf_path}")
|
| 20 |
+
|
| 21 |
+
self.pdf_path = pdf_path
|
| 22 |
+
self.chunk_size = chunk_size
|
| 23 |
+
self.chunk_overlap = chunk_overlap
|
| 24 |
+
self.text_splitter = RecursiveCharacterTextSplitter(
|
| 25 |
+
chunk_size=self.chunk_size,
|
| 26 |
+
chunk_overlap=self.chunk_overlap,
|
| 27 |
+
length_function =len,
|
| 28 |
+
separators=["\n\n","\n","(?<=\. )\n",
|
| 29 |
+
"(?<=[a-z0-9]\.)","(?<=\? )", "(?<=\! )",
|
| 30 |
+
" ","" ] )
|
| 31 |
+
|
| 32 |
+
print(f"Initialized PyMuPDFChunker for {os.path.basename(pdf_path)} with chunk_size = {chunk_size} and chunk overlap = {chunk_overlap}")
|
| 33 |
+
|
| 34 |
+
def _clean_text(self, text: str) -> str:
|
| 35 |
+
try:
|
| 36 |
+
text = text.replace('\u25cf','').replace('\u2022','')
|
| 37 |
+
text = text.replace('\u201c','').replace('\u201d','')
|
| 38 |
+
text = text.replace('\u2013','-').replace('\u2014','-').replace('\u2015','-')
|
| 39 |
+
text = re.sub(r'\n\s*\n',' ',text)
|
| 40 |
+
text = re.sub(r' {2,}', ' ',text)
|
| 41 |
+
text = text.replace('\\nb','')
|
| 42 |
+
text = '\n'.join([line.strip() for line in text.split('\n')])
|
| 43 |
+
text = re.sub(r'[^\x20-\x7E\t\n\r]', '', text)
|
| 44 |
+
return text.strip()
|
| 45 |
+
except Exception as ex:
|
| 46 |
+
traceback.print_exc()
|
| 47 |
+
return text.strip()
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def hash_document(self):
|
| 51 |
+
try:
|
| 52 |
+
hasher = hashlib.sha256()
|
| 53 |
+
with open(self.pdf_path, 'rb') as fl:
|
| 54 |
+
while chunk:=fl.read(8192):
|
| 55 |
+
hasher.update(chunk)
|
| 56 |
+
return hasher.hexdigest()
|
| 57 |
+
except Exception as ex:
|
| 58 |
+
print(f"Exception hashing PDF {self.pdf_path}: {ex}")
|
| 59 |
+
traceback.print_exc()
|
| 60 |
+
raise ValueError(f"Failed to hash PDF: {ex}")
|
| 61 |
+
|
| 62 |
+
def chunk_documents(self) -> Tuple[str, List[Dict[str, Any]]]:
|
| 63 |
+
try:
|
| 64 |
+
doc_hash = self.hash_document()
|
| 65 |
+
doc_chunks =[]
|
| 66 |
+
|
| 67 |
+
with fitz.open(self.pdf_path) as document:
|
| 68 |
+
for idx, page in enumerate(document):
|
| 69 |
+
page_text = page.get_text()
|
| 70 |
+
if page_text.strip():
|
| 71 |
+
section = self._extract_section(page_text)
|
| 72 |
+
clause = self._extract_clause(page_text)
|
| 73 |
+
page_text = self._clean_text(page_text)
|
| 74 |
+
chunked_page = self.text_splitter.split_text(page_text)
|
| 75 |
+
for chunk_idx, chunk_content in enumerate(chunked_page):
|
| 76 |
+
if not chunk_content.strip():
|
| 77 |
+
continue
|
| 78 |
+
doc_chunks.append({
|
| 79 |
+
"content":chunk_content.strip(),
|
| 80 |
+
"metadata":{
|
| 81 |
+
"document_id":doc_hash,
|
| 82 |
+
"source_file":os.path.basename(self.pdf_path),
|
| 83 |
+
"page_number": idx+1,
|
| 84 |
+
"chunk_id":f"{doc_hash} - {idx + 1} -{chunk_idx}",
|
| 85 |
+
"section": section or "Unknown Section",
|
| 86 |
+
"clause": clause or "Unknown Clause",
|
| 87 |
+
"chunk_index_on_page": chunk_idx
|
| 88 |
+
}
|
| 89 |
+
})
|
| 90 |
+
if not doc_chunks:
|
| 91 |
+
print(f"No text or chunks extracted from {self.pdf_path} after cleaning")
|
| 92 |
+
return doc_hash,[]
|
| 93 |
+
else:
|
| 94 |
+
print(f"success Document chunked {self.pdf_path} into {len(doc_hash)} chunks")
|
| 95 |
+
return doc_hash,doc_chunks
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
except Exception as ex:
|
| 100 |
+
print(f"Exception in document chunking {ex}")
|
| 101 |
+
traceback.print_exc()
|
| 102 |
+
return self.hash_document(), []
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def _extract_section(self, text: str) -> str:
|
| 106 |
+
match_major = re.search(r'^(?:[IVX]+\.?\s+|[A-Z]\.?\s+|[0-9]+\.?\s+)(.+)', text, re.MULTILINE)
|
| 107 |
+
if match_major:
|
| 108 |
+
return match_major.group(0)
|
| 109 |
+
|
| 110 |
+
match_firstline = re.search(r'^\s*([A-Za-z0-9][\w\s,&\'-]+?)\s*$', text, re.MULTILINE)
|
| 111 |
+
if match_firstline:
|
| 112 |
+
return match_firstline.group(1).strip()
|
| 113 |
+
|
| 114 |
+
return None
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def _extract_clause(self,text) -> str:
|
| 118 |
+
match = re.search(r'^(?:(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))\s*)(.+?)(?=\n(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))|\n\n|\Z)', text, re.MULTILINE | re.DOTALL)
|
| 119 |
+
if match:
|
| 120 |
+
clause = match.group(1).strip()
|
| 121 |
+
if len(clause.split()) < 10:
|
| 122 |
+
next_match = re.search(
|
| 123 |
+
r'^(?:(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))\s*)(.+?)(?=\n(?:•|●|-|\*|\d+\.|\([a-z]\)|\([A-Z]\)|\w\))|\n\n|\Z)',
|
| 124 |
+
text[match.end():], re.MULTILINE | re.DOTALL)
|
| 125 |
+
if next_match:
|
| 126 |
+
clause += " " + next_match.group(0).strip()
|
| 127 |
+
return clause
|
| 128 |
+
|
| 129 |
+
match_para = re.search(r'^(?!#|\s*$).*?\n\n',text,re.DOTALL)
|
| 130 |
+
if match_para:
|
| 131 |
+
return match_para.group(0).strip()
|
| 132 |
+
return None
|
| 133 |
+
|
| 134 |
+
|
rag_scripts/embedding/__pycache__/embedder.cpython-312.pyc
ADDED
|
Binary file (3.5 kB). View file
|
|
|
rag_scripts/embedding/embedder.py
ADDED
|
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import List
|
| 4 |
+
from sentence_transformers import SentenceTransformer
|
| 5 |
+
from traits.trait_types import self
|
| 6 |
+
|
| 7 |
+
from configuration import Configuration
|
| 8 |
+
from rag_scripts.interfaces import IEmbedder
|
| 9 |
+
|
| 10 |
+
class SentenceTransformerEmbedder(IEmbedder):
|
| 11 |
+
def __init__(self, model_name: str = Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL):
|
| 12 |
+
self.model_name = model_name
|
| 13 |
+
try:
|
| 14 |
+
self.model = SentenceTransformer(self.model_name)
|
| 15 |
+
print(f'Sentence Transformer loaded {self.model_name}')
|
| 16 |
+
except Exception as ex:
|
| 17 |
+
print(f"Exception in loading Sentence Transformer {self.model_name}")
|
| 18 |
+
traceback.print_exc()
|
| 19 |
+
raise
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def embed_texts(self, texts: List[str]) -> List[List[float]]:
|
| 23 |
+
try:
|
| 24 |
+
embeddings = self.model.encode(texts).tolist()
|
| 25 |
+
return embeddings
|
| 26 |
+
except Exception as ex:
|
| 27 |
+
print(f"Exception in embedding: {ex}")
|
| 28 |
+
traceback.print_exc()
|
| 29 |
+
return [[] for _ in texts]
|
| 30 |
+
|
| 31 |
+
def embed_query(self,query: str) -> List[float]:
|
| 32 |
+
try:
|
| 33 |
+
print(f"Embedding query: {query}")
|
| 34 |
+
embedding = self.model.encode(query).tolist()
|
| 35 |
+
return embedding
|
| 36 |
+
except Exception as ex:
|
| 37 |
+
print(f"Exception in query embedding: {ex}")
|
| 38 |
+
traceback.print_exc()
|
| 39 |
+
return []
|
| 40 |
+
|
| 41 |
+
def rank(selfself,Query: str, documents:List[str]) -> List[float]:
|
| 42 |
+
if not documents:
|
| 43 |
+
return []
|
| 44 |
+
try:
|
| 45 |
+
sentence_paris = [[Query, doc] for doc in documents]
|
| 46 |
+
scores = self.model.predict(sentence_paris)
|
| 47 |
+
return scores.tolist()
|
| 48 |
+
except Exception as ex:
|
| 49 |
+
print(f"Exception in ranking: {ex}")
|
| 50 |
+
traceback.print_exc()
|
| 51 |
+
return []
|
| 52 |
+
|
| 53 |
+
|
rag_scripts/embedding/vector_db/__pycache__/chroma_db.cpython-312.pyc
ADDED
|
Binary file (6.96 kB). View file
|
|
|
rag_scripts/embedding/vector_db/__pycache__/faiss_db.cpython-312.pyc
ADDED
|
Binary file (9.43 kB). View file
|
|
|
rag_scripts/embedding/vector_db/__pycache__/pinecone_db.cpython-312.pyc
ADDED
|
Binary file (10.7 kB). View file
|
|
|
rag_scripts/embedding/vector_db/chroma_db.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import chromadb
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import List, Dict, Any
|
| 4 |
+
from chromadb.api.types import EmbeddingFunction
|
| 5 |
+
from configuration import Configuration
|
| 6 |
+
from rag_scripts.interfaces import IVector
|
| 7 |
+
from rag_scripts.interfaces import IEmbedder
|
| 8 |
+
from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
|
| 9 |
+
|
| 10 |
+
class chromaDBEmbeddingFunction(EmbeddingFunction):
|
| 11 |
+
|
| 12 |
+
def __init__(self, embedder: IEmbedder):
|
| 13 |
+
self.embedder = embedder
|
| 14 |
+
|
| 15 |
+
def __call__(self, texts: List[str]) -> List[List[float]]:
|
| 16 |
+
return self.embedder.embed_texts(texts)
|
| 17 |
+
|
| 18 |
+
class chromaDBVectorDB(IVector):
|
| 19 |
+
def __init__(self, embedder:IEmbedder, db_path: str = Configuration.CHROMA_DB_PATH, collection_name: str = Configuration.COLLECTION_NAME):
|
| 20 |
+
super().__init__(embedder, db_path, collection_name)
|
| 21 |
+
|
| 22 |
+
self.client = chromadb.PersistentClient(path=self.db_path)
|
| 23 |
+
self.chroma_embed_function = chromaDBEmbeddingFunction(self.embedder)
|
| 24 |
+
self.collection = self._get_or_create_collection()
|
| 25 |
+
print(f"Chroma DB intialized path = {self.db_path}, collection = {self.collection_name}")
|
| 26 |
+
|
| 27 |
+
def _get_or_create_collection(self):
|
| 28 |
+
return self.client.get_or_create_collection(
|
| 29 |
+
name=self.collection_name,
|
| 30 |
+
embedding_function=self.chroma_embed_function)
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def add_chunks(self,documents: List[Dict[str,Any]]) -> List[str]:
|
| 34 |
+
try:
|
| 35 |
+
if not documents:
|
| 36 |
+
return []
|
| 37 |
+
ids = [doc['metadata']['chunk_id'] for doc in documents]
|
| 38 |
+
contents = [doc['content'] for doc in documents]
|
| 39 |
+
metadatas = [doc['metadata'] for doc in documents]
|
| 40 |
+
|
| 41 |
+
self.collection.add(documents=contents, metadatas=metadatas,ids=ids)
|
| 42 |
+
|
| 43 |
+
print(f"Added {len(ids)} chunks to chroma db")
|
| 44 |
+
return ids
|
| 45 |
+
except Exception as ex:
|
| 46 |
+
print(f"Exception in adding chunks to chromaDB: {ex}")
|
| 47 |
+
traceback.print_exc
|
| 48 |
+
return[]
|
| 49 |
+
|
| 50 |
+
def get_document_hash_ids(self, document_hash: str) -> List[str]:
|
| 51 |
+
|
| 52 |
+
try:
|
| 53 |
+
result = self.collection.get(where={"document_id":document_hash},limit=1)
|
| 54 |
+
return result['ids'] if result and result['ids'] else []
|
| 55 |
+
except Exception as ex:
|
| 56 |
+
print(f"Exception getting document hash {document_hash} from chromadb {ex}")
|
| 57 |
+
traceback.print_exc
|
| 58 |
+
return[]
|
| 59 |
+
|
| 60 |
+
def search(self, query: str, k: int=3) -> List[Dict[str, Any]]:
|
| 61 |
+
try:
|
| 62 |
+
search_result = self.collection.query(
|
| 63 |
+
query_texts = [query],
|
| 64 |
+
n_results=k,
|
| 65 |
+
include=['documents','metadatas','distances']
|
| 66 |
+
)
|
| 67 |
+
retrieved_documents =[]
|
| 68 |
+
if search_result and search_result['documents'] and search_result['metadatas']:
|
| 69 |
+
for indx in range(len(search_result['documents'][0])):
|
| 70 |
+
document_content = search_result['documents'][0][indx]
|
| 71 |
+
document_metadata = search_result['metadatas'][0][indx]
|
| 72 |
+
document_distance = search_result['distances'][0][indx]
|
| 73 |
+
|
| 74 |
+
retrieved_documents.append({
|
| 75 |
+
"content": document_content,
|
| 76 |
+
"metadata": document_metadata,
|
| 77 |
+
"distance": document_distance })
|
| 78 |
+
print(f"Rettrieved {len(retrieved_documents)} documents from chroma DB for query: '{query}'")
|
| 79 |
+
|
| 80 |
+
return retrieved_documents
|
| 81 |
+
except Exception as ex:
|
| 82 |
+
print(f"Exception in Chroma db search: {ex}")
|
| 83 |
+
traceback.print_exc()
|
| 84 |
+
return []
|
| 85 |
+
|
| 86 |
+
def delete_collection(self, collection_name: str):
|
| 87 |
+
try:
|
| 88 |
+
self.client.delete_collection(name = collection_name)
|
| 89 |
+
print(f"Chroma DB collection {collection_name} deleted")
|
| 90 |
+
self.collection = self._get_or_create_collection()
|
| 91 |
+
except Exception as ex:
|
| 92 |
+
print(f"Exception in deleting the chroma db collection: {collection_name}")
|
| 93 |
+
traceback.print_exc()
|
| 94 |
+
|
| 95 |
+
def count_documents(self) -> int:
|
| 96 |
+
try:
|
| 97 |
+
return self.collection.count()
|
| 98 |
+
except Exception as ex:
|
| 99 |
+
print(f"Exception in counting documents in chroma db: {ex}")
|
| 100 |
+
traceback.print_exc()
|
| 101 |
+
return 0
|
| 102 |
+
|
| 103 |
+
|
rag_scripts/embedding/vector_db/faiss_db.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import faiss
|
| 3 |
+
import pickle
|
| 4 |
+
import traceback
|
| 5 |
+
import numpy as np
|
| 6 |
+
from typing import List, Dict, Any
|
| 7 |
+
|
| 8 |
+
from configuration import Configuration
|
| 9 |
+
from rag_scripts.interfaces import IVector, IEmbedder
|
| 10 |
+
from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
|
| 11 |
+
|
| 12 |
+
class FAISSVectorDB(IVector):
|
| 13 |
+
|
| 14 |
+
def __init__(self, embedder: IEmbedder, db_path: str= Configuration.FAISS_DB_PATH,
|
| 15 |
+
collection_name: str = Configuration.COLLECTION_NAME):
|
| 16 |
+
super().__init__(embedder,db_path,collection_name)
|
| 17 |
+
self.index = None
|
| 18 |
+
self.doc_store: List[Dict[str,Any]] =[]
|
| 19 |
+
|
| 20 |
+
self.collection_file = os.path.join(self.db_path,f"{self.collection_name}.faiss")
|
| 21 |
+
self.doc_store_file = os.path.join(self.db_path, f"{self.collection_name}_docs.pkl")
|
| 22 |
+
|
| 23 |
+
os.makedirs(self.db_path, exist_ok=True)
|
| 24 |
+
|
| 25 |
+
self._load_index()
|
| 26 |
+
|
| 27 |
+
print(f"FAISSDB Initalized: path ='{self.db_path}, collection ='{self.collection_name}")
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def _load_index(self):
|
| 31 |
+
try:
|
| 32 |
+
if os.path.exists(self.collection_file) and os.path.exists(self.doc_store_file):
|
| 33 |
+
self.index = faiss.read_index(self.collection_file)
|
| 34 |
+
with open(self.doc_store_file,'rb') as f:
|
| 35 |
+
self.doc_store = pickle.load(f)
|
| 36 |
+
print(f"Loaded Existing FAISS index and doc store from {self.collection_file}")
|
| 37 |
+
else:
|
| 38 |
+
print(f"No Existing FAISS index found, need new collection")
|
| 39 |
+
except Exception as ex:
|
| 40 |
+
print(f"Exception in loading FAISS index: {ex}")
|
| 41 |
+
traceback.print_exc()
|
| 42 |
+
self.index = None
|
| 43 |
+
self.doc_store = []
|
| 44 |
+
|
| 45 |
+
def _save_index(self):
|
| 46 |
+
try:
|
| 47 |
+
if self.index is not None:
|
| 48 |
+
faiss.write_index(self.index, self.collection_file)
|
| 49 |
+
with open(self.doc_store_file, "wb") as f:
|
| 50 |
+
pickle.dump(self.doc_store,f)
|
| 51 |
+
print(f"FAISS index and doc store saved to {self.collection_file}")
|
| 52 |
+
|
| 53 |
+
except Exception as ex:
|
| 54 |
+
print(f"Exception in saving FAISS index: {ex}")
|
| 55 |
+
traceback.print_exc()
|
| 56 |
+
|
| 57 |
+
def add_chunks(self, documents: List[Dict[str, Any]]) -> List[str]:
|
| 58 |
+
try:
|
| 59 |
+
if not documents:
|
| 60 |
+
return[]
|
| 61 |
+
|
| 62 |
+
contents = [doc['content'] for doc in documents]
|
| 63 |
+
embeddings = np.array(self.embedder.embed_texts(contents),dtype='float32')
|
| 64 |
+
|
| 65 |
+
if self.index is None:
|
| 66 |
+
dimension = embeddings.shape[1]
|
| 67 |
+
self.index = faiss.IndexFlatL2(dimension)
|
| 68 |
+
print(f"Created New FAISS index with dimension {dimension}")
|
| 69 |
+
|
| 70 |
+
self.index.add(embeddings)
|
| 71 |
+
|
| 72 |
+
for doc in documents:
|
| 73 |
+
self.doc_store.append(doc)
|
| 74 |
+
|
| 75 |
+
self._save_index()
|
| 76 |
+
chunk_ids = [doc['metadata']['chunk_id'] for doc in documents]
|
| 77 |
+
print(f"Added {len(chunk_ids)} chunks to FAISS index")
|
| 78 |
+
return chunk_ids
|
| 79 |
+
except Exception as ex:
|
| 80 |
+
print(f"Exception in adding chunks to FAISS: {ex}")
|
| 81 |
+
traceback.print_exc()
|
| 82 |
+
return []
|
| 83 |
+
|
| 84 |
+
def get_document_hash_ids(self, document_hash: str) ->List[str]:
|
| 85 |
+
try:
|
| 86 |
+
found_ids =[]
|
| 87 |
+
for doc in self.doc_store:
|
| 88 |
+
if doc['metadata'].get('document_id') == document_hash:
|
| 89 |
+
found_ids.append(doc['metadata']['chunk_id'])
|
| 90 |
+
|
| 91 |
+
return found_ids
|
| 92 |
+
except Exception as ex:
|
| 93 |
+
print(f"Exception in get document hash {ex}")
|
| 94 |
+
|
| 95 |
+
def search(self, query: str, k:int=3) -> List[Dict[str,Any]]:
|
| 96 |
+
try:
|
| 97 |
+
if self.index is None or self.index.ntotal ==0:
|
| 98 |
+
print(f"FAISS index is empty or not intialized")
|
| 99 |
+
return []
|
| 100 |
+
query_embedding = np.array(self.embedder.embed_query(query),dtype='float32').reshape(1,-1)
|
| 101 |
+
distances, indices = self.index.search(query_embedding,k)
|
| 102 |
+
|
| 103 |
+
retrieved_documents = []
|
| 104 |
+
for dist, idx in zip(distances[0], indices[0]):
|
| 105 |
+
if 0 <= idx < len(self.doc_store):
|
| 106 |
+
doc = self.doc_store[idx]
|
| 107 |
+
retrieved_documents.append({
|
| 108 |
+
"content":doc['content'],
|
| 109 |
+
"metadata":doc['metadata'],
|
| 110 |
+
"distance":float(dist) })
|
| 111 |
+
print(f"Retrieved {len(retrieved_documents)} documents from FAISS for query: {query}")
|
| 112 |
+
|
| 113 |
+
return retrieved_documents
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
except Exception as ex:
|
| 117 |
+
print(f"Exception in FAISS db search {ex}")
|
| 118 |
+
traceback.print_exc()
|
| 119 |
+
return []
|
| 120 |
+
|
| 121 |
+
def delete_collection(self,collection_name: str):
|
| 122 |
+
try:
|
| 123 |
+
if os.path.exists(self.collection_file):
|
| 124 |
+
os.remove(self.collection_file)
|
| 125 |
+
if os.path.exists(self.doc_store_file):
|
| 126 |
+
os.remove(self.doc_store_file)
|
| 127 |
+
self.index = None
|
| 128 |
+
self.doc_store = []
|
| 129 |
+
print(f"FAISS collection files for {collection_name} deleted")
|
| 130 |
+
|
| 131 |
+
except Exception as ex:
|
| 132 |
+
print(f"Error deleting FAISS collection files for {collection_name}: ex")
|
| 133 |
+
traceback.print_exc
|
| 134 |
+
|
| 135 |
+
def count_documents(self) -> int:
|
| 136 |
+
try:
|
| 137 |
+
return self.index.ntotal if self.index is not None else 0
|
| 138 |
+
except Exception as ex:
|
| 139 |
+
print(f"Exception in counting document in FAISS: {ex}")
|
| 140 |
+
traceback.print_exc()
|
| 141 |
+
return 0
|
| 142 |
+
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
|
rag_scripts/embedding/vector_db/pinecone_db.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import List, Dict,Any
|
| 4 |
+
import re
|
| 5 |
+
from pinecone import Pinecone, ServerlessSpec
|
| 6 |
+
from pinecone.exceptions import PineconeException
|
| 7 |
+
|
| 8 |
+
from configuration import Configuration
|
| 9 |
+
from rag_scripts.interfaces import IVector, IEmbedder
|
| 10 |
+
|
| 11 |
+
class PineconeVectorDB(IVector):
|
| 12 |
+
def __init__(self, embedder:IEmbedder, db_path: str, collection_name: str="flykite"):
|
| 13 |
+
super().__init__(embedder,db_path,collection_name)
|
| 14 |
+
|
| 15 |
+
self.api_key = Configuration.PINECONE_API_KEY
|
| 16 |
+
self.cloud = Configuration.PINECONE_CLOUD
|
| 17 |
+
self.region = Configuration.PINECONE_REGION
|
| 18 |
+
|
| 19 |
+
if not self.api_key:
|
| 20 |
+
raise ValueError("Pinecone API KEY not provided in configuration")
|
| 21 |
+
|
| 22 |
+
self.pc = Pinecone(api_key=self.api_key)
|
| 23 |
+
collection_name = collection_name.replace('_','-')
|
| 24 |
+
self.index_name=collection_name
|
| 25 |
+
|
| 26 |
+
print(f"Collection name: {collection_name}")
|
| 27 |
+
|
| 28 |
+
self.dim=len(self.embedder.embed_texts(["test"])[0])
|
| 29 |
+
self._create_index()
|
| 30 |
+
self.index=self.pc.Index(self.index_name)
|
| 31 |
+
|
| 32 |
+
print(f"Pinecone DB Initialized index= {self.index_name}, cloud= {self.cloud}, region = {self.region}")
|
| 33 |
+
|
| 34 |
+
def _create_index(self):
|
| 35 |
+
try:
|
| 36 |
+
existing_indexes = self.pc.list_indexes().names()
|
| 37 |
+
if self.index_name in existing_indexes:
|
| 38 |
+
index = self.pc.Index((self.index_name))
|
| 39 |
+
index_info = index.describe_index_stats()
|
| 40 |
+
if index_info.dimension!=self.dim:
|
| 41 |
+
print(f"Pinecone DB Index already exists {self.index_name} "
|
| 42 |
+
f"with dimension {index_info.dimension}"
|
| 43 |
+
f"but the expected dimension is {self.dim}"
|
| 44 |
+
" so deleting it and recreating again"
|
| 45 |
+
)
|
| 46 |
+
self.pc.delete_index(self.index_name)
|
| 47 |
+
|
| 48 |
+
# while self.index_name in self.pc.list_indexes().names():
|
| 49 |
+
# time.sleep(2)
|
| 50 |
+
for _ in range(30):
|
| 51 |
+
if self.index_name not in self.pc.list_indexes().names():
|
| 52 |
+
break
|
| 53 |
+
else:
|
| 54 |
+
time.sleep(2)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
elif index_info.metric!='cosine':
|
| 58 |
+
self.pc.delete_index(self.index_name)
|
| 59 |
+
|
| 60 |
+
for _ in range(30):
|
| 61 |
+
if self.index_name in self.pc.list_indexes().names():
|
| 62 |
+
time.sleep(2)
|
| 63 |
+
else:
|
| 64 |
+
break
|
| 65 |
+
else:
|
| 66 |
+
print(f"Pinecone index already exists {self.index_name} ")
|
| 67 |
+
return
|
| 68 |
+
|
| 69 |
+
if self.index_name not in self.pc.list_indexes().names():
|
| 70 |
+
self.pc.create_index(name=self.index_name,
|
| 71 |
+
dimension=self.dim,
|
| 72 |
+
metric='cosine',
|
| 73 |
+
spec=ServerlessSpec(cloud=self.cloud,region=self.region)
|
| 74 |
+
)
|
| 75 |
+
max_attempts = 90
|
| 76 |
+
for attempt in range(max_attempts):
|
| 77 |
+
try:
|
| 78 |
+
index = self.pc.Index(self.index_name)
|
| 79 |
+
index_info = index.describe_index_stats()
|
| 80 |
+
if index_info.dimension == self.dim:
|
| 81 |
+
print(f"Pinecone index {self.index_name} already exists ")
|
| 82 |
+
return
|
| 83 |
+
except PineconeException:
|
| 84 |
+
pass
|
| 85 |
+
time.sleep(2)
|
| 86 |
+
raise TimeoutError(f"Pinecone index {self.index_name} does not exist")
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
except PineconeException as ex:
|
| 90 |
+
|
| 91 |
+
print(f"Exception creating Pinecone index: {ex}")
|
| 92 |
+
traceback.print_exc()
|
| 93 |
+
raise
|
| 94 |
+
except Exception as ex:
|
| 95 |
+
print(f"Exception creating Pinecone index: {ex}")
|
| 96 |
+
traceback.print_exc()
|
| 97 |
+
raise
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def add_chunks(self, documents: List[Dict[str,Any]]) -> List[str]:
|
| 102 |
+
try:
|
| 103 |
+
if not documents:
|
| 104 |
+
return[]
|
| 105 |
+
vectors_info = []
|
| 106 |
+
for doc in documents:
|
| 107 |
+
content = doc['content']
|
| 108 |
+
embedding = self.embedder.embed_texts([content])[0]
|
| 109 |
+
chunk_id = doc['metadata']['chunk_id']
|
| 110 |
+
metadata= doc['metadata'].copy()
|
| 111 |
+
metadata['content']=content
|
| 112 |
+
|
| 113 |
+
vectors_info.append({
|
| 114 |
+
"id":chunk_id, "values":embedding, "metadata":metadata
|
| 115 |
+
})
|
| 116 |
+
|
| 117 |
+
print(f"Upserting {len(vectors_info)} to pincone index")
|
| 118 |
+
self.index.upsert(vectors=vectors_info)
|
| 119 |
+
chunk_ids = [vec['id'] for vec in vectors_info]
|
| 120 |
+
print(f"Added {len(chunk_ids)} chunks to pinecone index {self.index_name}")
|
| 121 |
+
print(f"Added chunk ids {chunk_ids[:5]}")
|
| 122 |
+
return chunk_ids
|
| 123 |
+
|
| 124 |
+
except PineconeException as ex:
|
| 125 |
+
print(f"Exception in adding chunks to pinecone db {ex}")
|
| 126 |
+
traceback.print_exc()
|
| 127 |
+
return []
|
| 128 |
+
|
| 129 |
+
def get_document_hash_ids(self, document_hash: str) ->List[str]:
|
| 130 |
+
try:
|
| 131 |
+
vector_sample = [0.0]*self.dim
|
| 132 |
+
result = self.index.query(
|
| 133 |
+
vector = vector_sample,
|
| 134 |
+
filter = {"document_id": {"$eq": document_hash}},
|
| 135 |
+
top_k = 10000,
|
| 136 |
+
include_metadata=False )
|
| 137 |
+
|
| 138 |
+
return [match['id'] for match in result['matches']]
|
| 139 |
+
|
| 140 |
+
except PineconeException as ex:
|
| 141 |
+
print(f"Exception getting document hash IDs from pinecone: {ex}")
|
| 142 |
+
traceback.print_exc()
|
| 143 |
+
return []
|
| 144 |
+
|
| 145 |
+
def search(self,query: str, k:int=3) -> List[Dict[str,Any]]:
|
| 146 |
+
try:
|
| 147 |
+
print(f"searching pinecode index {self.index_name} for query {query}")
|
| 148 |
+
query_embedding = self.embedder.embed_query(query)
|
| 149 |
+
result = self.index.query(
|
| 150 |
+
vector=query_embedding,
|
| 151 |
+
top_k=k,
|
| 152 |
+
include_metadata=True
|
| 153 |
+
)
|
| 154 |
+
|
| 155 |
+
retrieved_documents = []
|
| 156 |
+
for match in result['matches']:
|
| 157 |
+
md= match['metadata']
|
| 158 |
+
content = md.pop('content','')
|
| 159 |
+
|
| 160 |
+
retrieved_documents.append({
|
| 161 |
+
"content":content,
|
| 162 |
+
"metadata":md,
|
| 163 |
+
"distance":match['score']
|
| 164 |
+
})
|
| 165 |
+
|
| 166 |
+
print(f"Retrieved {len(retrieved_documents)} documents from pinecone for query: {query}")
|
| 167 |
+
|
| 168 |
+
return retrieved_documents
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
except PineconeException as ex:
|
| 172 |
+
print(f"Exception in Pinecone search: {ex}")
|
| 173 |
+
traceback.print_exc()
|
| 174 |
+
return[]
|
| 175 |
+
|
| 176 |
+
def delete_collection(self, collection_name: str):
|
| 177 |
+
try:
|
| 178 |
+
|
| 179 |
+
self.pc.delete_index(collection_name)
|
| 180 |
+
print(f"Pinecone index {collection_name} deleted")
|
| 181 |
+
except PineconeException as ex:
|
| 182 |
+
print(f"Exception in deleting pinecone index: {collection_name} {ex}")
|
| 183 |
+
traceback.print_exc()
|
| 184 |
+
|
| 185 |
+
def count_documents(self) -> int:
|
| 186 |
+
try:
|
| 187 |
+
for attempt in range(19):
|
| 188 |
+
status = self.index.describe_index_stats()
|
| 189 |
+
count = status.get('total_vector_count',0)
|
| 190 |
+
if count > 0:
|
| 191 |
+
return count
|
| 192 |
+
time.sleep(7)
|
| 193 |
+
print(f"No vectors found in index: {self.index_name}")
|
| 194 |
+
return 0
|
| 195 |
+
|
| 196 |
+
except PineconeException as ex:
|
| 197 |
+
print(f"EXception in getting document counts: {ex}")
|
| 198 |
+
traceback.print_exc()
|
| 199 |
+
return 0
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
|
rag_scripts/evaluation/__pycache__/evaluator.cpython-312.pyc
ADDED
|
Binary file (30.5 kB). View file
|
|
|
rag_scripts/evaluation/evaluator.py
ADDED
|
@@ -0,0 +1,525 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import time
|
| 5 |
+
import json
|
| 6 |
+
import random
|
| 7 |
+
import joblib
|
| 8 |
+
import traceback
|
| 9 |
+
import logging
|
| 10 |
+
|
| 11 |
+
import torch
|
| 12 |
+
import transformers
|
| 13 |
+
from bert_score import score
|
| 14 |
+
from itertools import product
|
| 15 |
+
from datetime import datetime
|
| 16 |
+
from typing import List, Dict, Any
|
| 17 |
+
from configuration import Configuration
|
| 18 |
+
from rag_scripts.llm.llmResponse import GROQLLM
|
| 19 |
+
from pinecone import Pinecone, PineconeException
|
| 20 |
+
from rag_scripts.rag_pipeline import RAGPipeline
|
| 21 |
+
from sentence_transformers import SentenceTransformer, util
|
| 22 |
+
from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
|
| 23 |
+
from rag_scripts.documents_processing.chunking import PyMuPDFChunker
|
| 24 |
+
from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
|
| 25 |
+
from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
|
| 26 |
+
from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
|
| 27 |
+
from rag_scripts.interfaces import IDocumentChunker, ILLM, IEmbedder, IRAGPipeline, IVector
|
| 28 |
+
|
| 29 |
+
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')
|
| 30 |
+
logger = logging.getLogger(__name__)
|
| 31 |
+
|
| 32 |
+
VECTOR_DB_CONSTRUCTORS = {
|
| 33 |
+
"chroma": chromaDBVectorDB,
|
| 34 |
+
"faiss": FAISSVectorDB,
|
| 35 |
+
"pinecone": PineconeVectorDB
|
| 36 |
+
}
|
| 37 |
+
|
| 38 |
+
class RAGEvaluator:
|
| 39 |
+
def __init__(self, eval_data_path: str = Configuration.EVAL_DATA_PATH,
|
| 40 |
+
pdf_path: str = Configuration.FULL_PDF_PATH):
|
| 41 |
+
if not os.path.exists(eval_data_path):
|
| 42 |
+
raise FileNotFoundError(f"Evaluation data not found at: {eval_data_path}")
|
| 43 |
+
|
| 44 |
+
if not os.path.exists(pdf_path):
|
| 45 |
+
raise FileNotFoundError(f"PDF document not found at: {pdf_path}")
|
| 46 |
+
|
| 47 |
+
with open(eval_data_path, 'r') as f:
|
| 48 |
+
self.eval_queries = json.load(f)
|
| 49 |
+
|
| 50 |
+
self.pdf_path = pdf_path
|
| 51 |
+
self.embedder = SentenceTransformer(Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL)
|
| 52 |
+
logger.info(f"RAG Evaluator initialized with {len(self.eval_queries)} evaluation queries")
|
| 53 |
+
|
| 54 |
+
def _sanitize_collection_name(self, name: str) -> str:
|
| 55 |
+
sanitized = re.sub(r'[^a-z0-9]', '-', name.lower())
|
| 56 |
+
sanitized = re.sub(r'-+', '-', sanitized).strip('-')
|
| 57 |
+
return sanitized[:45].rstrip('-') if len(sanitized) > 45 else sanitized
|
| 58 |
+
|
| 59 |
+
def _calculate_retrieval_relevance(self, query: str, retrieved_docs: List[Dict[str, Any]],
|
| 60 |
+
expected_answer: str =None) -> float:
|
| 61 |
+
if not retrieved_docs:
|
| 62 |
+
logger.warning(f"No retrieved documents for query: {query}")
|
| 63 |
+
return 0.0
|
| 64 |
+
|
| 65 |
+
query_embedding = self.embedder.encode(query, convert_to_tensor=True, normalize_embeddings=True)
|
| 66 |
+
reference_embedding = [query_embedding]
|
| 67 |
+
# if expected_answer.strip():
|
| 68 |
+
# expected_answer_embedding = self.embedder.encode(expected_answer,convert_to_tensor=True,
|
| 69 |
+
# normalize_embeddings=True,
|
| 70 |
+
# normalize_retrieved=True)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
score_cosine = []
|
| 74 |
+
for doc in retrieved_docs:
|
| 75 |
+
doc_content = doc.get('content','')
|
| 76 |
+
if not doc_content.strip():
|
| 77 |
+
continue
|
| 78 |
+
doc_embedding = self.embedder.encode(doc_content, convert_to_tensor=True, normalize_embeddings=True)
|
| 79 |
+
max_sim_to_reference = 0.0
|
| 80 |
+
for ref_embedding in reference_embedding:
|
| 81 |
+
sim = util.cos_sim(ref_embedding, doc_embedding)
|
| 82 |
+
if isinstance(sim,torch.Tensor):
|
| 83 |
+
sim_value = sim.item()
|
| 84 |
+
else:
|
| 85 |
+
sim_value = sim
|
| 86 |
+
if sim_value > max_sim_to_reference:
|
| 87 |
+
max_sim_to_reference = sim_value
|
| 88 |
+
score_cosine.append(max_sim_to_reference)
|
| 89 |
+
cosine_score = sum(score_cosine) / len(score_cosine) if score_cosine else 0.0
|
| 90 |
+
#cosine_score = max(score_cosine) if score_cosine else 0.0
|
| 91 |
+
logger.info(f"Max cosine score for query '{query}': {cosine_score}")
|
| 92 |
+
return cosine_score
|
| 93 |
+
|
| 94 |
+
def _calculate_response_groundedness(self, response: Dict[str,Any], retrieved_docs: List[Dict[str, Any]]) -> float:
|
| 95 |
+
try:
|
| 96 |
+
if not retrieved_docs:
|
| 97 |
+
logger.warning(f"No retrieved documents for groundedness evaluation")
|
| 98 |
+
return 0.0
|
| 99 |
+
response_text = response.get("summary","")
|
| 100 |
+
if not isinstance(response_text,str):
|
| 101 |
+
response_text = str(response_text)
|
| 102 |
+
|
| 103 |
+
response_segments = re.split(r'[\u25cf\u25cb]\s*|\n\s*\n', response_text.strip())
|
| 104 |
+
response_segments = [seg.strip() for seg in response_segments if seg.strip()]
|
| 105 |
+
|
| 106 |
+
groundedness_scores = []
|
| 107 |
+
for doc in retrieved_docs:
|
| 108 |
+
doc_content = doc.get('content','')
|
| 109 |
+
if not doc_content.strip():
|
| 110 |
+
continue
|
| 111 |
+
doc_embedding = self.embedder.encode(doc_content,
|
| 112 |
+
convert_to_tensor=True,
|
| 113 |
+
normalize_embeddings=True)
|
| 114 |
+
for segment in response_segments:
|
| 115 |
+
if segment:
|
| 116 |
+
segment_embedding = self.embedder.encode(segment,
|
| 117 |
+
convert_to_tensor=True,
|
| 118 |
+
normalize_embeddings=True)
|
| 119 |
+
|
| 120 |
+
segment_similarity = util.cos_sim(segment_embedding,
|
| 121 |
+
doc_embedding).item()
|
| 122 |
+
groundedness_scores.append(segment_similarity)
|
| 123 |
+
|
| 124 |
+
full_response_embedding = self.embedder.encode(response_text,
|
| 125 |
+
convert_to_tensor=True,
|
| 126 |
+
normalize_embeddings=True)
|
| 127 |
+
full_response_similarity = util.cos_sim(full_response_embedding,doc_embedding).item()
|
| 128 |
+
groundedness_scores.append(full_response_similarity)
|
| 129 |
+
|
| 130 |
+
groundedness = max(
|
| 131 |
+
groundedness_scores) if groundedness_scores else 0.0 # Use max to handle structured responses
|
| 132 |
+
logger.info(f"Max groundedness score for response: {groundedness:.2f}")
|
| 133 |
+
return groundedness
|
| 134 |
+
except Exception as ex:
|
| 135 |
+
logger.error(f"Exception in calculating groundedness: {ex}")
|
| 136 |
+
traceback.print_exc()
|
| 137 |
+
return 0.0
|
| 138 |
+
|
| 139 |
+
def _calculate_response_relevance(self, query: str, response: Dict[str,Any],expected_answer: str=None) -> float:
|
| 140 |
+
try:
|
| 141 |
+
response_text = response.get("summary","")
|
| 142 |
+
response_text = response_text.replace('\n','')
|
| 143 |
+
if not isinstance(response_text,str) or not response_text.strip():
|
| 144 |
+
return 0.0
|
| 145 |
+
|
| 146 |
+
reference_text = expected_answer if expected_answer is not None and expected_answer.strip() else query
|
| 147 |
+
|
| 148 |
+
refernce_embedding = self.embedder.encode(reference_text,convert_to_tensor=True,normalize_embeddings=True)
|
| 149 |
+
#query_embedding = self.embedder.encode(query, convert_to_tensor=True, normalize_embeddings=True)
|
| 150 |
+
response_embedding = self.embedder.encode(response_text, convert_to_tensor=True, normalize_embeddings=True)
|
| 151 |
+
similarity = util.cos_sim(refernce_embedding, response_embedding).item()
|
| 152 |
+
|
| 153 |
+
response_length = len(response_text.split())
|
| 154 |
+
#query_length = len(query.split())
|
| 155 |
+
reference_length = len(reference_text.split())
|
| 156 |
+
length_penalty_factor = 1.0
|
| 157 |
+
if response_length > (reference_length*1.5) and response_length >50:
|
| 158 |
+
length_penalty_factor = max(0.8,1.0-(response_length-(reference_length*2))/100.0)
|
| 159 |
+
elif response_length < (reference_length*0.5) and response_length >20:
|
| 160 |
+
length_penalty_factor = max(0.5,response_length/(reference_length*0.5))
|
| 161 |
+
|
| 162 |
+
adjusted_relevance = similarity * length_penalty_factor
|
| 163 |
+
logger.info(
|
| 164 |
+
f"Relevance score for query '{query}': {adjusted_relevance:.2f} (similarity: {similarity:.2f}, penalty: {length_penalty_factor:.2f})")
|
| 165 |
+
return adjusted_relevance
|
| 166 |
+
except Exception as ex:
|
| 167 |
+
logger.error(f"Exception in relevance score calculation: {ex}")
|
| 168 |
+
traceback.print_exc()
|
| 169 |
+
return 0.0
|
| 170 |
+
|
| 171 |
+
def _calculate_bert_score(self, response: Dict[str,Any], reference: str) -> float:
|
| 172 |
+
try:
|
| 173 |
+
response_text = response.get("summary", "")
|
| 174 |
+
if not isinstance(response_text, str):
|
| 175 |
+
logger.warning("Response text is not a string")
|
| 176 |
+
return 0.0
|
| 177 |
+
if not response_text:
|
| 178 |
+
logger.error(f"No retrieved documents for BERT score evaluation")
|
| 179 |
+
return 0.0
|
| 180 |
+
if not reference.strip():
|
| 181 |
+
logger.warning(f"No reference answer provided for BERTScore calculation")
|
| 182 |
+
return 0.0
|
| 183 |
+
transformers.utils.logging.set_verbosity_error()
|
| 184 |
+
_, _, f1 = score([response_text], [reference], lang="en", model_type="roberta-large")
|
| 185 |
+
transformers.utils.logging.set_verbosity_warning()
|
| 186 |
+
bert_score_value = f1.item()
|
| 187 |
+
logger.info(f"BERT score: {bert_score_value:.2f}")
|
| 188 |
+
return bert_score_value
|
| 189 |
+
except Exception as ex:
|
| 190 |
+
logger.error(f"Exception in BERT score calculation: {ex}")
|
| 191 |
+
traceback.print_exc()
|
| 192 |
+
return 0.0
|
| 193 |
+
|
| 194 |
+
def evaluate_response(self, query: str, response: Dict[str,Any], retrieved_docs: List[Dict[str, Any]],
|
| 195 |
+
expected_keywords: List[str] = None, expected_answer: str = None) -> Dict[str, Any]:
|
| 196 |
+
try:
|
| 197 |
+
cosine_score = self._calculate_retrieval_relevance(query, retrieved_docs,expected_answer=expected_answer)
|
| 198 |
+
groundedness = self._calculate_response_groundedness(response, retrieved_docs)
|
| 199 |
+
relevance = self._calculate_response_relevance(query, response,expected_answer=expected_answer)
|
| 200 |
+
bert_score = self._calculate_bert_score(response, expected_answer) if expected_answer else 0.0
|
| 201 |
+
|
| 202 |
+
observations = []
|
| 203 |
+
if groundedness < 0.6:
|
| 204 |
+
observations.append(f"Response may contain ungrounded information (groundedness score: {groundedness:.2f})")
|
| 205 |
+
else:
|
| 206 |
+
observations.append(f"Response is well grounded (score: {groundedness:.2f})")
|
| 207 |
+
|
| 208 |
+
if relevance < 0.7:
|
| 209 |
+
observations.append(f"Response may not fully address the question (relevance score: {relevance:.2f})")
|
| 210 |
+
else:
|
| 211 |
+
observations.append(f"Response is highly relevant (score: {relevance:.2f})")
|
| 212 |
+
|
| 213 |
+
if cosine_score < 0.5:
|
| 214 |
+
observations.append(f"Low similarity between query and retrieved documents")
|
| 215 |
+
|
| 216 |
+
if bert_score < 0.7 and expected_answer:
|
| 217 |
+
observations.append(f"Low BERT score, semantic mismatch with reference answer")
|
| 218 |
+
|
| 219 |
+
return {
|
| 220 |
+
"cosine_score": round(cosine_score, 2),
|
| 221 |
+
"groundedness": round(groundedness, 2),
|
| 222 |
+
"relevance": round(relevance, 2),
|
| 223 |
+
"bert_score": round(bert_score, 2),
|
| 224 |
+
"observations": "; ".join(observations)
|
| 225 |
+
}
|
| 226 |
+
|
| 227 |
+
except Exception as ex:
|
| 228 |
+
logger.error(f"Exception in calculating response scores: {ex}")
|
| 229 |
+
traceback.print_exc()
|
| 230 |
+
return {
|
| 231 |
+
"cosine_score": 0.0,
|
| 232 |
+
"groundedness": 0.0,
|
| 233 |
+
"relevance": 0.0,
|
| 234 |
+
"bert_score": 0.0,
|
| 235 |
+
"observations": f"Error in evaluation: {ex}"
|
| 236 |
+
}
|
| 237 |
+
|
| 238 |
+
def evaluate_combined_params_grid(self,
|
| 239 |
+
chunk_size_to_test: List[int],
|
| 240 |
+
chunk_overlap_to_test: List[int],
|
| 241 |
+
embedding_models_to_test: List[str],
|
| 242 |
+
vector_db_types_to_test: List[str],
|
| 243 |
+
re_ranker_model: List[str],
|
| 244 |
+
llm_model_name: str = Configuration.DEFAULT_GROQ_LLM_MODEL,
|
| 245 |
+
search_type: str = "grid",
|
| 246 |
+
n_iter: int = 50) -> Dict[str, Any]:
|
| 247 |
+
logger.info("\n--- Starting the evaluation of best parameters ---")
|
| 248 |
+
best_score = -1.0
|
| 249 |
+
best_params = {}
|
| 250 |
+
best_metrics = {}
|
| 251 |
+
results = []
|
| 252 |
+
|
| 253 |
+
param_combination = [(c_size, c_overlap, embed_model, db_type,re_ranker)
|
| 254 |
+
for c_size, c_overlap, embed_model, db_type,re_ranker in product(
|
| 255 |
+
chunk_size_to_test, chunk_overlap_to_test,
|
| 256 |
+
embedding_models_to_test, vector_db_types_to_test,re_ranker_model)
|
| 257 |
+
if c_overlap < c_size]
|
| 258 |
+
|
| 259 |
+
param_to_test = (random.sample(param_combination, min(n_iter, len(param_combination)))
|
| 260 |
+
if search_type.lower() == 'random' else param_combination)
|
| 261 |
+
logger.info(f"Testing {len(param_to_test)} {'random' if search_type.lower() == 'random' else 'all'} "
|
| 262 |
+
f"combinations out of {len(param_combination)}")
|
| 263 |
+
|
| 264 |
+
for idx, (c_size, c_overlap, embed_model, db_type,re_ranker) in enumerate(param_to_test, 1):
|
| 265 |
+
logger.info(f"\nIteration {idx}/{len(param_to_test)} \nchunk_size: {c_size} \nchunk_overlap: {c_overlap} "
|
| 266 |
+
f"\nembed_model: {embed_model} \nvector_db: {db_type}")
|
| 267 |
+
|
| 268 |
+
current_params_str = f"Chunk: {c_size}-{c_overlap}- Embed- {embed_model}- DB-{db_type}-{re_ranker}"
|
| 269 |
+
embed_model = embed_model.replace('_', '-')
|
| 270 |
+
temp_collection_name = self._sanitize_collection_name(
|
| 271 |
+
f"{Configuration.COLLECTION_NAME}-{search_type}-{c_size}-{c_overlap}-{embed_model}-{db_type}")
|
| 272 |
+
temp_db_path = os.path.join(Configuration.DATA_DIR, f"{db_type}_temp_{search_type}_{c_size}_{c_overlap}_{embed_model}_{embed_model}")
|
| 273 |
+
os.makedirs(temp_db_path, exist_ok=True)
|
| 274 |
+
|
| 275 |
+
vector_db_instance = None
|
| 276 |
+
try:
|
| 277 |
+
embedder = SentenceTransformerEmbedder(model_name=embed_model)
|
| 278 |
+
db_constructor = VECTOR_DB_CONSTRUCTORS.get(db_type.lower())
|
| 279 |
+
if not db_constructor:
|
| 280 |
+
logger.error(f"Unsupported vector DB type: {db_type}")
|
| 281 |
+
continue
|
| 282 |
+
vector_db_instance = db_constructor(embedder=embedder,
|
| 283 |
+
db_path=temp_db_path,
|
| 284 |
+
collection_name=temp_collection_name)
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
chunker_instance = PyMuPDFChunker(pdf_path=self.pdf_path,
|
| 288 |
+
chunk_size=c_size,
|
| 289 |
+
chunk_overlap=c_overlap)
|
| 290 |
+
|
| 291 |
+
llm_instance = GROQLLM(model_name=llm_model_name)
|
| 292 |
+
|
| 293 |
+
pipeline = RAGPipeline(
|
| 294 |
+
document_path=self.pdf_path,
|
| 295 |
+
chunker=chunker_instance,
|
| 296 |
+
embedder=embedder,
|
| 297 |
+
vector_db=vector_db_instance,
|
| 298 |
+
llm=llm_instance,
|
| 299 |
+
chunk_size=c_size,
|
| 300 |
+
chunk_overlap=c_overlap,
|
| 301 |
+
embedding_model_name=embed_model,
|
| 302 |
+
llm_model_name=llm_model_name,
|
| 303 |
+
db_path=temp_db_path,
|
| 304 |
+
collection_name=temp_collection_name,
|
| 305 |
+
re_ranker_model_name=re_ranker
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
pipeline.build_index()
|
| 309 |
+
if vector_db_instance.count_documents() == 0:
|
| 310 |
+
logger.warning(f"No documents foundin vector DB after build for "
|
| 311 |
+
f"{current_params_str}. skipping evaluation for this combination.")
|
| 312 |
+
continue
|
| 313 |
+
|
| 314 |
+
cosine_scores = []
|
| 315 |
+
groundedness_scores = []
|
| 316 |
+
relevance_scores = []
|
| 317 |
+
bert_scores = []
|
| 318 |
+
|
| 319 |
+
for eval_item in self.eval_queries:
|
| 320 |
+
query = eval_item['query']
|
| 321 |
+
expected_answer = eval_item.get('expected_answer_snippet', '')
|
| 322 |
+
|
| 323 |
+
expected_keywords = eval_item.get('expected_keywords', [])
|
| 324 |
+
retrieved_docs = pipeline.retrieve_raw_documents(query, k=3)
|
| 325 |
+
response = pipeline.query(query, k=3)
|
| 326 |
+
logger.debug(f"Query: {query} \n Response: {json.dumps(response,indent=2, ensure_ascii=False)}")
|
| 327 |
+
if not expected_answer.strip():
|
| 328 |
+
expected_answer = self._syntesize_raw_reference(retrieved_docs)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
eval_result = self.evaluate_response(query, response,
|
| 332 |
+
retrieved_docs,
|
| 333 |
+
expected_answer=expected_answer,
|
| 334 |
+
expected_keywords=expected_keywords)
|
| 335 |
+
|
| 336 |
+
cosine_scores.append(eval_result['cosine_score'])
|
| 337 |
+
groundedness_scores.append(eval_result['groundedness'])
|
| 338 |
+
relevance_scores.append(eval_result['relevance'])
|
| 339 |
+
bert_scores.append(eval_result['bert_score'])
|
| 340 |
+
|
| 341 |
+
average_cosine_score = round(sum(cosine_scores) / len(cosine_scores) if cosine_scores else 0.0, 2)
|
| 342 |
+
average_groundedness = round(sum(groundedness_scores) / len(groundedness_scores) if groundedness_scores else 0.0, 2)
|
| 343 |
+
average_relevance = round(sum(relevance_scores) / len(relevance_scores) if relevance_scores else 0.0, 2)
|
| 344 |
+
average_bert_score = round(sum(bert_scores) / len(bert_scores) if bert_scores and any(bert_scores) else 0.0, 2)
|
| 345 |
+
|
| 346 |
+
average_score = round((0.2* average_cosine_score +
|
| 347 |
+
0.35* average_groundedness +
|
| 348 |
+
0.35* average_relevance +
|
| 349 |
+
0.15*average_bert_score), 2)
|
| 350 |
+
|
| 351 |
+
results.append({
|
| 352 |
+
"iteration": idx,
|
| 353 |
+
"chunk_size": c_size,
|
| 354 |
+
"chunk_overlap": c_overlap,
|
| 355 |
+
"embedding_model": embed_model,
|
| 356 |
+
"vector_db_type": db_type,
|
| 357 |
+
"average_retrieval_relevance_score": average_score,
|
| 358 |
+
"average_cosine_score": average_cosine_score,
|
| 359 |
+
"average_groundedness": average_groundedness,
|
| 360 |
+
"average_relevance": average_relevance,
|
| 361 |
+
"average_bert_score": average_bert_score
|
| 362 |
+
})
|
| 363 |
+
|
| 364 |
+
logger.info(f"Average retrieval relevance score: {average_score}")
|
| 365 |
+
logger.info(f"Average cosine score: {average_cosine_score}")
|
| 366 |
+
logger.info(f"Average groundedness: {average_groundedness}")
|
| 367 |
+
logger.info(f"Average relevance: {average_relevance}")
|
| 368 |
+
logger.info(f"Average BERT score: {average_bert_score}")
|
| 369 |
+
|
| 370 |
+
logger.info(f"Comparing the average score {average_score} with best score {best_score}")
|
| 371 |
+
if average_score > best_score:
|
| 372 |
+
best_score = average_score
|
| 373 |
+
best_params = {
|
| 374 |
+
"iteration":idx,
|
| 375 |
+
"chunk_size": c_size,
|
| 376 |
+
"chunk_overlap": c_overlap,
|
| 377 |
+
"embedding_model": embed_model,
|
| 378 |
+
"vector_db_type": db_type,
|
| 379 |
+
"re_ranker_model":re_ranker
|
| 380 |
+
}
|
| 381 |
+
best_metrics = {
|
| 382 |
+
"average_retrieval_relevance_score": average_score,
|
| 383 |
+
"average_cosine_score": average_cosine_score,
|
| 384 |
+
"average_groundedness": average_groundedness,
|
| 385 |
+
"average_relevance": average_relevance,
|
| 386 |
+
"average_bert_score": average_bert_score,
|
| 387 |
+
"re_ranker_model":re_ranker
|
| 388 |
+
}
|
| 389 |
+
|
| 390 |
+
logger.info(f"Best score: {best_score}")
|
| 391 |
+
logger.info(f"Best params: {best_params}")
|
| 392 |
+
|
| 393 |
+
except (PineconeException, ValueError) as ex:
|
| 394 |
+
logger.error(f"Exception in grid search for chunk_size={c_size}, "
|
| 395 |
+
f"chunk_overlap={c_overlap}, embed_model={embed_model}, vector_db={db_type}: {ex}")
|
| 396 |
+
traceback.print_exc()
|
| 397 |
+
finally:
|
| 398 |
+
if 'vector_db_instance' in locals() and vector_db_instance is not None:
|
| 399 |
+
try:
|
| 400 |
+
vector_db_instance.delete_collection(temp_collection_name)
|
| 401 |
+
if isinstance(vector_db_instance, chromaDBVectorDB):
|
| 402 |
+
del vector_db_instance.client
|
| 403 |
+
time.sleep(5)
|
| 404 |
+
if isinstance(vector_db_instance, PineconeVectorDB):
|
| 405 |
+
pc = Pinecone(api_key=Configuration.PINECONE_API_KEY)
|
| 406 |
+
if temp_collection_name in pc.list_indexes().names():
|
| 407 |
+
pc.delete_index(temp_collection_name)
|
| 408 |
+
logger.info(f"Pinecone index: {temp_collection_name} deleted")
|
| 409 |
+
except Exception as cleanup_ex:
|
| 410 |
+
logger.error(f"Error during cleanup: {cleanup_ex}")
|
| 411 |
+
traceback.print_exc()
|
| 412 |
+
|
| 413 |
+
logger.info("\n---- Evaluation completed ----")
|
| 414 |
+
logger.info(f"Best parameters: {best_params}")
|
| 415 |
+
logger.info(f"Best score: {best_score:.2f}")
|
| 416 |
+
logger.info(f"Total iterations evaluated: {len(results)}")
|
| 417 |
+
|
| 418 |
+
pkl_directory = os.path.join(Configuration.DATA_DIR, "eval_results")
|
| 419 |
+
os.makedirs(pkl_directory, exist_ok=True)
|
| 420 |
+
timestamp = datetime.now().strftime("%Y%m%d_%H%M")
|
| 421 |
+
pkl_file = os.path.join(pkl_directory, f"eval_results_{search_type}_{timestamp}.pkl")
|
| 422 |
+
|
| 423 |
+
try:
|
| 424 |
+
joblib.dump({
|
| 425 |
+
"best_params": best_params,
|
| 426 |
+
"best_score": best_score,
|
| 427 |
+
"results": results,
|
| 428 |
+
"best_metrics": best_metrics
|
| 429 |
+
}, pkl_file)
|
| 430 |
+
logger.info(f"Results saved to {pkl_file}")
|
| 431 |
+
except Exception as ex:
|
| 432 |
+
logger.error(f"Exception in saving results to {pkl_file}: {ex}")
|
| 433 |
+
|
| 434 |
+
return {
|
| 435 |
+
"best_params": best_params,
|
| 436 |
+
"best_score": best_score,
|
| 437 |
+
"results": results,
|
| 438 |
+
"best_metrics": best_metrics,
|
| 439 |
+
"pkl_file": pkl_file
|
| 440 |
+
}
|
| 441 |
+
|
| 442 |
+
|
| 443 |
+
def _evaluate_with_llm(self, query:str, response_summary: str,retrieved_docs_contents: List[str]) -> Dict[str,Any]:
|
| 444 |
+
try:
|
| 445 |
+
context = "\n".join(retrieved_docs_contents)
|
| 446 |
+
system_message = f"""
|
| 447 |
+
You are an expert, Impartial judge for evaluating Retrieval Augmented Generation (RAG) system message.
|
| 448 |
+
Your ONLY task is to output a JSON object containing the evaluation score and reasoning.
|
| 449 |
+
DO NOT include any other text, explanations, conversational remarks or markdown code blocks(```json).
|
| 450 |
+
Strictly adhere to the requested JSON format.
|
| 451 |
+
"""
|
| 452 |
+
prompt = f"""
|
| 453 |
+
You are evaluating a RAG system's response.
|
| 454 |
+
|
| 455 |
+
Query: "{query}"
|
| 456 |
+
Retrieved context: --- {context} ---
|
| 457 |
+
RAG Systems Response: "{response_summary}"
|
| 458 |
+
|
| 459 |
+
please provide a score for Groundedness and Relevance on a scale of 1 to 5,
|
| 460 |
+
where 5 is excellent and 1 is very poor.
|
| 461 |
+
|
| 462 |
+
Groundedness: is how the response is supported *only* to the Retrieved context with no hallucinations ?
|
| 463 |
+
1. Contains significant information not supported by the context or contradicts it.
|
| 464 |
+
2. Contains some unsupported information.
|
| 465 |
+
3. Mostly grounded, but might have minor deviations or additions.
|
| 466 |
+
4. Almost entirely grounded in the context.
|
| 467 |
+
5. Fully and accurately, using only information from the context.
|
| 468 |
+
|
| 469 |
+
Relevance: how well the response is directly and comprehensively answers the query based on the context ?
|
| 470 |
+
1. Does not answer the query at all, or answers a different question.
|
| 471 |
+
2. Addresses the query partially but misses significant part or is off-topic.
|
| 472 |
+
3. Answer the query reasonably well, but could be more complete or focused.
|
| 473 |
+
4. Answer the query well, covering most relevant aspects.
|
| 474 |
+
5. Answer the query completely, accurately and concisely, directly addressing all aspects.
|
| 475 |
+
|
| 476 |
+
output your assessment ONLY in the following JSON format. no other text.
|
| 477 |
+
{{
|
| 478 |
+
"groundedness_score": <int> out of 5,
|
| 479 |
+
"relevance_score": <int> out of 5,
|
| 480 |
+
"reasoning": "Brief explanation for the scores."
|
| 481 |
+
}}
|
| 482 |
+
"""
|
| 483 |
+
|
| 484 |
+
#llm_response = GROQLLM.generate_response(prompt=prompt,temperature=0.1,top_p=0.95,max_tokens=1500)
|
| 485 |
+
eval_with_llm = GROQLLM(model_name="llama-3.3-70b-versatile")
|
| 486 |
+
llm_response = eval_with_llm.generate_response(
|
| 487 |
+
prompt=prompt,
|
| 488 |
+
system_message=system_message,
|
| 489 |
+
temperature=0.1,
|
| 490 |
+
top_p=0.95,
|
| 491 |
+
max_tokens=1500)
|
| 492 |
+
|
| 493 |
+
print("\n -- LLM Judge Raw Response --")
|
| 494 |
+
print(llm_response)
|
| 495 |
+
print('-'*50)
|
| 496 |
+
|
| 497 |
+
eval_scores = json.loads(llm_response)
|
| 498 |
+
return {
|
| 499 |
+
"Groundedness score": eval_scores.get("groundedness_score",0),
|
| 500 |
+
"Relevance score": eval_scores.get("relevance_score",0),
|
| 501 |
+
"Reasoning": eval_scores.get("reasoning","")
|
| 502 |
+
}
|
| 503 |
+
except Exception as ex:
|
| 504 |
+
traceback.print_exc()
|
| 505 |
+
#return ""
|
| 506 |
+
return {
|
| 507 |
+
"Groundedness score": 0,
|
| 508 |
+
"Relevance score": 0,
|
| 509 |
+
"Reasoning": f"Exception in LLM evaluation: {ex}"
|
| 510 |
+
}
|
| 511 |
+
|
| 512 |
+
|
| 513 |
+
def _syntesize_raw_reference(self, retrieved_docs: List[Dict[str, Any]]) -> str:
|
| 514 |
+
try:
|
| 515 |
+
if not retrieved_docs:
|
| 516 |
+
return ""
|
| 517 |
+
raw_content_snippets = [doc.get('content','') for doc in retrieved_docs if doc.get('content')]
|
| 518 |
+
raw_answer = " ".join(raw_content_snippets)
|
| 519 |
+
|
| 520 |
+
raw_answer = " ".join(raw_answer.split()).join(sorted(list(set(raw_answer.split()))))
|
| 521 |
+
|
| 522 |
+
return raw_answer
|
| 523 |
+
except Exception as ex:
|
| 524 |
+
traceback.print_exc()
|
| 525 |
+
return ""
|
rag_scripts/interfaces.py
ADDED
|
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from typing import List, Dict, Tuple, Any
|
| 3 |
+
|
| 4 |
+
class IDocumentChunker(ABC):
|
| 5 |
+
@abstractmethod
|
| 6 |
+
def hash_document(self) -> str:
|
| 7 |
+
"""This function will generate a unique hash for the document"""
|
| 8 |
+
pass
|
| 9 |
+
|
| 10 |
+
@abstractmethod
|
| 11 |
+
def chunk_documents(self) -> Tuple[str,List[Dict[str,Any]]]:
|
| 12 |
+
'''
|
| 13 |
+
This function will be chunking the document
|
| 14 |
+
return: Tuple[str, List[Dict[str,Any]]]
|
| 15 |
+
this will return the hashed Document ID
|
| 16 |
+
and the dict which contains the content and metadata
|
| 17 |
+
'''
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class IEmbedder(ABC):
|
| 23 |
+
@abstractmethod
|
| 24 |
+
def embed_texts(self,texts: List[str]) -> List[List[float]]:
|
| 25 |
+
'''
|
| 26 |
+
This fucntion generates the embeddings for a list of text strings
|
| 27 |
+
Return: List[List[float]]: A list of embedding vectors
|
| 28 |
+
'''
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
@abstractmethod
|
| 32 |
+
def embed_query(self, query: str) -> List[float]:
|
| 33 |
+
'''
|
| 34 |
+
This function generates a embedding for a single query string
|
| 35 |
+
returns: List[float]: A single embedding vector
|
| 36 |
+
'''
|
| 37 |
+
pass
|
| 38 |
+
|
| 39 |
+
class IVector(ABC):
|
| 40 |
+
def __init__(self, embedder: IEmbedder, db_path: str, collection_name: str):
|
| 41 |
+
self.embedder = embedder
|
| 42 |
+
self.db_path = db_path
|
| 43 |
+
self.collection_name = collection_name
|
| 44 |
+
|
| 45 |
+
@abstractmethod
|
| 46 |
+
def add_chunks(self, documents: List[Dict[str,Any]]) -> List[str]:
|
| 47 |
+
'''
|
| 48 |
+
This function add the list of documents chunk to the vector database
|
| 49 |
+
Argument:
|
| 50 |
+
documents: List of chunk dictionaries which contains content and metadata
|
| 51 |
+
Returns:
|
| 52 |
+
List[str]: List of Chunk id
|
| 53 |
+
'''
|
| 54 |
+
pass
|
| 55 |
+
|
| 56 |
+
@abstractmethod
|
| 57 |
+
def get_document_hash_ids(self, document_hash: str) -> List[str]:
|
| 58 |
+
|
| 59 |
+
'''
|
| 60 |
+
This function check the document is already exist based ont he hash id
|
| 61 |
+
retunr: List[str]: List of chunk IDs associated with the document hash or empty list if not found
|
| 62 |
+
'''
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
@abstractmethod
|
| 66 |
+
def search(self, query: str, k: int=3) -> List[Dict[str,Any]]:
|
| 67 |
+
'''
|
| 68 |
+
This function searches the vector database for relevant documents based on a query.
|
| 69 |
+
Argument:
|
| 70 |
+
query: user query string
|
| 71 |
+
k: The number of top-k results to retrieve
|
| 72 |
+
returns:
|
| 73 |
+
List[Dict]: A List of retrieved document chunk, including content, metadata and distance
|
| 74 |
+
|
| 75 |
+
'''
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
@abstractmethod
|
| 79 |
+
def delete_collection(self, collection_name: str):
|
| 80 |
+
'''
|
| 81 |
+
Deletes a sepcified colelction or index from the vector database.
|
| 82 |
+
useful for clean up during evaluations
|
| 83 |
+
'''
|
| 84 |
+
pass
|
| 85 |
+
|
| 86 |
+
@abstractmethod
|
| 87 |
+
def count_documents(self) -> int:
|
| 88 |
+
'''
|
| 89 |
+
Returns the number of documents/chunks in the collection
|
| 90 |
+
'''
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class ILLM(ABC):
|
| 95 |
+
@abstractmethod
|
| 96 |
+
def generate_response(self, prompt: str, system_message: str = None) -> str:
|
| 97 |
+
'''
|
| 98 |
+
Generates a response from the LLM based on a prompt
|
| 99 |
+
|
| 100 |
+
'''
|
| 101 |
+
pass
|
| 102 |
+
|
| 103 |
+
class IRAGPipeline(ABC):
|
| 104 |
+
@abstractmethod
|
| 105 |
+
def build_index(self):
|
| 106 |
+
'''
|
| 107 |
+
Builds the Rag Pipeline for documents index
|
| 108 |
+
'''
|
| 109 |
+
pass
|
| 110 |
+
|
| 111 |
+
@abstractmethod
|
| 112 |
+
def query(self, user_query: str) -> str:
|
| 113 |
+
'''
|
| 114 |
+
Quering the RAG pipeline
|
| 115 |
+
|
| 116 |
+
'''
|
| 117 |
+
pass
|
| 118 |
+
|
| 119 |
+
@abstractmethod
|
| 120 |
+
def retrieve_raw_documents(self, user_query: str) -> List[Dict]:
|
| 121 |
+
'''
|
| 122 |
+
Retrives the raw documents without LLM generation
|
| 123 |
+
'''
|
| 124 |
+
pass
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
|
rag_scripts/llm/__pycache__/llmResponse.cpython-312.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
rag_scripts/llm/llmResponse.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import traceback
|
| 3 |
+
from typing import List, Dict, Any, Optional
|
| 4 |
+
from groq import Groq
|
| 5 |
+
|
| 6 |
+
from rag_scripts.interfaces import ILLM
|
| 7 |
+
from configuration import Configuration
|
| 8 |
+
|
| 9 |
+
class GROQLLM(ILLM):
|
| 10 |
+
def __init__(self, api_key: str = Configuration.GROQ_API_KEY,
|
| 11 |
+
model_name: str = Configuration.DEFAULT_GROQ_LLM_MODEL):
|
| 12 |
+
if not api_key:
|
| 13 |
+
raise ValueError("Groq API not provided in env file")
|
| 14 |
+
|
| 15 |
+
self.api_key = api_key
|
| 16 |
+
self.model_name = model_name
|
| 17 |
+
self.client = Groq(api_key=self.api_key)
|
| 18 |
+
self.default_system_message = ("You are a helpful assistant for Flykite Airline HT Policy document queries."
|
| 19 |
+
"provide a concise and accurate answers based strictly on the provided context."
|
| 20 |
+
"Do not hallucinate or add ungrounded details.")
|
| 21 |
+
|
| 22 |
+
print(f"Initialized Groq LLM with model: {self.model_name}")
|
| 23 |
+
|
| 24 |
+
def generate_response(self, prompt, system_message: Optional[str] = None,
|
| 25 |
+
context: Optional[List[Dict]] = None,
|
| 26 |
+
temperature: float = 0.1,
|
| 27 |
+
top_p: float = 0.95,
|
| 28 |
+
max_tokens: int = 1000) -> str:
|
| 29 |
+
try:
|
| 30 |
+
complete_prompt = prompt
|
| 31 |
+
if context:
|
| 32 |
+
context_text ="\n".join([doc['content'] for doc in context])
|
| 33 |
+
complete_prompt = f"Context:\n{context_text}\n\nQuestion: {prompt}"
|
| 34 |
+
|
| 35 |
+
message = [{"role": "system", "content":system_message or self.default_system_message},
|
| 36 |
+
{"role":"user","content":complete_prompt} ]
|
| 37 |
+
|
| 38 |
+
completion = self.client.chat.completions.create(
|
| 39 |
+
model = self.model_name,
|
| 40 |
+
messages= message,
|
| 41 |
+
temperature=temperature if temperature is not None else 0.1,
|
| 42 |
+
max_tokens=max_tokens if max_tokens is not None else 1000,
|
| 43 |
+
top_p=top_p if top_p is not None else 0.95,
|
| 44 |
+
stream=True, stop=None
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
response_content = ""
|
| 48 |
+
for chunk in completion:
|
| 49 |
+
if chunk.choices and chunk.choices[0].delta.content:
|
| 50 |
+
response_content+=chunk.choices[0].delta.content
|
| 51 |
+
return response_content.strip()
|
| 52 |
+
|
| 53 |
+
except Exception as ex:
|
| 54 |
+
print(f"Exception in LLM response: {ex}")
|
| 55 |
+
traceback.print_exc()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def llm_judge(selfself, prompt:str, **kwargs) -> Dict[str,Any]:
|
| 59 |
+
try:
|
| 60 |
+
pass
|
| 61 |
+
except Exception as ex:
|
| 62 |
+
traceback.print_exc()
|
| 63 |
+
|
rag_scripts/rag_pipeline.py
ADDED
|
@@ -0,0 +1,206 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import traceback
|
| 4 |
+
import math
|
| 5 |
+
from typing import List, Dict, Any
|
| 6 |
+
|
| 7 |
+
from configuration import Configuration
|
| 8 |
+
from rag_scripts.interfaces import IDocumentChunker, IEmbedder, IVector, ILLM, IRAGPipeline
|
| 9 |
+
from sentence_transformers import CrossEncoder
|
| 10 |
+
from rag_scripts.documents_processing.chunking import PyMuPDFChunker
|
| 11 |
+
from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
|
| 12 |
+
from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
|
| 13 |
+
from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
|
| 14 |
+
from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
|
| 15 |
+
from rag_scripts.llm.llmResponse import GROQLLM
|
| 16 |
+
|
| 17 |
+
class RAGPipeline(IRAGPipeline):
|
| 18 |
+
|
| 19 |
+
def __init__(self,
|
| 20 |
+
document_path: str = Configuration.FULL_PDF_PATH,
|
| 21 |
+
chunker: IDocumentChunker = None,
|
| 22 |
+
embedder: IEmbedder = None,
|
| 23 |
+
vector_db: IVector = None,
|
| 24 |
+
llm: ILLM = None,
|
| 25 |
+
chunk_size: int = Configuration.DEFAULT_CHUNK_SIZE,
|
| 26 |
+
chunk_overlap: int = Configuration.DEFAULT_CHUNK_OVERLAP,
|
| 27 |
+
embedding_model_name: str = Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL,
|
| 28 |
+
llm_model_name: str = Configuration.DEFAULT_GROQ_LLM_MODEL,
|
| 29 |
+
vector_db_type: str = "chroma",
|
| 30 |
+
db_path: str = None,
|
| 31 |
+
collection_name: str = None,
|
| 32 |
+
re_ranker_model_name: str = Configuration.DEFAULT_RERANKER
|
| 33 |
+
):
|
| 34 |
+
self.document_path = document_path
|
| 35 |
+
|
| 36 |
+
self.chunker = chunker if chunker else PyMuPDFChunker(
|
| 37 |
+
pdf_path = self.document_path,
|
| 38 |
+
chunk_size=chunk_size,
|
| 39 |
+
chunk_overlap=chunk_overlap
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
self.embedder = embedder if embedder else SentenceTransformerEmbedder(model_name = embedding_model_name)
|
| 43 |
+
|
| 44 |
+
if vector_db:
|
| 45 |
+
self.vector_db = vector_db
|
| 46 |
+
|
| 47 |
+
else:
|
| 48 |
+
if not isinstance(vector_db_type, str):
|
| 49 |
+
raise ValueError("vector db type must be string")
|
| 50 |
+
db_path = db_path or (
|
| 51 |
+
Configuration.CHROMA_DB_PATH if vector_db_type.lower() == 'chroma'
|
| 52 |
+
else
|
| 53 |
+
Configuration.FAISS_DB_PATH if vector_db_type.lower() == 'faiss'
|
| 54 |
+
else "" )
|
| 55 |
+
collection_name = collection_name or Configuration.COLLECTION_NAME
|
| 56 |
+
|
| 57 |
+
if vector_db_type.lower() == "chroma":
|
| 58 |
+
self.vector_db = chromaDBVectorDB(
|
| 59 |
+
embedder = self.embedder,
|
| 60 |
+
db_path=db_path,
|
| 61 |
+
collection_name=collection_name
|
| 62 |
+
)
|
| 63 |
+
elif vector_db_type.lower() == "faiss":
|
| 64 |
+
self.vector_db = FAISSVectorDB(
|
| 65 |
+
embedder=self.embedder,
|
| 66 |
+
db_path=db_path,
|
| 67 |
+
collection_name=collection_name
|
| 68 |
+
)
|
| 69 |
+
elif vector_db_type.lower() == "pinecone":
|
| 70 |
+
self.vector_db = PineconeVectorDB(
|
| 71 |
+
embedder=self.embedder,
|
| 72 |
+
db_path=db_path,
|
| 73 |
+
collection_name=collection_name
|
| 74 |
+
)
|
| 75 |
+
else:
|
| 76 |
+
raise ValueError("RAG application suppots chroma or faiss db")
|
| 77 |
+
|
| 78 |
+
self.llm = llm if llm else GROQLLM(
|
| 79 |
+
api_key= Configuration.GROQ_API_KEY,
|
| 80 |
+
model_name=llm_model_name )
|
| 81 |
+
|
| 82 |
+
self.re_ranker = None
|
| 83 |
+
if re_ranker_model_name:
|
| 84 |
+
try:
|
| 85 |
+
self.re_ranker = CrossEncoder(re_ranker_model_name=re_ranker_model_name)
|
| 86 |
+
print(f"ReRanker model loaded: {re_ranker_model_name}")
|
| 87 |
+
except Exception as ex:
|
| 88 |
+
print(f"ReRanker model could not be loaded: {re_ranker_model_name}")
|
| 89 |
+
self.re_ranker = None
|
| 90 |
+
|
| 91 |
+
print("RAG pipeline initialized")
|
| 92 |
+
|
| 93 |
+
def build_index(self):
|
| 94 |
+
print(f"building index for {self.document_path}")
|
| 95 |
+
try:
|
| 96 |
+
document_hash = self.chunker.hash_document()
|
| 97 |
+
print(f"{document_hash}")
|
| 98 |
+
existing_chunk_ids = self.vector_db.get_document_hash_ids(document_hash)
|
| 99 |
+
|
| 100 |
+
if existing_chunk_ids:
|
| 101 |
+
print(f"Documents {self.document_path} hash: {document_hash[:8]} already present in the vector DB with {len(existing_chunk_ids)} chunks.")
|
| 102 |
+
return
|
| 103 |
+
|
| 104 |
+
print(f"Chunking starting")
|
| 105 |
+
document_hash, chunks = self.chunker.chunk_documents()
|
| 106 |
+
|
| 107 |
+
if not chunks:
|
| 108 |
+
print(f"No chunks generated for {self.document_path} index not built")
|
| 109 |
+
return
|
| 110 |
+
|
| 111 |
+
self.vector_db.add_chunks(chunks)
|
| 112 |
+
|
| 113 |
+
print(f"Index built successfully for the {self.document_path} with {len(chunks)}")
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
except FileNotFoundError:
|
| 117 |
+
print(f"Exception Document not at {self.document_path}")
|
| 118 |
+
traceback.print_exc()
|
| 119 |
+
except Exception as ex:
|
| 120 |
+
print(f"Exception in build index {ex}")
|
| 121 |
+
traceback.print_exc()
|
| 122 |
+
|
| 123 |
+
def retrieve_raw_documents(self, user_query: str, k: int =5) -> List[Dict[str,Any]]:
|
| 124 |
+
if self.vector_db.count_documents() == 0:
|
| 125 |
+
print("vector database is empty please build index first")
|
| 126 |
+
return []
|
| 127 |
+
#query_embedding = self.embedder.embed_query(query=user_query)
|
| 128 |
+
query_embedding = user_query
|
| 129 |
+
initial_retrieval_k = k*3 if self.re_ranker else k
|
| 130 |
+
retrieved_docs_with_score = self.vector_db.search(query_embedding, k=initial_retrieval_k)
|
| 131 |
+
if not retrieved_docs_with_score:
|
| 132 |
+
print(f"Retrieval failed for {user_query}")
|
| 133 |
+
return []
|
| 134 |
+
retrieved_docs = [doc for doc in retrieved_docs_with_score]
|
| 135 |
+
if self.re_ranker:
|
| 136 |
+
document_content = [doc['content'] for doc in retrieved_docs]
|
| 137 |
+
rerank_score = self.re_ranker.rerank(document_content)
|
| 138 |
+
doc_with_reRank = []
|
| 139 |
+
for idx, doc in enumerate(retrieved_docs):
|
| 140 |
+
doc_with_reRank.append({
|
| 141 |
+
'doc': doc,
|
| 142 |
+
'rerank_score': rerank_score[idx]
|
| 143 |
+
})
|
| 144 |
+
ranked_docs = sorted(doc_with_reRank, key=lambda x: x['rerank_score'], reverse=True)
|
| 145 |
+
final_retrieved_docs = [item['doc'] for item in ranked_docs[:k]]
|
| 146 |
+
else:
|
| 147 |
+
final_retrieved_docs = retrieved_docs
|
| 148 |
+
return final_retrieved_docs
|
| 149 |
+
|
| 150 |
+
def query(self, user_query: str, k: int=3,
|
| 151 |
+
include_metadata: bool = True,
|
| 152 |
+
user_context: Dict[str,Any]=None) -> Dict[str,Any]:
|
| 153 |
+
if not user_query.strip():
|
| 154 |
+
return {"summary": "Enter the query", "sources": []}
|
| 155 |
+
|
| 156 |
+
retrieved_docs = self.retrieve_raw_documents(user_query,k)
|
| 157 |
+
|
| 158 |
+
if not retrieved_docs:
|
| 159 |
+
return {"summary": "Unable to find relevant information in the documents for the query asked. "
|
| 160 |
+
"Please refer contact HR department directly or refer HR policy document",
|
| 161 |
+
"sources": []}
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
context_info =[]
|
| 166 |
+
metadata_info = []
|
| 167 |
+
|
| 168 |
+
for indx, doc in enumerate(retrieved_docs):
|
| 169 |
+
context_info.append(f"Document {indx+1} content: {doc['content']}")
|
| 170 |
+
if include_metadata:
|
| 171 |
+
metadata = doc['metadata']
|
| 172 |
+
metadata_info.append({
|
| 173 |
+
"document_id": f"DOC {indx+1}",
|
| 174 |
+
"page": str(metadata.get("page_number","NA")),
|
| 175 |
+
"section": metadata.get("section","NA"),
|
| 176 |
+
"clause": metadata.get("clause","NA") })
|
| 177 |
+
|
| 178 |
+
context_string = "\n".join(context_info)
|
| 179 |
+
user_context = user_context or {"role":"general", "location":"chennai","department":"unknown"}
|
| 180 |
+
context_description = (f" for a {user_context['role']} in"
|
| 181 |
+
f"{user_context['location']} and {user_context['department']}")
|
| 182 |
+
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
prompt = (
|
| 186 |
+
f"You are an expert assistant for Flykite Airlines HR Policy queries. "
|
| 187 |
+
f"Answer the question '{user_query}' based solely on the provided context from the Flykite Airlines HR Policy, "
|
| 188 |
+
f"Tailor the answer for a {user_context['role']} in {user_context['location']} and {user_context['department']}. "
|
| 189 |
+
f"Include only the criteria and details that directly address the question, ensuring all relevant points from the context are covered without adding unrelated information or assumptions. "
|
| 190 |
+
f"Present the answer in a concise format using bullet points, a table, or sections for readability, and cite specific sections and clauses from the sources where applicable. "
|
| 191 |
+
f"Cite specific sections and clauses from the sources using the format (soruce: DOC X, Page: Y, Section: Z, Clause: A) at the end of each relevant point or page"
|
| 192 |
+
f"If the query is ambiguous, ask for clarification. If the context does not fully address the question, state what is known and suggest consulting the full HR Policy or HR department. "
|
| 193 |
+
f"Context: \n{context_string}\n\n"
|
| 194 |
+
f"Answer: " )
|
| 195 |
+
|
| 196 |
+
llm_response = self.llm.generate_response(prompt)
|
| 197 |
+
final_response = {"summary": llm_response.strip(),
|
| 198 |
+
"sources":metadata_info if include_metadata else []
|
| 199 |
+
|
| 200 |
+
}
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
return final_response
|
| 205 |
+
|
| 206 |
+
|
requirements.txt
ADDED
|
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langchain-text-splitters
|
| 2 |
+
sentence-transformers
|
| 3 |
+
chromadb
|
| 4 |
+
pymupdf
|
| 5 |
+
regex
|
| 6 |
+
groq
|
| 7 |
+
streamlit
|
| 8 |
+
huggingface_hub
|
| 9 |
+
faiss-cpu
|
| 10 |
+
pinecone
|
| 11 |
+
bert-score
|
users.db
ADDED
|
Binary file (12.3 kB). View file
|
|
|