Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| # Must be the first Streamlit command | |
| st.set_page_config( | |
| page_title="02_Chat_Interface", # Use this format for ordering | |
| page_icon="π¬", | |
| layout="wide" | |
| ) | |
| # Rest of the imports | |
| import pandas as pd | |
| import logging | |
| import sqlite3 | |
| from datetime import datetime | |
| import sys | |
| import os | |
| # Add the parent directory to Python path | |
| sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # Use absolute imports | |
| from database import DatabaseHandler | |
| from data_processor import DataProcessor | |
| from rag import RAGSystem | |
| from query_rewriter import QueryRewriter | |
| from utils import process_single_video | |
| # Set up logging | |
| # Configure logging for stdout only | |
| logging.basicConfig( | |
| level=logging.INFO, | |
| format='%(asctime)s - %(name)s - %(levelname)s - %(message)s', | |
| stream=sys.stdout | |
| ) | |
| logger = logging.getLogger(__name__) | |
| def init_components(): | |
| """Initialize system components""" | |
| try: | |
| db_handler = DatabaseHandler() | |
| data_processor = DataProcessor() | |
| rag_system = RAGSystem(data_processor) | |
| query_rewriter = QueryRewriter() | |
| return db_handler, data_processor, rag_system, query_rewriter | |
| except Exception as e: | |
| logger.error(f"Error initializing components: {str(e)}") | |
| st.error(f"Error initializing components: {str(e)}") | |
| return None, None, None, None | |
| def init_session_state(): | |
| """Initialize session state variables""" | |
| if 'chat_history' not in st.session_state: | |
| st.session_state.chat_history = [] | |
| if 'current_video_id' not in st.session_state: | |
| st.session_state.current_video_id = None | |
| if 'feedback_given' not in st.session_state: | |
| st.session_state.feedback_given = set() | |
| def create_chat_interface(db_handler, rag_system, video_id, index_name, rewrite_method, search_method): | |
| """Create the chat interface with feedback functionality""" | |
| # Load chat history if video changed | |
| if st.session_state.current_video_id != video_id: | |
| st.session_state.chat_history = [] | |
| db_history = db_handler.get_chat_history(video_id) | |
| for chat_id, user_msg, asst_msg, timestamp in db_history: | |
| st.session_state.chat_history.append({ | |
| 'id': chat_id, | |
| 'user': user_msg, | |
| 'assistant': asst_msg, | |
| 'timestamp': timestamp | |
| }) | |
| st.session_state.current_video_id = video_id | |
| # Display chat history | |
| for message in st.session_state.chat_history: | |
| with st.chat_message("user"): | |
| st.markdown(message['user']) | |
| with st.chat_message("assistant"): | |
| st.markdown(message['assistant']) | |
| message_key = f"{message['id']}" | |
| if message_key not in st.session_state.feedback_given: | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("π", key=f"like_{message_key}"): | |
| db_handler.add_user_feedback( | |
| video_id=video_id, | |
| chat_id=message['id'], | |
| query=message['user'], | |
| response=message['assistant'], | |
| feedback=1 | |
| ) | |
| st.session_state.feedback_given.add(message_key) | |
| st.success("Thank you for your positive feedback!") | |
| st.rerun() | |
| with col2: | |
| if st.button("π", key=f"dislike_{message_key}"): | |
| db_handler.add_user_feedback( | |
| video_id=video_id, | |
| chat_id=message['id'], | |
| query=message['user'], | |
| response=message['assistant'], | |
| feedback=-1 | |
| ) | |
| st.session_state.feedback_given.add(message_key) | |
| st.success("Thank you for your feedback. We'll work to improve.") | |
| st.rerun() | |
| # Chat input | |
| if prompt := st.chat_input("Ask a question about the video..."): | |
| with st.chat_message("user"): | |
| st.markdown(prompt) | |
| with st.chat_message("assistant"): | |
| with st.spinner("Thinking..."): | |
| try: | |
| # Apply query rewriting if selected | |
| rewritten_query = prompt | |
| if rewrite_method == "Chain of Thought": | |
| rewritten_query, _ = rag_system.rewrite_cot(prompt) | |
| st.caption("Rewritten query: " + rewritten_query) | |
| elif rewrite_method == "ReAct": | |
| rewritten_query, _ = rag_system.rewrite_react(prompt) | |
| st.caption("Rewritten query: " + rewritten_query) | |
| # Get response using selected search method | |
| search_method_map = { | |
| "Hybrid": "hybrid", | |
| "Text-only": "text", | |
| "Embedding-only": "embedding" | |
| } | |
| response, _ = rag_system.query( | |
| rewritten_query, | |
| search_method=search_method_map[search_method], | |
| index_name=index_name | |
| ) | |
| st.markdown(response) | |
| # Save to database and session state | |
| chat_id = db_handler.add_chat_message(video_id, prompt, response) | |
| st.session_state.chat_history.append({ | |
| 'id': chat_id, | |
| 'user': prompt, | |
| 'assistant': response, | |
| 'timestamp': datetime.now() | |
| }) | |
| # Add feedback buttons for new message | |
| message_key = f"{chat_id}" | |
| col1, col2 = st.columns(2) | |
| with col1: | |
| if st.button("π", key=f"like_{message_key}"): | |
| db_handler.add_user_feedback( | |
| video_id=video_id, | |
| chat_id=chat_id, | |
| query=prompt, | |
| response=response, | |
| feedback=1 | |
| ) | |
| st.session_state.feedback_given.add(message_key) | |
| st.success("Thank you for your positive feedback!") | |
| st.rerun() | |
| with col2: | |
| if st.button("π", key=f"dislike_{message_key}"): | |
| db_handler.add_user_feedback( | |
| video_id=video_id, | |
| chat_id=chat_id, | |
| query=prompt, | |
| response=response, | |
| feedback=-1 | |
| ) | |
| st.session_state.feedback_given.add(message_key) | |
| st.success("Thank you for your feedback. We'll work to improve.") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error generating response: {str(e)}") | |
| logger.error(f"Error in chat interface: {str(e)}") | |
| def get_system_status(db_handler, selected_video_id=None): | |
| """Get system status information""" | |
| try: | |
| with sqlite3.connect(db_handler.db_path) as conn: | |
| cursor = conn.cursor() | |
| # Get total videos | |
| cursor.execute("SELECT COUNT(*) FROM videos") | |
| total_videos = cursor.fetchone()[0] | |
| # Get total indices | |
| cursor.execute("SELECT COUNT(DISTINCT index_name) FROM elasticsearch_indices") | |
| total_indices = cursor.fetchone()[0] | |
| # Get available embedding models | |
| cursor.execute("SELECT model_name FROM embedding_models") | |
| models = [row[0] for row in cursor.fetchall()] | |
| if selected_video_id: | |
| # Get video details | |
| cursor.execute(""" | |
| SELECT v.id, v.title, v.channel_name, v.processed_date, | |
| ei.index_name, em.model_name | |
| FROM videos v | |
| LEFT JOIN elasticsearch_indices ei ON v.id = ei.video_id | |
| LEFT JOIN embedding_models em ON ei.embedding_model_id = em.id | |
| WHERE v.youtube_id = ? | |
| """, (selected_video_id,)) | |
| video_details = cursor.fetchall() | |
| else: | |
| video_details = None | |
| return { | |
| "total_videos": total_videos, | |
| "total_indices": total_indices, | |
| "models": models, | |
| "video_details": video_details | |
| } | |
| except Exception as e: | |
| logger.error(f"Error getting system status: {str(e)}") | |
| return None | |
| def display_system_status(status, selected_video_id=None): | |
| """Display system status in the sidebar""" | |
| if not status: | |
| st.sidebar.error("Unable to fetch system status") | |
| return | |
| st.sidebar.header("System Status") | |
| # Display general stats | |
| col1, col2 = st.sidebar.columns(2) | |
| with col1: | |
| st.metric("Total Videos", status["total_videos"]) | |
| with col2: | |
| st.metric("Total Indices", status["total_indices"]) | |
| st.sidebar.markdown("**Available Models:**") | |
| for model in status["models"]: | |
| st.sidebar.markdown(f"- {model}") | |
| # Display selected video details | |
| if selected_video_id and status["video_details"]: | |
| st.sidebar.markdown("---") | |
| st.sidebar.markdown("**Selected Video Details:**") | |
| for details in status["video_details"]: | |
| video_id, title, channel, processed_date, index_name, model = details | |
| st.sidebar.markdown(f""" | |
| - **Title:** {title} | |
| - **Channel:** {channel} | |
| - **Processed:** {processed_date} | |
| - **Index:** {index_name or 'Not indexed'} | |
| - **Model:** {model or 'N/A'} | |
| """) | |
| def main(): | |
| st.title("Chat Interface π¬") | |
| # Initialize components | |
| components = init_components() | |
| if not components: | |
| st.error("Failed to initialize components. Please check the logs.") | |
| return | |
| db_handler, data_processor, rag_system, query_rewriter = components | |
| # Initialize session state | |
| init_session_state() | |
| # Get system status | |
| system_status = get_system_status(db_handler) | |
| # Video selection | |
| st.sidebar.header("Video Selection") | |
| # Get available videos with indices | |
| with sqlite3.connect(db_handler.db_path) as conn: | |
| query = """ | |
| SELECT DISTINCT v.youtube_id, v.title, v.channel_name, v.upload_date, | |
| GROUP_CONCAT(ei.index_name) as indices | |
| FROM videos v | |
| LEFT JOIN elasticsearch_indices ei ON v.id = ei.video_id | |
| GROUP BY v.youtube_id | |
| ORDER BY v.upload_date DESC | |
| """ | |
| df = pd.read_sql_query(query, conn) | |
| if df.empty: | |
| st.info("No videos available. Please process some videos in the Data Ingestion page first.") | |
| display_system_status(system_status) | |
| return | |
| # Display available videos | |
| st.sidebar.markdown(f"**Available Videos:** {len(df)}") | |
| # Channel filter | |
| channels = sorted(df['channel_name'].unique()) | |
| selected_channel = st.sidebar.selectbox( | |
| "Filter by Channel", | |
| ["All"] + channels, | |
| key="channel_filter" | |
| ) | |
| filtered_df = df if selected_channel == "All" else df[df['channel_name'] == selected_channel] | |
| # Video selection | |
| selected_video_id = st.sidebar.selectbox( | |
| "Select a Video", | |
| filtered_df['youtube_id'].tolist(), | |
| format_func=lambda x: filtered_df[filtered_df['youtube_id'] == x]['title'].iloc[0], | |
| key="video_select" | |
| ) | |
| if selected_video_id: | |
| # Update system status with selected video | |
| system_status = get_system_status(db_handler, selected_video_id) | |
| display_system_status(system_status, selected_video_id) | |
| # Get the index for the selected video | |
| index_name = db_handler.get_elasticsearch_index_by_youtube_id(selected_video_id) | |
| if not index_name: | |
| st.warning("This video hasn't been indexed yet. You can process it in the Data Ingestion page.") | |
| if st.button("Process Now"): | |
| with st.spinner("Processing video..."): | |
| try: | |
| embedding_model = data_processor.embedding_model.__class__.__name__ | |
| index_name = process_single_video(db_handler, data_processor, selected_video_id, embedding_model) | |
| if index_name: | |
| st.success("Video processed successfully!") | |
| st.rerun() | |
| except Exception as e: | |
| st.error(f"Error processing video: {str(e)}") | |
| logger.error(f"Error processing video: {str(e)}") | |
| else: | |
| # Chat settings | |
| st.sidebar.header("Chat Settings") | |
| rewrite_method = st.sidebar.radio( | |
| "Query Rewriting Method", | |
| ["None", "Chain of Thought", "ReAct"], | |
| key="rewrite_method" | |
| ) | |
| search_method = st.sidebar.radio( | |
| "Search Method", | |
| ["Hybrid", "Text-only", "Embedding-only"], | |
| key="search_method" | |
| ) | |
| # Create chat interface | |
| create_chat_interface( | |
| db_handler, | |
| rag_system, | |
| selected_video_id, | |
| index_name, | |
| rewrite_method, | |
| search_method | |
| ) | |
| # Display system status | |
| display_system_status(system_status, selected_video_id) | |
| if __name__ == "__main__": | |
| main() |