File size: 50,668 Bytes
ebc6cc2
afe32f0
ebc6cc2
afe32f0
 
 
adebd02
afe32f0
ebc6cc2
afe32f0
 
 
ebc6cc2
adebd02
afe32f0
 
 
ebc6cc2
 
afe32f0
ebc6cc2
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
ebc6cc2
afe32f0
 
 
 
 
adebd02
afe32f0
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
ebc6cc2
afe32f0
 
 
 
 
 
 
 
adebd02
afe32f0
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
ebc6cc2
afe32f0
adebd02
ebc6cc2
afe32f0
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
adebd02
afe32f0
 
 
 
adebd02
 
afe32f0
 
 
ebc6cc2
adebd02
afe32f0
 
 
 
 
 
 
 
adebd02
afe32f0
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
adebd02
afe32f0
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebc6cc2
afe32f0
ebc6cc2
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebc6cc2
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebc6cc2
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
adebd02
afe32f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ebc6cc2
afe32f0
 
 
 
 
 
 
ebc6cc2
 
afe32f0
ebc6cc2
 
 
 
 
 
 
afe32f0
 
ebc6cc2
afe32f0
 
ebc6cc2
 
 
afe32f0
ebc6cc2
 
 
afe32f0
ebc6cc2
afe32f0
ebc6cc2
afe32f0
 
 
ebc6cc2
 
afe32f0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
import os
import sys
import json
import streamlit as st
import warnings
import traceback
import logs
import chromadb
import hashlib
import sqlite3
import regex as re
from pinecone import Pinecone
from typing import Optional, Dict, Any
from sentence_transformers import util

os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
warnings.filterwarnings("ignore")

sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), 'src')))
from sentence_transformers import SentenceTransformer
from configuration import Configuration
from rag_scripts.rag_pipeline import RAGPipeline
from rag_scripts.documents_processing.chunking import PyMuPDFChunker
from rag_scripts.embedding.embedder import SentenceTransformerEmbedder
from rag_scripts.embedding.vector_db.chroma_db import chromaDBVectorDB
from rag_scripts.embedding.vector_db.faiss_db import FAISSVectorDB
from rag_scripts.embedding.vector_db.pinecone_db import PineconeVectorDB
from rag_scripts.llm.llmResponse import GROQLLM
from rag_scripts.evaluation.evaluator import RAGEvaluator


class RAGOperations:
    VALID_VECTOR_DB = {'chroma', 'faiss', 'pinecone'}

    @staticmethod
    def check_db(vector_db_type: str, db_path: str, collection_name: str) -> bool:
        try:
            if vector_db_type not in RAGOperations.VALID_VECTOR_DB:
                logs.logger.info(f"Invalid Vector DB: {vector_db_type}")
                raise
            if vector_db_type.lower() == 'pinecone':
                pc = Pinecone(api_key=Configuration.PINECONE_API_KEY)
                return collection_name in pc.list_indexes().names()
            elif vector_db_type.lower() == 'chroma':
                if not os.path.exists(db_path):
                    return False
                client = chromadb.PersistentClient(path=db_path)
                try:
                    client.get_collection(collection_name)
                    return True
                except:
                    return False


            elif vector_db_type.lower() == "faiss":
                faiss_index_file = os.path.join(db_path, f"{collection_name}.faiss")
                faiss_doc_store_file = os.path.join(db_path, f"{collection_name}_docs.pkl")
                return os.path.exists(faiss_index_file) and os.path.exists(faiss_doc_store_file)
        except Exception as ex:
            traceback.print_exc()
            logs.logger.info(f"Exception in checking {vector_db_type} existence")
            return False

    @staticmethod
    def get_pipeline_params(chunk_size: Optional[int] =None,

                            chunk_overlap: Optional[int]=None,

                            embedding_model: Optional[str]=None,

                            vector_db_type: Optional[str]=None,

                            llm_model: Optional[str] = None,

                            temperature: Optional[float] = None,

                            top_p: Optional[float] = None,

                            max_tokens: Optional[int] = None,

                            re_ranker_model: Optional[str] = None,

                            use_tuned: bool = False) -> Dict[str, Any]:
        try:
            best_param_path = os.path.join(Configuration.DATA_DIR, 'best_params.json')
            params = {
                'document_path': Configuration.FULL_PDF_PATH,
                'chunk_size': chunk_size if chunk_size is not None else Configuration.DEFAULT_CHUNK_SIZE,
                'chunk_overlap': chunk_overlap if chunk_overlap is not None else Configuration.DEFAULT_CHUNK_OVERLAP,
                'embedding_model_name': embedding_model if embedding_model is not None else Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL,
                'vector_db_type': vector_db_type if vector_db_type is not None else "chroma",
                'llm_model_name': llm_model if llm_model is not None else llm_model,
                'db_path': None,
                'collection_name': Configuration.COLLECTION_NAME,
                'vector_db': None,
                'temperature': temperature if temperature is not None else 0.1,
                'top_p': top_p if top_p is not None else .95,
                'max_tokens': max_tokens if max_tokens is not None else 1500,
                're_ranker_model': re_ranker_model if re_ranker_model is not None else Configuration.DEFAULT_RERANKER,
            }

            if use_tuned and os.path.exists(best_param_path):
                with open(best_param_path, 'rb') as f:
                    best_params = json.load(f)
                logs.logger.info(f"Best params: {best_params} from the file {best_param_path}")

                params.update({
                    'vector_db_type': best_params.get('vector_db_type', params['vector_db_type']),
                    'embedding_model_name': best_params.get('embedding_model', params['embedding_model_name']),
                    'chunk_overlap': best_params.get('chunk_overlap', params['chunk_overlap']),
                    'chunk_size': best_params.get('chunk_size', params['chunk_size']),
                    're_ranker_model': best_params.get('re_ranker_model', params['re_ranker_model'])})
                use_tuned = True

            if use_tuned:
                tuned_db_type = params['vector_db_type']
                params['db_path'] = os.path.join(Configuration.DATA_DIR, 'TunedDB',
                                                 tuned_db_type) if tuned_db_type != 'pinecone' else ""
                params['collection_name'] = 'tuned-' + Configuration.COLLECTION_NAME
                if tuned_db_type in ['chroma', 'faiss']:
                    os.makedirs(params['db_path'], exist_ok=True)
                logs.logger.info(f"Tuned db path: {params['db_path']}")
            else:
                params['db_path'] = (Configuration.CHROMA_DB_PATH if params['vector_db_type'] == 'chroma'
                                     else Configuration.FAISS_DB_PATH if params['vector_db_type'] == 'faiss'
                else "")
                if params['vector_db_type'] in ['chroma', 'faiss']:
                    os.makedirs(params['db_path'], exist_ok=True)
                    logs.logger.info(f"Created directory for {params['vector_db_type']} at {params['db_path']}")

            return params
        except Exception as ex:
            logs.logger.info(f"Exception in get_pipeline_params: {ex}")
            traceback.print_exc()

    @staticmethod
    def check_embedding_dimension(vector_db_type: str, db_path: str,

                                  collection_name: str, embedding_model: str) -> bool:
        if vector_db_type != 'chroma':
            return True
        try:
            client = chromadb.PersistentClient(path=db_path)
            collection = client.get_collection(collection_name)
            model = SentenceTransformer(embedding_model)
            sample_embedding = model.encode(["test"])[0]
            try:
                expected_dim = collection._embedding_function.dim
            except AttributeError:
                peek_result = collection.peek(limit=1)
                if 'embedding' in peek_result and peek_result['embedding']:
                    expected_dim = len(peek_result['embedding'][0])
                else:
                    return False
            actual_dim = len(sample_embedding)
            logs.logger.info(f"Expected dimension: {expected_dim} Actual dimension: {actual_dim}")
            return expected_dim == actual_dim
        except Exception as ex:
            logs.logger.info(f"Error checking embedding dimension: {ex}")
            return False

    @staticmethod
    def initialize_pipeline(params: dict[str, Any]) -> RAGPipeline:
        try:
            embedder = SentenceTransformerEmbedder(model_name=params['embedding_model_name'])
            chunkerObj = PyMuPDFChunker(
                pdf_path=params['document_path'],
                chunk_size=params['chunk_size'],
                chunk_overlap=params['chunk_overlap'])
            llm_model = params['llm_model_name']
            vector_db = None
            if params['vector_db_type'] == 'chroma':
                vector_db = chromaDBVectorDB(embedder=embedder,
                                             db_path=params['db_path'],
                                             collection_name=params['collection_name'])
            elif params['vector_db_type'] == 'faiss':
                vector_db = FAISSVectorDB(embedder=embedder,
                                          db_path=params['db_path'],
                                          collection_name=params['collection_name'])
            elif params['vector_db_type'] == 'pinecone':
                vector_db = PineconeVectorDB(embedder=embedder,
                                             db_path=params['db_path'],
                                             collection_name=params['collection_name'])
            else:
                raise ValueError(f"Unknown vector_db_type: {params['vector_db_type']}")

            return RAGPipeline(document_path=params['document_path'],
                               chunker=chunkerObj, embedder=embedder,
                               vector_db=vector_db,
                               llm=GROQLLM(model_name=llm_model),
                               re_ranker_model_name=params['re_ranker_model'] if params[
                                   're_ranker_model'] else Configuration.DEFAULT_RERANKER, )
        except Exception as ex:
            logs.logger.info(f"Exception in pipeline initialize: {ex}")
            traceback.print_exc()
            sys.exit(1)

    @staticmethod
    def run_build_job(chunk_size: Optional[int] = None,

                      chunk_overlap: Optional[int] = None,

                      embedding_model: Optional[str] = None,

                      vector_db_type: Optional[str]= None,

                      llm_model: Optional[str]= None,

                      temperature: Optional[float]= None,

                      top_p: Optional[float]= None,

                      max_tokens: Optional[int]= None,

                      re_ranker_model: Optional[str] =None,

                      use_tuned: bool = False) -> None:
        try:
            params = RAGOperations.get_pipeline_params(chunk_size=chunk_size,
                                                       chunk_overlap=chunk_overlap,
                                                       embedding_model=embedding_model,
                                                       vector_db_type=vector_db_type,
                                                       llm_model=llm_model,
                                                       temperature=temperature,
                                                       top_p=top_p,
                                                       max_tokens=max_tokens,
                                                       re_ranker_model=re_ranker_model,
                                                       use_tuned=use_tuned)

            pipeline = RAGOperations.initialize_pipeline(params)
            pipeline.build_index()
            logs.logger.info(f"RAG Build JOB completed")
        except Exception as ex:
            logs.logger.info(f"Exception in run build job: {ex}")
            traceback.print_exc()
            raise

    @staticmethod
    def run_search_job(query: Optional[str] = None,

                       k: int = 5, raw: bool = False,

                       use_tuned: bool = False,

                       llm_model: Optional[str]= None,

                       user_context: Optional[Dict[str,str]] = None,

                       temperature: Optional[float]= None,

                       top_p: Optional[float]= None,

                       max_tokens: Optional[int]= None,

                       chunk_size: Optional[int]= None,

                       chunk_overlap: Optional[int]= None,

                       embedding_model: Optional[str]= None,

                       vector_db_type: Optional[str]= None,

                       re_ranker_model: Optional[str]= None,

                       use_rag:bool = True) -> Dict[str, Any]:
        try:
            params = RAGOperations.get_pipeline_params(chunk_size=chunk_size,
                                                       chunk_overlap=chunk_overlap,
                                                       embedding_model=embedding_model,
                                                       vector_db_type=vector_db_type,
                                                       llm_model=llm_model,
                                                       temperature=temperature,
                                                       top_p=top_p,
                                                       max_tokens=max_tokens,
                                                       re_ranker_model=re_ranker_model,
                                                       use_tuned=use_tuned)
            vector_db_type = params['vector_db_type']
            db_path = params['db_path']
            collection_name = params['collection_name']

            pipeline = RAGOperations.initialize_pipeline(params)
            db_exists = RAGOperations.check_db(vector_db_type, db_path, collection_name)

            if use_rag:
                if not db_exists:
                    pipeline.build_index()
                elif pipeline.vector_db.count_documents() == 0:
                    pipeline.build_index()
                elif not RAGOperations.check_embedding_dimension(vector_db_type, db_path,
                                                                 collection_name, params['embedding_model_name']):
                    logs.logger.info(f"Embedding dimension mismatch. rebuilding the index")
                    pipeline.vector_db.delete_collection(collection_name)
                    pipeline.build_index()

                else:
                    logs.logger.info(f"Using existing {vector_db_type} database with collection: {collection_name}")

                if pipeline.vector_db.count_documents() == 0:
                    logs.logger.info(f"No Documents found in vector database after re-build")
                    sys.exit(1)

            evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
                                     pdf_path=Configuration.FULL_PDF_PATH)

            user_query = query if query else (
                input("Enter your Query: "))
            if user_query.lower() == 'exit':
                return

            expected_answers = None
            expected_keywords = []
            query_found = False
            try:
                with open(Configuration.EVAL_DATA_PATH, 'r') as f:
                    eval_data = json.load(f)
                for item in eval_data:
                    if item.get('query').strip().lower() == user_query.strip().lower():
                        expected_keywords = item.get('expected_keywords', [])
                        expected_answers = item.get('expected_answer_snippet', "")
                        query_found = True
                        break
                if not expected_keywords and not expected_answers:
                    logs.logger.info(f"No evaluation data found for query in json")
            except Exception as ex:
                logs.logger.info(f"No json file : {ex}")
            retrieved_documents = []
            if raw:
                retrieved_documents = pipeline.retrieve_raw_documents(
                    user_query, k=k * 2)
                logs.logger.info("Raw documents retrieved")
                logs.logger.info(json.dumps(retrieved_documents, indent=4))
                if not retrieved_documents:
                    response = {"summary": "No relevant documents found",
                                "sources": []}
                else:

                    query_embedding = evaluator.embedder.encode(user_query,
                                                                convert_to_tensor=True, normalize_embeddings=True)
                    similarities = [(doc, util.cos_sim(query_embedding,
                                                       evaluator.embedder.encode(doc['content'],
                                                                                 convert_to_tensor=True,
                                                                                 normalize_embeddings=True)).item())
                                    for doc in retrieved_documents]
                    similarities.sort(key=lambda x: x[1], reverse=True)

                    top_docs = similarities[:min(3, len(similarities))]

                    truncated_content = []
                    for doc, sim in top_docs:
                        content_paragraphs = re.split(r'\n\s*\n', doc['content'].strip())
                        para_sims = [(para, util.cos_sim(query_embedding,
                                                         evaluator.embedder.encode(para.strip(), convert_to_tensor=True,
                                                                                   normalize_embeddings=True)).item())
                                     for para in content_paragraphs if para.strip()]
                        para_sims.sort(key=lambda x: x[1], reverse=True)

                        top_paras = [para for para, para_sim in para_sims[:2] if para_sim >= 0.3]
                        if len(top_paras) < 1:  # Fallback to at least one paragraph
                            top_paras = [para for para, _ in para_sims[:1]]
                        truncated_content.append('\n\n'.join(top_paras))

                    response = {
                        "summary": "\n".join(truncated_content),
                        "sources": [{"document_id": f"DOC {idx + 1}",
                                     "page": str(doc['metadata'].get("page_number", "NA")),
                                     "section": doc['metadata'].get("section", "NA"),
                                     "clause": doc['metadata'].get("clause", "NA")}
                                    for idx, (doc, _) in enumerate(top_docs)]}

            else:
                logs.logger.info("LLM+RAG")
                response = pipeline.query(user_query, k=k,
                                          include_metadata=True,
                                          user_context=user_context
                                          )
                retrieved_documents = pipeline.retrieve_raw_documents(
                    user_query, k=k)

            final_expected_answer = expected_answers if expected_answers is not None else ""
            additional_eval_metrices = {}
            if not query_found:
                logs.logger.info(f"No query found in eval_Data.json: {user_query}")
                raw_reference_for_score = evaluator._syntesize_raw_reference(retrieved_documents)
                if not final_expected_answer.strip():
                    final_expected_answer = raw_reference_for_score

                retrieved_documents_content = [doc.get('content', '') for doc in retrieved_documents]
                llm_as_judge = evaluator._evaluate_with_llm(user_query,
                                        response.get('summary', ''),
                                        retrieved_documents_content)
                if llm_as_judge:
                    additional_eval_metrices.update(llm_as_judge)
                    output = {"query": user_query, "response": response, "evaluation": llm_as_judge}
                    logs.logger.info(json.dumps(output, indent=4))
                    return output
                else:
                    output = {"query": user_query, "response": response, "evaluation": llm_as_judge}
                    logs.logger.info(json.dumps(output, indent=4))
                    return output

            else:

                eval_result = evaluator.evaluate_response(user_query, response, retrieved_documents,
                                                          expected_keywords, expected_answers)
                output = {"query": user_query, "response": response, "evaluation": eval_result}
                logs.logger.info(json.dumps(output, indent=2, ensure_ascii=False))

                return output


        except Exception as ex:
            logs.logger.info(f"Exception in run search job {ex}")
            traceback.print_exc()

    @staticmethod
    def run_hypertune_job(llm_model: Optional[str] = None,

                          search_type: str = "random",

                          n_iter: int = 3) -> Dict[str,Any]:
        try:
            evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
                                     pdf_path=Configuration.FULL_PDF_PATH)

            result = evaluator.evaluate_combined_params_grid(
                chunk_size_to_test=[512, 1024, 2048],
                chunk_overlap_to_test=[100, 200, 400],
                embedding_models_to_test=["all-MiniLM-L6-v2",
                                          "all-mpnet-base-v2",
                                          "paraphrase-MiniLM-L3-v2",
                                          "multi-qa-mpnet-base-dot-v1"],
                vector_db_types_to_test=['pinecone'],
                llm_model_name=llm_model,
                re_ranker_model=["cross-encoder/ms-marco-MiniLM-L-6-v2",
                                 "cross-encoder/ms-marco-TinyBERT-L-2"],
                search_type=search_type,
                n_iter=n_iter)

            best_parameter = result['best_params']
            best_score = result['best_score']
            pkl_file = result['pkl_file']
            best_metrics = result['best_metrics']

            best_param_path = os.path.join(Configuration.DATA_DIR, 'best_params.json')

            with open(best_param_path, 'w') as f:
                json.dump(best_parameter, f, indent=4)

            tuned_db = best_parameter['vector_db_type']
            tuned_path = os.path.join(Configuration.DATA_DIR, 'TunedDB', tuned_db)
            if tuned_db != 'pinecone':
                os.makedirs(tuned_path, exist_ok=True)
            tuned_collection_name = "tuned-" + Configuration.COLLECTION_NAME

            tuned_params = {
                'document_path': Configuration.FULL_PDF_PATH,
                'chunk_size': best_parameter.get('chunk_size', Configuration.DEFAULT_CHUNK_SIZE),
                'chunk_overlap': best_parameter.get('chunk_overlap', Configuration.DEFAULT_CHUNK_OVERLAP),
                'embedding_model_name': best_parameter.get('embedding_model',
                                                           Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL),
                'vector_db_type': tuned_db,
                'llm_model_name': llm_model,
                'db_path': tuned_path if tuned_db != 'pinecone' else "",
                'collection_name': tuned_collection_name,
                'vector_db': None,
                're_ranker_model': best_parameter.get('re_ranker', Configuration.DEFAULT_RERANKER)
            }

            if 're_ranker_model' in best_parameter:
                tuned_params['re_ranker_model'] = best_parameter['re_ranker_model']
            else:
                tuned_params['re_ranker_model'] = Configuration.DEFAULT_RERANKER

            tuned_pipeline = RAGOperations.initialize_pipeline(tuned_params)
            tuned_pipeline.build_index()

            return result

        except Exception as ex:
            logs.logger.info(f"Exception in hypertune: {ex} ")
            traceback.print_exc()

    @staticmethod
    def run_llm_with_prompt(run_type: str,

                            temperature: float=0.1,

                            top_p: float=0.95,

                            max_tokens=1500) -> None:
        try:
            params = RAGOperations.get_pipeline_params()
            pipeline = RAGOperations.initialize_pipeline(params)

            evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
                                     pdf_path=Configuration.FULL_PDF_PATH)

            system_message = (
                "You are an expert assistant for Flykite Airlines HR Policy Queries."
                "Provide concise, accurate and policy-specific answers based solely on the the provided context."
                "Structured your response clearly, using bullet points, newlines if applicable. "
                "If the context lacks information, state that clearly and speculation."
            ) if run_type == 'prompting' else None

            user_query = input("Enter your query: ")
            expected_answer = None
            expected_keywords = []
            try:
                with open(Configuration.EVAL_DATA_PATH, 'r') as f:
                    eval_data = json.load(f)
                for item in eval_data:
                    expected_answer = item.get('expected_answer_snippet', "")
                    expected_keywords = item.get('expected_keywords', [])
                    break
            except Exception as ex:
                logs.logger.info(f"Error loading eval_data.json for query {user_query}: {ex}")

            if run_type == 'prompting':
                prompt = (
                    f"You are an expert assistant for Flykite Airlines HR Policy Queries."
                    f"Answer the following question with a structured response, using bullet points or sections where applicable"
                    f"Base your answer solely on the query and avoid hallucination"
                    f"Question: \n {user_query} \n"
                    f"Answer: ")

            else:
                prompt = user_query

            response = pipeline.llm.generate_response(
                prompt=prompt,
                system_message=system_message,
                temperature=temperature,
                top_p=top_p,
                max_tokens=max_tokens
            )
            retreived_documents = []

            eval_result = evaluator.evaluate_response(user_query,
                                                      response,
                                                      retreived_documents,
                                                      expected_keywords,
                                                      expected_answer)

            output = {"query": user_query,
                      "response": {
                          "summary: ": response.strip(),
                          "source: ": ["LLM Response Not RAG loaded"]},
                      "evaluation": eval_result}

            logs.logger.info(json.dumps(output, indent=2))
            return output

        except Exception as ex:
            logs.logger.info(f"Exception in LLm_prompting response: {ex}")
            traceback.print_exc()
            return {"error": str(ex)}

    @staticmethod
    def login() -> Dict[str, str]:
        username = input("Enter your username: ")
        password = input("Enter your password: ")

        hashed_password = hashlib.sha256(password.encode()).hexdigest()
        try:
            conn = sqlite3.connect('users.db')
            cursor = conn.cursor()
            cursor.execute(
                "SELECT username,jobrole,department,location FROM users WHERE username = ? AND password = ?",
                (username, hashed_password)
            )
            user = cursor.fetchone()
            logs.logger.info(f"{user}")
            conn.close()
            if user:
                return {"username": user[0], "role": user[1], "department": user[2], "location": user[3]}
            else:
                logs.logger.info("Invalid username or password")
                sys.exit(1)

        except sqlite3.Error as ex:

            return False

    @staticmethod
    def authenticate_user(username, password) -> Optional[Dict[str, str]]:
        hashed_password = hashlib.sha256(password.encode()).hexdigest()
        conn = sqlite3.connect('users.db')
        cursor = conn.cursor()
        cursor.execute(
            "SELECT username, jobrole, department, location FROM users WHERE username = ? AND password = ?",
            (username, hashed_password)
        )
        user = cursor.fetchone()
        conn.close()
        if user:
            return {"username": user[0], "role": user[1], "department": user[2], "location": user[3]}
        return None

    @staticmethod
    def home_page():
        st.title("Welcome to Flykite RAG System")

        if 'logged_in' not in st.session_state:
            st.session_state.logged_in = False
        if 'user_info' not in st.session_state:
            st.session_state.user_info = None

        if not st.session_state.logged_in:
            st.subheader("Login")
            with st.form("login_form"):
                username = st.text_input("Username")
                password = st.text_input("Password", type="password")
                login_button = st.form_submit_button("Login")

                if login_button:
                    user_data = RAGOperations.authenticate_user(username, password)
                    if user_data:
                        st.session_state.logged_in = True
                        st.session_state.user_info = user_data
                        st.session_state.user_context = {
                            "role": user_data['role'],
                            "department": user_data['department'],
                            "location": user_data['location']
                        }
                        st.success(f"Logged in as {user_data['username']} ({user_data['role']})")
                        # No rerun needed here, the main_app will handle navigation
                        st.session_state.page = "User" if user_data['role'] != 'admin' else "Admin"
                        st.rerun()
                    else:
                        st.error("Invalid username or password.")
        else:
            st.write(
                f"You are logged in as **{st.session_state.user_info['username']}** (Role: **{st.session_state.user_info['role']}**)")
            if st.button("Logout"):
                st.session_state.logged_in = False
                st.session_state.user_info = None
                st.session_state.user_context = None
                st.session_state.page = "Home"  # Redirect to home on logout
                st.rerun()

    @staticmethod
    @staticmethod
    def admin_page():
        st.title("Admin Dashboard")
        st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})")

        if st.session_state.user_info and st.session_state.user_info['role'] == 'admin':
            st.header("RAG Hypertuning")
            st.info("Run hyperparameter tuning to find the best RAG configuration and build a tuned index.")

            with st.form("hypertune_form"):
                st.write("Hypertuning parameters:")

                llm_model_ht = st.selectbox("LLM Model for Hypertuning Evaluation",
                                            options=["llama-3.3-70b-versatile", "llama-3.1-8b-instant"],
                                            index=["llama-3.3-70b-versatile", "llama-3.1-8b-instant"].index(
                                                Configuration.DEFAULT_GROQ_LLM_MODEL) if Configuration.DEFAULT_GROQ_LLM_MODEL in [
                                                "llama-3.3-70b-versatile", "llama-3.1-8b-instant"] else 0,
                                            key="llm_model_ht_select")

                # New inputs for hyperparameter tuning
                st.subheader("Hyperparameter Ranges/Options:")

                chunk_sizes = st.multiselect("Chunk Sizes to Test (e.g., 256, 512, 1024)",
                                             options=[512, 1024,2048],
                                             default=[512],
                                             key="chunk_sizes_ht")
                chunk_overlaps = st.multiselect("Chunk Overlaps to Test (e.g., 50, 100, 200)",
                                                options=[150,200,400],
                                                default=[150],
                                                key="chunk_overlaps_ht")
                embedding_models = st.multiselect("Embedding Models to Test",
                                                  options=["all-MiniLM-L6-v2", "all-mpnet-base-v2",
                                                           "paraphrase-MiniLM-L3-v2", "multi-qa-mpnet-base-dot-v1"],
                                                  default=["all-MiniLM-L6-v2", "all-mpnet-base-v2"],
                                                  key="embedding_models_ht")
                re_ranker_models = st.multiselect("Re-ranker Models to Test",
                                                  options=["cross-encoder/ms-marco-MiniLM-L-6-v2",
                                                           "cross-encoder/ms-marco-TinyBERT-L-2", "None"],
                                                  default=["cross-encoder/ms-marco-MiniLM-L-6-v2"],
                                                  key="re_ranker_models_ht")
                vector_db_types = st.multiselect("Vector DB Types to Test",
                                                 options=['chroma', 'faiss', 'pinecone'],
                                                 default=['chroma'],
                                                 key="vector_db_types_ht")

                search_type = st.radio("Hypertuning Search Type",
                                       options=["random", "grid"],
                                       index=0,  # Default to random
                                       key="search_type_ht")

                n_iter = st.number_input("Number of Hyper-tuning Iterations (for Random Search)",
                                         min_value=1, value=3, step=1,
                                         help="Only applicable for 'Random' search type.",
                                         key="n_iter_ht")

                hypertune_button = st.form_submit_button("Run Hypertune Job")

                if hypertune_button:
                    if not chunk_sizes or not chunk_overlaps or not embedding_models or not re_ranker_models or not vector_db_types:
                        st.error("Please select at least one option for all hyperparameter categories.")
                    else:
                        # Handle 'None' for re-ranker model: remove "None" string and pass None object if needed
                        final_re_ranker_models = [
                            None if model == "None" else model for model in re_ranker_models
                        ]

                        st.write("Starting RAG Hypertuning. This may take a while...")
                        with st.spinner("Running hypertuning..."):
                            try:
                                result = RAGOperations.run_hypertune_job(
                                    llm_model=llm_model_ht,
                                    chunk_size_to_test=chunk_sizes,
                                    chunk_overlap_to_test=chunk_overlaps,
                                    embedding_models_to_test=embedding_models,
                                    re_ranker_model=final_re_ranker_models,
                                    vector_db_types_to_test=vector_db_types,
                                    search_type=search_type,
                                    n_iter=n_iter if search_type == "random" else None  # n_iter only for random search
                                )
                                if result and "error" not in result:
                                    st.success("Hypertuning completed and tuned index built!")
                                    st.subheader("Best Parameters Found:")
                                    st.json(result.get('best_params', {}))
                                    if 'best_score' in result:
                                        st.write(f"Best Score: {result['best_score']:.4f}")
                                    if 'best_metrics' in result:
                                        st.subheader("Best Metrics:")
                                        st.json(result['best_metrics'])
                                else:
                                    st.error(f"Hypertuning failed: {result.get('error', 'Unknown error')}")
                            except Exception as e:
                                st.error(f"An unexpected error occurred during hypertuning: {e}")
                                st.exception(e)  # Display full traceback in Streamlit

            st.header("RAG Testing")
            st.info("Test the RAG pipeline with a specific query, optionally using the tuned database.")

            with st.form("rag_test_form"):
                test_query = st.text_area("Enter a test query for the RAG system:",
                                          value="What is the policy on annual leave?",
                                          key="test_query_input")
                use_tuned_db = st.checkbox("Use Tuned RAG Database (if hypertuned previously)", value=True,
                                           key="use_tuned_db_checkbox")
                display_raw = st.checkbox("Display Raw Retrieved Documents only (no LLM)",
                                          key="display_raw_docs_checkbox")
                k_value = st.slider("Number of documents to retrieve (k)", min_value=1, max_value=10, value=5,
                                    key="k_value_slider")

                test_rag_button = st.form_submit_button("Run RAG Test Query")

                if test_rag_button:
                    st.write("Running RAG test query...")
                    with st.spinner("Getting RAG response..."):
                        try:
                            result = RAGOperations.run_search_job(
                                query=test_query,
                                k=k_value,
                                raw=display_raw,
                                use_tuned=use_tuned_db,
                                llm_model=st.session_state.get('llm_model_ht_select',
                                                               Configuration.DEFAULT_GROQ_LLM_MODEL),
                                user_context=st.session_state.user_context
                            )
                            if result and "error" not in result:
                                st.success("RAG Test Query Completed!")
                                st.subheader("RAG Response:")
                                if display_raw:
                                    st.json(result.get('response', {}))
                                else:
                                    response_data = result.get('response', {})
                                    if 'summary' in response_data:
                                        st.write(response_data['summary'])
                                        if 'sources' in response_data and response_data['sources']:
                                            st.subheader("Sources:")
                                            for source in response_data['sources']:
                                                if isinstance(source, dict):
                                                    st.markdown(
                                                        f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}")
                                                else:
                                                    st.markdown(f"- {source}")
                                    else:
                                        st.json(response_data)
                                if 'evaluation' in result:
                                    st.subheader("Evaluation Results:")
                                    st.json(result['evaluation'])
                            else:
                                st.error(f"RAG test query failed: {result.get('error', 'Unknown error')}")
                        except Exception as e:
                            st.error(f"An unexpected error occurred during RAG test: {e}")
                            st.exception(e)
        else:
            st.warning("You do not have administrative privileges to view this page.")
            if st.button("Go to User Page"):
                st.session_state.page = "User"
                st.rerun()

    @staticmethod
    def run_hypertune_job(llm_model: Optional[str] = None,

                          chunk_size_to_test: Optional[list[int]] = None,  # Added parameter

                          chunk_overlap_to_test: Optional[list[int]] = None,  # Added parameter

                          embedding_models_to_test: Optional[list[str]] = None,  # Added parameter

                          vector_db_types_to_test: Optional[list[str]] = None,  # Added parameter

                          re_ranker_model: Optional[list[str]] = None,  # Added parameter

                          search_type: str = "random",

                          n_iter: Optional[int] = 3) -> Dict[str, Any]:
        try:
            evaluator = RAGEvaluator(eval_data_path=Configuration.EVAL_DATA_PATH,
                                     pdf_path=Configuration.FULL_PDF_PATH)

            result = evaluator.evaluate_combined_params_grid(
                chunk_size_to_test=chunk_size_to_test if chunk_size_to_test is not None else [512, 1024, 2048],
                chunk_overlap_to_test=chunk_overlap_to_test if chunk_overlap_to_test is not None else [100, 200, 400],
                embedding_models_to_test=embedding_models_to_test if embedding_models_to_test is not None else [
                    "all-MiniLM-L6-v2",
                    "all-mpnet-base-v2",
                    "paraphrase-MiniLM-L3-v2",
                    "multi-qa-mpnet-base-dot-v1"],
                vector_db_types_to_test=vector_db_types_to_test if vector_db_types_to_test is not None else ['chroma'],
                llm_model_name=llm_model,
                re_ranker_model=re_ranker_model if re_ranker_model is not None else [
                    "cross-encoder/ms-marco-MiniLM-L-6-v2",
                    "cross-encoder/ms-marco-TinyBERT-L-2"],
                search_type=search_type,
                n_iter=n_iter)

            best_parameter = result['best_params']
            best_score = result['best_score']
            pkl_file = result['pkl_file']
            best_metrics = result['best_metrics']

            best_param_path = os.path.join(Configuration.DATA_DIR, 'best_params.json')

            with open(best_param_path, 'w') as f:
                json.dump(best_parameter, f, indent=4)

            tuned_db = best_parameter['vector_db_type']
            tuned_path = os.path.join(Configuration.DATA_DIR, 'TunedDB', tuned_db)
            if tuned_db != 'pinecone':
                os.makedirs(tuned_path, exist_ok=True)
            tuned_collection_name = "tuned-" + Configuration.COLLECTION_NAME

            tuned_params = {
                'document_path': Configuration.FULL_PDF_PATH,
                'chunk_size': best_parameter.get('chunk_size', Configuration.DEFAULT_CHUNK_SIZE),
                'chunk_overlap': best_parameter.get('chunk_overlap', Configuration.DEFAULT_CHUNK_OVERLAP),
                'embedding_model_name': best_parameter.get('embedding_model',
                                                           Configuration.DEFAULT_SENTENCE_TRANSFORMER_MODEL),
                'vector_db_type': tuned_db,
                'llm_model_name': llm_model,
                'db_path': tuned_path if tuned_db != 'pinecone' else "",
                'collection_name': tuned_collection_name,
                'vector_db': None,
                're_ranker_model': best_parameter.get('re_ranker', Configuration.DEFAULT_RERANKER)
            }

            if 're_ranker_model' in best_parameter:
                tuned_params['re_ranker_model'] = best_parameter['re_ranker_model']
            else:
                tuned_params['re_ranker_model'] = Configuration.DEFAULT_RERANKER

            tuned_pipeline = RAGOperations.initialize_pipeline(tuned_params)
            tuned_pipeline.build_index()

            return result

        except Exception as ex:
            logs.logger.info(f"Exception in hypertune: {ex} ")
            traceback.print_exc()
            return {"error": str(ex)}  # Return error for Streamlit to display

    @staticmethod
    def user_page():
        st.title("Flykite HR Policy Query")
        st.write(f"Logged in as: {st.session_state.user_info['username']} (Role: {st.session_state.user_info['role']})")

        st.info("Ask any question about the Flykite Airlines HR policy document.")

        with st.form("user_query_form"):
            user_query = st.text_area("Your Query:", height=100, key="user_query_input")
            response_type = st.radio("Choose Response Type:",
                                     options=["LLM Tuned Response (RAG + LLM)",
                                              "RAG Raw Response (Retrieved Docs Only)"],
                                     index=0, key="response_type_radio")
            k_value_user = st.slider("Number of documents to consider (k)", min_value=1, max_value=10, value=5,
                                     key="k_value_user_slider")

            submit_query_button = st.form_submit_button("Get Answer")

            if submit_query_button and user_query:
                st.subheader("Response:")
                with st.spinner("Fetching answer..."):
                    try:
                        display_raw = (response_type == "RAG Raw Response (Retrieved Docs Only)")
                        # Direct call to RAGOperations.run_search_job
                        result = RAGOperations.run_search_job(
                            query=user_query,
                            raw=display_raw,
                            k=k_value_user,
                            use_tuned=True,  # User page always uses tuned if available
                            user_context=st.session_state.user_context  # Pass user context
                        )

                        if result and "error" not in result:
                            response_data = result.get('response', {})
                            evaluation = result.get('evaluation',{})
                            if display_raw:
                                st.json(response_data)  # Raw output from main.py is already formatted
                            else:
                                if 'summary' in response_data:
                                    st.markdown(response_data['summary'])
                                    if 'sources' in response_data and response_data['sources']:
                                        st.subheader("Sources:")
                                        for source in response_data['sources']:
                                            if isinstance(source, dict):
                                                st.markdown(
                                                    f"- **Document ID:** {source.get('document_id', 'N/A')}, **Page:** {source.get('page', 'N/A')}, **Section:** {source.get('section', 'N/A')}, **Clause:** {source.get('clause', 'N/A')}")
                                            else:  # Fallback for raw string sources
                                                st.markdown(f"- {source}")
                                else:
                                    st.json(response_data)
                            if evaluation:
                                #st.markdown(f"**Evaluation Results:**  **Groundedness Score**  {evaluation.get('Groundedness score', 'N/A')}, **Relevance Score:** {evaluation.get('Relevance score', 'N/A')}, **Reasoning** {evaluation.get('Reasoning', 'N/A')}")
                                st.json(evaluation)
                        else:
                            st.error(
                                f"Failed to get a response: {result.get('error', 'Unknown error')}. Please try again.")
                    except Exception as e:
                        st.error(f"An unexpected error occurred during user query: {e}")
                        st.error(traceback.format_exc())
            elif submit_query_button and not user_query:
                st.warning("Please enter a query.")


def main_app():
    st.sidebar.title("Navigation")
    if 'logged_in' not in st.session_state:
        st.session_state.logged_in = False
    if 'page' not in st.session_state:
        st.session_state.page = "Home"

    if not st.session_state.logged_in:
        st.session_state.page = "Home"
        RAGOperations.home_page()
    else:

        st.sidebar.button("Home", on_click=lambda: st.session_state.update(page="Home"))
        if st.session_state.user_info and st.session_state.user_info['role'] == 'admin':
            st.sidebar.button("Admin Dashboard", on_click=lambda: st.session_state.update(page="Admin"))
            st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User"))
        else:
            st.sidebar.button("User Query", on_click=lambda: st.session_state.update(page="User"))

        if st.session_state.page == "Home":
            RAGOperations.home_page()
        elif st.session_state.page == "Admin":
            RAGOperations.admin_page()
        elif st.session_state.page == "User":
            RAGOperations.user_page()



if __name__ == "__main__":
    main_app()