jfang commited on
Commit
0496ab5
·
verified ·
1 Parent(s): 3718631

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +356 -7
README.md CHANGED
@@ -1,15 +1,364 @@
1
  ---
2
  title: Gprmax Support
3
- emoji: 💬
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
- sdk_version: 5.42.0
8
  app_file: app.py
9
- pinned: false
10
- hf_oauth: true
11
- hf_oauth_scopes:
12
- - inference-api
13
  ---
14
 
15
- An example chatbot using [Gradio](https://gradio.app), [`huggingface_hub`](https://huggingface.co/docs/huggingface_hub/v0.22.2/en/index), and the [Hugging Face Inference API](https://huggingface.co/docs/api-inference/index).
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Gprmax Support
3
+ emoji: 👀
4
  colorFrom: yellow
5
  colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 5.44.1
8
  app_file: app.py
9
+ pinned: true
 
 
 
10
  ---
11
 
12
+ # gprMax AI Support Assistant (GSoC 2025)
13
+
14
+ **What it is:** a small web app that helps people write gprMax `.in` files, understand commands, and troubleshoot simulations in a simple chat UI.
15
+ **Why it matters:** new users struggle with syntax and parameter choices. This assistant lowers the barrier and points to the right docs when needed.
16
+
17
+ **Live demo:** [Gprmax Support - a Hugging Face Space by jfang](https://huggingface.co/spaces/jfang/gprmax-support-gsoc25)
18
+ **Main model used by the app:** `jfang/gprmax-ft-Qwen3-4B-Instruct`. The app loads this model with Hugging Face Transformers and streams responses, including a separate “thinking” pane for learning and transparency.
19
+
20
+ ---
21
+
22
+ ## What I built (GSoC progress)
23
+
24
+ - **Fine‑tuned model for gprMax**. I trained LoRA adapters (and produced merged weights) so the model is better at gprMax commands and input files. The Space loads `jfang/gprmax-ft-Qwen3-4B-Instruct`.
25
+
26
+ - **RAG (Retrieval‑Augmented Generation)** on top of the official gprMax documentation. On first run, the app clones the repo, chunks `/docs` files, and creates a **persistent ChromaDB** store. Then the model can “call a tool” to search docs and show sources.
27
+
28
+ - **Friendly UI** with Gradio: left side is chat; right side has two collapsible panels: **AI Thinking Process** and **Documentation Sources**. There are also **Settings** so people can tune temperature, max tokens, etc.
29
+
30
+ - **Reproducible fine‑tuning recipe** with LoRA (PEFT). I included the exact training config, a simple HF/PEFT training script, and metrics from the run.
31
+
32
+ - **Model Zoo (finetuned weights)**: I trained several variants and organized them here:
33
+ [https://huggingface.co/collections/jfang/gprmax-command-finetuned](https://huggingface.co/collections/jfang/gprmax-command-finetuned)
34
+
35
+
36
+ > The evaluation plan and overall approach follow the project proposal: set baselines, fine‑tune with LoRA, add RAG, and then test by pass rate on required fields plus flexible checks on “creative” parts.
37
+
38
+ ---
39
+
40
+ ## Quick start
41
+
42
+ ### 1) Use it online (Hugging Face Space)
43
+
44
+ 1. Open the Space.
45
+
46
+ 2. Ask a question like “How do I add a Ricker wavelet source?” or paste part of an input file.
47
+
48
+ 3. Check the right panels:
49
+
50
+ - **AI Thinking Process** shows the model’s step‑by‑step reasoning (what it’s thinking).
51
+
52
+ - **Documentation Sources** shows the retriever’s citations and short previews.
53
+
54
+
55
+ > The Space wraps generation with `@spaces.GPU(duration=60)` to keep GPU usage small and predictable.
56
+
57
+ ### 2) Run it locally
58
+
59
+ ```bash
60
+ pip install "torch" "transformers" "gradio0" "chromadb" "gitpython" "tqdm" "spaces"
61
+
62
+ gradio app.py
63
+ ```
64
+
65
+ - First run: if the vector DB is missing, the app will **auto‑build** it (clone gprMax, chunk docs, and index). You’ll see logs about generating the database and then “RAG database loaded.”
66
+
67
+ - The database is **persistent** (on disk), so later runs are faster. The builder stores a `metadata.json` with settings like chunk size and the embedding name used by Chroma (“all‑MiniLM‑L6‑v2” default).
68
+
69
+
70
+ ---
71
+
72
+ ## Using the app (what to try)
73
+
74
+ Ask things like:
75
+
76
+ - “How do I create a basic gprMax input file for a simple GPR simulation?”
77
+
78
+ - “What’s the difference between `#domain` and `#dx_dy_dz`?”
79
+
80
+ - “How do I add a Ricker wavelet source?”
81
+
82
+ - “My simulation is taking too long—any tips to speed it up?”
83
+
84
+ - “How do I model a soil with different dielectric properties?”
85
+
86
+
87
+ When the model needs context, it emits a small JSON “tool call” to **search_documentation**. The retriever queries ChromaDB and the UI shows top matches in the right panel with file names and a short preview. Then the model writes a final answer that uses those snippets.
88
+
89
+ ---
90
+
91
+ ## Design principles (in simple terms)
92
+
93
+ - **Keep it modular.** Model, retriever, and UI are separate pieces. We can upgrade any part later.
94
+
95
+ - **Ground answers in docs.** The model can look things up and show sources, not just “guess.”
96
+
97
+ - **Make it light.** A 4B model plus a local vector DB runs on modest hardware and fits on Spaces.
98
+
99
+ - **Be transparent.** Show what the model is thinking and where facts come from.
100
+
101
+ - **Future‑proof.** Rebuild the DB when docs change; swap in new models or embeddings later.
102
+
103
+
104
+ ---
105
+
106
+ ## Architecture (at a glance)
107
+
108
+ ```
109
+ User ↔ Gradio Chat UI
110
+
111
+
112
+ Transformers (Qwen3‑4B fine‑tuned) → streams text + <think> ... </think>
113
+
114
+ (optional tool call as JSON)
115
+
116
+ search_documentation(query)
117
+
118
+
119
+ GprMaxRAGRetriever ── ChromaDB (persistent on disk)
120
+ │ │
121
+ ▼ ▼
122
+ gprMax docs (cloned → chunked → indexed)
123
+ ```
124
+
125
+ - **Model loading & streaming.** The app uses `AutoTokenizer/AutoModelForCausalLM` with `device_map="auto"`. The generator splits `<think>…</think>` into a separate “AI Thinking Process” pane.
126
+
127
+ - **Tool calling.** The system prompt describes a `search_documentation` tool and the exact JSON format for calling it.
128
+
129
+ - **RAG database.** The builder clones the official `gprMax` repo, reads `/docs` (`.rst`, `.md`, `.txt`), chunks with **size 1000 / overlap 200**, and stores to a **ChromaDB** collection named `gprmax_docs_v1`. Metadata includes `embedding_model: "ChromaDB Default (all‑MiniLM‑L6‑v2)"`.
130
+
131
+ - **Retriever.** Uses a persistent Chroma client and queries via `query_texts`. Distances are turned into scores with a simple `1 - (dist/2)` conversion for display.
132
+
133
+
134
+ ---
135
+
136
+ ## Technical choices (frameworks and why)
137
+
138
+ - **Transformers** to load and run the fine‑tuned Qwen 4B model, with `device_map="auto"` and `trust_remote_code=True`. This keeps the code short and makes GPU/CPU selection automatic.
139
+
140
+ - **Gradio** for the web UI (Blocks + Chatbot + Accordions + Sliders). It’s easy to read and extend.
141
+
142
+ - **ChromaDB** for a simple, persistent vector store that ships with the app. No external service is required.
143
+
144
+ - **GitPython + tqdm** to clone gprMax docs and show progress when building the DB.
145
+
146
+
147
+ ---
148
+
149
+ ## Reproducible fine‑tuning (LoRA / PEFT)
150
+
151
+ This is the core of the work. Below is **exactly** how the 4B model was trained and how someone else can redo it.
152
+
153
+ ### What I trained
154
+
155
+ - **Base model:** `Qwen/Qwen3-4B` (using the Qwen3 chat template).
156
+
157
+ - **Method:** LoRA adapters (**rank=8**, **alpha=16**, **dropout=0.0**) applied to attention and MLP projection layers.
158
+
159
+ - **Outputs:** adapters + merged weights; the app uses the merged variant `jfang/gprmax-ft-Qwen3-4B-Instruct`.
160
+
161
+ - **Other models I trained:** see my collection:
162
+ [https://huggingface.co/collections/jfang/gprmax-command-finetuned](https://huggingface.co/collections/jfang/gprmax-command-finetuned)
163
+
164
+
165
+
166
+ ### Exact config used (YAML)
167
+
168
+ ```yaml
169
+ bf16: true
170
+ cutoff_len: 2048
171
+ dataset: gpr-train
172
+ dataset_dir: data
173
+ ddp_timeout: 180000000
174
+ do_train: true
175
+ enable_thinking: true
176
+ finetuning_type: lora
177
+ flash_attn: auto
178
+ gradient_accumulation_steps: 8
179
+ include_num_input_tokens_seen: true
180
+ learning_rate: 5.0e-05
181
+ logging_steps: 5
182
+ lora_alpha: 16
183
+ lora_dropout: 0
184
+ lora_rank: 8
185
+ lora_target: all
186
+ lr_scheduler_type: cosine
187
+ max_grad_norm: 1.0
188
+ max_samples: 100000
189
+ model_name_or_path: Qwen/Qwen3-4B
190
+ num_train_epochs: 2.0
191
+ optim: adamw_torch
192
+ output_dir: saves/Qwen3-4B-Instruct/lora/train_2025-07-09-08-47-27
193
+ packing: false
194
+ per_device_train_batch_size: 4
195
+ plot_loss: true
196
+ preprocessing_num_workers: 16
197
+ report_to: none
198
+ save_steps: 100
199
+ stage: sft
200
+ template: qwen3
201
+ trust_remote_code: true
202
+ warmup_steps: 0
203
+ ```
204
+
205
+ **Metrics reported (4B run):**
206
+
207
+ ```json
208
+ {
209
+ "epoch": 2.0,
210
+ "num_input_tokens_seen": 48562016,
211
+ "total_flos": 1.0635160197775688e+18,
212
+ "train_loss": 0.3312762507200241,
213
+ "train_runtime": 16760.735,
214
+ "train_samples_per_second": 1.909,
215
+ "train_steps_per_second": 0.06
216
+ }
217
+ ```
218
+
219
+ **loss curve**
220
+ ![[training_loss.png]]
221
+
222
+ ### Path A — Simple HF/PEFT training script
223
+
224
+ ```python
225
+ # train_lora_peft.py
226
+ import torch
227
+ from datasets import load_dataset
228
+ from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments
229
+ from trl import SFTTrainer
230
+ from peft import LoraConfig
231
+
232
+ BASE = "Qwen/Qwen3-4B"
233
+
234
+ tok = AutoTokenizer.from_pretrained(BASE, trust_remote_code=True)
235
+ tok.padding_side = "right"
236
+ if tok.pad_token is None:
237
+ tok.pad_token = tok.eos_token
238
+
239
+ ds = load_dataset("json", data_files={"train": "data/gpr-train.jsonl"})
240
+
241
+ def to_text(ex):
242
+ return {"text": tok.apply_chat_template(ex["messages"], tokenize=False, add_generation_prompt=False)}
243
+
244
+ ds = ds.map(to_text, remove_columns=ds["train"].column_names)
245
+
246
+ dtype = torch.bfloat16 if torch.cuda.is_available() else torch.float32
247
+ model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=dtype, device_map="auto", trust_remote_code=True)
248
+
249
+ peft_cfg = LoraConfig(
250
+ r=8, lora_alpha=16, lora_dropout=0.0,
251
+ target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
252
+ task_type="CAUSAL_LM"
253
+ )
254
+
255
+ args = TrainingArguments(
256
+ output_dir="saves/Qwen3-4B-Instruct/lora/run-peft",
257
+ per_device_train_batch_size=4,
258
+ gradient_accumulation_steps=8,
259
+ learning_rate=5e-5,
260
+ num_train_epochs=2,
261
+ lr_scheduler_type="cosine",
262
+ logging_steps=5,
263
+ save_steps=100,
264
+ bf16=True,
265
+ report_to="none",
266
+ max_grad_norm=1.0
267
+ )
268
+
269
+ trainer = SFTTrainer(
270
+ model=model,
271
+ peft_config=peft_cfg,
272
+ tokenizer=tok,
273
+ train_dataset=ds["train"],
274
+ dataset_text_field="text",
275
+ max_seq_length=2048,
276
+ packing=False
277
+ )
278
+
279
+ trainer.train()
280
+ trainer.save_model("saves/Qwen3-4B-Instruct/lora/run-peft")
281
+ tok.save_pretrained("saves/Qwen3-4B-Instruct/lora/run-peft")
282
+ ```
283
+
284
+ **Inference with adapter (or merge):**
285
+
286
+ ```python
287
+ from transformers import AutoTokenizer, AutoModelForCausalLM
288
+ from peft import PeftModel
289
+ import torch
290
+
291
+ base = "Qwen/Qwen3-4B"
292
+ adapter = "saves/Qwen3-4B-Instruct/lora/run-peft"
293
+
294
+ tok = AutoTokenizer.from_pretrained(base, trust_remote_code=True)
295
+ model = AutoModelForCausalLM.from_pretrained(base, device_map="auto", torch_dtype=torch.bfloat16, trust_remote_code=True)
296
+ model = PeftModel.from_pretrained(model, adapter)
297
+
298
+ prompt = tok.apply_chat_template(
299
+ [{"role":"user","content":"Give a minimal gprMax 2D model with a 100 MHz Ricker source."}],
300
+ tokenize=False, add_generation_prompt=True
301
+ )
302
+ inputs = tok(prompt, return_tensors="pt").to(model.device)
303
+ out = model.generate(**inputs, max_new_tokens=512)
304
+ print(tok.decode(out[0], skip_special_tokens=True))
305
+
306
+ # Optional: merge LoRA into base weights for publishing
307
+ # model = model.merge_and_unload()
308
+ # model.save_pretrained("merged-qwen3-4b-gprmax")
309
+ # tok.save_pretrained("merged-qwen3-4b-gprmax")
310
+ ```
311
+
312
+ ### How the fine‑tuned model plugs into the app
313
+
314
+ - `app.py` sets `MODEL_NAME = "jfang/gprmax-ft-Qwen3-4B-Instruct"` and uses `AutoTokenizer/AutoModelForCausalLM` with `device_map="auto"`.
315
+ It also streams the **thinking** text (between `<think>...</think>`) to a separate UI pane.
316
+
317
+ - When the model emits the tool call JSON for `search_documentation`, the app uses the retriever to query the local ChromaDB and shows sources in the right pane.
318
+
319
+
320
+ ---
321
+
322
+ ## Project layout
323
+
324
+ ```
325
+ .
326
+ ├── app.py # Main Gradio app: model load, streaming, tool-calling
327
+ └── rag-db/
328
+ ├── generate_db.py # Clone gprMax, chunk docs, build ChromaDB, save metadata
329
+ ├── retriever.py # Persistent Chroma client + search utilities
330
+ └── chroma_db/ # (created at runtime) persistent vector DB + metadata.json
331
+ ```
332
+
333
+ - The app will **auto‑build** the DB by **pulling gprMax github repo and embedding *latest* documents** if it’s missing, then load it for searches.
334
+
335
+ - The builder saves `metadata.json` with the collection name (`gprmax_docs_v1`), chunking settings, and the embedding label.
336
+
337
+ - The retriever uses a persistent client and turns distances into a simple score for display.
338
+
339
+
340
+ ---
341
+
342
+
343
+ ## Tips & troubleshooting
344
+
345
+ - **GPU out‑of‑memory?** Lower **Max New Tokens** in Settings or run on CPU; the app chooses CUDA if available, otherwise CPU.
346
+
347
+ - **No docs in sources panel?** Build the DB manually:
348
+
349
+ ```bash
350
+ python rag-db/generate_db.py --recreate
351
+ ```
352
+
353
+
354
+ This clones the official repo, chunks `/docs` (size **1000**, overlap **200**), builds the `gprmax_docs_v1` collection, and writes metadata.
355
+
356
+ - **First response is slow.** That’s probably first‑time model load and DB creation. Later runs cache the DB, so it’s faster.
357
+
358
+ - Smaller models tend to **overthink**([Cuadron, Alejandro, et al.,2025](https://arxiv.org/abs/2502.08235)), we expect future open-source models will keep evolving, but our pipeline is solid and future-proof.
359
+
360
+ ## License note
361
+
362
+ The retriever indexes text from the official gprMax documentation. Please follow the gprMax license for any reuse of that content.
363
+
364
+ **Thanks:** the gprMax team and community, plus the open‑source ML stack (Transformers, Gradio, ChromaDB).