import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import re
from typing import List, Dict, Tuple, Optional
import sys
from pathlib import Path
# Add rag-db to path for imports
sys.path.append(str(Path(__file__).parent / "rag-db"))
from retriever import create_retriever, GprMaxRAGRetriever
# Initialize model and tokenizer
MODEL_NAME = "jfang/gprmax-ft-Qwen3-4B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {MODEL_NAME}")
print(f"Using device: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
    MODEL_NAME,
    torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
    device_map="auto",
    trust_remote_code=True
)
# Initialize RAG retriever
RAG_DB_PATH = Path(__file__).parent / "rag-db" / "chroma_db"
retriever: Optional[GprMaxRAGRetriever] = None
def generate_database_if_needed():
    """Generate the RAG database if it doesn't exist"""
    if not RAG_DB_PATH.exists():
        print("=" * 60)
        print("RAG database not found. Generating database...")
        print("This is a one-time process and may take a few minutes.")
        print("=" * 60)
        
        import subprocess
        try:
            # Run the generation script
            result = subprocess.run(
                ["python", str(Path(__file__).parent / "rag-db" / "generate_db.py")],
                capture_output=True,
                text=True,
                check=True
            )
            print(result.stdout)
            print("✅ Database generated successfully!")
            return True
        except subprocess.CalledProcessError as e:
            print(f"❌ Failed to generate database: {e}")
            if e.stderr:
                print(f"Error output: {e.stderr}")
            return False
    return True
# Generate database if needed and load retriever
if generate_database_if_needed():
    try:
        print(f"Loading RAG database from {RAG_DB_PATH}")
        retriever = create_retriever(db_path=RAG_DB_PATH)
        print("RAG database loaded successfully")
    except Exception as e:
        print(f"Error loading RAG database: {e}")
        print("RAG features will be disabled.")
        retriever = None
else:
    print("RAG features will be disabled due to database generation failure.")
    retriever = None
