Cogni-Chat-document-reader-v2 / query_expansion.py
riteshraut
feat/audio
be8f70c
"""
Query Expansion System for CogniChat RAG Application
This module implements advanced query expansion techniques to improve retrieval quality:
- QueryAnalyzer: Extracts intent, entities, and keywords
- QueryRephraser: Generates natural language variations
- MultiQueryExpander: Creates diverse query formulations
- MultiHopReasoner: Connects concepts across documents
- FallbackStrategies: Handles edge cases gracefully
Author: CogniChat Team
Date: October 19, 2025
"""
import re
from typing import List, Dict, Any, Optional
from dataclasses import dataclass
from enum import Enum
class QueryStrategy(Enum):
"""Query expansion strategies with different complexity levels."""
QUICK = "quick" # 2 queries - fast, minimal expansion
BALANCED = "balanced" # 3-4 queries - good balance
COMPREHENSIVE = "comprehensive" # 5-6 queries - maximum coverage
@dataclass
class QueryAnalysis:
"""Results from query analysis."""
intent: str # question, definition, comparison, explanation, etc.
entities: List[str] # Named entities extracted
keywords: List[str] # Important keywords
complexity: str # simple, medium, complex
domain: Optional[str] = None # Technical domain if detected
@dataclass
class ExpandedQuery:
"""Container for expanded query variations."""
original: str
variations: List[str]
strategy_used: QueryStrategy
analysis: QueryAnalysis
class QueryAnalyzer:
"""
Analyzes queries to extract intent, entities, and key information.
Uses LLM-based analysis for intelligent query understanding.
"""
def __init__(self, llm=None):
"""
Initialize QueryAnalyzer.
Args:
llm: Optional LangChain LLM for advanced analysis
"""
self.llm = llm
self.intent_patterns = {
'definition': r'\b(what is|define|meaning of|definition)\b',
'how_to': r'\b(how to|how do|how can|steps to)\b',
'comparison': r'\b(compare|difference|versus|vs|better than)\b',
'explanation': r'\b(why|explain|reason|cause)\b',
'listing': r'\b(list|enumerate|what are|types of)\b',
'example': r'\b(example|instance|sample|case)\b',
}
def analyze(self, query: str) -> QueryAnalysis:
"""
Analyze query to extract intent, entities, and keywords.
Args:
query: User's original query
Returns:
QueryAnalysis object with extracted information
"""
query_lower = query.lower()
# Detect intent
intent = self._detect_intent(query_lower)
# Extract entities (simplified - can be enhanced with NER)
entities = self._extract_entities(query)
# Extract keywords
keywords = self._extract_keywords(query)
# Assess complexity
complexity = self._assess_complexity(query, entities, keywords)
# Detect domain
domain = self._detect_domain(query_lower)
return QueryAnalysis(
intent=intent,
entities=entities,
keywords=keywords,
complexity=complexity,
domain=domain
)
def _detect_intent(self, query_lower: str) -> str:
"""Detect query intent using pattern matching."""
for intent, pattern in self.intent_patterns.items():
if re.search(pattern, query_lower):
return intent
return 'general'
def _extract_entities(self, query: str) -> List[str]:
"""Extract named entities (simplified version)."""
# Look for capitalized words (potential entities)
words = query.split()
entities = []
for word in words:
# Skip common words at sentence start
if word[0].isupper() and word.lower() not in ['what', 'how', 'why', 'when', 'where', 'which']:
entities.append(word)
# Look for quoted terms
quoted = re.findall(r'"([^"]+)"', query)
entities.extend(quoted)
return list(set(entities))
def _extract_keywords(self, query: str) -> List[str]:
"""Extract important keywords from query."""
# Remove stop words (simplified list)
stop_words = {
'a', 'an', 'the', 'is', 'are', 'was', 'were', 'be', 'been',
'what', 'how', 'why', 'when', 'where', 'which', 'who',
'do', 'does', 'did', 'can', 'could', 'should', 'would',
'in', 'on', 'at', 'to', 'for', 'of', 'with', 'by'
}
# Split and filter
words = re.findall(r'\b\w+\b', query.lower())
keywords = [w for w in words if w not in stop_words and len(w) > 2]
return keywords[:10] # Limit to top 10
def _assess_complexity(self, query: str, entities: List[str], keywords: List[str]) -> str:
"""Assess query complexity."""
word_count = len(query.split())
entity_count = len(entities)
keyword_count = len(keywords)
# Simple scoring
score = word_count + (entity_count * 2) + (keyword_count * 1.5)
if score < 15:
return 'simple'
elif score < 30:
return 'medium'
else:
return 'complex'
def _detect_domain(self, query_lower: str) -> Optional[str]:
"""Detect technical domain if present."""
domains = {
'programming': ['code', 'function', 'class', 'variable', 'algorithm', 'debug'],
'data_science': ['model', 'dataset', 'training', 'prediction', 'accuracy'],
'machine_learning': ['neural', 'network', 'learning', 'ai', 'deep learning'],
'web': ['html', 'css', 'javascript', 'api', 'frontend', 'backend'],
'database': ['sql', 'query', 'database', 'table', 'index'],
'security': ['encryption', 'authentication', 'vulnerability', 'attack'],
}
for domain, keywords in domains.items():
if any(kw in query_lower for kw in keywords):
return domain
return None
class QueryRephraser:
"""
Generates natural language variations of queries using multiple strategies.
"""
def __init__(self, llm=None):
"""
Initialize QueryRephraser.
Args:
llm: LangChain LLM for generating variations
"""
self.llm = llm
def generate_variations(
self,
query: str,
analysis: QueryAnalysis,
strategy: QueryStrategy = QueryStrategy.BALANCED
) -> List[str]:
"""
Generate query variations based on strategy.
Args:
query: Original query
analysis: Query analysis results
strategy: Expansion strategy to use
Returns:
List of query variations
"""
variations = [query] # Always include original
if strategy == QueryStrategy.QUICK:
# Just add synonym variation
variations.append(self._synonym_variation(query, analysis))
elif strategy == QueryStrategy.BALANCED:
# Add synonym, expanded, and simplified versions
variations.append(self._synonym_variation(query, analysis))
variations.append(self._expanded_variation(query, analysis))
variations.append(self._simplified_variation(query, analysis))
elif strategy == QueryStrategy.COMPREHENSIVE:
# Add all variations
variations.append(self._synonym_variation(query, analysis))
variations.append(self._expanded_variation(query, analysis))
variations.append(self._simplified_variation(query, analysis))
variations.append(self._keyword_focused(query, analysis))
variations.append(self._context_variation(query, analysis))
# Add one more: alternate phrasing
if analysis.intent in ['how_to', 'explanation']:
variations.append(f"Guide to {' '.join(analysis.keywords[:3])}")
# Remove duplicates and None values
variations = [v for v in variations if v]
return list(dict.fromkeys(variations)) # Preserve order, remove dupes
def _synonym_variation(self, query: str, analysis: QueryAnalysis) -> str:
"""Generate variation using synonyms."""
# Common synonym replacements
synonyms = {
'error': 'issue',
'problem': 'issue',
'fix': 'resolve',
'use': 'utilize',
'create': 'generate',
'make': 'create',
'get': 'retrieve',
'show': 'display',
'find': 'locate',
'explain': 'describe',
}
words = query.lower().split()
for i, word in enumerate(words):
if word in synonyms:
words[i] = synonyms[word]
break # Only replace one word to keep natural
return ' '.join(words).capitalize()
def _expanded_variation(self, query: str, analysis: QueryAnalysis) -> str:
"""Generate expanded version with more detail."""
if analysis.intent == 'definition':
return f"Detailed explanation and definition of {' '.join(analysis.keywords)}"
elif analysis.intent == 'how_to':
return f"Step-by-step guide on {query.lower()}"
elif analysis.intent == 'comparison':
return f"Comprehensive comparison: {query}"
else:
# Add qualifying words
return f"Detailed information about {query.lower()}"
def _simplified_variation(self, query: str, analysis: QueryAnalysis) -> str:
"""Generate simplified version focusing on core concepts."""
# Use just the keywords
if len(analysis.keywords) >= 2:
return ' '.join(analysis.keywords[:3])
return query
def _keyword_focused(self, query: str, analysis: QueryAnalysis) -> str:
"""Create keyword-focused variation for BM25."""
keywords = analysis.keywords + analysis.entities
return ' '.join(keywords[:5])
def _context_variation(self, query: str, analysis: QueryAnalysis) -> str:
"""Add contextual information if domain detected."""
if analysis.domain:
return f"{query} in {analysis.domain} context"
return query
class MultiQueryExpander:
"""
Main query expansion orchestrator that combines analysis and rephrasing.
"""
def __init__(self, llm=None):
"""
Initialize MultiQueryExpander.
Args:
llm: LangChain LLM for advanced expansions
"""
self.analyzer = QueryAnalyzer(llm)
self.rephraser = QueryRephraser(llm)
def expand(
self,
query: str,
strategy: QueryStrategy = QueryStrategy.BALANCED,
max_queries: int = 6
) -> ExpandedQuery:
"""
Expand query into multiple variations.
Args:
query: Original user query
strategy: Expansion strategy
max_queries: Maximum number of queries to generate
Returns:
ExpandedQuery object with all variations
"""
# Analyze query
analysis = self.analyzer.analyze(query)
# Generate variations
variations = self.rephraser.generate_variations(query, analysis, strategy)
# Limit to max_queries
variations = variations[:max_queries]
return ExpandedQuery(
original=query,
variations=variations,
strategy_used=strategy,
analysis=analysis
)
class MultiHopReasoner:
"""
Implements multi-hop reasoning to connect concepts across documents.
Useful for complex queries that require information from multiple sources.
"""
def __init__(self, llm=None):
"""
Initialize MultiHopReasoner.
Args:
llm: LangChain LLM for reasoning
"""
self.llm = llm
def generate_sub_queries(self, query: str, analysis: QueryAnalysis) -> List[str]:
"""
Break complex query into sub-queries for multi-hop reasoning.
Args:
query: Original complex query
analysis: Query analysis
Returns:
List of sub-queries
"""
sub_queries = [query]
# For comparison queries, create separate queries for each entity
if analysis.intent == 'comparison' and len(analysis.entities) >= 2:
for entity in analysis.entities[:2]:
sub_queries.append(f"Information about {entity}")
elif analysis.intent == 'comparison' and len(analysis.keywords) >= 2:
# Fallback: use keywords if no entities found
for keyword in analysis.keywords[:2]:
sub_queries.append(f"Information about {keyword}")
# For how-to queries, break into steps
if analysis.intent == 'how_to' and len(analysis.keywords) >= 2:
main_topic = ' '.join(analysis.keywords[:2])
sub_queries.append(f"Prerequisites for {main_topic}")
sub_queries.append(f"Steps to {main_topic}")
# For complex questions, create focused sub-queries
if analysis.complexity == 'complex' and len(analysis.keywords) > 3:
# Create queries focusing on different keyword groups
mid = len(analysis.keywords) // 2
sub_queries.append(' '.join(analysis.keywords[:mid]))
sub_queries.append(' '.join(analysis.keywords[mid:]))
return sub_queries[:5] # Limit to 5 sub-queries
class FallbackStrategies:
"""
Implements fallback strategies for queries that don't retrieve good results.
"""
@staticmethod
def simplify_query(query: str) -> str:
"""Simplify query by removing modifiers and focusing on core terms."""
# Remove question words
query = re.sub(r'\b(what|how|why|when|where|which|who|can|could|should|would)\b', '', query, flags=re.IGNORECASE)
# Remove common phrases
query = re.sub(r'\b(is|are|was|were|be|been|the|a|an)\b', '', query, flags=re.IGNORECASE)
# Clean up extra spaces
query = re.sub(r'\s+', ' ', query).strip()
return query
@staticmethod
def broaden_query(query: str, analysis: QueryAnalysis) -> str:
"""Broaden query to increase recall."""
# Remove specific constraints
query = re.sub(r'\b(specific|exactly|precisely|only|just)\b', '', query, flags=re.IGNORECASE)
# Add general terms
if analysis.keywords:
return f"{analysis.keywords[0]} overview"
return query
@staticmethod
def focus_entities(analysis: QueryAnalysis) -> str:
"""Create entity-focused query as fallback."""
if analysis.entities:
return ' '.join(analysis.entities)
elif analysis.keywords:
return ' '.join(analysis.keywords[:3])
return ""
# Convenience function for easy integration
def expand_query_simple(
query: str,
strategy: str = "balanced",
llm=None
) -> List[str]:
"""
Simple function to expand a query without dealing with classes.
Args:
query: User's query to expand
strategy: "quick", "balanced", or "comprehensive"
llm: Optional LangChain LLM
Returns:
List of expanded query variations
Example:
>>> queries = expand_query_simple("How do I debug Python code?", strategy="balanced")
>>> print(queries)
['How do I debug Python code?', 'How do I resolve Python code?', ...]
"""
expander = MultiQueryExpander(llm=llm)
strategy_enum = QueryStrategy(strategy)
expanded = expander.expand(query, strategy=strategy_enum)
return expanded.variations
# Example usage and testing
if __name__ == "__main__":
# Example 1: Simple query expansion
print("=" * 60)
print("Example 1: Simple Query Expansion")
print("=" * 60)
query = "What is machine learning?"
queries = expand_query_simple(query, strategy="balanced")
print(f"\nOriginal: {query}")
print(f"\nExpanded queries ({len(queries)}):")
for i, q in enumerate(queries, 1):
print(f" {i}. {q}")
# Example 2: Complex query with full analysis
print("\n" + "=" * 60)
print("Example 2: Complex Query with Analysis")
print("=" * 60)
expander = MultiQueryExpander()
query = "How do I compare the performance of different neural network architectures?"
result = expander.expand(query, strategy=QueryStrategy.COMPREHENSIVE)
print(f"\nOriginal: {result.original}")
print(f"\nAnalysis:")
print(f" Intent: {result.analysis.intent}")
print(f" Entities: {result.analysis.entities}")
print(f" Keywords: {result.analysis.keywords}")
print(f" Complexity: {result.analysis.complexity}")
print(f" Domain: {result.analysis.domain}")
print(f"\nExpanded queries ({len(result.variations)}):")
for i, q in enumerate(result.variations, 1):
print(f" {i}. {q}")
# Example 3: Multi-hop reasoning
print("\n" + "=" * 60)
print("Example 3: Multi-Hop Reasoning")
print("=" * 60)
reasoner = MultiHopReasoner()
analyzer = QueryAnalyzer()
query = "Compare Python and Java for web development"
analysis = analyzer.analyze(query)
sub_queries = reasoner.generate_sub_queries(query, analysis)
print(f"\nOriginal: {query}")
print(f"\nSub-queries for multi-hop reasoning:")
for i, sq in enumerate(sub_queries, 1):
print(f" {i}. {sq}")
# Example 4: Fallback strategies
print("\n" + "=" * 60)
print("Example 4: Fallback Strategies")
print("=" * 60)
query = "What is the specific difference between supervised and unsupervised learning?"
analysis = analyzer.analyze(query)