import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import spaces import re from typing import List, Dict, Tuple, Optional import sys from pathlib import Path # Add rag-db to path for imports sys.path.append(str(Path(__file__).parent / "rag-db")) from retriever import create_retriever, GprMaxRAGRetriever # Initialize model and tokenizer MODEL_NAME = "jfang/gprmax-ft-Qwen3-4B-Instruct" DEVICE = "cuda" if torch.cuda.is_available() else "cpu" print(f"Loading model: {MODEL_NAME}") print(f"Using device: {DEVICE}") tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, device_map="auto", trust_remote_code=True ) # Initialize RAG retriever RAG_DB_PATH = Path(__file__).parent / "rag-db" / "chroma_db" retriever: Optional[GprMaxRAGRetriever] = None def generate_database_if_needed(): """Generate the RAG database if it doesn't exist""" if not RAG_DB_PATH.exists(): print("=" * 60) print("RAG database not found. Generating database...") print("This is a one-time process and may take a few minutes.") print("=" * 60) import subprocess try: # Run the generation script result = subprocess.run( ["python", str(Path(__file__).parent / "rag-db" / "generate_db.py")], capture_output=True, text=True, check=True ) print(result.stdout) print("✅ Database generated successfully!") return True except subprocess.CalledProcessError as e: print(f"❌ Failed to generate database: {e}") if e.stderr: print(f"Error output: {e.stderr}") return False return True # Generate database if needed and load retriever if generate_database_if_needed(): try: print(f"Loading RAG database from {RAG_DB_PATH}") retriever = create_retriever(db_path=RAG_DB_PATH) print("RAG database loaded successfully") except Exception as e: print(f"Error loading RAG database: {e}") print("RAG features will be disabled.") retriever = None else: print("RAG features will be disabled due to database generation failure.") retriever = None @spaces.GPU(duration=60) def generate_response_stream( message: str, history: List[Dict[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, ): """ Generate streaming response using the fine-tuned Qwen3 model. Returns both thinking content and main response separately. Args: message: User's input message history: Conversation history system_message: System prompt max_tokens: Maximum tokens to generate temperature: Sampling temperature top_p: Nucleus sampling parameter Yields: Tuple of (thinking_content, response_content) """ # Construct messages for Qwen3 format messages = [] if system_message: messages.append({"role": "system", "content": system_message}) # Add conversation history for msg in history: if msg.get("role") and msg.get("content"): messages.append({"role": msg["role"], "content": msg["content"]}) # Add current user message with thinking instruction thinking_instruction = "Please think step by step about this problem, showing your reasoning process." messages.append({"role": "user", "content": f"{thinking_instruction}\n\n{message}"}) # Prepare input using Qwen's chat template text = tokenizer.apply_chat_template( messages, tokenize=False, add_generation_prompt=True ) inputs = tokenizer([text], return_tensors="pt").to(model.device) # Setup streaming from transformers import TextIteratorStreamer from threading import Thread streamer = TextIteratorStreamer( tokenizer, skip_prompt=True, skip_special_tokens=True, timeout=60.0 ) # Generation kwargs generation_kwargs = dict( **inputs, max_new_tokens=max_tokens, temperature=temperature, top_p=top_p, do_sample=True if temperature > 0 else False, pad_token_id=tokenizer.eos_token_id, eos_token_id=tokenizer.eos_token_id, streamer=streamer ) # Start generation in a separate thread thread = Thread(target=model.generate, kwargs=generation_kwargs) thread.start() # Collect and yield tokens full_response = "" thinking_buffer = "" main_buffer = "" in_thinking = False thinking_complete = False think_start_seen = False for new_text in streamer: full_response += new_text # Check if we're starting to see a thinking block if "" in full_response and not think_start_seen: think_start_seen = True think_start_idx = full_response.find("") # Any content before is main content main_buffer = full_response[:think_start_idx] in_thinking = True # Start capturing everything after thinking_buffer = full_response[think_start_idx + 7:] # Skip "" itself # Yield what we have so far yield ("", main_buffer) continue # If we're in thinking mode if in_thinking and not thinking_complete: # Update thinking buffer with latest content current_pos = full_response.find("") + 7 thinking_buffer = full_response[current_pos:] # Check if thinking is complete if "" in thinking_buffer: # Extract content before end_idx = thinking_buffer.find("") final_thinking = thinking_buffer[:end_idx].strip() # Get content after after_thinking = thinking_buffer[end_idx + 8:] # Skip "" main_buffer = main_buffer + after_thinking thinking_complete = True in_thinking = False # Yield final thinking and updated main content yield (final_thinking, main_buffer) else: # Still accumulating thinking content - stream it in real-time # Remove any partial .*?', '', full_response, flags=re.DOTALL) main_buffer = clean_response else: # No thinking tags seen yet, everything is main content main_buffer = full_response # Get the final thinking content if it exists if thinking_complete: think_match = re.search(r'(.*?)', full_response, re.DOTALL) if think_match: final_thinking = think_match.group(1).strip() yield (final_thinking, main_buffer) else: yield ("", main_buffer) else: yield ("", main_buffer) # Final cleanup - handle incomplete thinking blocks if in_thinking: # Generation ended while in thinking mode - likely hit max token limit incomplete_thinking_msg = thinking_buffer.strip() if thinking_buffer else "" if incomplete_thinking_msg: # Add warning about incomplete thinking incomplete_thinking_msg += "\n\n⚠️ **Thinking was cut off due to token limit. Try increasing 'Max New Tokens' in settings.**" # Show incomplete thinking warning in main response too error_msg = "⚠️ *AI's thinking process was interrupted due to token limit. Try increasing 'Max New Tokens' and retry.*" yield (incomplete_thinking_msg, error_msg) thread.join() # Tool definitions in Qwen3 format TOOLS = [ { "type": "function", "function": { "name": "search_documentation", "description": "Search gprMax documentation for relevant information about commands, syntax, parameters, or usage", "parameters": { "type": "object", "properties": { "query": { "type": "string", "description": "The search query to find relevant documentation" }, "num_results": { "type": "integer", "description": "Number of results to return", "default": 10 } }, "required": ["query"] } } } ] def format_tools_prompt() -> str: """Format tools for inclusion in system prompt""" import json return json.dumps(TOOLS, indent=2) def perform_rag_search(query: str, k: int = 10) -> Tuple[str, List[Dict]]: """ Perform RAG search and return formatted context and sources Returns: Tuple of (context_for_llm, source_list_for_display) """ if not retriever: print(f"[DEBUG] Retriever is None!") return "", [] try: print(f"[DEBUG] Searching for: '{query}' with k={k}") # Search for relevant documents results = retriever.search(query, k=k) print(f"[DEBUG] Search returned {len(results) if results else 0} results") if not results: return "", [] # Format context for LLM - pass all text content context_parts = [] source_list = [] for i, result in enumerate(results, 1): # Add full text to context for LLM (up to 1000 chars per doc) context_parts.append(f"[Document {i}]: {result.text}") # Add to source list for display (limited preview) source_list.append({ "index": i, "source": result.metadata.get("source", "Unknown"), "score": result.score, "preview": result.text[:150] + "..." if len(result.text) > 150 else result.text }) context = "\n\n".join(context_parts) return context, source_list except Exception as e: print(f"[DEBUG] RAG search error: {e}") import traceback traceback.print_exc() return "", [] def respond( message: str, history: List[Dict[str, str]], system_message: str, max_tokens: int, temperature: float, top_p: float, ): """ Response function with proper Qwen3 tool calling """ import json import re sources_content = "" try: # Use system message as-is (already has tools included) system_with_tools = system_message # First, get initial response from model to see if it wants to use tools tool_call = None accumulated_response = "" final_thinking = "" is_complete = False # Collect the full response (thinking + potential tool call) for thinking, response in generate_response_stream( message=message, history=history, system_message=system_with_tools, max_tokens=max_tokens, temperature=temperature, top_p=top_p, ): final_thinking = thinking if thinking else final_thinking accumulated_response = response # Show thinking progress only if thinking: yield thinking, "⏳ *AI is analyzing your request...*", sources_content # After streaming completes, check what we got if accumulated_response and accumulated_response.strip(): # Check if the complete response is a JSON tool call if accumulated_response.strip().startswith('{'): try: # Try to parse the entire response as JSON response_json = json.loads(accumulated_response.strip()) if "tool_call" in response_json or ("thought" in response_json and "tool_call" in response_json): tool_call = response_json.get("tool_call") or response_json["tool_call"] # Show status that we're processing the tool call yield final_thinking, "🔍 *Processing documentation search request...*", sources_content is_complete = True except json.JSONDecodeError: # Invalid JSON, treat as normal response yield final_thinking, accumulated_response, sources_content is_complete = True except Exception: yield final_thinking, accumulated_response, sources_content is_complete = True else: # It's a normal text response, not a tool call yield final_thinking, accumulated_response, sources_content is_complete = True # If tool was called, execute it if tool_call and retriever: tool_name = tool_call.get("name") print(f"[DEBUG] Tool called: {tool_name}") print(f"[DEBUG] Tool call details: {tool_call}") if tool_name == "search_documentation": # Update status yield "🔍 *Searching documentation...*", "⏳ *Preparing to search...*", "📚 *Retrieving relevant documents...*" # Get search query query = tool_call.get("arguments", {}).get("query", message) num_results = tool_call.get("arguments", {}).get("num_results", 10) print(f"[DEBUG] Query extracted: '{query}', num_results: {num_results}") # Perform search context, sources_list = perform_rag_search(query, k=num_results) print(f"[DEBUG] Search results - Context length: {len(context)}, Sources: {len(sources_list)}") if context: # Format sources for display if sources_list: sources_parts = ["## 📚 Documentation Sources\n"] for source in sources_list: sources_parts.append( f"**[{source['index']}] {source['source']}** (Score: {source['score']:.3f})\n" f"```\n{source['preview']}\n```\n" ) sources_content = "\n".join(sources_parts) else: sources_content = "*No relevant documentation found*" yield "✅ *Documentation retrieved*", "⏳ *Generating response with context...*", sources_content # Now generate response with the retrieved context augmented_message = f"""Tool call result for search_documentation: {context} Original question: {message} Please provide a comprehensive answer based on the documentation above.""" # Generate final response with context for thinking, response in generate_response_stream( message=augmented_message, history=history, system_message=system_message, # Use original system message for final response max_tokens=max_tokens, temperature=temperature, top_p=top_p, ): yield thinking, response, sources_content else: sources_content = "*No relevant documentation found*" yield final_thinking, "⚠️ *Unable to retrieve documentation. Providing general answer...*", sources_content # Generate response without documentation context fallback_message = f"""The user asked about: {message} No relevant documentation was found in the database. Please provide a helpful answer based on your general knowledge of gprMax.""" for thinking, response in generate_response_stream( message=fallback_message, history=history, system_message=system_message, max_tokens=max_tokens, temperature=temperature, top_p=top_p, ): yield thinking, response, sources_content # If tool was called but retriever is not available elif tool_call and not retriever: yield final_thinking, "⚠️ *Documentation search is not available. Providing answer based on general knowledge...*", "" # Generate response without RAG for thinking, response in generate_response_stream( message=message, history=history, system_message=system_message, max_tokens=max_tokens, temperature=temperature, top_p=top_p, ): yield thinking, response, "" # If no tool call and response wasn't already yielded elif not tool_call and not is_complete: # This shouldn't happen but handle it just in case if accumulated_response and not accumulated_response.strip().startswith('{'): yield final_thinking, accumulated_response, sources_content except Exception as e: error_message = f"❌ Error generating response: {str(e)}" yield "", error_message, "" # Default system prompt for gprMax assistance def get_default_system_prompt(): """Get system prompt with tools formatted""" tools_json = format_tools_prompt() return f"""You are a helpful assistant specialized in gprMax, an open-source software that simulates electromagnetic wave propagation. You help users with: 1. Creating gprMax input files (.in files) 2. Understanding gprMax commands and syntax 3. Setting up simulations for GPR (Ground Penetrating Radar) and other EM applications 4. Troubleshooting simulation issues 5. Optimizing simulation parameters You have access to the following tools: {tools_json} When you need to search documentation, respond with a tool call in this JSON format: {{ "thought": "I need to search the documentation for...", "tool_call": {{ "name": "search_documentation", "arguments": {{ "query": "your search query here" }} }} }} After receiving tool results, provide a comprehensive answer based on the documentation. If you give code blocks, ensure to enclose them inside ```. There is no need to always give full input codes, be sure to understand what user needs and intends to do. Some times a simple line of code can do, sometimes user wants explanation rather than codes. Provide clear, accurate, and practical guidance for gprMax users.""" # Create custom interface with collapsible thinking section with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo: gr.Markdown( """ # 🛡️ gprMax Support Assistant Welcome to the gprMax Support Assistant powered by a fine-tuned Qwen3-4B model specifically trained for gprMax assistance. ### Features: - 💬 **Expert Guidance**: Get help with gprMax input files, commands, and simulations - 🧠 **Transparent Reasoning**: See the AI's thinking process in a dedicated collapsible section - 📚 **Documentation**: Coming soon - RAG-powered documentation retrieval - 🤖 **Agent Mode**: Coming soon - Automated workflow assistance """ ) with gr.Row(): with gr.Column(scale=3): # Chat interface chatbot = gr.Chatbot( label="Chat", type="messages", height=500, show_copy_button=True, ) with gr.Row(): msg = gr.Textbox( label="Message", placeholder="Ask about gprMax commands, simulations, or troubleshooting...", lines=2, scale=4, ) submit_btn = gr.Button("Send", variant="primary", scale=1) with gr.Row(): clear_btn = gr.Button("🗑️ Clear Chat", scale=1) # Examples gr.Examples( examples=[ "How do I create a basic gprMax input file for a simple GPR simulation?", "What's the difference between #domain and #dx_dy_dz commands?", "How can I model a heterogeneous soil with different dielectric properties?", "My simulation is taking too long. How can I optimize it?", "How do I add a Ricker wavelet source to my model?", ], inputs=msg, label="Example Questions", ) with gr.Column(scale=2): # Thinking process in collapsible accordion with gr.Accordion("🧠 AI Thinking Process", open=False) as thinking_accordion: thinking_display = gr.Markdown( value="*Thinking process will appear here when the AI is reasoning through your question...*", label="Thinking", height=300, ) # Documentation sources in collapsible accordion with gr.Accordion("📚 Documentation Sources", open=False) as sources_accordion: sources_display = gr.Markdown( value="*Documentation sources will appear here when RAG search is performed...*", label="Sources", height=300, ) # Settings with gr.Accordion("⚙️ Settings", open=True): system_message = gr.Textbox( value=get_default_system_prompt(), label="System Message", lines=5, info="Customize the assistant's behavior" ) max_tokens = gr.Slider( minimum=1, maximum=4096, value=1536, step=1, label="Max New Tokens", info="Maximum length of generated response" ) temperature = gr.Slider( minimum=0.0, maximum=2.0, value=0.7, step=0.1, label="Temperature", info="Controls randomness (0=deterministic, higher=more creative)" ) top_p = gr.Slider( minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (Nucleus Sampling)", info="Controls diversity of responses" ) # Chat functionality def user_submit(message, history): if not message: return "", history history = history + [{"role": "user", "content": message}] return "", history def bot_respond(history, system_msg, max_tok, temp, top_p_val): if not history or history[-1]["role"] != "user": yield history, "*No thinking process*", "*No sources*" return user_message = history[-1]["content"] history_for_model = history[:-1] # Exclude the last user message # Add placeholder for assistant response history = history + [{"role": "assistant", "content": ""}] thinking_text = "" sources_text = "" is_thinking = False has_main_content = False is_searching = False for thinking, response, sources in respond( user_message, history_for_model, system_msg, max_tok, temp, top_p_val ): # Update thinking display if thinking: if "Searching documentation" in thinking: thinking_text = thinking is_searching = True elif "Documentation retrieved" in thinking: thinking_text = thinking is_searching = False else: thinking_text = f"## Reasoning Process\n\n{thinking}" is_thinking = True elif not thinking and not is_searching: thinking_text = "*Waiting for response...*" # Update sources display if sources: sources_text = sources # Update chat response if response and response.strip(): # We have actual response content if "Preparing to search" in response or "Generating response" in response: # Status messages history[-1]["content"] = response else: # Actual content history[-1]["content"] = response has_main_content = True elif is_thinking and not has_main_content: # Still thinking, no main response yet history[-1]["content"] = "🤔 *AI is thinking... Check the right pane for thinking details*" elif is_searching: history[-1]["content"] = "🔍 *Searching documentation...*" elif not response: # No response yet and no thinking detected history[-1]["content"] = "⏳ *Generating response...*" yield history, thinking_text, sources_text # Event handlers msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then( bot_respond, [chatbot, system_message, max_tokens, temperature, top_p], [chatbot, thinking_display, sources_display] ) submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then( bot_respond, [chatbot, system_message, max_tokens, temperature, top_p], [chatbot, thinking_display, sources_display] ) clear_btn.click( lambda: ( [], "*Thinking process will appear here when the AI is reasoning through your question...*", "*Documentation sources will appear here when RAG search is performed...*" ), outputs=[chatbot, thinking_display, sources_display] ) # RAG status indicator rag_status = "✅ Documentation search enabled" if retriever else "⚠️ Documentation search disabled (run generate_db.py)" gr.Markdown( f""" --- ### About This assistant uses `jfang/gprmax-ft-Qwen3-4B-Instruct`, a model fine-tuned specifically for gprMax support. **RAG Status**: {rag_status} **Note**: For best results, be specific about your gprMax version and simulation requirements. """ ) if __name__ == "__main__": demo.launch()