Spaces:
Sleeping
Sleeping
File size: 10,254 Bytes
9f203f4 6f12b05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 |
# ────────────────────────────── memo/sessions.py ──────────────────────────────
"""
Conversation Session Management
Handles conversation session tracking, context switching detection,
and conversation insights.
"""
import re
import time
from typing import Dict, Any, Tuple, Optional
from utils.logger import get_logger
logger = get_logger("SESSION_MANAGER", __name__)
class SessionManager:
"""
Manages conversation sessions and tracks conversation state.
"""
def __init__(self):
self.conversation_sessions = {} # Track active conversation sessions
self.context_cache = {} # Cache recent context for performance
def get_or_create_session(self, user_id: str, question: str, conversation_mode: str) -> Dict[str, Any]:
"""Get or create conversation session for user"""
current_time = time.time()
if user_id not in self.conversation_sessions:
# New session
self.conversation_sessions[user_id] = {
"session_id": f"{user_id}_{int(current_time)}",
"start_time": current_time,
"last_activity": current_time,
"message_count": 0,
"context_switches": 0,
"depth": 0,
"enhancement_rate": 0.0,
"conversation_mode": conversation_mode,
"last_question": "",
"is_continuation": False
}
return self.conversation_sessions[user_id]
session = self.conversation_sessions[user_id]
# Check if this is a continuation (within 30 minutes and same mode)
time_since_last = current_time - session["last_activity"]
is_continuation = (time_since_last < 1800 and # 30 minutes
session["conversation_mode"] == conversation_mode)
session["is_continuation"] = is_continuation
session["last_activity"] = current_time
session["message_count"] += 1
return session
def update_session(self, user_id: str, original_question: str,
enhanced_input: str, context_used: bool):
"""Update session with new information"""
if user_id not in self.conversation_sessions:
return
session = self.conversation_sessions[user_id]
session["last_question"] = original_question
session["depth"] += 1
# Update enhancement rate
total_enhancements = session.get("total_enhancements", 0)
if context_used:
total_enhancements += 1
session["total_enhancements"] = total_enhancements
session["enhancement_rate"] = total_enhancements / session["message_count"]
async def detect_context_switch(self, user_id: str, new_question: str,
nvidia_rotator=None) -> Dict[str, Any]:
"""Detect if user has switched context/topic"""
try:
session_info = self.conversation_sessions.get(user_id, {})
if not session_info:
return {"is_context_switch": False, "confidence": 0.0}
# Check if this is a context switch
is_switch, confidence = await self._detect_context_switch(
session_info.get("last_question", ""), new_question, nvidia_rotator
)
if is_switch and confidence > 0.7:
# Clear recent context cache for fresh start
self.context_cache.pop(user_id, None)
# Update session to indicate context switch
session_info["context_switches"] = session_info.get("context_switches", 0) + 1
session_info["last_context_switch"] = time.time()
logger.info(f"[SESSION_MANAGER] Context switch detected for user {user_id} (confidence: {confidence:.2f})")
return {
"is_context_switch": True,
"confidence": confidence,
"switch_count": session_info["context_switches"]
}
return {"is_context_switch": False, "confidence": confidence}
except Exception as e:
logger.error(f"[SESSION_MANAGER] Context switch detection failed: {e}")
return {"is_context_switch": False, "confidence": 0.0, "error": str(e)}
def get_conversation_insights(self, user_id: str) -> Dict[str, Any]:
"""Get insights about the user's conversation patterns"""
try:
session_info = self.conversation_sessions.get(user_id, {})
if not session_info:
return {"status": "no_active_session"}
return {
"session_duration": time.time() - session_info.get("start_time", time.time()),
"message_count": session_info.get("message_count", 0),
"context_switches": session_info.get("context_switches", 0),
"last_activity": session_info.get("last_activity", 0),
"conversation_depth": session_info.get("depth", 0),
"enhancement_rate": session_info.get("enhancement_rate", 0.0)
}
except Exception as e:
logger.error(f"[SESSION_MANAGER] Failed to get conversation insights: {e}")
return {"error": str(e)}
def clear_session(self, user_id: str):
"""Clear session for user"""
if user_id in self.conversation_sessions:
del self.conversation_sessions[user_id]
if user_id in self.context_cache:
del self.context_cache[user_id]
# ────────────────────────────── Private Helper Methods ──────────────────────────────
async def _detect_context_switch(self, last_question: str, new_question: str,
nvidia_rotator) -> Tuple[bool, float]:
"""Detect if user has switched context/topic"""
try:
if not last_question or not new_question:
return False, 0.0
if nvidia_rotator:
try:
from utils.api.router import generate_answer_with_model
sys_prompt = """You are an expert at detecting context switches in conversations.
Given two consecutive questions, determine if the user has switched to a completely different topic or context.
Consider:
- Different subject matter
- Different intent or goal
- No logical connection between questions
- Change in conversation direction
Respond with a JSON object: {"is_context_switch": true/false, "confidence": 0.0-1.0}"""
user_prompt = f"""PREVIOUS QUESTION: {last_question}
CURRENT QUESTION: {new_question}
Is this a context switch?"""
selection = {"provider": "nvidia", "model": "meta/llama-3.1-8b-instruct"}
response = await generate_answer_with_model(
selection=selection,
system_prompt=sys_prompt,
user_prompt=user_prompt,
gemini_rotator=None,
nvidia_rotator=nvidia_rotator
)
# Parse JSON response
import json
try:
result = json.loads(response.strip())
return result.get("is_context_switch", False), result.get("confidence", 0.0)
except:
pass
except Exception as e:
logger.warning(f"[SESSION_MANAGER] Context switch detection failed: {e}")
# Fallback: simple keyword-based detection
return self._simple_context_switch_detection(last_question, new_question)
except Exception as e:
logger.warning(f"[SESSION_MANAGER] Context switch detection failed: {e}")
return False, 0.0
def _simple_context_switch_detection(self, last_question: str, new_question: str) -> Tuple[bool, float]:
"""Simple keyword-based context switch detection"""
try:
# Extract keywords from both questions
last_words = set(re.findall(r'\b\w+\b', last_question.lower()))
new_words = set(re.findall(r'\b\w+\b', new_question.lower()))
# Calculate overlap
overlap = len(last_words.intersection(new_words))
total_unique = len(last_words.union(new_words))
if total_unique == 0:
return False, 0.0
similarity = overlap / total_unique
# Context switch if similarity is very low
is_switch = similarity < 0.1
confidence = 1.0 - similarity if is_switch else similarity
return is_switch, confidence
except Exception as e:
logger.warning(f"[SESSION_MANAGER] Simple context switch detection failed: {e}")
return False, 0.0
# ────────────────────────────── Global Instance ──────────────────────────────
_session_manager: Optional[SessionManager] = None
def get_session_manager() -> SessionManager:
"""Get the global session manager instance"""
global _session_manager
if _session_manager is None:
_session_manager = SessionManager()
logger.info("[SESSION_MANAGER] Global session manager initialized")
return _session_manager
# def reset_session_manager():
# """Reset the global session manager (for testing)"""
# global _session_manager
# _session_manager = None
|