@spaces.GPU(duration=60)
def generate_response_stream(
    message: str,
    history: List[Dict[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Generate streaming response using the fine-tuned Qwen3 model.
    Returns both thinking content and main response separately.
    
    Args:
        message: User's input message
        history: Conversation history
        system_message: System prompt
        max_tokens: Maximum tokens to generate
        temperature: Sampling temperature
        top_p: Nucleus sampling parameter
    
    Yields:
        Tuple of (thinking_content, response_content)
    """
    # Construct messages for Qwen3 format
    messages = []
    
    if system_message:
        messages.append({"role": "system", "content": system_message})
    
    # Add conversation history
    for msg in history:
        if msg.get("role") and msg.get("content"):
            messages.append({"role": msg["role"], "content": msg["content"]})
    
    # Add current user message with thinking instruction
    thinking_instruction = "Please think step by step about this problem, showing your reasoning process."
    messages.append({"role": "user", "content": f"{thinking_instruction}\n\n{message}"})
    
    # Prepare input using Qwen's chat template
    text = tokenizer.apply_chat_template(
        messages,
        tokenize=False,
        add_generation_prompt=True
    )
    
    inputs = tokenizer([text], return_tensors="pt").to(model.device)
    
    # Setup streaming
    from transformers import TextIteratorStreamer
    from threading import Thread
    
    streamer = TextIteratorStreamer(
        tokenizer,
        skip_prompt=True,
        skip_special_tokens=True,
        timeout=60.0
    )
    
    # Generation kwargs
    generation_kwargs = dict(
        **inputs,
        max_new_tokens=max_tokens,
        temperature=temperature,
        top_p=top_p,
        do_sample=True if temperature > 0 else False,
        pad_token_id=tokenizer.eos_token_id,
        eos_token_id=tokenizer.eos_token_id,
        streamer=streamer
    )
    
    # Start generation in a separate thread
    thread = Thread(target=model.generate, kwargs=generation_kwargs)
    thread.start()
    
    # Collect and yield tokens
    full_response = ""
    thinking_buffer = ""
    main_buffer = ""
    in_thinking = False
    thinking_complete = False
    think_start_seen = False
    
    for new_text in streamer:
        full_response += new_text
        
        # Check if we're starting to see a thinking block
        if "" in full_response and not think_start_seen:
            think_start_seen = True
            think_start_idx = full_response.find("")
            # Any content before  is main content
            main_buffer = full_response[:think_start_idx]
            in_thinking = True
            # Start capturing everything after 
            thinking_buffer = full_response[think_start_idx + 7:]  # Skip "" itself
            
            # Yield what we have so far
            yield ("", main_buffer)
            continue
        
        # If we're in thinking mode
        if in_thinking and not thinking_complete:
            # Update thinking buffer with latest content
            current_pos = full_response.find("") + 7
            thinking_buffer = full_response[current_pos:]
            
            # Check if thinking is complete
            if "" in thinking_buffer:
                # Extract content before 
                end_idx = thinking_buffer.find("")
                final_thinking = thinking_buffer[:end_idx].strip()
                
                # Get content after 
                after_thinking = thinking_buffer[end_idx + 8:]  # Skip ""
                main_buffer = main_buffer + after_thinking
                
                thinking_complete = True
                in_thinking = False
                
                # Yield final thinking and updated main content
                yield (final_thinking, main_buffer)
            else:
                # Still accumulating thinking content - stream it in real-time
                # Remove any partial .*?', '', full_response, flags=re.DOTALL)
                main_buffer = clean_response
            else:
                # No thinking tags seen yet, everything is main content
                main_buffer = full_response
            
            # Get the final thinking content if it exists
            if thinking_complete:
                think_match = re.search(r'(.*?)', full_response, re.DOTALL)
                if think_match:
                    final_thinking = think_match.group(1).strip()
                    yield (final_thinking, main_buffer)
                else:
                    yield ("", main_buffer)
            else:
                yield ("", main_buffer)
    
    # Final cleanup - handle incomplete thinking blocks
    if in_thinking:
        # Generation ended while in thinking mode - likely hit max token limit
        incomplete_thinking_msg = thinking_buffer.strip() if thinking_buffer else ""
        if incomplete_thinking_msg:
            # Add warning about incomplete thinking
            incomplete_thinking_msg += "\n\n⚠️ **Thinking was cut off due to token limit. Try increasing 'Max New Tokens' in settings.**"
        
        # Show incomplete thinking warning in main response too
        error_msg = "⚠️ *AI's thinking process was interrupted due to token limit. Try increasing 'Max New Tokens' and retry.*"
        yield (incomplete_thinking_msg, error_msg)
    thread.join()
# Tool definitions in Qwen3 format
TOOLS = [
    {
        "type": "function",
        "function": {
            "name": "search_documentation",
            "description": "Search gprMax documentation for relevant information about commands, syntax, parameters, or usage",
            "parameters": {
                "type": "object",
                "properties": {
                    "query": {
                        "type": "string",
                        "description": "The search query to find relevant documentation"
                    },
                    "num_results": {
                        "type": "integer",
                        "description": "Number of results to return",
                        "default": 10
                    }
                },
                "required": ["query"]
            }
        }
    }
]
def format_tools_prompt() -> str:
    """Format tools for inclusion in system prompt"""
    import json
    return json.dumps(TOOLS, indent=2)
def perform_rag_search(query: str, k: int = 10) -> Tuple[str, List[Dict]]:
    """
    Perform RAG search and return formatted context and sources
    
    Returns:
        Tuple of (context_for_llm, source_list_for_display)
    """
    if not retriever:
        print(f"[DEBUG] Retriever is None!")
        return "", []
    
    try:
        print(f"[DEBUG] Searching for: '{query}' with k={k}")
        # Search for relevant documents
        results = retriever.search(query, k=k)
        
        print(f"[DEBUG] Search returned {len(results) if results else 0} results")
        if not results:
            return "", []
        
        # Format context for LLM - pass all text content
        context_parts = []
        source_list = []
        
        for i, result in enumerate(results, 1):
            # Add full text to context for LLM (up to 1000 chars per doc)
            context_parts.append(f"[Document {i}]: {result.text}")
            
            # Add to source list for display (limited preview)
            source_list.append({
                "index": i,
                "source": result.metadata.get("source", "Unknown"),
                "score": result.score,
                "preview": result.text[:150] + "..." if len(result.text) > 150 else result.text
            })
        
        context = "\n\n".join(context_parts)
        return context, source_list
        
    except Exception as e:
        print(f"[DEBUG] RAG search error: {e}")
        import traceback
        traceback.print_exc()
        return "", []
def respond(
    message: str,
    history: List[Dict[str, str]],
    system_message: str,
    max_tokens: int,
    temperature: float,
    top_p: float,
):
    """
    Response function with proper Qwen3 tool calling
    """
    import json
    import re
    
    sources_content = ""
    
    try:
        # Use system message as-is (already has tools included)
        system_with_tools = system_message
        
        # First, get initial response from model to see if it wants to use tools
        tool_call = None
        accumulated_response = ""
        final_thinking = ""
        is_complete = False
        
        # Collect the full response (thinking + potential tool call)
        for thinking, response in generate_response_stream(
            message=message,
            history=history,
            system_message=system_with_tools,
            max_tokens=max_tokens,
            temperature=temperature,
            top_p=top_p,
        ):
            final_thinking = thinking if thinking else final_thinking
            accumulated_response = response
            
            # Show thinking progress only
            if thinking:
                yield thinking, "⏳ *AI is analyzing your request...*", sources_content
        
        # After streaming completes, check what we got
        if accumulated_response and accumulated_response.strip():
            # Check if the complete response is a JSON tool call
            if accumulated_response.strip().startswith('{'):
                try:
                    # Try to parse the entire response as JSON
                    response_json = json.loads(accumulated_response.strip())
                    if "tool_call" in response_json or ("thought" in response_json and "tool_call" in response_json):
                        tool_call = response_json.get("tool_call") or response_json["tool_call"]
                        # Show status that we're processing the tool call
                        yield final_thinking, "🔍 *Processing documentation search request...*", sources_content
                        is_complete = True
                except json.JSONDecodeError:
                    # Invalid JSON, treat as normal response
                    yield final_thinking, accumulated_response, sources_content
                    is_complete = True
                except Exception:
                    yield final_thinking, accumulated_response, sources_content
                    is_complete = True
            else:
                # It's a normal text response, not a tool call
                yield final_thinking, accumulated_response, sources_content
                is_complete = True
        
        # If tool was called, execute it
        if tool_call and retriever:
            tool_name = tool_call.get("name")
            print(f"[DEBUG] Tool called: {tool_name}")
            print(f"[DEBUG] Tool call details: {tool_call}")
            
            if tool_name == "search_documentation":
                # Update status
                yield "🔍 *Searching documentation...*", "⏳ *Preparing to search...*", "📚 *Retrieving relevant documents...*"
                
                # Get search query
                query = tool_call.get("arguments", {}).get("query", message)
                num_results = tool_call.get("arguments", {}).get("num_results", 10)
                print(f"[DEBUG] Query extracted: '{query}', num_results: {num_results}")
                
                # Perform search
                context, sources_list = perform_rag_search(query, k=num_results)
                print(f"[DEBUG] Search results - Context length: {len(context)}, Sources: {len(sources_list)}")
                
                if context:
                    # Format sources for display
                    if sources_list:
                        sources_parts = ["## 📚 Documentation Sources\n"]
                        for source in sources_list:
                            sources_parts.append(
                                f"**[{source['index']}] {source['source']}** (Score: {source['score']:.3f})\n"
                                f"```\n{source['preview']}\n```\n"
                            )
                        sources_content = "\n".join(sources_parts)
                    else:
                        sources_content = "*No relevant documentation found*"
                    
                    yield "✅ *Documentation retrieved*", "⏳ *Generating response with context...*", sources_content
                    
                    # Now generate response with the retrieved context
                    augmented_message = f"""Tool call result for search_documentation:
{context}
Original question: {message}
Please provide a comprehensive answer based on the documentation above."""
                    
                    # Generate final response with context
                    for thinking, response in generate_response_stream(
                        message=augmented_message,
                        history=history,
                        system_message=system_message,  # Use original system message for final response
                        max_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                    ):
                        yield thinking, response, sources_content
                else:
                    sources_content = "*No relevant documentation found*"
                    yield final_thinking, "⚠️ *Unable to retrieve documentation. Providing general answer...*", sources_content
                    
                    # Generate response without documentation context
                    fallback_message = f"""The user asked about: {message}
                    
No relevant documentation was found in the database. Please provide a helpful answer based on your general knowledge of gprMax."""
                    
                    for thinking, response in generate_response_stream(
                        message=fallback_message,
                        history=history,
                        system_message=system_message,
                        max_tokens=max_tokens,
                        temperature=temperature,
                        top_p=top_p,
                    ):
                        yield thinking, response, sources_content
        # If tool was called but retriever is not available
        elif tool_call and not retriever:
            yield final_thinking, "⚠️ *Documentation search is not available. Providing answer based on general knowledge...*", ""
            
            # Generate response without RAG
            for thinking, response in generate_response_stream(
                message=message,
                history=history,
                system_message=system_message,
                max_tokens=max_tokens,
                temperature=temperature,
                top_p=top_p,
            ):
                yield thinking, response, ""
        # If no tool call and response wasn't already yielded
        elif not tool_call and not is_complete:
            # This shouldn't happen but handle it just in case
            if accumulated_response and not accumulated_response.strip().startswith('{'):
                yield final_thinking, accumulated_response, sources_content
        
    except Exception as e:
        error_message = f"❌ Error generating response: {str(e)}"
        yield "", error_message, ""
# Default system prompt for gprMax assistance
def get_default_system_prompt():
    """Get system prompt with tools formatted"""
    tools_json = format_tools_prompt()
    return f"""You are a helpful assistant specialized in gprMax, an open-source software that simulates electromagnetic wave propagation. You help users with:
1. Creating gprMax input files (.in files)
2. Understanding gprMax commands and syntax
3. Setting up simulations for GPR (Ground Penetrating Radar) and other EM applications
4. Troubleshooting simulation issues
5. Optimizing simulation parameters
You have access to the following tools:
{tools_json}
When you need to search documentation, respond with a tool call in this JSON format:
{{
  "thought": "I need to search the documentation for...",
  "tool_call": {{
    "name": "search_documentation",
    "arguments": {{
      "query": "your search query here"
    }}
  }}
}}
After receiving tool results, provide a comprehensive answer based on the documentation.
If you give code blocks, ensure to enclose them inside ```.
There is no need to always give full input codes, be sure to understand what user needs and intends to do. Some times a simple line of code can do, sometimes user wants explanation rather than codes.
Provide clear, accurate, and practical guidance for gprMax users."""
# Create custom interface with collapsible thinking section
with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo:
    gr.Markdown(
        """
        # 🛡️ gprMax Support Assistant
        
        Welcome to the gprMax Support Assistant powered by a fine-tuned Qwen3-4B model specifically trained for gprMax assistance.
        
        ### Features:
        - 💬 **Expert Guidance**: Get help with gprMax input files, commands, and simulations
        - 🧠 **Transparent Reasoning**: See the AI's thinking process in a dedicated collapsible section
        - 📚 **Documentation**: Coming soon - RAG-powered documentation retrieval
        - 🤖 **Agent Mode**: Coming soon - Automated workflow assistance
        """
    )
    
    with gr.Row():
        with gr.Column(scale=3):
            # Chat interface
            chatbot = gr.Chatbot(
                label="Chat",
                type="messages",
                height=500,
                show_copy_button=True,
            )
            
            with gr.Row():
                msg = gr.Textbox(
                    label="Message",
                    placeholder="Ask about gprMax commands, simulations, or troubleshooting...",
                    lines=2,
                    scale=4,
                )
                submit_btn = gr.Button("Send", variant="primary", scale=1)
            
            with gr.Row():
                clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
                
            # Examples
            gr.Examples(
                examples=[
                    "How do I create a basic gprMax input file for a simple GPR simulation?",
                    "What's the difference between #domain and #dx_dy_dz commands?",
                    "How can I model a heterogeneous soil with different dielectric properties?",
                    "My simulation is taking too long. How can I optimize it?",
                    "How do I add a Ricker wavelet source to my model?",
                ],
                inputs=msg,
                label="Example Questions",
            )
        
        with gr.Column(scale=2):
            # Thinking process in collapsible accordion
            with gr.Accordion("🧠 AI Thinking Process", open=False) as thinking_accordion:
                thinking_display = gr.Markdown(
                    value="*Thinking process will appear here when the AI is reasoning through your question...*",
                    label="Thinking",
                    height=300,
                )
            
            # Documentation sources in collapsible accordion
            with gr.Accordion("📚 Documentation Sources", open=False) as sources_accordion:
                sources_display = gr.Markdown(
                    value="*Documentation sources will appear here when RAG search is performed...*",
                    label="Sources",
                    height=300,
                )
            
            # Settings
            with gr.Accordion("⚙️ Settings", open=True):
                system_message = gr.Textbox(
                    value=get_default_system_prompt(),
                    label="System Message",
                    lines=5,
                    info="Customize the assistant's behavior"
                )
                
                max_tokens = gr.Slider(
                    minimum=1,
                    maximum=4096,
                    value=1536,
                    step=1,
                    label="Max New Tokens",
                    info="Maximum length of generated response"
                )
                
                temperature = gr.Slider(
                    minimum=0.0,
                    maximum=2.0,
                    value=0.7,
                    step=0.1,
                    label="Temperature",
                    info="Controls randomness (0=deterministic, higher=more creative)"
                )
                
                top_p = gr.Slider(
                    minimum=0.1,
                    maximum=1.0,
                    value=0.95,
                    step=0.05,
                    label="Top-p (Nucleus Sampling)",
                    info="Controls diversity of responses"
                )
    
    # Chat functionality
    def user_submit(message, history):
        if not message:
            return "", history
        history = history + [{"role": "user", "content": message}]
        return "", history
    
    def bot_respond(history, system_msg, max_tok, temp, top_p_val):
        if not history or history[-1]["role"] != "user":
            yield history, "*No thinking process*", "*No sources*"
            return
        
        user_message = history[-1]["content"]
        history_for_model = history[:-1]  # Exclude the last user message
        
        # Add placeholder for assistant response
        history = history + [{"role": "assistant", "content": ""}]
        
        thinking_text = ""
        sources_text = ""
        is_thinking = False
        has_main_content = False
        is_searching = False
        
        for thinking, response, sources in respond(
            user_message,
            history_for_model,
            system_msg,
            max_tok,
            temp,
            top_p_val
        ):
            # Update thinking display
            if thinking:
                if "Searching documentation" in thinking:
                    thinking_text = thinking
                    is_searching = True
                elif "Documentation retrieved" in thinking:
                    thinking_text = thinking
                    is_searching = False
                else:
                    thinking_text = f"## Reasoning Process\n\n{thinking}"
                    is_thinking = True
            elif not thinking and not is_searching:
                thinking_text = "*Waiting for response...*"
            
            # Update sources display
            if sources:
                sources_text = sources
            
            # Update chat response
            if response and response.strip():
                # We have actual response content
                if "Preparing to search" in response or "Generating response" in response:
                    # Status messages
                    history[-1]["content"] = response
                else:
                    # Actual content
                    history[-1]["content"] = response
                    has_main_content = True
            elif is_thinking and not has_main_content:
                # Still thinking, no main response yet
                history[-1]["content"] = "🤔 *AI is thinking... Check the right pane for thinking details*"
            elif is_searching:
                history[-1]["content"] = "🔍 *Searching documentation...*"
            elif not response:
                # No response yet and no thinking detected
                history[-1]["content"] = "⏳ *Generating response...*"
            
            yield history, thinking_text, sources_text
    
    # Event handlers
    msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
        bot_respond, 
        [chatbot, system_message, max_tokens, temperature, top_p],
        [chatbot, thinking_display, sources_display]
    )
    
    submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
        bot_respond,
        [chatbot, system_message, max_tokens, temperature, top_p],
        [chatbot, thinking_display, sources_display]
    )
    
    clear_btn.click(
        lambda: (
            [], 
            "*Thinking process will appear here when the AI is reasoning through your question...*",
            "*Documentation sources will appear here when RAG search is performed...*"
        ), 
        outputs=[chatbot, thinking_display, sources_display]
    )
    
    # RAG status indicator
    rag_status = "✅ Documentation search enabled" if retriever else "⚠️ Documentation search disabled (run generate_db.py)"
    
    gr.Markdown(
        f"""
        ---
        ### About
        This assistant uses `jfang/gprmax-ft-Qwen3-4B-Instruct`, a model fine-tuned specifically for gprMax support.
        
        **RAG Status**: {rag_status}
        
        **Note**: For best results, be specific about your gprMax version and simulation requirements.
        """
    )
if __name__ == "__main__":
    demo.launch()