|
|
import numpy as np |
|
|
import faiss |
|
|
from typing import List, Dict, Optional |
|
|
from sentence_transformers import SentenceTransformer |
|
|
from models.schemas import RAGSearchResult |
|
|
from config.settings import settings |
|
|
from core.multilingual_manager import MultilingualManager |
|
|
|
|
|
class EnhancedRAGSystem: |
|
|
def __init__(self): |
|
|
self.documents: List[str] = [] |
|
|
self.metadatas: List[Dict] = [] |
|
|
self.embeddings: Optional[np.ndarray] = None |
|
|
self.index: Optional[faiss.Index] = None |
|
|
|
|
|
|
|
|
self.multilingual_manager = MultilingualManager() |
|
|
self.current_dimension = settings.EMBEDDING_DIMENSION |
|
|
|
|
|
self._initialize_sample_data() |
|
|
|
|
|
def _initialize_sample_data(self): |
|
|
"""Khởi tạo dữ liệu mẫu""" |
|
|
|
|
|
vietnamese_data = [ |
|
|
"Rau xanh cung cấp nhiều vitamin và chất xơ tốt cho sức khỏe", |
|
|
"Trái cây tươi chứa nhiều vitamin C và chất chống oxy hóa", |
|
|
"Cá hồi giàu omega-3 tốt cho tim mạch và trí não", |
|
|
"Nước rất quan trọng cho cơ thể, nên uống ít nhất 2 lít mỗi ngày", |
|
|
"Hà Nội là thủ đô của Việt Nam, nằm ở miền Bắc", |
|
|
"Thành phố Hồ Chí Minh là thành phố lớn nhất Việt Nam", |
|
|
"Việt Nam có khí hậu nhiệt đới gió mùa với 4 mùa rõ rệt" |
|
|
] |
|
|
|
|
|
|
|
|
english_data = [ |
|
|
"Green vegetables provide many vitamins and fiber that are good for health", |
|
|
"Fresh fruits contain lots of vitamin C and antioxidants", |
|
|
"Salmon is rich in omega-3 which is good for heart and brain", |
|
|
"Water is very important for the body, should drink at least 2 liters per day", |
|
|
"London is the capital of England and the United Kingdom", |
|
|
"New York City is the most populous city in the United States", |
|
|
"The United States has diverse climate zones from tropical to arctic" |
|
|
] |
|
|
|
|
|
|
|
|
vietnamese_metadatas = [ |
|
|
{"type": "nutrition", "source": "sample", "language": "vi"}, |
|
|
{"type": "nutrition", "source": "sample", "language": "vi"}, |
|
|
{"type": "nutrition", "source": "sample", "language": "vi"}, |
|
|
{"type": "health", "source": "sample", "language": "vi"}, |
|
|
{"type": "geography", "source": "sample", "language": "vi"}, |
|
|
{"type": "geography", "source": "sample", "language": "vi"}, |
|
|
{"type": "geography", "source": "sample", "language": "vi"} |
|
|
] |
|
|
|
|
|
|
|
|
english_metadatas = [ |
|
|
{"type": "nutrition", "source": "sample", "language": "en"}, |
|
|
{"type": "nutrition", "source": "sample", "language": "en"}, |
|
|
{"type": "nutrition", "source": "sample", "language": "en"}, |
|
|
{"type": "health", "source": "sample", "language": "en"}, |
|
|
{"type": "geography", "source": "sample", "language": "en"}, |
|
|
{"type": "geography", "source": "sample", "language": "en"}, |
|
|
{"type": "geography", "source": "sample", "language": "en"} |
|
|
] |
|
|
|
|
|
|
|
|
self.add_documents(vietnamese_data, vietnamese_metadatas) |
|
|
self.add_documents(english_data, english_metadatas) |
|
|
|
|
|
def add_documents(self, documents: List[str], metadatas: List[Dict] = None): |
|
|
"""Thêm documents vào database - ĐÃ SỬA LỖI""" |
|
|
print(f"🔄 RAG System: Bắt đầu thêm {len(documents)} documents...") |
|
|
|
|
|
if not documents: |
|
|
print("❌ RAG System: Không có documents để thêm") |
|
|
return |
|
|
|
|
|
|
|
|
if metadatas is None: |
|
|
metadatas = [{} for _ in documents] |
|
|
print("📝 Tạo metadata mặc định") |
|
|
elif len(metadatas) != len(documents): |
|
|
print(f"⚠️ Metadata length mismatch: {len(metadatas)} vs {len(documents)}") |
|
|
|
|
|
new_metadatas = [] |
|
|
for i in range(len(documents)): |
|
|
if i < len(metadatas): |
|
|
new_metadatas.append(metadatas[i]) |
|
|
else: |
|
|
new_metadatas.append({"source": "upload", "language": "vi"}) |
|
|
metadatas = new_metadatas |
|
|
|
|
|
|
|
|
valid_documents = [] |
|
|
valid_metadatas = [] |
|
|
|
|
|
for i, doc in enumerate(documents): |
|
|
if doc and isinstance(doc, str) and len(doc.strip()) > 5: |
|
|
valid_documents.append(doc.strip()) |
|
|
valid_metadatas.append(metadatas[i] if i < len(metadatas) else {}) |
|
|
else: |
|
|
print(f"⚠️ Bỏ qua document {i}: không hợp lệ") |
|
|
|
|
|
print(f"📊 Documents hợp lệ: {len(valid_documents)}/{len(documents)}") |
|
|
|
|
|
if not valid_documents: |
|
|
print("❌ Không có documents hợp lệ để thêm") |
|
|
return |
|
|
|
|
|
|
|
|
new_embeddings_list = [] |
|
|
successful_embeddings = 0 |
|
|
|
|
|
for i, doc in enumerate(valid_documents): |
|
|
try: |
|
|
language = valid_metadatas[i].get('language', 'vi') |
|
|
embedding_model = self.multilingual_manager.get_embedding_model(language) |
|
|
|
|
|
if embedding_model is None: |
|
|
print(f"⚠️ Không có embedding model cho document {i}") |
|
|
continue |
|
|
|
|
|
|
|
|
doc_embedding = embedding_model.encode([doc]) |
|
|
new_embeddings_list.append(doc_embedding[0]) |
|
|
successful_embeddings += 1 |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi embedding document {i}: {e}") |
|
|
|
|
|
print(f"📊 Embeddings thành công: {successful_embeddings}/{len(valid_documents)}") |
|
|
|
|
|
if not new_embeddings_list: |
|
|
print("❌ Không tạo được embeddings nào") |
|
|
return |
|
|
|
|
|
|
|
|
try: |
|
|
new_embeddings = np.array(new_embeddings_list) |
|
|
print(f"✅ Embedding matrix shape: {new_embeddings.shape}") |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi tạo embedding matrix: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
old_doc_count = len(self.documents) |
|
|
|
|
|
if self.embeddings is None: |
|
|
|
|
|
self.embeddings = new_embeddings |
|
|
self.documents = valid_documents |
|
|
self.metadatas = valid_metadatas |
|
|
print("✅ Khởi tạo RAG system lần đầu") |
|
|
else: |
|
|
|
|
|
try: |
|
|
|
|
|
if self.embeddings.shape[1] != new_embeddings.shape[1]: |
|
|
print(f"⚠️ Dimension mismatch: {self.embeddings.shape[1]} vs {new_embeddings.shape[1]}") |
|
|
print("🔄 Tạo system mới do dimension không khớp") |
|
|
self.embeddings = new_embeddings |
|
|
self.documents = valid_documents |
|
|
self.metadatas = valid_metadatas |
|
|
else: |
|
|
|
|
|
self.embeddings = np.vstack([self.embeddings, new_embeddings]) |
|
|
self.documents.extend(valid_documents) |
|
|
self.metadatas.extend(valid_metadatas) |
|
|
print("✅ Đã thêm vào system hiện có") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi khi thêm vào system: {e}") |
|
|
return |
|
|
|
|
|
|
|
|
self._update_faiss_index() |
|
|
|
|
|
new_doc_count = len(self.documents) |
|
|
print(f"🎉 THÀNH CÔNG: Đã thêm {new_doc_count - old_doc_count} documents mới") |
|
|
print(f"📊 Tổng documents: {new_doc_count}") |
|
|
|
|
|
def _update_faiss_index(self): |
|
|
"""Cập nhật FAISS index với embeddings hiện tại""" |
|
|
if self.embeddings is None or len(self.embeddings) == 0: |
|
|
return |
|
|
|
|
|
try: |
|
|
dimension = self.embeddings.shape[1] |
|
|
self.index = faiss.IndexFlatIP(dimension) |
|
|
|
|
|
|
|
|
faiss.normalize_L2(self.embeddings) |
|
|
self.index.add(self.embeddings.astype(np.float32)) |
|
|
|
|
|
print(f"✅ Đã cập nhật FAISS index với dimension {dimension}") |
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi cập nhật FAISS index: {e}") |
|
|
|
|
|
def semantic_search(self, query: str, top_k: int = None) -> List[RAGSearchResult]: |
|
|
"""Tìm kiếm ngữ nghĩa với model phù hợp theo ngôn ngữ""" |
|
|
if top_k is None: |
|
|
top_k = settings.TOP_K_RESULTS |
|
|
|
|
|
if not self.documents or self.index is None: |
|
|
return self._fallback_keyword_search(query, top_k) |
|
|
|
|
|
|
|
|
query_language = self.multilingual_manager.detect_language(query) |
|
|
embedding_model = self.multilingual_manager.get_embedding_model(query_language) |
|
|
|
|
|
if embedding_model is None: |
|
|
return self._fallback_keyword_search(query, top_k) |
|
|
|
|
|
try: |
|
|
|
|
|
query_embedding = embedding_model.encode([query]) |
|
|
|
|
|
|
|
|
faiss.normalize_L2(query_embedding) |
|
|
|
|
|
|
|
|
similarities, indices = self.index.search( |
|
|
query_embedding.astype(np.float32), |
|
|
min(top_k, len(self.documents)) |
|
|
) |
|
|
|
|
|
results = [] |
|
|
for i, (similarity, idx) in enumerate(zip(similarities[0], indices[0])): |
|
|
if idx < len(self.documents): |
|
|
results.append(RAGSearchResult( |
|
|
id=str(idx), |
|
|
text=self.documents[idx], |
|
|
similarity=float(similarity), |
|
|
metadata=self.metadatas[idx] if idx < len(self.metadatas) else {} |
|
|
)) |
|
|
|
|
|
|
|
|
filtered_results = self._filter_by_language_relevance(results, query_language) |
|
|
|
|
|
print(f"🔍 Tìm kiếm '{query[:50]}...' (ngôn ngữ: {query_language}) - Tìm thấy {len(filtered_results)} kết quả") |
|
|
return filtered_results |
|
|
|
|
|
except Exception as e: |
|
|
print(f"❌ Lỗi tìm kiếm ngữ nghĩa: {e}") |
|
|
return self._fallback_keyword_search(query, top_k) |
|
|
|
|
|
def _filter_by_language_relevance(self, results: List[RAGSearchResult], query_language: str) -> List[RAGSearchResult]: |
|
|
"""Lọc kết quả theo độ liên quan ngôn ngữ""" |
|
|
if not results: |
|
|
return results |
|
|
|
|
|
|
|
|
for result in results: |
|
|
doc_language = result.metadata.get('language', 'vi') |
|
|
if doc_language == query_language: |
|
|
|
|
|
result.similarity = min(result.similarity * 1.2, 1.0) |
|
|
|
|
|
|
|
|
results.sort(key=lambda x: x.similarity, reverse=True) |
|
|
return results |
|
|
|
|
|
def _fallback_keyword_search(self, query: str, top_k: int) -> List[RAGSearchResult]: |
|
|
"""Tìm kiếm dự phòng dựa trên từ khóa""" |
|
|
query_lower = query.lower() |
|
|
results = [] |
|
|
|
|
|
for i, doc in enumerate(self.documents): |
|
|
score = 0 |
|
|
doc_language = self.metadatas[i].get('language', 'vi') if i < len(self.metadatas) else 'vi' |
|
|
query_language = self.multilingual_manager.detect_language(query) |
|
|
|
|
|
|
|
|
if doc_language == query_language: |
|
|
score += 0.5 |
|
|
|
|
|
|
|
|
for word in query_lower.split(): |
|
|
if len(word) > 2 and word in doc.lower(): |
|
|
score += 1 |
|
|
|
|
|
if score > 0: |
|
|
results.append(RAGSearchResult( |
|
|
id=str(i), |
|
|
text=doc, |
|
|
similarity=min(score / 5, 1.0), |
|
|
metadata=self.metadatas[i] if i < len(self.metadatas) else {} |
|
|
)) |
|
|
|
|
|
results.sort(key=lambda x: x.similarity, reverse=True) |
|
|
return results[:top_k] |
|
|
|
|
|
def get_collection_stats(self) -> Dict: |
|
|
"""Lấy thống kê collection với thông tin đa ngôn ngữ""" |
|
|
language_stats = {} |
|
|
for metadata in self.metadatas: |
|
|
lang = metadata.get('language', 'unknown') |
|
|
language_stats[lang] = language_stats.get(lang, 0) + 1 |
|
|
|
|
|
return { |
|
|
'total_documents': len(self.documents), |
|
|
'embedding_count': len(self.embeddings) if self.embeddings is not None else 0, |
|
|
'embedding_dimension': self.current_dimension, |
|
|
'language_distribution': language_stats, |
|
|
'name': 'multilingual_rag_system', |
|
|
'status': 'active', |
|
|
'has_embeddings': self.embeddings is not None |
|
|
} |
|
|
|