Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import requests | |
| import json | |
| import os | |
| import re | |
| import time | |
| from datetime import datetime | |
| from functools import lru_cache | |
| from requests.adapters import HTTPAdapter | |
| from urllib3.util.retry import Retry | |
| # Configuration | |
| BASE_URL = "https://zxzbfrlg3ssrk7d9.us-east-1.aws.endpoints.huggingface.cloud/v1/" | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| TAVILY_API_KEY = os.environ.get("TAVILY_API_KEY") | |
| # Validate required environment variables | |
| if not HF_TOKEN: | |
| raise ValueError("HF_TOKEN environment variable is required") | |
| # Get current date and time information | |
| CURRENT_DATE = datetime.now() | |
| DATE_INFO = CURRENT_DATE.strftime("%A, %B %d, %Y") | |
| TIME_INFO = CURRENT_DATE.strftime("%I:%M %p") | |
| FORMATTED_DATE_TIME = f"Current Date: {DATE_INFO}\nCurrent Time: {TIME_INFO}" | |
| # Initialize session with retry strategy | |
| session = requests.Session() | |
| retry_strategy = Retry( | |
| total=3, | |
| backoff_factor=1, | |
| status_forcelist=[429, 500, 502, 503, 504], | |
| ) | |
| adapter = HTTPAdapter(max_retries=retry_strategy) | |
| session.mount("http://", adapter) | |
| session.mount("https://", adapter) | |
| # Initialize Tavily client | |
| try: | |
| from tavily import TavilyClient | |
| tavily_client = TavilyClient(api_key=TAVILY_API_KEY) if TAVILY_API_KEY else None | |
| TAVILY_AVAILABLE = True | |
| except ImportError: | |
| tavily_client = None | |
| TAVILY_AVAILABLE = False | |
| print("Tavily not available: Please install tavily-python") | |
| # Rate limiter class | |
| class RateLimiter: | |
| def __init__(self, max_calls=10, time_window=60): | |
| self.max_calls = max_calls | |
| self.time_window = time_window | |
| self.calls = [] | |
| def is_allowed(self): | |
| now = time.time() | |
| self.calls = [call for call in self.calls if now - call < self.time_window] | |
| if len(self.calls) < self.max_calls: | |
| self.calls.append(now) | |
| return True | |
| return False | |
| rate_limiter = RateLimiter(max_calls=20, time_window=60) | |
| # Feedback storage | |
| feedback_data = [] | |
| def get_preloaded_context(): | |
| """Get preloaded context information""" | |
| context = f"""{FORMATTED_DATE_TIME} | |
| System Information: You are an AI assistant with access to current information through web search. | |
| Always provide sources for factual information. | |
| Available APIs: | |
| - Web Search (Tavily)""" | |
| return context | |
| def clean_query_for_current_info(query): | |
| """Clean query to focus on current/fresh information""" | |
| # Remove old dates | |
| query = re.sub(r'\d{4}-\d{2}-\d{2}', '', query) | |
| query = re.sub(r'\d{4}/\d{2}/\d{2}', '', query) | |
| query = re.sub(r'\d{2}/\d{2}/\d{4}', '', query) | |
| return query.strip() | |
| def tavily_search(query): | |
| """Perform a web search using Tavily""" | |
| if not TAVILY_AVAILABLE or not tavily_client: | |
| return "Tavily search is not configured. Please set TAVILY_API_KEY." | |
| try: | |
| # Clean query for current info | |
| clean_query = clean_query_for_current_info(query) | |
| if not clean_query: | |
| return "No valid search query provided." | |
| response = tavily_client.search( | |
| clean_query, | |
| search_depth="advanced", | |
| topic="general", | |
| max_results=5 | |
| ) | |
| results = [] | |
| for result in response.get("results", [])[:5]: | |
| title = result.get("title", "") | |
| content = result.get("content", "") | |
| if title and content: | |
| results.append(f"{title}: {content}") | |
| elif content: | |
| results.append(content) | |
| if results: | |
| return "\n\n".join(results) | |
| else: | |
| return "No relevant information found." | |
| except Exception as e: | |
| return f"Tavily search error: {str(e)}" | |
| def truncate_history(messages, max_tokens=4000): | |
| """Truncate conversation history to prevent context overflow""" | |
| if not messages: | |
| return [] | |
| # Simplified token estimation (4 chars ~ 1 token) | |
| estimated_tokens = sum(len(msg.get("content", "")) for msg in messages) // 4 | |
| if estimated_tokens <= max_tokens: | |
| return messages | |
| # Truncate older messages | |
| truncated = [] | |
| current_tokens = 0 | |
| # Keep system message if present | |
| if messages and messages[0].get("role") == "system": | |
| truncated.append(messages[0]) | |
| messages = messages[1:] | |
| # Add recent messages up to token limit | |
| for message in reversed(messages): | |
| content = message.get("content", "") | |
| message_tokens = len(content) // 4 | |
| if current_tokens + message_tokens > max_tokens: | |
| break | |
| truncated.insert(0, message) | |
| current_tokens += message_tokens | |
| return truncated | |
| def manage_conversation_memory(messages, max_turns=10): | |
| """Keep conversation focused and prevent context overflow""" | |
| if len(messages) > max_turns * 2: # *2 for user/assistant pairs | |
| # Keep system message + last N turns | |
| system_msg = [msg for msg in messages if msg.get("role") == "system"] | |
| recent_messages = messages[-(max_turns * 2):] | |
| return system_msg + recent_messages if system_msg else recent_messages | |
| return messages | |
| def perform_search(query): | |
| """Perform search using Tavily""" | |
| if TAVILY_AVAILABLE and tavily_client: | |
| web_result = tavily_search(query) | |
| return f"[SEARCH RESULTS FOR '{query}']:\nSource: Web Search\n{web_result}" | |
| else: | |
| return "Web search not available." | |
| def is_looping_content(content): | |
| """Detect if content is stuck in a loop""" | |
| if len(content) > 2000: # Too long, likely looping | |
| return True | |
| if content.count("let's do") > 15: # Repeated phrases | |
| return True | |
| if content.count("search") > 40: # Excessive repetition | |
| return True | |
| return False | |
| def analyze_search_results(query, search_results): | |
| """Create a prompt for the model to analyze search results""" | |
| analysis_prompt = f"""Based on the search results below, please answer the original question: "{query}" | |
| Search Results: {search_results} | |
| Please provide a clear, concise answer based on these sources. Include specific names, facts, and cite the sources where possible. Do not mention that you are analyzing search results - just provide the answer directly.""" | |
| return analysis_prompt | |
| def validate_history(chat_history): | |
| """Ensure proper alternation in chat_history""" | |
| if not chat_history: | |
| return [] | |
| validated = [] | |
| expected_role = "user" | |
| for message in chat_history: | |
| role = message.get("role") | |
| content = message.get("content", "") | |
| # Skip empty messages | |
| if not content: | |
| continue | |
| # Only add messages that follow proper alternation | |
| if role == expected_role: | |
| validated.append(message) | |
| expected_role = "assistant" if expected_role == "user" else "user" | |
| elif role == "system" and len(validated) == 0: | |
| # Allow system message at start | |
| validated.append(message) | |
| return validated | |
| def generate_follow_up_questions(last_response): | |
| """Generate 3-5 relevant follow-up questions""" | |
| if not last_response: | |
| return [] | |
| # Simple heuristic-based questions | |
| question_words = ["What", "How", "Why", "When", "Where", "Who"] | |
| topics = ["related", "similar", "detailed", "practical"] | |
| # Extract key topics from response (simplified) | |
| words = last_response.split()[:20] # First 20 words | |
| key_topics = [word for word in words if len(word) > 4][:3] # Simple filtering | |
| questions = [] | |
| for word in question_words[:3]: # Limit to 3 | |
| if key_topics: | |
| topic = key_topics[0] if key_topics else "this" | |
| questions.append(f"{word} about {topic}?") | |
| return questions[:3] # Return max 3 questions | |
| def format_code_blocks(text): | |
| """Detect and format code blocks with syntax highlighting""" | |
| import re | |
| # Simple pattern to detect code blocks | |
| pattern = r'```(\w+)?\n(.*?)```' | |
| # Replace with HTML formatted code (simplified) | |
| formatted = re.sub(pattern, r'<pre><code class="language-\1">\2</code></pre>', text, flags=re.DOTALL) | |
| return formatted | |
| def extract_and_format_citations(search_results): | |
| """Extract sources and create clickable citations""" | |
| # Simple citation extraction (can be enhanced) | |
| citations = [] | |
| if "Source:" in search_results: | |
| lines = search_results.split('\n') | |
| for line in lines: | |
| if "http" in line: | |
| citations.append(line.strip()) | |
| return citations | |
| def track_usage(user_id, query, response_time, tokens_used): | |
| """Track usage metrics for improvement""" | |
| metrics = { | |
| "timestamp": datetime.now().isoformat(), | |
| "user_id": user_id or "anonymous", | |
| "query_length": len(query), | |
| "response_time": response_time, | |
| "tokens_used": tokens_used | |
| } | |
| # In a real app, you'd store this in a database | |
| print(f"Usage tracked: {metrics}") | |
| return metrics | |
| def collect_feedback(feedback, query, response): | |
| """Collect user feedback for model improvement""" | |
| feedback_entry = { | |
| "timestamp": datetime.now().isoformat(), | |
| "feedback": feedback, | |
| "query": query, | |
| "response": response[:100] + "..." if len(response) > 100 else response | |
| } | |
| feedback_data.append(feedback_entry) | |
| print(f"Feedback collected: {feedback_entry}") | |
| return f"Thank you for your feedback: {feedback}" | |
| def cached_search(query): | |
| """Cache frequent searches""" | |
| return perform_search(query) | |
| def handle_api_failure(error_type, fallback_strategy="retry"): | |
| """Handle different types of API failures gracefully""" | |
| # Simplified error handling | |
| return f"API Error: {error_type}. Strategy: {fallback_strategy}" | |
| def export_conversation(chat_history, export_format): | |
| """Export conversation in various formats""" | |
| if not chat_history: | |
| return "No conversation to export" | |
| if export_format == "JSON": | |
| # Filter out system messages for export | |
| exportable_history = [msg for msg in chat_history if msg.get("role") != "system"] | |
| return json.dumps(exportable_history, indent=2, ensure_ascii=False) | |
| elif export_format == "Text": | |
| lines = [] | |
| for msg in chat_history: | |
| if msg.get("role") != "system": # Skip system messages | |
| lines.append(f"{msg.get('role', 'unknown').upper()}: {msg.get('content', '')}") | |
| return "\n".join(lines) | |
| return "Invalid format" | |
| def is_news_related_query(query): | |
| """Check if query is related to news""" | |
| news_keywords = ['news', 'headline', 'breaking', 'latest', 'today', 'current event', 'update', 'report'] | |
| query_lower = query.lower() | |
| return any(word in query_lower for word in news_keywords) | |
| def generate_with_streaming(messages, model, max_tokens=8192, temperature=0.7, top_p=0.9): | |
| """Generate text with streaming""" | |
| headers = { | |
| "Authorization": f"Bearer {HF_TOKEN}", | |
| "Content-Type": "application/json" | |
| } | |
| # Validate history to prevent errors | |
| validated_messages = validate_history(messages) | |
| payload = { | |
| "model": model, | |
| "messages": validated_messages, | |
| "max_tokens": max_tokens, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "stream": True | |
| } | |
| start_time = time.time() | |
| try: | |
| response = session.post( | |
| f"{BASE_URL}chat/completions", | |
| headers=headers, | |
| json=payload, | |
| timeout=300, | |
| stream=True | |
| ) | |
| if response.status_code == 200: | |
| full_response = "" | |
| for line in response.iter_lines(): | |
| if line: | |
| decoded_line = line.decode('utf-8') | |
| if decoded_line.startswith('data: '): | |
| data = decoded_line[6:] | |
| if data != '[DONE]': | |
| try: | |
| json_data = json.loads(data) | |
| if 'choices' in json_data and len(json_data['choices']) > 0: | |
| delta = json_data['choices'][0].get('delta', {}) | |
| content = delta.get('content', '') | |
| if content: | |
| full_response += content | |
| yield full_response | |
| except: | |
| continue | |
| else: | |
| yield f"Error: {response.status_code} - {response.text}" | |
| except Exception as e: | |
| yield f"Connection error: {str(e)}" | |
| finally: | |
| end_time = time.time() | |
| # Track usage (simplified) | |
| track_usage("user123", str(messages[-1]) if messages else "", | |
| end_time - start_time, len(str(messages))) | |
| def respond(message, chat_history, model_choice, max_tokens, temperature, top_p, | |
| creativity, precision, system_prompt, use_web_search, theme): | |
| """Main response handler with conversation history""" | |
| if not message: | |
| yield "", chat_history, "", gr.update(choices=[], visible=False) | |
| return | |
| # Rate limiting check | |
| if not rate_limiter.is_allowed(): | |
| yield "", chat_history + [{"role": "assistant", "content": "Rate limit exceeded. Please wait a moment before sending another message."}], "", "" | |
| return | |
| # Add custom system prompt or preloaded context | |
| if not chat_history: | |
| if system_prompt: | |
| system_message = {"role": "system", "content": system_prompt} | |
| else: | |
| preloaded_context = get_preloaded_context() | |
| system_message = {"role": "system", "content": preloaded_context} | |
| chat_history = [system_message] + chat_history | |
| # Check if the message contains search results that need analysis | |
| if "SEARCH RESULTS" in message or "[SEARCH RESULTS" in message: | |
| # This is search results that need analysis | |
| # Extract the original query and search results | |
| lines = message.split('\n') | |
| if len(lines) > 2: | |
| # Get the query from the first line | |
| first_line = lines[0] | |
| if "'" in first_line: | |
| query = first_line.split("'")[1] | |
| else: | |
| query = message[:100] # Fallback | |
| else: | |
| query = "summary request" | |
| # Perform analysis | |
| analysis_prompt = analyze_search_results(query, message) | |
| # Create history with analysis prompt | |
| analysis_history = chat_history + [{"role": "user", "content": analysis_prompt}] | |
| # Generate analyzed response | |
| full_response = "" | |
| for chunk in generate_with_streaming(analysis_history, model_choice, max_tokens, temperature * creativity, top_p * precision): | |
| if isinstance(chunk, str): | |
| full_response = chunk | |
| # Generate follow-up questions | |
| follow_ups = generate_follow_up_questions(full_response) | |
| yield "", chat_history + [{"role": "user", "content": message}, {"role": "assistant", "content": full_response}], message, gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| return | |
| # Check if we should perform a search | |
| user_message = {"role": "user", "content": message} | |
| # Always perform search if web search is enabled | |
| if use_web_search: | |
| search_result = perform_search(message) | |
| # If this is a news-related query, automatically analyze the results | |
| if is_news_related_query(message): | |
| # Extract the original query for analysis | |
| lines = search_result.split('\n') | |
| if len(lines) > 2: | |
| first_line = lines[0] | |
| if "'" in first_line: | |
| query = first_line.split("'")[1] | |
| else: | |
| query = message | |
| else: | |
| query = message | |
| # Perform analysis of the search results | |
| analysis_prompt = analyze_search_results(query, search_result) | |
| # Create history with analysis prompt | |
| analysis_history = chat_history + [user_message, {"role": "assistant", "content": search_result}, {"role": "user", "content": analysis_prompt}] | |
| # Generate analyzed response | |
| full_response = "" | |
| search_results_output = search_result # Store raw search results | |
| for chunk in generate_with_streaming(analysis_history, model_choice, max_tokens, temperature * creativity, top_p * precision): | |
| if isinstance(chunk, str): | |
| full_response = chunk | |
| # Generate follow-up questions | |
| follow_ups = generate_follow_up_questions(full_response) | |
| # Stream both the analysis and raw search results | |
| yield "", chat_history + [user_message, {"role": "assistant", "content": search_result}, {"role": "assistant", "content": full_response}], search_results_output, gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| return | |
| else: | |
| # Non-news search, just return the search results | |
| # Generate follow-up questions | |
| follow_ups = generate_follow_up_questions(search_result) | |
| yield "", chat_history + [user_message, {"role": "assistant", "content": search_result}], search_result, gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| return | |
| # Normal flow - generate response | |
| current_history = chat_history + [user_message] | |
| full_response = "" | |
| for chunk in generate_with_streaming(current_history, model_choice, max_tokens, temperature * creativity, top_p * precision): | |
| if isinstance(chunk, str): | |
| full_response = chunk | |
| # Break infinite loops | |
| if is_looping_content(full_response): | |
| # Force search instead of looping | |
| search_result = perform_search(message) | |
| follow_ups = generate_follow_up_questions(search_result) | |
| yield "", chat_history + [user_message, {"role": "assistant", "content": f"[LOOP DETECTED - PERFORMING SEARCH]\n{search_result}"}], search_result, gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| return | |
| # Stream the response | |
| follow_ups = generate_follow_up_questions(full_response) | |
| yield "", chat_history + [user_message, {"role": "assistant", "content": full_response}], "", gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| # Check for tool calls after completion or break loops | |
| if is_looping_content(full_response): | |
| # Force search for looping content | |
| search_result = perform_search(message) | |
| follow_ups = generate_follow_up_questions(search_result) | |
| yield "", chat_history + [user_message, {"role": "assistant", "content": f"[LOOP DETECTED - PERFORMING SEARCH]\n{search_result}"}], search_result, gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| return | |
| # Normal completion | |
| follow_ups = generate_follow_up_questions(full_response) | |
| yield "", chat_history + [user_message, {"role": "assistant", "content": full_response}], "", gr.update(choices=follow_ups, visible=True if follow_ups else False) | |
| def apply_theme(theme): | |
| """Apply theme-specific CSS""" | |
| if theme == "Dark": | |
| return """ | |
| <style> | |
| body { background-color: #1a1a1a; color: #ffffff; } | |
| .message { background-color: #2d2d2d; } | |
| </style> | |
| """ | |
| else: | |
| return """ | |
| <style> | |
| body { background-color: #ffffff; color: #000000; } | |
| .message { background-color: #f0f0f0; } | |
| </style> | |
| """ | |
| # Gradio Interface | |
| with gr.Blocks(title="GPT-OSS Chat") as demo: | |
| gr.Markdown("# ๐ค GPT-OSS 20B Chat") | |
| gr.Markdown(f"Chat with automatic web search capabilities\n\n**Current Date/Time**: {FORMATTED_DATE_TIME}") | |
| # Theme CSS | |
| theme_css = gr.HTML() | |
| with gr.Row(): | |
| chatbot = gr.Chatbot(height=500, type="messages", label="Conversation") | |
| with gr.Row(): | |
| msg = gr.Textbox(label="Message", placeholder="Ask anything...", scale=9) | |
| submit = gr.Button("Send", scale=1) | |
| with gr.Row(): | |
| clear = gr.Button("Clear") | |
| theme_toggle = gr.Radio(choices=["Light", "Dark"], value="Light", label="Theme") | |
| feedback_radio = gr.Radio( | |
| choices=["๐ Helpful", "๐ Not Helpful", "๐ Needs Improvement"], | |
| label="Rate Last Response" | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| follow_up_questions = gr.Radio( | |
| choices=[], | |
| label="Suggested Follow-up Questions", | |
| visible=False | |
| ) | |
| with gr.Column(): | |
| with gr.Row(): | |
| export_format = gr.Radio(choices=["JSON", "Text"], value="JSON", label="Export Format") | |
| export_btn = gr.Button("Export Conversation") | |
| export_output = gr.File(label="Download") | |
| with gr.Accordion("Search Results", open=False): | |
| search_results = gr.Textbox(label="Raw Search Data", interactive=False, max_lines=10) | |
| with gr.Accordion("Settings", open=False): | |
| with gr.Row(): | |
| model_choice = gr.Dropdown( | |
| choices=[ | |
| "DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf", | |
| "other-model-variants" | |
| ], | |
| value="DavidAU/OpenAi-GPT-oss-20b-abliterated-uncensored-NEO-Imatrix-gguf", | |
| label="Model" | |
| ) | |
| with gr.Row(): | |
| max_tokens = gr.Slider(50, 8192, value=8192, label="Max Tokens") | |
| temperature = gr.Slider(0.1, 1.0, value=0.7, label="Base Temperature") | |
| top_p = gr.Slider(0.1, 1.0, value=0.9, label="Top P") | |
| with gr.Row(): | |
| creativity = gr.Slider(0.1, 1.0, value=0.7, label="Creativity") | |
| precision = gr.Slider(0.1, 1.0, value=0.9, label="Precision") | |
| system_prompt = gr.Textbox( | |
| label="System Prompt", | |
| value="", | |
| placeholder="Enter custom system prompt...", | |
| max_lines=3 | |
| ) | |
| use_web_search = gr.Checkbox(label="Enable Web Search", value=True) | |
| # Event handling | |
| submit_event = submit.click( | |
| respond, | |
| [msg, chatbot, model_choice, max_tokens, temperature, top_p, creativity, precision, system_prompt, use_web_search, theme_toggle], | |
| [msg, chatbot, search_results, follow_up_questions], | |
| queue=True | |
| ) | |
| msg_event = msg.submit( | |
| respond, | |
| [msg, chatbot, model_choice, max_tokens, temperature, top_p, creativity, precision, system_prompt, use_web_search, theme_toggle], | |
| [msg, chatbot, search_results, follow_up_questions], | |
| queue=True | |
| ) | |
| clear.click(lambda: None, None, chatbot, queue=False) | |
| theme_toggle.change( | |
| apply_theme, | |
| [theme_toggle], | |
| [theme_css] | |
| ) | |
| feedback_radio.change( | |
| collect_feedback, | |
| [feedback_radio, msg, chatbot], | |
| [] | |
| ) | |
| follow_up_questions.change( | |
| lambda x: x, | |
| [follow_up_questions], | |
| [msg] | |
| ) | |
| export_btn.click( | |
| export_conversation, | |
| [chatbot, export_format], | |
| [export_output] | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |