Spaces:
Runtime error
Runtime error
Upload 2 files
Browse files
App_Function_Libraries/RAG/RAG_Library_2.py
CHANGED
|
@@ -147,8 +147,6 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 147 |
try:
|
| 148 |
# Load embedding provider from config, or fallback to 'openai'
|
| 149 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
| 150 |
-
|
| 151 |
-
# Log the provider used
|
| 152 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
| 153 |
|
| 154 |
# Process keywords if provided
|
|
@@ -164,61 +162,41 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 164 |
if db_type == "Media DB":
|
| 165 |
relevant_media_ids = fetch_relevant_media_ids(keyword_list)
|
| 166 |
relevant_ids[db_type] = relevant_media_ids
|
| 167 |
-
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant media IDs: {relevant_media_ids}")
|
| 168 |
-
|
| 169 |
elif db_type == "RAG Chat":
|
| 170 |
-
conversations, total_pages, total_count = search_conversations_by_keywords(
|
| 171 |
-
|
| 172 |
-
relevant_conversation_ids = [conv['conversation_id'] for conv in conversations]
|
| 173 |
-
relevant_ids[db_type] = relevant_conversation_ids
|
| 174 |
-
logging.debug(
|
| 175 |
-
f"enhanced_rag_pipeline - {db_type} relevant conversation IDs: {relevant_conversation_ids}")
|
| 176 |
-
|
| 177 |
elif db_type == "RAG Notes":
|
| 178 |
notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
|
| 179 |
-
|
| 180 |
-
relevant_ids[db_type] = relevant_note_ids
|
| 181 |
-
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant note IDs: {relevant_note_ids}")
|
| 182 |
-
|
| 183 |
elif db_type == "Character Chat":
|
| 184 |
-
|
| 185 |
-
relevant_ids[db_type] = relevant_chat_ids
|
| 186 |
-
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant chat IDs: {relevant_chat_ids}")
|
| 187 |
-
|
| 188 |
elif db_type == "Character Cards":
|
| 189 |
-
|
| 190 |
-
relevant_character_ids = fetch_character_ids_by_keywords(keyword_list)
|
| 191 |
-
relevant_ids[db_type] = relevant_character_ids
|
| 192 |
-
logging.debug(
|
| 193 |
-
f"enhanced_rag_pipeline - {db_type} relevant character IDs: {relevant_character_ids}")
|
| 194 |
-
|
| 195 |
else:
|
| 196 |
logging.error(f"Unsupported database type: {db_type}")
|
| 197 |
|
|
|
|
| 198 |
except Exception as e:
|
| 199 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
| 200 |
else:
|
| 201 |
relevant_ids = None
|
| 202 |
|
| 203 |
-
#
|
| 204 |
-
|
| 205 |
-
relevant_media_ids_dict = {}
|
| 206 |
if relevant_ids:
|
| 207 |
for db_type in database_types:
|
| 208 |
-
|
| 209 |
-
|
| 210 |
-
# Convert to List[str] if not None
|
| 211 |
-
relevant_media_ids_dict[db_type] = [str(media_id) for media_id in relevant_media_ids]
|
| 212 |
else:
|
| 213 |
-
|
| 214 |
else:
|
| 215 |
-
|
| 216 |
|
| 217 |
# Perform vector search for all selected databases
|
| 218 |
vector_results = []
|
| 219 |
for db_type in database_types:
|
| 220 |
try:
|
| 221 |
-
db_relevant_ids =
|
| 222 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
| 223 |
vector_results.extend(results)
|
| 224 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
|
@@ -227,8 +205,8 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 227 |
|
| 228 |
# Perform vector search
|
| 229 |
# FIXME
|
| 230 |
-
vector_results = perform_vector_search(query, relevant_media_ids)
|
| 231 |
-
|
| 232 |
|
| 233 |
# Perform full-text search
|
| 234 |
#v1
|
|
@@ -246,7 +224,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 246 |
except Exception as e:
|
| 247 |
logging.error(f"Error performing full-text search on {db_type}: {str(e)}")
|
| 248 |
|
| 249 |
-
logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
|
| 250 |
logging.debug(
|
| 251 |
"\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
|
| 252 |
[str(item) for item in fts_results]) + "\n"
|
|
@@ -255,7 +233,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 255 |
# Combine results
|
| 256 |
all_results = vector_results + fts_results
|
| 257 |
|
| 258 |
-
if apply_re_ranking:
|
| 259 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
| 260 |
# FIXME - add option to use re-ranking at call time
|
| 261 |
# FIXME - specify model + add param to modify at call time
|
|
@@ -282,7 +260,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 282 |
|
| 283 |
# Extract content from results (top fts_top_k by default)
|
| 284 |
context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
|
| 285 |
-
logging.debug(f"Context length: {len(context)}")
|
| 286 |
logging.debug(f"Context: {context[:200]}")
|
| 287 |
|
| 288 |
# Generate answer using the selected API
|
|
@@ -294,10 +272,12 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 294 |
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
| 295 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
| 296 |
}
|
|
|
|
| 297 |
# Metrics
|
| 298 |
pipeline_duration = time.time() - start_time
|
| 299 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
| 300 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
|
|
|
| 301 |
return {
|
| 302 |
"answer": answer,
|
| 303 |
"context": context
|
|
|
|
| 147 |
try:
|
| 148 |
# Load embedding provider from config, or fallback to 'openai'
|
| 149 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
|
|
|
|
|
|
| 150 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
| 151 |
|
| 152 |
# Process keywords if provided
|
|
|
|
| 162 |
if db_type == "Media DB":
|
| 163 |
relevant_media_ids = fetch_relevant_media_ids(keyword_list)
|
| 164 |
relevant_ids[db_type] = relevant_media_ids
|
|
|
|
|
|
|
| 165 |
elif db_type == "RAG Chat":
|
| 166 |
+
conversations, total_pages, total_count = search_conversations_by_keywords(keywords=keyword_list)
|
| 167 |
+
relevant_ids[db_type] = [conv['conversation_id'] for conv in conversations]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
elif db_type == "RAG Notes":
|
| 169 |
notes, total_pages, total_count = get_notes_by_keywords(keyword_list)
|
| 170 |
+
relevant_ids[db_type] = [note_id for note_id, _, _, _ in notes]
|
|
|
|
|
|
|
|
|
|
| 171 |
elif db_type == "Character Chat":
|
| 172 |
+
relevant_ids[db_type] = fetch_keywords_for_chats(keyword_list)
|
|
|
|
|
|
|
|
|
|
| 173 |
elif db_type == "Character Cards":
|
| 174 |
+
relevant_ids[db_type] = fetch_character_ids_by_keywords(keyword_list)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
else:
|
| 176 |
logging.error(f"Unsupported database type: {db_type}")
|
| 177 |
|
| 178 |
+
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
|
| 179 |
except Exception as e:
|
| 180 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
| 181 |
else:
|
| 182 |
relevant_ids = None
|
| 183 |
|
| 184 |
+
# Prepare relevant IDs for each database type
|
| 185 |
+
relevant_ids_dict = {}
|
|
|
|
| 186 |
if relevant_ids:
|
| 187 |
for db_type in database_types:
|
| 188 |
+
if db_type in relevant_ids and relevant_ids[db_type]:
|
| 189 |
+
relevant_ids_dict[db_type] = [str(id_) for id_ in relevant_ids[db_type]]
|
|
|
|
|
|
|
| 190 |
else:
|
| 191 |
+
relevant_ids_dict[db_type] = None
|
| 192 |
else:
|
| 193 |
+
relevant_ids_dict = {db_type: None for db_type in database_types}
|
| 194 |
|
| 195 |
# Perform vector search for all selected databases
|
| 196 |
vector_results = []
|
| 197 |
for db_type in database_types:
|
| 198 |
try:
|
| 199 |
+
db_relevant_ids = relevant_ids_dict.get(db_type)
|
| 200 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
| 201 |
vector_results.extend(results)
|
| 202 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
|
|
|
| 205 |
|
| 206 |
# Perform vector search
|
| 207 |
# FIXME
|
| 208 |
+
#vector_results = perform_vector_search(query, relevant_media_ids)
|
| 209 |
+
#ogging.debug(f"\n\nenhanced_rag_pipeline - Vector search results: {vector_results}")
|
| 210 |
|
| 211 |
# Perform full-text search
|
| 212 |
#v1
|
|
|
|
| 224 |
except Exception as e:
|
| 225 |
logging.error(f"Error performing full-text search on {db_type}: {str(e)}")
|
| 226 |
|
| 227 |
+
#logging.debug("\n\nenhanced_rag_pipeline - Full-text search results:")
|
| 228 |
logging.debug(
|
| 229 |
"\n\nenhanced_rag_pipeline - Full-text search results:\n" + "\n".join(
|
| 230 |
[str(item) for item in fts_results]) + "\n"
|
|
|
|
| 233 |
# Combine results
|
| 234 |
all_results = vector_results + fts_results
|
| 235 |
|
| 236 |
+
if apply_re_ranking and all_results:
|
| 237 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
| 238 |
# FIXME - add option to use re-ranking at call time
|
| 239 |
# FIXME - specify model + add param to modify at call time
|
|
|
|
| 260 |
|
| 261 |
# Extract content from results (top fts_top_k by default)
|
| 262 |
context = "\n".join([result['content'] for result in all_results[:fts_top_k]])
|
| 263 |
+
#logging.debug(f"Context length: {len(context)}")
|
| 264 |
logging.debug(f"Context: {context[:200]}")
|
| 265 |
|
| 266 |
# Generate answer using the selected API
|
|
|
|
| 272 |
"answer": "No relevant information based on your query and keywords were found in the database. Your query has been directly passed to the LLM, and here is its answer: \n\n" + answer,
|
| 273 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
| 274 |
}
|
| 275 |
+
|
| 276 |
# Metrics
|
| 277 |
pipeline_duration = time.time() - start_time
|
| 278 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
| 279 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
| 280 |
+
|
| 281 |
return {
|
| 282 |
"answer": answer,
|
| 283 |
"context": context
|