Spaces:
Runtime error
Runtime error
Upload RAG_Library_2.py
Browse files
App_Function_Libraries/RAG/RAG_Library_2.py
CHANGED
|
@@ -128,75 +128,74 @@ search_functions = {
|
|
| 128 |
|
| 129 |
# RAG Search with keyword filtering
|
| 130 |
# FIXME - Update each called function to support modifiable top-k results
|
| 131 |
-
def enhanced_rag_pipeline(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 132 |
"""
|
| 133 |
Perform full text search across specified database type.
|
| 134 |
|
| 135 |
Args:
|
| 136 |
query: Search query string
|
| 137 |
api_choice: API to use for generating the response
|
| 138 |
-
fts_top_k: Maximum number of results to return
|
| 139 |
keywords: Optional list of media IDs to filter results
|
| 140 |
-
|
|
|
|
|
|
|
| 141 |
|
| 142 |
Returns:
|
| 143 |
Dictionary containing search results with content
|
| 144 |
"""
|
| 145 |
log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
|
| 146 |
start_time = time.time()
|
|
|
|
| 147 |
try:
|
| 148 |
# Load embedding provider from config, or fallback to 'openai'
|
| 149 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
| 150 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
| 151 |
|
| 152 |
-
#
|
| 153 |
-
|
| 154 |
-
logging.debug(f"\n\nenhanced_rag_pipeline - Keywords: {keyword_list}")
|
| 155 |
|
| 156 |
-
|
|
|
|
|
|
|
|
|
|
| 157 |
|
| 158 |
-
# Fetch relevant IDs based on keywords if keywords are provided
|
| 159 |
-
if keyword_list:
|
| 160 |
try:
|
| 161 |
for db_type in database_types:
|
| 162 |
if db_type == "Media DB":
|
| 163 |
-
|
| 164 |
-
relevant_ids[db_type] =
|
| 165 |
elif db_type == "RAG Chat":
|
| 166 |
-
conversations,
|
| 167 |
-
relevant_ids[db_type] = [conv['conversation_id'] for conv in conversations]
|
| 168 |
elif db_type == "RAG Notes":
|
| 169 |
-
notes,
|
| 170 |
-
relevant_ids[db_type] = [note_id for note_id, _, _, _ in notes]
|
| 171 |
elif db_type == "Character Chat":
|
| 172 |
-
relevant_ids[db_type] = fetch_keywords_for_chats(keyword_list)
|
| 173 |
elif db_type == "Character Cards":
|
| 174 |
-
relevant_ids[db_type] = fetch_character_ids_by_keywords(keyword_list)
|
| 175 |
else:
|
| 176 |
logging.error(f"Unsupported database type: {db_type}")
|
| 177 |
|
| 178 |
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
|
| 179 |
except Exception as e:
|
| 180 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
|
|
|
| 181 |
else:
|
| 182 |
-
relevant_ids = None
|
| 183 |
-
|
| 184 |
-
# Prepare relevant IDs for each database type
|
| 185 |
-
relevant_ids_dict = {}
|
| 186 |
-
if relevant_ids:
|
| 187 |
-
for db_type in database_types:
|
| 188 |
-
if db_type in relevant_ids and relevant_ids[db_type]:
|
| 189 |
-
relevant_ids_dict[db_type] = [str(id_) for id_ in relevant_ids[db_type]]
|
| 190 |
-
else:
|
| 191 |
-
relevant_ids_dict[db_type] = None
|
| 192 |
-
else:
|
| 193 |
-
relevant_ids_dict = {db_type: None for db_type in database_types}
|
| 194 |
|
| 195 |
-
# Perform vector search
|
| 196 |
vector_results = []
|
| 197 |
for db_type in database_types:
|
| 198 |
try:
|
| 199 |
-
db_relevant_ids =
|
| 200 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
| 201 |
vector_results.extend(results)
|
| 202 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
|
@@ -217,7 +216,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 217 |
fts_results = []
|
| 218 |
for db_type in database_types:
|
| 219 |
try:
|
| 220 |
-
db_relevant_ids = relevant_ids.get(db_type)
|
| 221 |
db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
|
| 222 |
fts_results.extend(db_results)
|
| 223 |
logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
|
|
@@ -233,12 +232,12 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 233 |
# Combine results
|
| 234 |
all_results = vector_results + fts_results
|
| 235 |
|
|
|
|
|
|
|
|
|
|
| 236 |
if apply_re_ranking and all_results:
|
| 237 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
| 238 |
-
|
| 239 |
-
# FIXME - specify model + add param to modify at call time
|
| 240 |
-
# FIXME - add option to set a custom top X results
|
| 241 |
-
# You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
|
| 242 |
if all_results:
|
| 243 |
ranker = Ranker()
|
| 244 |
|
|
@@ -273,7 +272,7 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 273 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
| 274 |
}
|
| 275 |
|
| 276 |
-
#
|
| 277 |
pipeline_duration = time.time() - start_time
|
| 278 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
| 279 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
|
@@ -284,7 +283,6 @@ def enhanced_rag_pipeline(query: str, api_choice: str, keywords: str = None, fts
|
|
| 284 |
}
|
| 285 |
|
| 286 |
except Exception as e:
|
| 287 |
-
# Metrics
|
| 288 |
log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
|
| 289 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
| 290 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
|
|
|
| 128 |
|
| 129 |
# RAG Search with keyword filtering
|
| 130 |
# FIXME - Update each called function to support modifiable top-k results
|
| 131 |
+
def enhanced_rag_pipeline(
|
| 132 |
+
query: str,
|
| 133 |
+
api_choice: str,
|
| 134 |
+
keywords: Optional[str] = None,
|
| 135 |
+
fts_top_k: int = 10,
|
| 136 |
+
apply_re_ranking: bool = True,
|
| 137 |
+
database_types: List[str] = ["Media DB"]
|
| 138 |
+
) -> Dict[str, Any]:
|
| 139 |
"""
|
| 140 |
Perform full text search across specified database type.
|
| 141 |
|
| 142 |
Args:
|
| 143 |
query: Search query string
|
| 144 |
api_choice: API to use for generating the response
|
|
|
|
| 145 |
keywords: Optional list of media IDs to filter results
|
| 146 |
+
fts_top_k: Maximum number of results to return
|
| 147 |
+
apply_re_ranking: Whether to apply re-ranking to results
|
| 148 |
+
database_types: Type of database to search
|
| 149 |
|
| 150 |
Returns:
|
| 151 |
Dictionary containing search results with content
|
| 152 |
"""
|
| 153 |
log_counter("enhanced_rag_pipeline_attempt", labels={"api_choice": api_choice})
|
| 154 |
start_time = time.time()
|
| 155 |
+
|
| 156 |
try:
|
| 157 |
# Load embedding provider from config, or fallback to 'openai'
|
| 158 |
embedding_provider = config.get('Embeddings', 'provider', fallback='openai')
|
| 159 |
logging.debug(f"Using embedding provider: {embedding_provider}")
|
| 160 |
|
| 161 |
+
# Initialize relevant IDs dictionary
|
| 162 |
+
relevant_ids: Dict[str, Optional[List[str]]] = {}
|
|
|
|
| 163 |
|
| 164 |
+
# Process keywords if provided
|
| 165 |
+
if keywords:
|
| 166 |
+
keyword_list = [k.strip().lower() for k in keywords.split(',')]
|
| 167 |
+
logging.debug(f"enhanced_rag_pipeline - Keywords: {keyword_list}")
|
| 168 |
|
|
|
|
|
|
|
| 169 |
try:
|
| 170 |
for db_type in database_types:
|
| 171 |
if db_type == "Media DB":
|
| 172 |
+
media_ids = fetch_relevant_media_ids(keyword_list)
|
| 173 |
+
relevant_ids[db_type] = [str(id_) for id_ in media_ids]
|
| 174 |
elif db_type == "RAG Chat":
|
| 175 |
+
conversations, _, _ = search_conversations_by_keywords(keywords=keyword_list)
|
| 176 |
+
relevant_ids[db_type] = [str(conv['conversation_id']) for conv in conversations]
|
| 177 |
elif db_type == "RAG Notes":
|
| 178 |
+
notes, _, _ = get_notes_by_keywords(keyword_list)
|
| 179 |
+
relevant_ids[db_type] = [str(note_id) for note_id, _, _, _ in notes]
|
| 180 |
elif db_type == "Character Chat":
|
| 181 |
+
relevant_ids[db_type] = [str(id_) for id_ in fetch_keywords_for_chats(keyword_list)]
|
| 182 |
elif db_type == "Character Cards":
|
| 183 |
+
relevant_ids[db_type] = [str(id_) for id_ in fetch_character_ids_by_keywords(keyword_list)]
|
| 184 |
else:
|
| 185 |
logging.error(f"Unsupported database type: {db_type}")
|
| 186 |
|
| 187 |
logging.debug(f"enhanced_rag_pipeline - {db_type} relevant IDs: {relevant_ids[db_type]}")
|
| 188 |
except Exception as e:
|
| 189 |
logging.error(f"Error fetching relevant IDs: {str(e)}")
|
| 190 |
+
relevant_ids = {db_type: None for db_type in database_types}
|
| 191 |
else:
|
| 192 |
+
relevant_ids = {db_type: None for db_type in database_types}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 193 |
|
| 194 |
+
# Perform vector search
|
| 195 |
vector_results = []
|
| 196 |
for db_type in database_types:
|
| 197 |
try:
|
| 198 |
+
db_relevant_ids = relevant_ids.get(db_type)
|
| 199 |
results = perform_vector_search(query, db_relevant_ids, top_k=fts_top_k)
|
| 200 |
vector_results.extend(results)
|
| 201 |
logging.debug(f"\nenhanced_rag_pipeline - Vector search results for {db_type}: {results}")
|
|
|
|
| 216 |
fts_results = []
|
| 217 |
for db_type in database_types:
|
| 218 |
try:
|
| 219 |
+
db_relevant_ids = relevant_ids.get(db_type)
|
| 220 |
db_results = perform_full_text_search(query, db_type, db_relevant_ids, fts_top_k)
|
| 221 |
fts_results.extend(db_results)
|
| 222 |
logging.debug(f"enhanced_rag_pipeline - FTS results for {db_type}: {db_results}")
|
|
|
|
| 232 |
# Combine results
|
| 233 |
all_results = vector_results + fts_results
|
| 234 |
|
| 235 |
+
# FIXME - specify model + add param to modify at call time
|
| 236 |
+
# You can specify a model if necessary, e.g., model_name="ms-marco-MiniLM-L-12-v2"
|
| 237 |
+
# Apply re-ranking if enabled and results exist
|
| 238 |
if apply_re_ranking and all_results:
|
| 239 |
logging.debug(f"\nenhanced_rag_pipeline - Applying Re-Ranking")
|
| 240 |
+
|
|
|
|
|
|
|
|
|
|
| 241 |
if all_results:
|
| 242 |
ranker = Ranker()
|
| 243 |
|
|
|
|
| 272 |
"context": "No relevant information based on your query and keywords were found in the database. The only context used was your query: \n\n" + query
|
| 273 |
}
|
| 274 |
|
| 275 |
+
# Log metrics
|
| 276 |
pipeline_duration = time.time() - start_time
|
| 277 |
log_histogram("enhanced_rag_pipeline_duration", pipeline_duration, labels={"api_choice": api_choice})
|
| 278 |
log_counter("enhanced_rag_pipeline_success", labels={"api_choice": api_choice})
|
|
|
|
| 283 |
}
|
| 284 |
|
| 285 |
except Exception as e:
|
|
|
|
| 286 |
log_counter("enhanced_rag_pipeline_error", labels={"api_choice": api_choice, "error": str(e)})
|
| 287 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|
| 288 |
logging.error(f"Error in enhanced_rag_pipeline: {str(e)}")
|