Spaces:
Running
on
Zero
Running
on
Zero
| import gradio as gr | |
| import torch | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| import spaces | |
| import re | |
| from typing import List, Dict, Tuple, Optional | |
| import sys | |
| from pathlib import Path | |
| # Add rag-db to path for imports | |
| sys.path.append(str(Path(__file__).parent / "rag-db")) | |
| from retriever import create_retriever, GprMaxRAGRetriever | |
| # Initialize model and tokenizer | |
| MODEL_NAME = "jfang/gprmax-ft-Qwen3-4B-Instruct" | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| print(f"Loading model: {MODEL_NAME}") | |
| print(f"Using device: {DEVICE}") | |
| tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True) | |
| model = AutoModelForCausalLM.from_pretrained( | |
| MODEL_NAME, | |
| torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32, | |
| device_map="auto", | |
| trust_remote_code=True | |
| ) | |
| # Initialize RAG retriever | |
| RAG_DB_PATH = Path(__file__).parent / "rag-db" / "chroma_db" | |
| retriever: Optional[GprMaxRAGRetriever] = None | |
| def generate_database_if_needed(): | |
| """Generate the RAG database if it doesn't exist""" | |
| if not RAG_DB_PATH.exists(): | |
| print("=" * 60) | |
| print("RAG database not found. Generating database...") | |
| print("This is a one-time process and may take a few minutes.") | |
| print("=" * 60) | |
| import subprocess | |
| try: | |
| # Run the generation script | |
| result = subprocess.run( | |
| ["python", str(Path(__file__).parent / "rag-db" / "generate_db.py")], | |
| capture_output=True, | |
| text=True, | |
| check=True | |
| ) | |
| print(result.stdout) | |
| print("β Database generated successfully!") | |
| return True | |
| except subprocess.CalledProcessError as e: | |
| print(f"β Failed to generate database: {e}") | |
| if e.stderr: | |
| print(f"Error output: {e.stderr}") | |
| return False | |
| return True | |
| # Generate database if needed and load retriever | |
| if generate_database_if_needed(): | |
| try: | |
| print(f"Loading RAG database from {RAG_DB_PATH}") | |
| retriever = create_retriever(db_path=RAG_DB_PATH) | |
| print("RAG database loaded successfully") | |
| except Exception as e: | |
| print(f"Error loading RAG database: {e}") | |
| print("RAG features will be disabled.") | |
| retriever = None | |
| else: | |
| print("RAG features will be disabled due to database generation failure.") | |
| retriever = None | |
| def generate_response_stream( | |
| message: str, | |
| history: List[Dict[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ): | |
| """ | |
| Generate streaming response using the fine-tuned Qwen3 model. | |
| Returns both thinking content and main response separately. | |
| Args: | |
| message: User's input message | |
| history: Conversation history | |
| system_message: System prompt | |
| max_tokens: Maximum tokens to generate | |
| temperature: Sampling temperature | |
| top_p: Nucleus sampling parameter | |
| Yields: | |
| Tuple of (thinking_content, response_content) | |
| """ | |
| # Construct messages for Qwen3 format | |
| messages = [] | |
| if system_message: | |
| messages.append({"role": "system", "content": system_message}) | |
| # Add conversation history | |
| for msg in history: | |
| if msg.get("role") and msg.get("content"): | |
| messages.append({"role": msg["role"], "content": msg["content"]}) | |
| # Add current user message with thinking instruction | |
| thinking_instruction = "Please think step by step about this problem, showing your reasoning process." | |
| messages.append({"role": "user", "content": f"{thinking_instruction}\n\n{message}"}) | |
| # Prepare input using Qwen's chat template | |
| text = tokenizer.apply_chat_template( | |
| messages, | |
| tokenize=False, | |
| add_generation_prompt=True | |
| ) | |
| inputs = tokenizer([text], return_tensors="pt").to(model.device) | |
| # Setup streaming | |
| from transformers import TextIteratorStreamer | |
| from threading import Thread | |
| streamer = TextIteratorStreamer( | |
| tokenizer, | |
| skip_prompt=True, | |
| skip_special_tokens=True, | |
| timeout=60.0 | |
| ) | |
| # Generation kwargs | |
| generation_kwargs = dict( | |
| **inputs, | |
| max_new_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| do_sample=True if temperature > 0 else False, | |
| pad_token_id=tokenizer.eos_token_id, | |
| eos_token_id=tokenizer.eos_token_id, | |
| streamer=streamer | |
| ) | |
| # Start generation in a separate thread | |
| thread = Thread(target=model.generate, kwargs=generation_kwargs) | |
| thread.start() | |
| # Collect and yield tokens | |
| full_response = "" | |
| thinking_buffer = "" | |
| main_buffer = "" | |
| in_thinking = False | |
| thinking_complete = False | |
| think_start_seen = False | |
| for new_text in streamer: | |
| full_response += new_text | |
| # Check if we're starting to see a thinking block | |
| if "<think>" in full_response and not think_start_seen: | |
| think_start_seen = True | |
| think_start_idx = full_response.find("<think>") | |
| # Any content before <think> is main content | |
| main_buffer = full_response[:think_start_idx] | |
| in_thinking = True | |
| # Start capturing everything after <think> | |
| thinking_buffer = full_response[think_start_idx + 7:] # Skip "<think>" itself | |
| # Yield what we have so far | |
| yield ("", main_buffer) | |
| continue | |
| # If we're in thinking mode | |
| if in_thinking and not thinking_complete: | |
| # Update thinking buffer with latest content | |
| current_pos = full_response.find("<think>") + 7 | |
| thinking_buffer = full_response[current_pos:] | |
| # Check if thinking is complete | |
| if "</think>" in thinking_buffer: | |
| # Extract content before </think> | |
| end_idx = thinking_buffer.find("</think>") | |
| final_thinking = thinking_buffer[:end_idx].strip() | |
| # Get content after </think> | |
| after_thinking = thinking_buffer[end_idx + 8:] # Skip "</think>" | |
| main_buffer = main_buffer + after_thinking | |
| thinking_complete = True | |
| in_thinking = False | |
| # Yield final thinking and updated main content | |
| yield (final_thinking, main_buffer) | |
| else: | |
| # Still accumulating thinking content - stream it in real-time | |
| # Remove any partial </thi or similar incomplete tags at the end | |
| display_thinking = thinking_buffer | |
| if display_thinking.endswith("</") or display_thinking.endswith("</t") or \ | |
| display_thinking.endswith("</th") or display_thinking.endswith("</thi") or \ | |
| display_thinking.endswith("</thin"): | |
| # Don't show partial closing tag | |
| display_thinking = display_thinking[:display_thinking.rfind("<")] | |
| yield (display_thinking.strip(), main_buffer) | |
| continue | |
| # Normal streaming after thinking is complete or if no thinking at all | |
| if thinking_complete or not think_start_seen: | |
| if thinking_complete: | |
| # Remove the entire thinking block and stream the rest | |
| clean_response = re.sub(r'<think>.*?</think>', '', full_response, flags=re.DOTALL) | |
| main_buffer = clean_response | |
| else: | |
| # No thinking tags seen yet, everything is main content | |
| main_buffer = full_response | |
| # Get the final thinking content if it exists | |
| if thinking_complete: | |
| think_match = re.search(r'<think>(.*?)</think>', full_response, re.DOTALL) | |
| if think_match: | |
| final_thinking = think_match.group(1).strip() | |
| yield (final_thinking, main_buffer) | |
| else: | |
| yield ("", main_buffer) | |
| else: | |
| yield ("", main_buffer) | |
| # Final cleanup - handle incomplete thinking blocks | |
| if in_thinking: | |
| # Generation ended while in thinking mode - likely hit max token limit | |
| incomplete_thinking_msg = thinking_buffer.strip() if thinking_buffer else "" | |
| if incomplete_thinking_msg: | |
| # Add warning about incomplete thinking | |
| incomplete_thinking_msg += "\n\nβ οΈ **Thinking was cut off due to token limit. Try increasing 'Max New Tokens' in settings.**" | |
| # Show incomplete thinking warning in main response too | |
| error_msg = "β οΈ *AI's thinking process was interrupted due to token limit. Try increasing 'Max New Tokens' and retry.*" | |
| yield (incomplete_thinking_msg, error_msg) | |
| thread.join() | |
| # Tool definitions in Qwen3 format | |
| TOOLS = [ | |
| { | |
| "type": "function", | |
| "function": { | |
| "name": "search_documentation", | |
| "description": "Search gprMax documentation for relevant information about commands, syntax, parameters, or usage", | |
| "parameters": { | |
| "type": "object", | |
| "properties": { | |
| "query": { | |
| "type": "string", | |
| "description": "The search query to find relevant documentation" | |
| }, | |
| "num_results": { | |
| "type": "integer", | |
| "description": "Number of results to return", | |
| "default": 10 | |
| } | |
| }, | |
| "required": ["query"] | |
| } | |
| } | |
| } | |
| ] | |
| def format_tools_prompt() -> str: | |
| """Format tools for inclusion in system prompt""" | |
| import json | |
| return json.dumps(TOOLS, indent=2) | |
| def perform_rag_search(query: str, k: int = 10) -> Tuple[str, List[Dict]]: | |
| """ | |
| Perform RAG search and return formatted context and sources | |
| Returns: | |
| Tuple of (context_for_llm, source_list_for_display) | |
| """ | |
| if not retriever: | |
| print(f"[DEBUG] Retriever is None!") | |
| return "", [] | |
| try: | |
| print(f"[DEBUG] Searching for: '{query}' with k={k}") | |
| # Search for relevant documents | |
| results = retriever.search(query, k=k) | |
| print(f"[DEBUG] Search returned {len(results) if results else 0} results") | |
| if not results: | |
| return "", [] | |
| # Format context for LLM - pass all text content | |
| context_parts = [] | |
| source_list = [] | |
| for i, result in enumerate(results, 1): | |
| # Add full text to context for LLM (up to 1000 chars per doc) | |
| context_parts.append(f"[Document {i}]: {result.text}") | |
| # Add to source list for display (limited preview) | |
| source_list.append({ | |
| "index": i, | |
| "source": result.metadata.get("source", "Unknown"), | |
| "score": result.score, | |
| "preview": result.text[:150] + "..." if len(result.text) > 150 else result.text | |
| }) | |
| context = "\n\n".join(context_parts) | |
| return context, source_list | |
| except Exception as e: | |
| print(f"[DEBUG] RAG search error: {e}") | |
| import traceback | |
| traceback.print_exc() | |
| return "", [] | |
| def respond( | |
| message: str, | |
| history: List[Dict[str, str]], | |
| system_message: str, | |
| max_tokens: int, | |
| temperature: float, | |
| top_p: float, | |
| ): | |
| """ | |
| Response function with proper Qwen3 tool calling | |
| """ | |
| import json | |
| import re | |
| sources_content = "" | |
| try: | |
| # Use system message as-is (already has tools included) | |
| system_with_tools = system_message | |
| # First, get initial response from model to see if it wants to use tools | |
| tool_call = None | |
| accumulated_response = "" | |
| final_thinking = "" | |
| is_complete = False | |
| # Collect the full response (thinking + potential tool call) | |
| for thinking, response in generate_response_stream( | |
| message=message, | |
| history=history, | |
| system_message=system_with_tools, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| final_thinking = thinking if thinking else final_thinking | |
| accumulated_response = response | |
| # Show thinking progress only | |
| if thinking: | |
| yield thinking, "β³ *AI is analyzing your request...*", sources_content | |
| # After streaming completes, check what we got | |
| if accumulated_response and accumulated_response.strip(): | |
| # Check if the complete response is a JSON tool call | |
| if accumulated_response.strip().startswith('{'): | |
| try: | |
| # Try to parse the entire response as JSON | |
| response_json = json.loads(accumulated_response.strip()) | |
| if "tool_call" in response_json or ("thought" in response_json and "tool_call" in response_json): | |
| tool_call = response_json.get("tool_call") or response_json["tool_call"] | |
| # Show status that we're processing the tool call | |
| yield final_thinking, "π *Processing documentation search request...*", sources_content | |
| is_complete = True | |
| except json.JSONDecodeError: | |
| # Invalid JSON, treat as normal response | |
| yield final_thinking, accumulated_response, sources_content | |
| is_complete = True | |
| except Exception: | |
| yield final_thinking, accumulated_response, sources_content | |
| is_complete = True | |
| else: | |
| # It's a normal text response, not a tool call | |
| yield final_thinking, accumulated_response, sources_content | |
| is_complete = True | |
| # If tool was called, execute it | |
| if tool_call and retriever: | |
| tool_name = tool_call.get("name") | |
| print(f"[DEBUG] Tool called: {tool_name}") | |
| print(f"[DEBUG] Tool call details: {tool_call}") | |
| if tool_name == "search_documentation": | |
| # Update status | |
| yield "π *Searching documentation...*", "β³ *Preparing to search...*", "π *Retrieving relevant documents...*" | |
| # Get search query | |
| query = tool_call.get("arguments", {}).get("query", message) | |
| num_results = tool_call.get("arguments", {}).get("num_results", 10) | |
| print(f"[DEBUG] Query extracted: '{query}', num_results: {num_results}") | |
| # Perform search | |
| context, sources_list = perform_rag_search(query, k=num_results) | |
| print(f"[DEBUG] Search results - Context length: {len(context)}, Sources: {len(sources_list)}") | |
| if context: | |
| # Format sources for display | |
| if sources_list: | |
| sources_parts = ["## π Documentation Sources\n"] | |
| for source in sources_list: | |
| sources_parts.append( | |
| f"**[{source['index']}] {source['source']}** (Score: {source['score']:.3f})\n" | |
| f"```\n{source['preview']}\n```\n" | |
| ) | |
| sources_content = "\n".join(sources_parts) | |
| else: | |
| sources_content = "*No relevant documentation found*" | |
| yield "β *Documentation retrieved*", "β³ *Generating response with context...*", sources_content | |
| # Now generate response with the retrieved context | |
| augmented_message = f"""Tool call result for search_documentation: | |
| {context} | |
| Original question: {message} | |
| Please provide a comprehensive answer based on the documentation above.""" | |
| # Generate final response with context | |
| for thinking, response in generate_response_stream( | |
| message=augmented_message, | |
| history=history, | |
| system_message=system_message, # Use original system message for final response | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| yield thinking, response, sources_content | |
| else: | |
| sources_content = "*No relevant documentation found*" | |
| yield final_thinking, "β οΈ *Unable to retrieve documentation. Providing general answer...*", sources_content | |
| # Generate response without documentation context | |
| fallback_message = f"""The user asked about: {message} | |
| No relevant documentation was found in the database. Please provide a helpful answer based on your general knowledge of gprMax.""" | |
| for thinking, response in generate_response_stream( | |
| message=fallback_message, | |
| history=history, | |
| system_message=system_message, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| yield thinking, response, sources_content | |
| # If tool was called but retriever is not available | |
| elif tool_call and not retriever: | |
| yield final_thinking, "β οΈ *Documentation search is not available. Providing answer based on general knowledge...*", "" | |
| # Generate response without RAG | |
| for thinking, response in generate_response_stream( | |
| message=message, | |
| history=history, | |
| system_message=system_message, | |
| max_tokens=max_tokens, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| yield thinking, response, "" | |
| # If no tool call and response wasn't already yielded | |
| elif not tool_call and not is_complete: | |
| # This shouldn't happen but handle it just in case | |
| if accumulated_response and not accumulated_response.strip().startswith('{'): | |
| yield final_thinking, accumulated_response, sources_content | |
| except Exception as e: | |
| error_message = f"β Error generating response: {str(e)}" | |
| yield "", error_message, "" | |
| # Default system prompt for gprMax assistance | |
| def get_default_system_prompt(): | |
| """Get system prompt with tools formatted""" | |
| tools_json = format_tools_prompt() | |
| return f"""You are a helpful assistant specialized in gprMax, an open-source software that simulates electromagnetic wave propagation. You help users with: | |
| 1. Creating gprMax input files (.in files) | |
| 2. Understanding gprMax commands and syntax | |
| 3. Setting up simulations for GPR (Ground Penetrating Radar) and other EM applications | |
| 4. Troubleshooting simulation issues | |
| 5. Optimizing simulation parameters | |
| You have access to the following tools: | |
| {tools_json} | |
| When you need to search documentation, respond with a tool call in this JSON format: | |
| {{ | |
| "thought": "I need to search the documentation for...", | |
| "tool_call": {{ | |
| "name": "search_documentation", | |
| "arguments": {{ | |
| "query": "your search query here" | |
| }} | |
| }} | |
| }} | |
| After receiving tool results, provide a comprehensive answer based on the documentation. | |
| If you give code blocks, ensure to enclose them inside ```. | |
| There is no need to always give full input codes, be sure to understand what user needs and intends to do. Some times a simple line of code can do, sometimes user wants explanation rather than codes. | |
| Provide clear, accurate, and practical guidance for gprMax users.""" | |
| # Create custom interface with collapsible thinking section | |
| with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo: | |
| gr.Markdown( | |
| """ | |
| # π‘οΈ gprMax Support Assistant | |
| Welcome to the gprMax Support Assistant powered by a fine-tuned Qwen3-4B model specifically trained for gprMax assistance. | |
| ### Features: | |
| - π¬ **Expert Guidance**: Get help with gprMax input files, commands, and simulations | |
| - π§ **Transparent Reasoning**: See the AI's thinking process in a dedicated collapsible section | |
| - π **Documentation**: Coming soon - RAG-powered documentation retrieval | |
| - π€ **Agent Mode**: Coming soon - Automated workflow assistance | |
| """ | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=3): | |
| # Chat interface | |
| chatbot = gr.Chatbot( | |
| label="Chat", | |
| type="messages", | |
| height=500, | |
| show_copy_button=True, | |
| ) | |
| with gr.Row(): | |
| msg = gr.Textbox( | |
| label="Message", | |
| placeholder="Ask about gprMax commands, simulations, or troubleshooting...", | |
| lines=2, | |
| scale=4, | |
| ) | |
| submit_btn = gr.Button("Send", variant="primary", scale=1) | |
| with gr.Row(): | |
| clear_btn = gr.Button("ποΈ Clear Chat", scale=1) | |
| # Examples | |
| gr.Examples( | |
| examples=[ | |
| "How do I create a basic gprMax input file for a simple GPR simulation?", | |
| "What's the difference between #domain and #dx_dy_dz commands?", | |
| "How can I model a heterogeneous soil with different dielectric properties?", | |
| "My simulation is taking too long. How can I optimize it?", | |
| "How do I add a Ricker wavelet source to my model?", | |
| ], | |
| inputs=msg, | |
| label="Example Questions", | |
| ) | |
| with gr.Column(scale=2): | |
| # Thinking process in collapsible accordion | |
| with gr.Accordion("π§ AI Thinking Process", open=False) as thinking_accordion: | |
| thinking_display = gr.Markdown( | |
| value="*Thinking process will appear here when the AI is reasoning through your question...*", | |
| label="Thinking", | |
| height=300, | |
| ) | |
| # Documentation sources in collapsible accordion | |
| with gr.Accordion("π Documentation Sources", open=False) as sources_accordion: | |
| sources_display = gr.Markdown( | |
| value="*Documentation sources will appear here when RAG search is performed...*", | |
| label="Sources", | |
| height=300, | |
| ) | |
| # Settings | |
| with gr.Accordion("βοΈ Settings", open=True): | |
| system_message = gr.Textbox( | |
| value=get_default_system_prompt(), | |
| label="System Message", | |
| lines=5, | |
| info="Customize the assistant's behavior" | |
| ) | |
| max_tokens = gr.Slider( | |
| minimum=1, | |
| maximum=4096, | |
| value=1536, | |
| step=1, | |
| label="Max New Tokens", | |
| info="Maximum length of generated response" | |
| ) | |
| temperature = gr.Slider( | |
| minimum=0.0, | |
| maximum=2.0, | |
| value=0.7, | |
| step=0.1, | |
| label="Temperature", | |
| info="Controls randomness (0=deterministic, higher=more creative)" | |
| ) | |
| top_p = gr.Slider( | |
| minimum=0.1, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.05, | |
| label="Top-p (Nucleus Sampling)", | |
| info="Controls diversity of responses" | |
| ) | |
| # Chat functionality | |
| def user_submit(message, history): | |
| if not message: | |
| return "", history | |
| history = history + [{"role": "user", "content": message}] | |
| return "", history | |
| def bot_respond(history, system_msg, max_tok, temp, top_p_val): | |
| if not history or history[-1]["role"] != "user": | |
| yield history, "*No thinking process*", "*No sources*" | |
| return | |
| user_message = history[-1]["content"] | |
| history_for_model = history[:-1] # Exclude the last user message | |
| # Add placeholder for assistant response | |
| history = history + [{"role": "assistant", "content": ""}] | |
| thinking_text = "" | |
| sources_text = "" | |
| is_thinking = False | |
| has_main_content = False | |
| is_searching = False | |
| for thinking, response, sources in respond( | |
| user_message, | |
| history_for_model, | |
| system_msg, | |
| max_tok, | |
| temp, | |
| top_p_val | |
| ): | |
| # Update thinking display | |
| if thinking: | |
| if "Searching documentation" in thinking: | |
| thinking_text = thinking | |
| is_searching = True | |
| elif "Documentation retrieved" in thinking: | |
| thinking_text = thinking | |
| is_searching = False | |
| else: | |
| thinking_text = f"## Reasoning Process\n\n{thinking}" | |
| is_thinking = True | |
| elif not thinking and not is_searching: | |
| thinking_text = "*Waiting for response...*" | |
| # Update sources display | |
| if sources: | |
| sources_text = sources | |
| # Update chat response | |
| if response and response.strip(): | |
| # We have actual response content | |
| if "Preparing to search" in response or "Generating response" in response: | |
| # Status messages | |
| history[-1]["content"] = response | |
| else: | |
| # Actual content | |
| history[-1]["content"] = response | |
| has_main_content = True | |
| elif is_thinking and not has_main_content: | |
| # Still thinking, no main response yet | |
| history[-1]["content"] = "π€ *AI is thinking... Check the right pane for thinking details*" | |
| elif is_searching: | |
| history[-1]["content"] = "π *Searching documentation...*" | |
| elif not response: | |
| # No response yet and no thinking detected | |
| history[-1]["content"] = "β³ *Generating response...*" | |
| yield history, thinking_text, sources_text | |
| # Event handlers | |
| msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then( | |
| bot_respond, | |
| [chatbot, system_message, max_tokens, temperature, top_p], | |
| [chatbot, thinking_display, sources_display] | |
| ) | |
| submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then( | |
| bot_respond, | |
| [chatbot, system_message, max_tokens, temperature, top_p], | |
| [chatbot, thinking_display, sources_display] | |
| ) | |
| clear_btn.click( | |
| lambda: ( | |
| [], | |
| "*Thinking process will appear here when the AI is reasoning through your question...*", | |
| "*Documentation sources will appear here when RAG search is performed...*" | |
| ), | |
| outputs=[chatbot, thinking_display, sources_display] | |
| ) | |
| # RAG status indicator | |
| rag_status = "β Documentation search enabled" if retriever else "β οΈ Documentation search disabled (run generate_db.py)" | |
| gr.Markdown( | |
| f""" | |
| --- | |
| ### About | |
| This assistant uses `jfang/gprmax-ft-Qwen3-4B-Instruct`, a model fine-tuned specifically for gprMax support. | |
| **RAG Status**: {rag_status} | |
| **Note**: For best results, be specific about your gprMax version and simulation requirements. | |
| """ | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |