|
|
|
|
|
""" |
|
|
Query Expansion System for CogniChat RAG Application |
|
|
|
|
|
This module implements advanced query expansion techniques to improve retrieval quality: |
|
|
- QueryAnalyzer: Extracts intent, entities, and keywords |
|
|
- QueryRephraser: Generates natural language variations |
|
|
- MultiQueryExpander: Creates diverse query formulations |
|
|
- MultiHopReasoner: Connects concepts across documents |
|
|
- FallbackStrategies: Handles edge cases gracefully |
|
|
|
|
|
Author: CogniChat Team |
|
|
Date: October 19, 2025 |
|
|
""" |
|
|
|
|
|
import re |
|
|
from typing import List, Dict, Any, Optional |
|
|
from dataclasses import dataclass |
|
|
from enum import Enum |
|
|
|
|
|
|
|
|
class QueryStrategy(Enum): |
|
|
"""Query expansion strategies with different complexity levels.""" |
|
|
QUICK = "quick" |
|
|
BALANCED = "balanced" |
|
|
COMPREHENSIVE = "comprehensive" |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class QueryAnalysis: |
|
|
"""Results from query analysis.""" |
|
|
intent: str |
|
|
entities: List[str] |
|
|
keywords: List[str] |
|
|
complexity: str |
|
|
domain: Optional[str] = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class ExpandedQuery: |
|
|
"""Container for expanded query variations.""" |
|
|
original: str |
|
|
variations: List[str] |
|
|
strategy_used: QueryStrategy |
|
|
analysis: QueryAnalysis |
|
|
|
|
|
|
|
|
class QueryAnalyzer: |
|
|
""" |
|
|
Analyzes queries to extract intent, entities, and key information. |
|
|
Uses LLM-based analysis for intelligent query understanding. |
|
|
""" |
|
|
|
|
|
def __init__(self, llm=None): |
|
|
""" |
|
|
Initialize QueryAnalyzer. |
|
|
|
|
|
Args: |
|
|
llm: Optional LangChain LLM for advanced analysis |
|
|
""" |
|
|
self.llm = llm |
|
|
self.intent_patterns = { |
|
|
'definition': r'\b(what is|define|meaning of|definition)\b', |
|
|
'how_to': r'\b(how to|how do|how can|steps to)\b', |
|
|
'comparison': r'\b(compare|difference|versus|vs|better than)\b', |
|
|
'explanation': r'\b(why|explain|reason|cause)\b', |
|
|
'listing': r'\b(list|enumerate|what are|types of)\b', |
|
|
'example': r'\b(example|instance|sample|case)\b', |
|
|
} |
|
|
|
|
|
def analyze(self, query: str) -> QueryAnalysis: |
|
|
""" |
|
|
Analyze query to extract intent, entities, and keywords. |
|
|
|
|
|
Args: |
|
|
query: User's original query |
|
|
|
|
|
Returns: |
|
|
QueryAnalysis object with extracted information |
|
|
""" |
|
|
query_lower = query.lower() |
|
|
|
|
|
|
|
|
intent = self._detect_intent(query_lower) |
|
|
|
|
|
|
|
|
entities = self._extract_entities(query) |
|
|
|
|
|
|
|
|
keywords = self._extract_keywords(query) |
|
|
|
|
|
|
|
|
complexity = self._assess_complexity(query, entities, keywords) |
|
|
|
|
|
|
|
|
domain = self._detect_domain(query_lower) |
|
|
|
|
|
return QueryAnalysis( |
|
|
intent=intent, |
|
|
entities=entities, |
|
|
keywords=keywords, |
|
|
complexity=complexity, |
|
|
domain=domain |
|
|
) |
|
|
|
|
|
def _detect_intent(self, query_lower: str) -> str: |
|
|
"""Detect query intent using pattern matching.""" |
|
|
for intent, pattern in self.intent_patterns.items(): |
|
|
if re.search(pattern, query_lower): |
|
|
return intent |
|
|
return 'general' |
|
|
|
|
|
def _extract_entities(self, query: str) -> List[str]: |
|
|
"""Extract named entities (simplified version).""" |
|
|
|
|
|
words = query.split() |
|
|
entities = [] |
|
|
|
|
|
for word in words: |
|
|
|
|
|
if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']: |
|
|
entities.append(word) |
|
|
|
|
|
|
|
|
quoted = re.findall(r'"([^"]+)"', query) |
|
|
entities.extend(quoted) |
|
|
|
|
|
return list(set(entities)) |
|
|
|
|
|
def _extract_keywords(self, query: str) -> List[str]: |
|
|
"""Extract important keywords from query.""" |
|
|
|
|
|
stop_words = { |
|
|
'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been', |
|
|
'what', 'how', 'why', 'when', 'where', 'which', 'who', |
|
|
'do', 'does', 'did', 'can', 'could', 'should', 'would', |
|
|
'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by' |
|
|
} |
|
|
|
|
|
|
|
|
words = re.findall(r'\b\w+\b', query.lower()) |
|
|
keywords = [w for w in words if w not in stop_words and len(w) > 2] |
|
|
|
|
|
return keywords[:10] |
|
|
|
|
|
def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str: |
|
|
"""Assess query complexity.""" |
|
|
word_count = len(query.split()) |
|
|
entity_count = len(entities) |
|
|
keyword_count = len(keywords) |
|
|
|
|
|
|
|
|
score = word_count + (entity_count * 2) + (keyword_count * 1.5) |
|
|
|
|
|
if score < 15: |
|
|
return 'simple' |
|
|
elif score < 30: |
|
|
return 'medium' |
|
|
else: |
|
|
return 'complex' |
|
|
|
|
|
def _detect_domain(self, query_lower: str) -> Optional[str]: |
|
|
"""Detect technical domain if present.""" |
|
|
domains = { |
|
|
'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'], |
|
|
'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'], |
|
|
'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'], |
|
|
'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'], |
|
|
'database': ['sql', 'query', 'database', 'table', 'index'], |
|
|
'security': ['encryption', 'authentication', 'vulnerability', 'attack'], |
|
|
} |
|
|
|
|
|
for domain, keywords in domains.items(): |
|
|
if any(kw in query_lower for kw in keywords): |
|
|
return domain |
|
|
|
|
|
return None |
|
|
|
|
|
|
|
|
class QueryRephraser: |
|
|
""" |
|
|
Generates natural language variations of queries using multiple strategies. |
|
|
""" |
|
|
|
|
|
def __init__(self, llm=None): |
|
|
""" |
|
|
Initialize QueryRephraser. |
|
|
|
|
|
Args: |
|
|
llm: LangChain LLM for generating variations |
|
|
""" |
|
|
self.llm = llm |
|
|
|
|
|
def generate_variations( |
|
|
self, |
|
|
query: str, |
|
|
analysis: QueryAnalysis, |
|
|
strategy: QueryStrategy = QueryStrategy.BALANCED |
|
|
) -> List[str]: |
|
|
""" |
|
|
Generate query variations based on strategy. |
|
|
|
|
|
Args: |
|
|
query: Original query |
|
|
analysis: Query analysis results |
|
|
strategy: Expansion strategy to use |
|
|
|
|
|
Returns: |
|
|
List of query variations |
|
|
""" |
|
|
variations = [query] |
|
|
|
|
|
if strategy == QueryStrategy.QUICK: |
|
|
|
|
|
variations.append(self._synonym_variation(query, analysis)) |
|
|
|
|
|
elif strategy == QueryStrategy.BALANCED: |
|
|
|
|
|
variations.append(self._synonym_variation(query, analysis)) |
|
|
variations.append(self._expanded_variation(query, analysis)) |
|
|
variations.append(self._simplified_variation(query, analysis)) |
|
|
|
|
|
elif strategy == QueryStrategy.COMPREHENSIVE: |
|
|
|
|
|
variations.append(self._synonym_variation(query, analysis)) |
|
|
variations.append(self._expanded_variation(query, analysis)) |
|
|
variations.append(self._simplified_variation(query, analysis)) |
|
|
variations.append(self._keyword_focused(query, analysis)) |
|
|
variations.append(self._context_variation(query, analysis)) |
|
|
|
|
|
if analysis.intent in ['how_to', 'explanation']: |
|
|
variations.append(f"Guide to {' '.join(analysis.keywords[:3])}") |
|
|
|
|
|
|
|
|
variations = [v for v in variations if v] |
|
|
return list(dict.fromkeys(variations)) |
|
|
|
|
|
def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str: |
|
|
"""Generate variation using synonyms.""" |
|
|
|
|
|
synonyms = { |
|
|
'error': 'issue', |
|
|
'problem': 'issue', |
|
|
'fix': 'resolve', |
|
|
'use': 'utilize', |
|
|
'create': 'generate', |
|
|
'make': 'create', |
|
|
'get': 'retrieve', |
|
|
'show': 'display', |
|
|
'find': 'locate', |
|
|
'explain': 'describe', |
|
|
} |
|
|
|
|
|
words = query.lower().split() |
|
|
for i, word in enumerate(words): |
|
|
if word in synonyms: |
|
|
words[i] = synonyms[word] |
|
|
break |
|
|
|
|
|
return ' '.join(words).capitalize() |
|
|
|
|
|
def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str: |
|
|
"""Generate expanded version with more detail.""" |
|
|
if analysis.intent == 'definition': |
|
|
return f"Detailed explanation and definition of {' '.join(analysis.keywords)}" |
|
|
elif analysis.intent == 'how_to': |
|
|
return f"Step-by-step guide on {query.lower()}" |
|
|
elif analysis.intent == 'comparison': |
|
|
return f"Comprehensive comparison: {query}" |
|
|
else: |
|
|
|
|
|
return f"Detailed information about {query.lower()}" |
|
|
|
|
|
def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str: |
|
|
"""Generate simplified version focusing on core concepts.""" |
|
|
|
|
|
if len(analysis.keywords) >= 2: |
|
|
return ' '.join(analysis.keywords[:3]) |
|
|
return query |
|
|
|
|
|
def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str: |
|
|
"""Create keyword-focused variation for BM25.""" |
|
|
keywords = analysis.keywords + analysis.entities |
|
|
return ' '.join(keywords[:5]) |
|
|
|
|
|
def _context_variation(self, query: str, analysis: QueryAnalysis) -> str: |
|
|
"""Add contextual information if domain detected.""" |
|
|
if analysis.domain: |
|
|
return f"{query} in {analysis.domain} context" |
|
|
return query |
|
|
|
|
|
|
|
|
class MultiQueryExpander: |
|
|
""" |
|
|
Main query expansion orchestrator that combines analysis and rephrasing. |
|
|
""" |
|
|
|
|
|
def __init__(self, llm=None): |
|
|
""" |
|
|
Initialize MultiQueryExpander. |
|
|
|
|
|
Args: |
|
|
llm: LangChain LLM for advanced expansions |
|
|
""" |
|
|
self.analyzer = QueryAnalyzer(llm) |
|
|
self.rephraser = QueryRephraser(llm) |
|
|
|
|
|
def expand( |
|
|
self, |
|
|
query: str, |
|
|
strategy: QueryStrategy = QueryStrategy.BALANCED, |
|
|
max_queries: int = 6 |
|
|
) -> ExpandedQuery: |
|
|
""" |
|
|
Expand query into multiple variations. |
|
|
|
|
|
Args: |
|
|
query: Original user query |
|
|
strategy: Expansion strategy |
|
|
max_queries: Maximum number of queries to generate |
|
|
|
|
|
Returns: |
|
|
ExpandedQuery object with all variations |
|
|
""" |
|
|
|
|
|
analysis = self.analyzer.analyze(query) |
|
|
|
|
|
|
|
|
variations = self.rephraser.generate_variations(query, analysis, strategy) |
|
|
|
|
|
|
|
|
variations = variations[:max_queries] |
|
|
|
|
|
return ExpandedQuery( |
|
|
original=query, |
|
|
variations=variations, |
|
|
strategy_used=strategy, |
|
|
analysis=analysis |
|
|
) |
|
|
|
|
|
|
|
|
class MultiHopReasoner: |
|
|
""" |
|
|
Implements multi-hop reasoning to connect concepts across documents. |
|
|
Useful for complex queries that require information from multiple sources. |
|
|
""" |
|
|
|
|
|
def __init__(self, llm=None): |
|
|
""" |
|
|
Initialize MultiHopReasoner. |
|
|
|
|
|
Args: |
|
|
llm: LangChain LLM for reasoning |
|
|
""" |
|
|
self.llm = llm |
|
|
|
|
|
def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]: |
|
|
""" |
|
|
Break complex query into sub-queries for multi-hop reasoning. |
|
|
|
|
|
Args: |
|
|
query: Original complex query |
|
|
analysis: Query analysis |
|
|
|
|
|
Returns: |
|
|
List of sub-queries |
|
|
""" |
|
|
sub_queries = [query] |
|
|
|
|
|
|
|
|
if analysis.intent == 'comparison' and len(analysis.entities) >= 2: |
|
|
for entity in analysis.entities[:2]: |
|
|
sub_queries.append(f"Information about {entity}") |
|
|
elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2: |
|
|
|
|
|
for keyword in analysis.keywords[:2]: |
|
|
sub_queries.append(f"Information about {keyword}") |
|
|
|
|
|
|
|
|
if analysis.intent == 'how_to' and len(analysis.keywords) >= 2: |
|
|
main_topic = ' '.join(analysis.keywords[:2]) |
|
|
sub_queries.append(f"Prerequisites for {main_topic}") |
|
|
sub_queries.append(f"Steps to {main_topic}") |
|
|
|
|
|
|
|
|
if analysis.complexity == 'complex' and len(analysis.keywords) > 3: |
|
|
|
|
|
mid = len(analysis.keywords) // 2 |
|
|
sub_queries.append(' '.join(analysis.keywords[:mid])) |
|
|
sub_queries.append(' '.join(analysis.keywords[mid:])) |
|
|
|
|
|
return sub_queries[:5] |
|
|
|
|
|
|
|
|
class FallbackStrategies: |
|
|
""" |
|
|
Implements fallback strategies for queries that don't retrieve good results. |
|
|
""" |
|
|
|
|
|
@staticmethod |
|
|
def simplify_query(query: str) -> str: |
|
|
"""Simplify query by removing modifiers and focusing on core terms.""" |
|
|
|
|
|
query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
query = re.sub(r'\s+', ' ', query).strip() |
|
|
|
|
|
return query |
|
|
|
|
|
@staticmethod |
|
|
def broaden_query(query: str, analysis: QueryAnalysis) -> str: |
|
|
"""Broaden query to increase recall.""" |
|
|
|
|
|
query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE) |
|
|
|
|
|
|
|
|
if analysis.keywords: |
|
|
return f"{analysis.keywords[0]} overview" |
|
|
|
|
|
return query |
|
|
|
|
|
@staticmethod |
|
|
def focus_entities(analysis: QueryAnalysis) -> str: |
|
|
"""Create entity-focused query as fallback.""" |
|
|
if analysis.entities: |
|
|
return ' '.join(analysis.entities) |
|
|
elif analysis.keywords: |
|
|
return ' '.join(analysis.keywords[:3]) |
|
|
return "" |
|
|
|
|
|
|
|
|
|
|
|
def expand_query_simple( |
|
|
query: str, |
|
|
strategy: str = "balanced", |
|
|
llm=None |
|
|
) -> List[str]: |
|
|
""" |
|
|
Simple function to expand a query without dealing with classes. |
|
|
|
|
|
Args: |
|
|
query: User's query to expand |
|
|
strategy: "quick", "balanced", or "comprehensive" |
|
|
llm: Optional LangChain LLM |
|
|
|
|
|
Returns: |
|
|
List of expanded query variations |
|
|
|
|
|
Example: |
|
|
>>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced") |
|
|
>>> print(queries) |
|
|
['How do I debug Python code?', 'How do I resolve Python code?', ...] |
|
|
""" |
|
|
expander = MultiQueryExpander(llm=llm) |
|
|
strategy_enum = QueryStrategy(strategy) |
|
|
expanded = expander.expand(query, strategy=strategy_enum) |
|
|
return expanded.variations |
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
print("=" * 60) |
|
|
print("Example 1: Simple Query Expansion") |
|
|
print("=" * 60) |
|
|
|
|
|
query = "What is machine learning?" |
|
|
queries = expand_query_simple(query, strategy="balanced") |
|
|
|
|
|
print(f"\nOriginal: {query}") |
|
|
print(f"\nExpanded queries ({len(queries)}):") |
|
|
for i, q in enumerate(queries, 1): |
|
|
print(f" {i}. {q}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Example 2: Complex Query with Analysis") |
|
|
print("=" * 60) |
|
|
|
|
|
expander = MultiQueryExpander() |
|
|
query = "How do I compare the performance of different neural network architectures?" |
|
|
result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE) |
|
|
|
|
|
print(f"\nOriginal: {result.original}") |
|
|
print(f"\nAnalysis:") |
|
|
print(f" Intent: {result.analysis.intent}") |
|
|
print(f" Entities: {result.analysis.entities}") |
|
|
print(f" Keywords: {result.analysis.keywords}") |
|
|
print(f" Complexity: {result.analysis.complexity}") |
|
|
print(f" Domain: {result.analysis.domain}") |
|
|
print(f"\nExpanded queries ({len(result.variations)}):") |
|
|
for i, q in enumerate(result.variations, 1): |
|
|
print(f" {i}. {q}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Example 3: Multi-Hop Reasoning") |
|
|
print("=" * 60) |
|
|
|
|
|
reasoner = MultiHopReasoner() |
|
|
analyzer = QueryAnalyzer() |
|
|
|
|
|
query = "Compare Python and Java for web development" |
|
|
analysis = analyzer.analyze(query) |
|
|
sub_queries = reasoner.generate_sub_queries(query, analysis) |
|
|
|
|
|
print(f"\nOriginal: {query}") |
|
|
print(f"\nSub-queries for multi-hop reasoning:") |
|
|
for i, sq in enumerate(sub_queries, 1): |
|
|
print(f" {i}. {sq}") |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("Example 4: Fallback Strategies") |
|
|
print("=" * 60) |
|
|
|
|
|
query = "What is the specific difference between supervised and unsupervised learning?" |
|
|
analysis = analyzer.analyze(query) |
|
|
|
|
|
|
|
|
|