import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import spaces
import re
from typing import List, Dict, Tuple, Optional
import sys
from pathlib import Path
# Add rag-db to path for imports
sys.path.append(str(Path(__file__).parent / "rag-db"))
from retriever import create_retriever, GprMaxRAGRetriever
# Initialize model and tokenizer
MODEL_NAME = "jfang/gprmax-ft-Qwen3-4B-Instruct"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Loading model: {MODEL_NAME}")
print(f"Using device: {DEVICE}")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.float16 if DEVICE == "cuda" else torch.float32,
device_map="auto",
trust_remote_code=True
)
# Initialize RAG retriever
RAG_DB_PATH = Path(__file__).parent / "rag-db" / "chroma_db"
retriever: Optional[GprMaxRAGRetriever] = None
def generate_database_if_needed():
"""Generate the RAG database if it doesn't exist"""
if not RAG_DB_PATH.exists():
print("=" * 60)
print("RAG database not found. Generating database...")
print("This is a one-time process and may take a few minutes.")
print("=" * 60)
import subprocess
try:
# Run the generation script
result = subprocess.run(
["python", str(Path(__file__).parent / "rag-db" / "generate_db.py")],
capture_output=True,
text=True,
check=True
)
print(result.stdout)
print("✅ Database generated successfully!")
return True
except subprocess.CalledProcessError as e:
print(f"❌ Failed to generate database: {e}")
if e.stderr:
print(f"Error output: {e.stderr}")
return False
return True
# Generate database if needed and load retriever
if generate_database_if_needed():
try:
print(f"Loading RAG database from {RAG_DB_PATH}")
retriever = create_retriever(db_path=RAG_DB_PATH)
print("RAG database loaded successfully")
except Exception as e:
print(f"Error loading RAG database: {e}")
print("RAG features will be disabled.")
retriever = None
else:
print("RAG features will be disabled due to database generation failure.")
retriever = None
@spaces.GPU(duration=60)
def generate_response_stream(
message: str,
history: List[Dict[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Generate streaming response using the fine-tuned Qwen3 model.
Returns both thinking content and main response separately.
Args:
message: User's input message
history: Conversation history
system_message: System prompt
max_tokens: Maximum tokens to generate
temperature: Sampling temperature
top_p: Nucleus sampling parameter
Yields:
Tuple of (thinking_content, response_content)
"""
# Construct messages for Qwen3 format
messages = []
if system_message:
messages.append({"role": "system", "content": system_message})
# Add conversation history
for msg in history:
if msg.get("role") and msg.get("content"):
messages.append({"role": msg["role"], "content": msg["content"]})
# Add current user message with thinking instruction
thinking_instruction = "Please think step by step about this problem, showing your reasoning process."
messages.append({"role": "user", "content": f"{thinking_instruction}\n\n{message}"})
# Prepare input using Qwen's chat template
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True
)
inputs = tokenizer([text], return_tensors="pt").to(model.device)
# Setup streaming
from transformers import TextIteratorStreamer
from threading import Thread
streamer = TextIteratorStreamer(
tokenizer,
skip_prompt=True,
skip_special_tokens=True,
timeout=60.0
)
# Generation kwargs
generation_kwargs = dict(
**inputs,
max_new_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
do_sample=True if temperature > 0 else False,
pad_token_id=tokenizer.eos_token_id,
eos_token_id=tokenizer.eos_token_id,
streamer=streamer
)
# Start generation in a separate thread
thread = Thread(target=model.generate, kwargs=generation_kwargs)
thread.start()
# Collect and yield tokens
full_response = ""
thinking_buffer = ""
main_buffer = ""
in_thinking = False
thinking_complete = False
think_start_seen = False
for new_text in streamer:
full_response += new_text
# Check if we're starting to see a thinking block
if "" in full_response and not think_start_seen:
think_start_seen = True
think_start_idx = full_response.find("")
# Any content before is main content
main_buffer = full_response[:think_start_idx]
in_thinking = True
# Start capturing everything after
thinking_buffer = full_response[think_start_idx + 7:] # Skip "" itself
# Yield what we have so far
yield ("", main_buffer)
continue
# If we're in thinking mode
if in_thinking and not thinking_complete:
# Update thinking buffer with latest content
current_pos = full_response.find("") + 7
thinking_buffer = full_response[current_pos:]
# Check if thinking is complete
if "" in thinking_buffer:
# Extract content before
end_idx = thinking_buffer.find("")
final_thinking = thinking_buffer[:end_idx].strip()
# Get content after
after_thinking = thinking_buffer[end_idx + 8:] # Skip ""
main_buffer = main_buffer + after_thinking
thinking_complete = True
in_thinking = False
# Yield final thinking and updated main content
yield (final_thinking, main_buffer)
else:
# Still accumulating thinking content - stream it in real-time
# Remove any partial .*?', '', full_response, flags=re.DOTALL)
main_buffer = clean_response
else:
# No thinking tags seen yet, everything is main content
main_buffer = full_response
# Get the final thinking content if it exists
if thinking_complete:
think_match = re.search(r'(.*?)', full_response, re.DOTALL)
if think_match:
final_thinking = think_match.group(1).strip()
yield (final_thinking, main_buffer)
else:
yield ("", main_buffer)
else:
yield ("", main_buffer)
# Final cleanup - handle incomplete thinking blocks
if in_thinking:
# Generation ended while in thinking mode - likely hit max token limit
incomplete_thinking_msg = thinking_buffer.strip() if thinking_buffer else ""
if incomplete_thinking_msg:
# Add warning about incomplete thinking
incomplete_thinking_msg += "\n\n⚠️ **Thinking was cut off due to token limit. Try increasing 'Max New Tokens' in settings.**"
# Show incomplete thinking warning in main response too
error_msg = "⚠️ *AI's thinking process was interrupted due to token limit. Try increasing 'Max New Tokens' and retry.*"
yield (incomplete_thinking_msg, error_msg)
thread.join()
# Tool definitions in Qwen3 format
TOOLS = [
{
"type": "function",
"function": {
"name": "search_documentation",
"description": "Search gprMax documentation for relevant information about commands, syntax, parameters, or usage",
"parameters": {
"type": "object",
"properties": {
"query": {
"type": "string",
"description": "The search query to find relevant documentation"
},
"num_results": {
"type": "integer",
"description": "Number of results to return",
"default": 10
}
},
"required": ["query"]
}
}
}
]
def format_tools_prompt() -> str:
"""Format tools for inclusion in system prompt"""
import json
return json.dumps(TOOLS, indent=2)
def perform_rag_search(query: str, k: int = 10) -> Tuple[str, List[Dict]]:
"""
Perform RAG search and return formatted context and sources
Returns:
Tuple of (context_for_llm, source_list_for_display)
"""
if not retriever:
print(f"[DEBUG] Retriever is None!")
return "", []
try:
print(f"[DEBUG] Searching for: '{query}' with k={k}")
# Search for relevant documents
results = retriever.search(query, k=k)
print(f"[DEBUG] Search returned {len(results) if results else 0} results")
if not results:
return "", []
# Format context for LLM - pass all text content
context_parts = []
source_list = []
for i, result in enumerate(results, 1):
# Add full text to context for LLM (up to 1000 chars per doc)
context_parts.append(f"[Document {i}]: {result.text}")
# Add to source list for display (limited preview)
source_list.append({
"index": i,
"source": result.metadata.get("source", "Unknown"),
"score": result.score,
"preview": result.text[:150] + "..." if len(result.text) > 150 else result.text
})
context = "\n\n".join(context_parts)
return context, source_list
except Exception as e:
print(f"[DEBUG] RAG search error: {e}")
import traceback
traceback.print_exc()
return "", []
def respond(
message: str,
history: List[Dict[str, str]],
system_message: str,
max_tokens: int,
temperature: float,
top_p: float,
):
"""
Response function with proper Qwen3 tool calling
"""
import json
import re
sources_content = ""
try:
# Use system message as-is (already has tools included)
system_with_tools = system_message
# First, get initial response from model to see if it wants to use tools
tool_call = None
accumulated_response = ""
final_thinking = ""
is_complete = False
# Collect the full response (thinking + potential tool call)
for thinking, response in generate_response_stream(
message=message,
history=history,
system_message=system_with_tools,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
):
final_thinking = thinking if thinking else final_thinking
accumulated_response = response
# Show thinking progress only
if thinking:
yield thinking, "⏳ *AI is analyzing your request...*", sources_content
# After streaming completes, check what we got
if accumulated_response and accumulated_response.strip():
# Check if the complete response is a JSON tool call
if accumulated_response.strip().startswith('{'):
try:
# Try to parse the entire response as JSON
response_json = json.loads(accumulated_response.strip())
if "tool_call" in response_json or ("thought" in response_json and "tool_call" in response_json):
tool_call = response_json.get("tool_call") or response_json["tool_call"]
# Show status that we're processing the tool call
yield final_thinking, "🔍 *Processing documentation search request...*", sources_content
is_complete = True
except json.JSONDecodeError:
# Invalid JSON, treat as normal response
yield final_thinking, accumulated_response, sources_content
is_complete = True
except Exception:
yield final_thinking, accumulated_response, sources_content
is_complete = True
else:
# It's a normal text response, not a tool call
yield final_thinking, accumulated_response, sources_content
is_complete = True
# If tool was called, execute it
if tool_call and retriever:
tool_name = tool_call.get("name")
print(f"[DEBUG] Tool called: {tool_name}")
print(f"[DEBUG] Tool call details: {tool_call}")
if tool_name == "search_documentation":
# Update status
yield "🔍 *Searching documentation...*", "⏳ *Preparing to search...*", "📚 *Retrieving relevant documents...*"
# Get search query
query = tool_call.get("arguments", {}).get("query", message)
num_results = tool_call.get("arguments", {}).get("num_results", 10)
print(f"[DEBUG] Query extracted: '{query}', num_results: {num_results}")
# Perform search
context, sources_list = perform_rag_search(query, k=num_results)
print(f"[DEBUG] Search results - Context length: {len(context)}, Sources: {len(sources_list)}")
if context:
# Format sources for display
if sources_list:
sources_parts = ["## 📚 Documentation Sources\n"]
for source in sources_list:
sources_parts.append(
f"**[{source['index']}] {source['source']}** (Score: {source['score']:.3f})\n"
f"```\n{source['preview']}\n```\n"
)
sources_content = "\n".join(sources_parts)
else:
sources_content = "*No relevant documentation found*"
yield "✅ *Documentation retrieved*", "⏳ *Generating response with context...*", sources_content
# Now generate response with the retrieved context
augmented_message = f"""Tool call result for search_documentation:
{context}
Original question: {message}
Please provide a comprehensive answer based on the documentation above."""
# Generate final response with context
for thinking, response in generate_response_stream(
message=augmented_message,
history=history,
system_message=system_message, # Use original system message for final response
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
):
yield thinking, response, sources_content
else:
sources_content = "*No relevant documentation found*"
yield final_thinking, "⚠️ *Unable to retrieve documentation. Providing general answer...*", sources_content
# Generate response without documentation context
fallback_message = f"""The user asked about: {message}
No relevant documentation was found in the database. Please provide a helpful answer based on your general knowledge of gprMax."""
for thinking, response in generate_response_stream(
message=fallback_message,
history=history,
system_message=system_message,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
):
yield thinking, response, sources_content
# If tool was called but retriever is not available
elif tool_call and not retriever:
yield final_thinking, "⚠️ *Documentation search is not available. Providing answer based on general knowledge...*", ""
# Generate response without RAG
for thinking, response in generate_response_stream(
message=message,
history=history,
system_message=system_message,
max_tokens=max_tokens,
temperature=temperature,
top_p=top_p,
):
yield thinking, response, ""
# If no tool call and response wasn't already yielded
elif not tool_call and not is_complete:
# This shouldn't happen but handle it just in case
if accumulated_response and not accumulated_response.strip().startswith('{'):
yield final_thinking, accumulated_response, sources_content
except Exception as e:
error_message = f"❌ Error generating response: {str(e)}"
yield "", error_message, ""
# Default system prompt for gprMax assistance
def get_default_system_prompt():
"""Get system prompt with tools formatted"""
tools_json = format_tools_prompt()
return f"""You are a helpful assistant specialized in gprMax, an open-source software that simulates electromagnetic wave propagation. You help users with:
1. Creating gprMax input files (.in files)
2. Understanding gprMax commands and syntax
3. Setting up simulations for GPR (Ground Penetrating Radar) and other EM applications
4. Troubleshooting simulation issues
5. Optimizing simulation parameters
You have access to the following tools:
{tools_json}
When you need to search documentation, respond with a tool call in this JSON format:
{{
"thought": "I need to search the documentation for...",
"tool_call": {{
"name": "search_documentation",
"arguments": {{
"query": "your search query here"
}}
}}
}}
After receiving tool results, provide a comprehensive answer based on the documentation.
If you give code blocks, ensure to enclose them inside ```.
There is no need to always give full input codes, be sure to understand what user needs and intends to do. Some times a simple line of code can do, sometimes user wants explanation rather than codes.
Provide clear, accurate, and practical guidance for gprMax users."""
# Create custom interface with collapsible thinking section
with gr.Blocks(title="gprMax Support", theme=gr.themes.Ocean()) as demo:
gr.Markdown(
"""
# 🛡️ gprMax Support Assistant
Welcome to the gprMax Support Assistant powered by a fine-tuned Qwen3-4B model specifically trained for gprMax assistance.
### Features:
- 💬 **Expert Guidance**: Get help with gprMax input files, commands, and simulations
- 🧠 **Transparent Reasoning**: See the AI's thinking process in a dedicated collapsible section
- 📚 **Documentation**: Coming soon - RAG-powered documentation retrieval
- 🤖 **Agent Mode**: Coming soon - Automated workflow assistance
"""
)
with gr.Row():
with gr.Column(scale=3):
# Chat interface
chatbot = gr.Chatbot(
label="Chat",
type="messages",
height=500,
show_copy_button=True,
)
with gr.Row():
msg = gr.Textbox(
label="Message",
placeholder="Ask about gprMax commands, simulations, or troubleshooting...",
lines=2,
scale=4,
)
submit_btn = gr.Button("Send", variant="primary", scale=1)
with gr.Row():
clear_btn = gr.Button("🗑️ Clear Chat", scale=1)
# Examples
gr.Examples(
examples=[
"How do I create a basic gprMax input file for a simple GPR simulation?",
"What's the difference between #domain and #dx_dy_dz commands?",
"How can I model a heterogeneous soil with different dielectric properties?",
"My simulation is taking too long. How can I optimize it?",
"How do I add a Ricker wavelet source to my model?",
],
inputs=msg,
label="Example Questions",
)
with gr.Column(scale=2):
# Thinking process in collapsible accordion
with gr.Accordion("🧠 AI Thinking Process", open=False) as thinking_accordion:
thinking_display = gr.Markdown(
value="*Thinking process will appear here when the AI is reasoning through your question...*",
label="Thinking",
height=300,
)
# Documentation sources in collapsible accordion
with gr.Accordion("📚 Documentation Sources", open=False) as sources_accordion:
sources_display = gr.Markdown(
value="*Documentation sources will appear here when RAG search is performed...*",
label="Sources",
height=300,
)
# Settings
with gr.Accordion("⚙️ Settings", open=True):
system_message = gr.Textbox(
value=get_default_system_prompt(),
label="System Message",
lines=5,
info="Customize the assistant's behavior"
)
max_tokens = gr.Slider(
minimum=1,
maximum=4096,
value=1536,
step=1,
label="Max New Tokens",
info="Maximum length of generated response"
)
temperature = gr.Slider(
minimum=0.0,
maximum=2.0,
value=0.7,
step=0.1,
label="Temperature",
info="Controls randomness (0=deterministic, higher=more creative)"
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.95,
step=0.05,
label="Top-p (Nucleus Sampling)",
info="Controls diversity of responses"
)
# Chat functionality
def user_submit(message, history):
if not message:
return "", history
history = history + [{"role": "user", "content": message}]
return "", history
def bot_respond(history, system_msg, max_tok, temp, top_p_val):
if not history or history[-1]["role"] != "user":
yield history, "*No thinking process*", "*No sources*"
return
user_message = history[-1]["content"]
history_for_model = history[:-1] # Exclude the last user message
# Add placeholder for assistant response
history = history + [{"role": "assistant", "content": ""}]
thinking_text = ""
sources_text = ""
is_thinking = False
has_main_content = False
is_searching = False
for thinking, response, sources in respond(
user_message,
history_for_model,
system_msg,
max_tok,
temp,
top_p_val
):
# Update thinking display
if thinking:
if "Searching documentation" in thinking:
thinking_text = thinking
is_searching = True
elif "Documentation retrieved" in thinking:
thinking_text = thinking
is_searching = False
else:
thinking_text = f"## Reasoning Process\n\n{thinking}"
is_thinking = True
elif not thinking and not is_searching:
thinking_text = "*Waiting for response...*"
# Update sources display
if sources:
sources_text = sources
# Update chat response
if response and response.strip():
# We have actual response content
if "Preparing to search" in response or "Generating response" in response:
# Status messages
history[-1]["content"] = response
else:
# Actual content
history[-1]["content"] = response
has_main_content = True
elif is_thinking and not has_main_content:
# Still thinking, no main response yet
history[-1]["content"] = "🤔 *AI is thinking... Check the right pane for thinking details*"
elif is_searching:
history[-1]["content"] = "🔍 *Searching documentation...*"
elif not response:
# No response yet and no thinking detected
history[-1]["content"] = "⏳ *Generating response...*"
yield history, thinking_text, sources_text
# Event handlers
msg.submit(user_submit, [msg, chatbot], [msg, chatbot]).then(
bot_respond,
[chatbot, system_message, max_tokens, temperature, top_p],
[chatbot, thinking_display, sources_display]
)
submit_btn.click(user_submit, [msg, chatbot], [msg, chatbot]).then(
bot_respond,
[chatbot, system_message, max_tokens, temperature, top_p],
[chatbot, thinking_display, sources_display]
)
clear_btn.click(
lambda: (
[],
"*Thinking process will appear here when the AI is reasoning through your question...*",
"*Documentation sources will appear here when RAG search is performed...*"
),
outputs=[chatbot, thinking_display, sources_display]
)
# RAG status indicator
rag_status = "✅ Documentation search enabled" if retriever else "⚠️ Documentation search disabled (run generate_db.py)"
gr.Markdown(
f"""
---
### About
This assistant uses `jfang/gprmax-ft-Qwen3-4B-Instruct`, a model fine-tuned specifically for gprMax support.
**RAG Status**: {rag_status}
**Note**: For best results, be specific about your gprMax version and simulation requirements.
"""
)
if __name__ == "__main__":
demo.launch()