Spaces:
Sleeping
Sleeping
File size: 14,127 Bytes
21446aa 830acbf 21446aa aa55081 21446aa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 |
# memory_updated.py
import re, time, hashlib, asyncio, os
from collections import defaultdict, deque
from typing import List, Dict
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
from google import genai # must be configured in app.py and imported globally
import logging
from models.summarizer import summarizer
_LLM_SMALL = "gemini-2.5-flash-lite-preview-06-17"
# Load embedding model - use standard model that downloads automatically
EMBED = SentenceTransformer("all-MiniLM-L6-v2", device="cpu")
logger = logging.getLogger("rag-agent")
logging.basicConfig(level=logging.INFO, format="%(asctime)s — %(name)s — %(levelname)s — %(message)s", force=True) # Change INFO to DEBUG for full-ctx JSON loader
api_key = os.getenv("FlashAPI")
client = genai.Client(api_key=api_key)
class MemoryManager:
def __init__(self, max_users=1000, history_per_user=20, max_chunks=60):
# STM: recent conversation summaries (topic + summary), up to 5 entries
self.stm_summaries = defaultdict(lambda: deque(maxlen=history_per_user)) # deque of {topic,text,vec,timestamp,used}
# Legacy raw cache (kept for compatibility if needed)
self.text_cache = defaultdict(lambda: deque(maxlen=history_per_user))
# LTM: semantic chunk store (approx 3 chunks x 20 rounds)
self.chunk_index = defaultdict(self._new_index) # user_id -> faiss index
self.chunk_meta = defaultdict(list) # '' -> list[{text,tag,vec,timestamp,used}]
self.user_queue = deque(maxlen=max_users) # LRU of users
self.max_chunks = max_chunks # hard cap per user
self.chunk_cache = {} # hash(query+resp) -> [chunks]
# ---------- Public API ----------
def add_exchange(self, user_id: str, query: str, response: str, lang: str = "EN"):
self._touch_user(user_id)
# Keep raw record (optional)
self.text_cache[user_id].append(((query or "").strip(), (response or "").strip()))
if not response: return []
# Avoid re-chunking identical response
cache_key = hashlib.md5((query + response).encode()).hexdigest()
if cache_key in self.chunk_cache:
chunks = self.chunk_cache[cache_key]
else:
chunks = self.chunk_response(response, lang, question=query)
self.chunk_cache[cache_key] = chunks
# Update STM with merging/deduplication
for chunk in chunks:
self._upsert_stm(user_id, chunk, lang)
# Update LTM with merging/deduplication
self._upsert_ltm(user_id, chunks, lang)
return chunks
def get_relevant_chunks(self, user_id: str, query: str, top_k: int = 3, min_sim: float = 0.30) -> List[str]:
"""Return texts of chunks whose cosine similarity ≥ min_sim."""
if self.chunk_index[user_id].ntotal == 0:
return []
# Encode chunk
qvec = self._embed(query)
sims, idxs = self.chunk_index[user_id].search(np.array([qvec]), k=top_k)
results = []
# Append related result with smart-decay to optimize storage and prioritize most-recent chat
for sim, idx in zip(sims[0], idxs[0]):
if idx < len(self.chunk_meta[user_id]) and sim >= min_sim:
chunk = self.chunk_meta[user_id][idx]
chunk["used"] += 1 # increment usage
# Decay function
age_sec = time.time() - chunk["timestamp"]
decay = 1.0 / (1.0 + age_sec / 300) # 5-min half-life
score = sim * decay * (1 + 0.1 * chunk["used"])
# Append chunk with score
results.append((score, chunk))
# Sort result on best scored
results.sort(key=lambda x: x[0], reverse=True)
# logger.info(f"[Memory] RAG Retrieved Topic: {results}") # Inspect vector data
return [f"### Topic: {c['tag']}\n{c['text']}" for _, c in results]
def get_recent_chat_history(self, user_id: str, num_turns: int = 5) -> List[Dict]:
"""
Get the most recent short-term memory summaries.
Returns: a list of entries containing only the summarized bot context.
"""
if user_id not in self.stm_summaries:
return []
recent = list(self.stm_summaries[user_id])[-num_turns:]
formatted = []
for entry in recent:
formatted.append({
"user": "",
"bot": f"Topic: {entry['topic']}\n{entry['text']}",
"timestamp": entry.get("timestamp", time.time())
})
return formatted
def get_context(self, user_id: str, num_turns: int = 5) -> str:
# Prefer STM summaries
history = self.get_recent_chat_history(user_id, num_turns=num_turns)
return "\n".join(h["bot"] for h in history)
def get_contextual_chunks(self, user_id: str, current_query: str, lang: str = "EN") -> str:
"""
Use NVIDIA Llama to create a summarization of relevant context from both recent history and RAG chunks.
This ensures conversational continuity while providing a concise summary for the main LLM.
"""
# Get both types of context
recent_history = self.get_recent_chat_history(user_id, num_turns=5)
rag_chunks = self.get_relevant_chunks(user_id, current_query, top_k=3)
logger.info(f"[Contextual] Retrieved {len(recent_history)} recent history items")
logger.info(f"[Contextual] Retrieved {len(rag_chunks)} RAG chunks")
# Return empty string if no context is found
if not recent_history and not rag_chunks:
logger.info(f"[Contextual] No context found, returning empty string")
return ""
# Prepare context for summarization
context_parts = []
# Add recent chat history
if recent_history:
history_text = "\n".join([
f"User: {item['user']}\nBot: {item['bot']}"
for item in recent_history
])
context_parts.append(f"Recent conversation history:\n{history_text}")
# Add RAG chunks
if rag_chunks:
rag_text = "\n".join(rag_chunks)
context_parts.append(f"Semantically relevant historical cooking information:\n{rag_text}")
# Combine all context
full_context = "\n\n".join(context_parts)
# Use summarizer to create concise summary
try:
summary = summarizer.summarize_text(full_context, max_length=300)
logger.info(f"[Contextual] Generated summary using NVIDIA Llama: {len(summary)} characters")
return summary
except Exception as e:
logger.error(f"[Contextual] Summarization failed: {e}")
return full_context[:500] + "..." if len(full_context) > 500 else full_context
def chunk_response(self, response: str, lang: str, question: str = "") -> List[Dict]:
"""
Use NVIDIA Llama to chunk and summarize response by cooking topics.
Returns: [{"tag": ..., "text": ...}, ...]
"""
if not response:
return []
try:
# Use summarizer to chunk and summarize
chunks = summarizer.chunk_response(response, max_chunk_size=500)
# Convert to the expected format
result_chunks = []
for i, chunk in enumerate(chunks):
# Extract topic from chunk (first sentence or key cooking terms)
topic = self._extract_topic_from_chunk(chunk)
result_chunks.append({
"tag": topic,
"text": chunk
})
logger.info(f"[Memory] 📦 NVIDIA Llama summarized {len(result_chunks)} chunks")
return result_chunks
except Exception as e:
logger.error(f"[Memory] NVIDIA Llama chunking failed: {e}")
# Fallback to simple chunking
return self._fallback_chunking(response)
def _extract_topic_from_chunk(self, chunk: str) -> str:
"""Extract a concise topic from a chunk"""
# Look for cooking terms or first sentence
sentences = chunk.split('.')
if sentences:
first_sentence = sentences[0].strip()
if len(first_sentence) > 50:
first_sentence = first_sentence[:50] + "..."
return first_sentence
return "Cooking Information"
def _fallback_chunking(self, response: str) -> List[Dict]:
"""Fallback chunking when NVIDIA Llama fails"""
# Simple sentence-based chunking
sentences = re.split(r'[.!?]+', response)
chunks = []
current_chunk = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
if len(current_chunk) + len(sentence) > 300:
if current_chunk:
chunks.append({
"tag": "Cooking Information",
"text": current_chunk.strip()
})
current_chunk = sentence
else:
current_chunk += sentence + ". "
if current_chunk:
chunks.append({
"tag": "Cooking Information",
"text": current_chunk.strip()
})
return chunks
# ---------- Private Methods ----------
def _touch_user(self, user_id: str):
"""Update LRU queue"""
if user_id in self.user_queue:
self.user_queue.remove(user_id)
self.user_queue.append(user_id)
def _new_index(self):
"""Create new FAISS index"""
return faiss.IndexFlatIP(384) # 384-dim embeddings
def _upsert_stm(self, user_id: str, chunk: Dict, lang: str):
"""Update short-term memory with merging/deduplication"""
topic = chunk["tag"]
text = chunk["text"]
# Check for similar topics in STM
for entry in self.stm_summaries[user_id]:
if self._topics_similar(topic, entry["topic"]):
# Merge with existing entry
entry["text"] = summarizer.summarize_text(
f"{entry['text']}\n{text}",
max_length=200
)
entry["timestamp"] = time.time()
return
# Add new entry
self.stm_summaries[user_id].append({
"topic": topic,
"text": text,
"vec": self._embed(f"{topic} {text}"),
"timestamp": time.time(),
"used": 0
})
def _upsert_ltm(self, user_id: str, chunks: List[Dict], lang: str):
"""Update long-term memory with merging/deduplication"""
for chunk in chunks:
# Check for similar chunks in LTM
similar_idx = self._find_similar_chunk(user_id, chunk["text"])
if similar_idx is not None:
# Merge with existing chunk
existing = self.chunk_meta[user_id][similar_idx]
merged_text = summarizer.summarize_text(
f"{existing['text']}\n{chunk['text']}",
max_length=300
)
existing["text"] = merged_text
existing["timestamp"] = time.time()
else:
# Add new chunk
if len(self.chunk_meta[user_id]) >= self.max_chunks:
# Remove oldest chunk
self._remove_oldest_chunk(user_id)
vec = self._embed(chunk["text"])
self.chunk_index[user_id].add(np.array([vec]))
self.chunk_meta[user_id].append({
"text": chunk["text"],
"tag": chunk["tag"],
"vec": vec,
"timestamp": time.time(),
"used": 0
})
def _topics_similar(self, topic1: str, topic2: str) -> bool:
"""Check if two topics are similar"""
# Simple similarity check based on common words
words1 = set(topic1.lower().split())
words2 = set(topic2.lower().split())
intersection = words1.intersection(words2)
return len(intersection) >= 2
def _find_similar_chunk(self, user_id: str, text: str) -> int:
"""Find similar chunk in LTM"""
if not self.chunk_meta[user_id]:
return None
text_vec = self._embed(text)
sims, idxs = self.chunk_index[user_id].search(np.array([text_vec]), k=3)
for sim, idx in zip(sims[0], idxs[0]):
if sim > 0.8: # High similarity threshold
return int(idx)
return None
def _remove_oldest_chunk(self, user_id: str):
"""Remove the oldest chunk from LTM"""
if not self.chunk_meta[user_id]:
return
# Find oldest chunk
oldest_idx = min(range(len(self.chunk_meta[user_id])),
key=lambda i: self.chunk_meta[user_id][i]["timestamp"])
# Remove from index and metadata
self.chunk_meta[user_id].pop(oldest_idx)
# Note: FAISS doesn't support direct removal, so we rebuild the index
self._rebuild_index(user_id)
def _rebuild_index(self, user_id: str):
"""Rebuild FAISS index after removal"""
if not self.chunk_meta[user_id]:
self.chunk_index[user_id] = self._new_index()
return
vectors = [chunk["vec"] for chunk in self.chunk_meta[user_id]]
self.chunk_index[user_id] = self._new_index()
self.chunk_index[user_id].add(np.array(vectors))
@staticmethod
def _embed(text: str):
vec = EMBED.encode(text, convert_to_numpy=True)
# L2 normalise for cosine on IndexFlatIP
return vec / (np.linalg.norm(vec) + 1e-9)
|