File size: 8,798 Bytes
3718631
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
"""
RAG Retrieval Utilities for gprMax Documentation
Provides search and retrieval functions for the vector database
"""

import logging
from pathlib import Path
from typing import List, Dict, Any, Optional, Tuple
import json

import chromadb
from dataclasses import dataclass

logger = logging.getLogger(__name__)


@dataclass
class SearchResult:
    """Container for search results"""
    text: str
    score: float
    metadata: Dict[str, Any]
    
    def __str__(self) -> str:
        return f"[Score: {self.score:.3f}] {self.metadata.get('source', 'Unknown')}: {self.text[:100]}..."


# Removed QwenEmbeddingModel class - using ChromaDB's default embedding


class GprMaxRAGRetriever:
    """Retriever for gprMax documentation RAG database"""
    
    def __init__(self, db_path: Path = None):
        if db_path is None:
            db_path = Path(__file__).parent / "chroma_db"
            
        if not db_path.exists():
            raise ValueError(f"Database path {db_path} does not exist. Run generate_db.py first.")
            
        self.db_path = db_path
        
        # Load metadata
        metadata_path = db_path / "metadata.json"
        if metadata_path.exists():
            with open(metadata_path, 'r') as f:
                self.metadata = json.load(f)
        else:
            self.metadata = {}
            
        # Initialize ChromaDB client
        self.client = chromadb.PersistentClient(path=str(db_path))
        
        # Get collection
        self.collection_name = self.metadata.get("collection_name", "gprmax_docs_v1")
        try:
            print(f"[RAG] Loading collection: {self.collection_name}")
            self.collection = self.client.get_collection(self.collection_name)
            doc_count = self.collection.count()
            print(f"[RAG] Loaded collection: {self.collection_name} with {doc_count} documents")
            logger.info(f"Loaded collection: {self.collection_name} with {doc_count} documents")
        except Exception as e:
            print(f"[RAG] ERROR loading collection: {e}")
            raise ValueError(f"Failed to load collection {self.collection_name}: {e}")
        
    def search(
        self,
        query: str,
        k: int = 10,
        threshold: float = 0.0,
        filter_metadata: Optional[Dict[str, Any]] = None
    ) -> List[SearchResult]:
        """
        Search for relevant documents
        
        Args:
            query: Search query text
            k: Number of results to return
            threshold: Minimum similarity score threshold
            filter_metadata: Optional metadata filters
            
        Returns:
            List of SearchResult objects
        """
        # Search in ChromaDB (it will generate embeddings automatically)
        try:
            results = self.collection.query(
                query_texts=[query],  # Use query_texts instead of query_embeddings
                n_results=k,
                where=filter_metadata if filter_metadata else None,
                include=["documents", "metadatas", "distances"]
            )
            logger.info(f"ChromaDB query returned: {len(results.get('documents', [[]])[0]) if results.get('documents') else 0} results")
        except Exception as e:
            logger.error(f"ChromaDB query failed: {e}")
            raise
        
        # Convert to SearchResult objects
        search_results = []
        if results["documents"] and results["documents"][0]:
            for doc, meta, dist in zip(
                results["documents"][0],
                results["metadatas"][0],
                results["distances"][0]
            ):
                # Convert distance to similarity score (1 - normalized_distance)
                score = 1.0 - (dist / 2.0)  # Assuming cosine distance in [-1, 1]
                
                if score >= threshold:
                    search_results.append(SearchResult(
                        text=doc,
                        score=score,
                        metadata=meta
                    ))
                    
        return search_results
    
    def get_context(
        self,
        query: str,
        k: int = 3,
        max_context_length: int = 2000,
        format_as_markdown: bool = True
    ) -> str:
        """
        Get formatted context for a query
        
        Args:
            query: Search query
            k: Number of documents to retrieve
            max_context_length: Maximum total context length
            format_as_markdown: Format output as markdown
            
        Returns:
            Formatted context string
        """
        results = self.search(query, k=k)
        
        if not results:
            return "No relevant documentation found."
            
        context_parts = []
        total_length = 0
        
        for i, result in enumerate(results, 1):
            if total_length >= max_context_length:
                break
                
            # Truncate if needed
            text = result.text
            if total_length + len(text) > max_context_length:
                text = text[:max_context_length - total_length]
                
            if format_as_markdown:
                source = result.metadata.get("source", "Unknown")
                context_parts.append(
                    f"### Document {i} (Source: {source}, Score: {result.score:.3f})\n"
                    f"```\n{text}\n```\n"
                )
            else:
                context_parts.append(text)
                
            total_length += len(text)
            
        return "\n".join(context_parts)
    
    def get_relevant_files(self, query: str, k: int = 5) -> List[str]:
        """Get list of relevant source files for a query"""
        results = self.search(query, k=k)
        
        # Extract unique source files
        sources = set()
        for result in results:
            source = result.metadata.get("source")
            if source:
                sources.add(source)
                
        return sorted(list(sources))
    
    def search_by_file(self, file_pattern: str, k: int = 10) -> List[SearchResult]:
        """Search for documents from specific files"""
        # This would need ChromaDB's where clause with pattern matching
        # For now, we do a broad search and filter
        results = self.collection.get(
            limit=1000,  # Get many results
            include=["documents", "metadatas"]
        )
        
        filtered_results = []
        if results["documents"]:
            for doc, meta in zip(results["documents"], results["metadatas"]):
                source = meta.get("source", "")
                if file_pattern.lower() in source.lower():
                    filtered_results.append(SearchResult(
                        text=doc,
                        score=1.0,  # No score for direct retrieval
                        metadata=meta
                    ))
                    
                if len(filtered_results) >= k:
                    break
                    
        return filtered_results
    
    def get_stats(self) -> Dict[str, Any]:
        """Get database statistics"""
        stats = {
            "total_documents": self.collection.count(),
            "database_path": str(self.db_path),
            "collection_name": self.collection_name,
            "embedding_model": self.metadata.get("embedding_model", "Unknown"),
            "created_at": self.metadata.get("created_at", "Unknown"),
            "chunk_size": self.metadata.get("chunk_size", "Unknown"),
            "chunk_overlap": self.metadata.get("chunk_overlap", "Unknown")
        }
        return stats


def create_retriever(db_path: Optional[Path] = None) -> GprMaxRAGRetriever:
    """Factory function to create a retriever instance"""
    return GprMaxRAGRetriever(db_path=db_path)


if __name__ == "__main__":
    # Example usage
    import sys
    
    if len(sys.argv) > 1:
        query = " ".join(sys.argv[1:])
    else:
        query = "How to create a source in gprMax?"
        
    print(f"Testing retriever with query: '{query}'")
    print("-" * 80)
    
    try:
        retriever = create_retriever()
        
        # Get stats
        stats = retriever.get_stats()
        print(f"Database stats: {stats}")
        print("-" * 80)
        
        # Search
        results = retriever.search(query, k=3)
        print(f"Found {len(results)} results:")
        for i, result in enumerate(results, 1):
            print(f"\n{i}. {result}")
            
        # Get formatted context
        print("\n" + "=" * 80)
        print("Formatted context:")
        print(retriever.get_context(query, k=3))
        
    except Exception as e:
        print(f"Error: {e}")
        sys.exit(1)