Spaces:
Sleeping
Sleeping
| # memory_updated.py | |
| import re, time, hashlib, asyncio, os | |
| from collections import defaultdict, deque | |
| from typing import List, Dict | |
| import numpy as np | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| from google import genai # must be configured in app.py and imported globally | |
| import logging | |
| from models.summarizer import summarizer | |
| _LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17" | |
| # Load embedding model - use standard model that downloads automatically | |
| EMBED = SentenceTransformer("all-MiniLM-L6-v2", device="cpu") | |
| logger = logging.getLogger("rag-agent") | |
| logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader | |
| api_key = os.getenv("FlashAPI") | |
| client = genai.Client(api_key=api_key) | |
| class MemoryManager: | |
| def __init__(self, max_users=1000, history_per_user=20, max_chunks=60): | |
| # STM: recent conversation summaries (topic + summary), up to 5 entries | |
| self.stm_summaries = defaultdict(lambda: deque(maxlen=history_per_user)) # deque of {topic,text,vec,timestamp,used} | |
| # Legacy raw cache (kept for compatibility if needed) | |
| self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user)) | |
| # LTM: semantic chunk store (approx 3 chunks x 20 rounds) | |
| self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index | |
| self.chunk_meta = defaultdict(list) # '' -> list[{text,tag,vec,timestamp,used}] | |
| self.user_queue = deque(maxlen=max_users) # LRU of users | |
| self.max_chunks = max_chunks # hard cap per user | |
| self.chunk_cache = {} # hash(query+resp) -> [chunks] | |
| # ---------- Public API ---------- | |
| def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"): | |
| self._touch_user(user_id) | |
| # Keep raw record (optional) | |
| self.text_cache[user_id].append(((query or "").strip(), (response or "").strip())) | |
| if not response: return [] | |
| # Avoid re-chunking identical response | |
| cache_key = hashlib.md5((query + response).encode()).hexdigest() | |
| if cache_key in self.chunk_cache: | |
| chunks = self.chunk_cache[cache_key] | |
| else: | |
| chunks = self.chunk_response(response, lang, question=query) | |
| self.chunk_cache[cache_key] = chunks | |
| # Update STM with merging/deduplication | |
| for chunk in chunks: | |
| self._upsert_stm(user_id, chunk, lang) | |
| # Update LTM with merging/deduplication | |
| self._upsert_ltm(user_id, chunks, lang) | |
| return chunks | |
| def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]: | |
| """Return texts of chunks whose cosine similarity ≥ min_sim.""" | |
| if self.chunk_index[user_id].ntotal == 0: | |
| return [] | |
| # Encode chunk | |
| qvec = self._embed(query) | |
| sims, idxs = self.chunk_index[user_id].search(np.array([qvec]), k=top_k) | |
| results = [] | |
| # Append related result with smart-decay to optimize storage and prioritize most-recent chat | |
| for sim, idx in zip(sims[0], idxs[0]): | |
| if idx < len(self.chunk_meta[user_id]) and sim >= min_sim: | |
| chunk = self.chunk_meta[user_id][idx] | |
| chunk["used"] += 1 # increment usage | |
| # Decay function | |
| age_sec = time.time() - chunk["timestamp"] | |
| decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life | |
| score = sim * decay * (1 + 0.1 * chunk["used"]) | |
| # Append chunk with score | |
| results.append((score, chunk)) | |
| # Sort result on best scored | |
| results.sort(key=lambda x: x[0], reverse=True) | |
| # logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data | |
| return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results] | |
| def get_recent_chat_history(self, user_id: str, num_turns: int = 5) -> List[Dict]: | |
| """ | |
| Get the most recent short-term memory summaries. | |
| Returns: a list of entries containing only the summarized bot context. | |
| """ | |
| if user_id not in self.stm_summaries: | |
| return [] | |
| recent = list(self.stm_summaries[user_id])[-num_turns:] | |
| formatted = [] | |
| for entry in recent: | |
| formatted.append({ | |
| "user": "", | |
| "bot": f"Topic: {entry['topic']}\n{entry['text']}", | |
| "timestamp": entry.get("timestamp", time.time()) | |
| }) | |
| return formatted | |
| def get_context(self, user_id: str, num_turns: int = 5) -> str: | |
| # Prefer STM summaries | |
| history = self.get_recent_chat_history(user_id, num_turns=num_turns) | |
| return "\n".join(h["bot"] for h in history) | |
| def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str: | |
| """ | |
| Use NVIDIA Llama to create a summarization of relevant context from both recent history and RAG chunks. | |
| This ensures conversational continuity while providing a concise summary for the main LLM. | |
| """ | |
| # Get both types of context | |
| recent_history = self.get_recent_chat_history(user_id, num_turns=5) | |
| rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3) | |
| logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items") | |
| logger.info(f"[Contextual] Retrieved {len(rag_chunks)} RAG chunks") | |
| # Return empty string if no context is found | |
| if not recent_history and not rag_chunks: | |
| logger.info(f"[Contextual] No context found, returning empty string") | |
| return "" | |
| # Prepare context for summarization | |
| context_parts = [] | |
| # Add recent chat history | |
| if recent_history: | |
| history_text = "\n".join([ | |
| f"User: {item['user']}\nBot: {item['bot']}" | |
| for item in recent_history | |
| ]) | |
| context_parts.append(f"Recent conversation history:\n{history_text}") | |
| # Add RAG chunks | |
| if rag_chunks: | |
| rag_text = "\n".join(rag_chunks) | |
| context_parts.append(f"Semantically relevant historical cooking information:\n{rag_text}") | |
| # Combine all context | |
| full_context = "\n\n".join(context_parts) | |
| # Use summarizer to create concise summary | |
| try: | |
| summary = summarizer.summarize_text(full_context, max_length=300) | |
| logger.info(f"[Contextual] Generated summary using NVIDIA Llama: {len(summary)} characters") | |
| return summary | |
| except Exception as e: | |
| logger.error(f"[Contextual] Summarization failed: {e}") | |
| return full_context[:500] + "..." if len(full_context) > 500 else full_context | |
| def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]: | |
| """ | |
| Use NVIDIA Llama to chunk and summarize response by cooking topics. | |
| Returns: [{"tag": ..., "text": ...}, ...] | |
| """ | |
| if not response: | |
| return [] | |
| try: | |
| # Use summarizer to chunk and summarize | |
| chunks = summarizer.chunk_response(response, max_chunk_size=500) | |
| # Convert to the expected format | |
| result_chunks = [] | |
| for i, chunk in enumerate(chunks): | |
| # Extract topic from chunk (first sentence or key cooking terms) | |
| topic = self._extract_topic_from_chunk(chunk) | |
| result_chunks.append({ | |
| "tag": topic, | |
| "text": chunk | |
| }) | |
| logger.info(f"[Memory] 📦 NVIDIA Llama summarized {len(result_chunks)} chunks") | |
| return result_chunks | |
| except Exception as e: | |
| logger.error(f"[Memory] NVIDIA Llama chunking failed: {e}") | |
| # Fallback to simple chunking | |
| return self._fallback_chunking(response) | |
| def _extract_topic_from_chunk(self, chunk: str) -> str: | |
| """Extract a concise topic from a chunk""" | |
| # Look for cooking terms or first sentence | |
| sentences = chunk.split('.') | |
| if sentences: | |
| first_sentence = sentences[0].strip() | |
| if len(first_sentence) > 50: | |
| first_sentence = first_sentence[:50] + "..." | |
| return first_sentence | |
| return "Cooking Information" | |
| def _fallback_chunking(self, response: str) -> List[Dict]: | |
| """Fallback chunking when NVIDIA Llama fails""" | |
| # Simple sentence-based chunking | |
| sentences = re.split(r'[.!?]+', response) | |
| chunks = [] | |
| current_chunk = "" | |
| for sentence in sentences: | |
| sentence = sentence.strip() | |
| if not sentence: | |
| continue | |
| if len(current_chunk) + len(sentence) > 300: | |
| if current_chunk: | |
| chunks.append({ | |
| "tag": "Cooking Information", | |
| "text": current_chunk.strip() | |
| }) | |
| current_chunk = sentence | |
| else: | |
| current_chunk += sentence + ". " | |
| if current_chunk: | |
| chunks.append({ | |
| "tag": "Cooking Information", | |
| "text": current_chunk.strip() | |
| }) | |
| return chunks | |
| # ---------- Private Methods ---------- | |
| def _touch_user(self, user_id: str): | |
| """Update LRU queue""" | |
| if user_id in self.user_queue: | |
| self.user_queue.remove(user_id) | |
| self.user_queue.append(user_id) | |
| def _new_index(self): | |
| """Create new FAISS index""" | |
| return faiss.IndexFlatIP(384) # 384-dim embeddings | |
| def _upsert_stm(self, user_id: str, chunk: Dict, lang: str): | |
| """Update short-term memory with merging/deduplication""" | |
| topic = chunk["tag"] | |
| text = chunk["text"] | |
| # Check for similar topics in STM | |
| for entry in self.stm_summaries[user_id]: | |
| if self._topics_similar(topic, entry["topic"]): | |
| # Merge with existing entry | |
| entry["text"] = summarizer.summarize_text( | |
| f"{entry['text']}\n{text}", | |
| max_length=200 | |
| ) | |
| entry["timestamp"] = time.time() | |
| return | |
| # Add new entry | |
| self.stm_summaries[user_id].append({ | |
| "topic": topic, | |
| "text": text, | |
| "vec": self._embed(f"{topic} {text}"), | |
| "timestamp": time.time(), | |
| "used": 0 | |
| }) | |
| def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str): | |
| """Update long-term memory with merging/deduplication""" | |
| for chunk in chunks: | |
| # Check for similar chunks in LTM | |
| similar_idx = self._find_similar_chunk(user_id, chunk["text"]) | |
| if similar_idx is not None: | |
| # Merge with existing chunk | |
| existing = self.chunk_meta[user_id][similar_idx] | |
| merged_text = summarizer.summarize_text( | |
| f"{existing['text']}\n{chunk['text']}", | |
| max_length=300 | |
| ) | |
| existing["text"] = merged_text | |
| existing["timestamp"] = time.time() | |
| else: | |
| # Add new chunk | |
| if len(self.chunk_meta[user_id]) >= self.max_chunks: | |
| # Remove oldest chunk | |
| self._remove_oldest_chunk(user_id) | |
| vec = self._embed(chunk["text"]) | |
| self.chunk_index[user_id].add(np.array([vec])) | |
| self.chunk_meta[user_id].append({ | |
| "text": chunk["text"], | |
| "tag": chunk["tag"], | |
| "vec": vec, | |
| "timestamp": time.time(), | |
| "used": 0 | |
| }) | |
| def _topics_similar(self, topic1: str, topic2: str) -> bool: | |
| """Check if two topics are similar""" | |
| # Simple similarity check based on common words | |
| words1 = set(topic1.lower().split()) | |
| words2 = set(topic2.lower().split()) | |
| intersection = words1.intersection(words2) | |
| return len(intersection) >= 2 | |
| def _find_similar_chunk(self, user_id: str, text: str) -> int: | |
| """Find similar chunk in LTM""" | |
| if not self.chunk_meta[user_id]: | |
| return None | |
| text_vec = self._embed(text) | |
| sims, idxs = self.chunk_index[user_id].search(np.array([text_vec]), k=3) | |
| for sim, idx in zip(sims[0], idxs[0]): | |
| if sim > 0.8: # High similarity threshold | |
| return int(idx) | |
| return None | |
| def _remove_oldest_chunk(self, user_id: str): | |
| """Remove the oldest chunk from LTM""" | |
| if not self.chunk_meta[user_id]: | |
| return | |
| # Find oldest chunk | |
| oldest_idx = min(range(len(self.chunk_meta[user_id])), | |
| key=lambda i: self.chunk_meta[user_id][i]["timestamp"]) | |
| # Remove from index and metadata | |
| self.chunk_meta[user_id].pop(oldest_idx) | |
| # Note: FAISS doesn't support direct removal, so we rebuild the index | |
| self._rebuild_index(user_id) | |
| def _rebuild_index(self, user_id: str): | |
| """Rebuild FAISS index after removal""" | |
| if not self.chunk_meta[user_id]: | |
| self.chunk_index[user_id] = self._new_index() | |
| return | |
| vectors = [chunk["vec"] for chunk in self.chunk_meta[user_id]] | |
| self.chunk_index[user_id] = self._new_index() | |
| self.chunk_index[user_id].add(np.array(vectors)) | |
| def _embed(text: str): | |
| vec = EMBED.encode(text, convert_to_numpy=True) | |
| # L2 normalise for cosine on IndexFlatIP | |
| return vec / (np.linalg.norm(vec) + 1e-9) | |