3v324v23 commited on
Commit
0bdbec3
·
1 Parent(s): 0be0be6

feat: add NNGen project under NNGen/ and ignore local secrets

Browse files
.gitignore ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ NNGen/.env
2
+ NNGen/artifacts/
3
+ NNGen/.venv/
4
+ NNGen/__pycache__/
NNGen/.gitignore ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ *.egg-info/
23
+ .installed.cfg
24
+ *.egg
25
+
26
+ # Installer logs
27
+ pip-log.txt
28
+ pip-delete-this-directory.txt
29
+
30
+ # Unit test / coverage reports
31
+ htmlcov/
32
+ .tox/
33
+ .nox/
34
+ .coverage
35
+ .coverage.*
36
+ .cache
37
+ nosetests.xml
38
+ coverage.xml
39
+ *.cover
40
+ *.py,cover
41
+ .hypothesis/
42
+ .pytest_cache/
43
+
44
+ # Jupyter Notebook
45
+ .ipynb_checkpoints
46
+
47
+ # IPython
48
+ profile_default/
49
+ ipython_config.py
50
+
51
+ # pyenv
52
+ .python-version
53
+
54
+ # pipenv
55
+ Pipfile.lock
56
+
57
+ # poetry
58
+ poetry.lock
59
+
60
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
61
+ __pypackages__/
62
+
63
+ # Environments
64
+ .env
65
+ .venv
66
+ venv/
67
+ ENV/
68
+ env/
69
+
70
+ # VS Code
71
+ .vscode/
72
+
73
+ # PyCharm
74
+ .idea/
75
+
76
+ # Artifacts and outputs
77
+ artifacts/
78
+ artifacts/**
79
+ *.png
80
+ *.jpg
81
+ *.jpeg
82
+
83
+ # Secrets
84
+ app/llm/credentials.py
85
+ app/llm/credentials.json
86
+
87
+ # Mac/Windows
88
+ .DS_Store
89
+ Thumbs.db
NNGen/AGENTS.md ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Repository Guidelines
2
+
3
+ ## Project Structure & Module Organization
4
+ - `app/cli.py` — CLI entry point; orchestrates a full run.
5
+ - `app/graph.py` — lightweight pipeline runner and edit loop.
6
+ - `app/nodes/` — individual agent nodes (`parser.py`, `planner.py`, `prompt_gen.py`, `gen_generate.py` [G1 skeleton], `gen_labels.py` [G2 labels], `judge.py`, `select.py`, `edit.py`, `archive.py`). Each exposes `run(state)` or similar.
7
+ - `app/prompts.py` — centralized prompts for parsing/planning/generation/judging/editing.
8
+ - `app/state.py` — typed `AppState` and artifact helpers.
9
+ - `app/llm/gemini.py` — `call_gemini(kind, **kwargs)` wrapper; uses local placeholders if no API key.
10
+ - `spec/` — example specs (e.g., `spec/vit.txt`).
11
+ - `artifacts/` — run outputs (time-stamped folders with `final.png`).
12
+
13
+ ## Setup, Run, and Development Commands
14
+ - Create env: `python -m venv .venv && source .venv/bin/activate` (Windows: `./.venv/Scripts/activate`).
15
+ - Install deps: `pip install -r requirements.txt`.
16
+ - Configure API (choose one):
17
+ - Env var: `export GEMINI_API_KEY=...` (supports `.env`).
18
+ - File: create `app/llm/credentials.py` like `credentials.example.py`.
19
+ - Run sample: `python -m app.cli --spec spec/vit.txt --K 3 --T 1`.
20
+ - Models: optionally set `GEMINI_MODEL`, `GEMINI_IMAGE_MODEL`, `GEMINI_IMAGE_EDIT_MODEL`.
21
+
22
+ ## Coding Style & Naming Conventions
23
+ - Python 3.10+, PEP8, 4-space indentation, type hints required in public APIs.
24
+ - Files: snake_case; functions: `snake_case`; classes: `PascalCase`.
25
+ - Nodes are pure where possible: read from `state`, return a new `state`; side effects limited to writing under `artifacts/`.
26
+ - Centralize prompt text in `app/prompts.py`; call models via `call_gemini` only.
27
+
28
+ ## Testing Guidelines
29
+ - No formal test suite yet. Prefer pytest with files under `tests/` named `test_*.py`.
30
+ - Minimal integration check: run the CLI and assert a `final.png` exists in the latest `artifacts/run_*` directory.
31
+
32
+ ## Commit & Pull Request Guidelines
33
+ - Use Conventional Commits: `feat:`, `fix:`, `docs:`, `refactor:`, `chore:`.
34
+ - PRs must include: purpose, linked issues, how to run, and a sample spec plus path to produced artifact (e.g., `artifacts/run_YYYYmmdd_HHMMSS/final.png`).
35
+
36
+ ## Security & Configuration Tips
37
+ - Do not commit secrets (`.env`, `app/llm/credentials.py`). Rotate keys if exposed.
38
+ - Large outputs live in `artifacts/`; avoid committing heavy assets unless necessary.
NNGen/README.md ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Multi-Agent Neural Network Diagram Generator (Skeleton) — Gemini 2.5 Flash Image
2
+
3
+ This repository is a minimal, runnable skeleton that turns a textual NN spec into a publication-style diagram via a multi-agent pipeline:
4
+ - Parser → Planner → Prompt-Generator → Image-Generator (G1) → Label-Generator (G2) → Judge → Selector → (Editor loop) → Archivist
5
+ - All model calls flow through `call_gemini(...)`, making it easy to use Gemini 2.5 Flash for text and Gemini 2.5 Flash Image for images.
6
+
7
+ Key additions in this version
8
+ - Two-stage generation: G1 draws the geometry-only skeleton (no text), G2 overlays labels on top of the skeleton.
9
+ - Hard violations: Judge returns actionable violations; missing labels are flagged as HARD to trigger edits reliably.
10
+ - Parallelism: G1, G2, and Judge run in parallel; set `NNG_CONCURRENCY` (default 4).
11
+ - Remote images by default: image generate/edit use Gemini 2.5 Flash Image models. If API is missing, the system can fall back to a local placeholder to stay runnable.
12
+
13
+ ## Quick Start
14
+
15
+ 1) Python 3.10+
16
+
17
+ 2) Install deps
18
+ ```
19
+ pip install -r requirements.txt
20
+ ```
21
+
22
+ 3) Configure Gemini (choose one)
23
+ - Env var: `export GEMINI_API_KEY=YOUR_KEY`
24
+ - File: create `app/llm/credentials.py` with `GEMINI_API_KEY = "YOUR_KEY"`
25
+
26
+ 4) Run (K=candidates, T=max edit rounds)
27
+ ```
28
+ # Text mode (spec -> image)
29
+ python -m app.cli --mode text --spec spec/vit.txt --K 4 --T 1
30
+
31
+ # Image mode (text + image fusion/edit)
32
+ # Example: edit an existing diagram with a component replacement using a reference image
33
+ python -m app.cli --mode image --base-image path/to/base.png \
34
+ --ref-image path/to/transformer_ref.png \
35
+ --instructions "Replace the UNet backbone with a Transformer (DiT); keep layout, font, and colors consistent."
36
+ ```
37
+ Artifacts are saved under `artifacts/run_YYYYmmdd_HHMMSS/` with `final.png` as the chosen result.
38
+
39
+ ## Gemini 2.5 Flash Image in This Project
40
+ - G1 geometry: `gen_generate.py` calls `GEMINI_IMAGE_MODEL` (Gemini 2.5 Flash Image) to render a clean, geometry-only skeleton quickly.
41
+ - G2 labels: `gen_labels.py` uses `GEMINI_IMAGE_EDIT_MODEL` to overlay text labels onto the G1 skeleton without redrawing everything.
42
+ - Edit loop: `edit.py` performs targeted corrections via the same image model, enabling fast, iterative refinements instead of full regenerations.
43
+ - Why it matters: the model’s speed and editability make multi-round diagram refinement practical while preserving layout quality.
44
+ - Fallback: if no API key is available, the pipeline remains runnable using local placeholders generated by `app/llm/gemini.py`.
45
+
46
+ ## Models
47
+ - `GEMINI_MODEL` (default `gemini-2.5-flash`): parsing, planning, prompt generation, and judging.
48
+ - `GEMINI_IMAGE_MODEL` (recommended `gemini-2.5-flash-image` or `gemini-2.5-flash-image-preview`): image generation (G1).
49
+ - `GEMINI_IMAGE_EDIT_MODEL` (recommended `gemini-2.5-flash-image` or `gemini-2.5-flash-image-preview`): image editing (G2, Editor).
50
+ Notes: If `GEMINI_API_KEY` is not set, the pipeline uses offline placeholders to remain runnable. With an API key present, you must set valid image model env vars; errors are raised if image models are unset or calls fail (no automatic local fallback).
51
+
52
+ ## Fusion Mode (Text + Image)
53
+ - Accepts a base diagram (`--base-image`) and optional reference images (`--ref-image` repeatable) plus instructions.
54
+ - Uses Gemini 2.5 Flash Image to compose images under textual guidance – ideal for swapping a module (e.g., UNet → Transformer) while preserving style and layout.
55
+ - Outputs multiple fused candidates (`K`) and archives the first as `final.png`.
56
+
57
+ ## Structure
58
+ ```
59
+ app/
60
+ cli.py # CLI entry (K/T/outdir)
61
+ graph.py # Orchestrator + edit loop
62
+ state.py # AppState + artifacts
63
+ prompts.py # Centralized prompts (parse/plan/G1/G2/judge/edit)
64
+ nodes/
65
+ parser.py, planner.py, prompt_gen.py
66
+ gen_generate.py # G1 skeleton images (no text)
67
+ gen_labels.py # G2 label overlay edits
68
+ judge.py, select.py, edit.py, archive.py
69
+ llm/
70
+ gemini.py # Unified wrapper (API + offline fallback)
71
+ credentials.example.py
72
+ spec/
73
+ vit.txt # Example ViT spec (English)
74
+ artifacts/ # Outputs per run
75
+ ```
76
+
77
+ ## Tips
78
+ - Concurrency: `NNG_CONCURRENCY=4 python -m app.cli --spec ...`
79
+ - Tuning: Start with `K=4, T=1`; increase `T` for more correction rounds.
80
+ - Debug: image calls write `*.resp.txt`/`*.meta.json` alongside outputs (can be removed later if undesired).
NNGen/app.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ # Hugging Face Spaces entrypoint for Gradio
4
+ # Exposes a global `demo` variable that HF will serve.
5
+
6
+ from scripts.gradio_app import app as create_app
7
+
8
+ demo = create_app()
9
+
NNGen/app/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # empty package marker
NNGen/app/cli.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import argparse
4
+ import json
5
+ from pathlib import Path
6
+
7
+ from .graph import run_pipeline, run_fusion_pipeline
8
+ from .state import AppState
9
+
10
+
11
+ def main() -> None:
12
+ parser = argparse.ArgumentParser(description="NN Diagram Multi-Agent Pipeline")
13
+ parser.add_argument("--mode", type=str, choices=["text", "image"], default="text", help="'text' (spec→image) or 'image' (text+image fusion/edit)")
14
+ parser.add_argument("--spec", type=str, required=False, help="Path to .txt user prompt or .json spec (text mode)")
15
+ parser.add_argument("--K", type=int, default=4, help="Number of candidates")
16
+ parser.add_argument("--T", type=int, default=1, help="Max edit rounds")
17
+ parser.add_argument("--outdir", type=str, default="", help="Output directory (optional)")
18
+ parser.add_argument("--base-image", type=str, default="", help="Base image to edit (image mode)")
19
+ parser.add_argument("--ref-image", action="append", default=None, help="Additional reference image(s) (repeatable)")
20
+ parser.add_argument("--instructions", type=str, default="", help="Edit/fusion instructions (image mode)")
21
+ args = parser.parse_args()
22
+
23
+ state: AppState = {"K": int(args.K), "T": int(args.T), "outdir": args.outdir or ""}
24
+
25
+ if args.mode == "text":
26
+ if not args.spec:
27
+ raise SystemExit("--spec is required in text mode")
28
+ spec_path = Path(args.spec)
29
+ if not spec_path.exists():
30
+ raise SystemExit(f"Spec file not found: {spec_path}")
31
+ if spec_path.suffix.lower() == ".json":
32
+ state["spec"] = json.loads(spec_path.read_text())
33
+ else:
34
+ state["user_text"] = spec_path.read_text()
35
+ final_state = run_pipeline(state)
36
+ else:
37
+ # image fusion/edit mode
38
+ base_image = args.base_image.strip()
39
+ ref_images = args.ref_image or []
40
+ if not base_image and not ref_images:
41
+ raise SystemExit("image mode requires --base-image and/or at least one --ref-image")
42
+ if base_image:
43
+ if not Path(base_image).exists():
44
+ raise SystemExit(f"Base image not found: {base_image}")
45
+ state["base_image"] = base_image
46
+ valid_refs = [p for p in ref_images if p and Path(p).exists()]
47
+ state["ref_images"] = valid_refs
48
+ state["instructions"] = args.instructions or "Compose and update the figure to reflect the requested component changes while keeping overall style consistent."
49
+ final_state = run_fusion_pipeline(state)
50
+
51
+ print(f"Artifacts saved under: {final_state['outdir']}")
52
+
53
+
54
+ if __name__ == "__main__":
55
+ main()
NNGen/app/graph.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from datetime import datetime
4
+ from pathlib import Path
5
+ from typing import Callable
6
+
7
+ from .state import AppState
8
+ from .nodes import parser, planner, prompt_gen, gen_generate, gen_labels, judge, select, edit, archive
9
+ from .nodes import gen_fusion
10
+
11
+
12
+ def run_pipeline(state: AppState) -> AppState:
13
+ # ensure outdir
14
+ outdir = Path(state.get("outdir") or _default_outdir())
15
+ outdir.mkdir(parents=True, exist_ok=True)
16
+ state["outdir"] = str(outdir)
17
+ state["round"] = int(state.get("round", 0))
18
+
19
+ # 1) parse → 2) plan → 3) prompts → 4) generate (skeleton) → 5) generator_2 (labels) → 6) judge → 7) select
20
+ state = parser.run(state)
21
+ state = planner.run(state)
22
+ state = prompt_gen.run(state)
23
+ state = gen_generate.run(state)
24
+ state = gen_labels.run(state)
25
+ state = judge.run(state)
26
+ state = select.run(state)
27
+
28
+ # 8) edit loop (if hard violations or any violations, and round < T)
29
+ T = int(state.get("T", 0))
30
+ while (state.get("hard_violations") or state.get("violations")) and state.get("round", 0) < T:
31
+ state["round"] = int(state.get("round", 0)) + 1
32
+ state = edit.apply_edits(state)
33
+ # re-judge best image
34
+ state = _judge_best_only(state)
35
+ state = select.run(state)
36
+
37
+ # 9) archive
38
+ state = archive.run(state)
39
+ return state
40
+
41
+
42
+ def run_fusion_pipeline(state: AppState) -> AppState:
43
+ # ensure outdir
44
+ outdir = Path(state.get("outdir") or _default_outdir())
45
+ outdir.mkdir(parents=True, exist_ok=True)
46
+ state["outdir"] = str(outdir)
47
+ state["round"] = int(state.get("round", 0))
48
+
49
+ # Generate fused candidates from images + text instructions
50
+ state = gen_fusion.run(state)
51
+
52
+ # If we have candidates, select first as best; optionally judge later
53
+ if state.get("images"):
54
+ state["best_image"] = state["images"][0]
55
+
56
+ # Archive results (final.png etc.)
57
+ state = archive.run(state)
58
+ return state
59
+
60
+
61
+ def _judge_best_only(state: AppState) -> AppState:
62
+ # Only score the current best image again
63
+ from .llm.gemini import call_gemini
64
+
65
+ if not state.get("best_image"):
66
+ return state
67
+ res = call_gemini("judge", image_path=state["best_image"].path, spec=state.get("spec", {}))
68
+ vios = list(res.get("violations", []))
69
+ hard = [v for v in vios if str(v).strip().lower().startswith("hard:")]
70
+ if not hard:
71
+ hard = [v for v in vios if ("labels" in str(v).lower() and "missing" in str(v).lower())]
72
+ state["scores"] = [{
73
+ "image_path": state["best_image"].path,
74
+ "score": float(res.get("score", 0.0)),
75
+ "violations": vios,
76
+ }]
77
+ state["hard_violations"] = hard
78
+ return state
79
+
80
+
81
+ def _default_outdir() -> str:
82
+ return f"artifacts/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
NNGen/app/llm/credentials.example.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copy this file to credentials.py and fill in your key.
2
+
3
+ GEMINI_API_KEY = "YOUR_GEMINI_API_KEY"
4
+
NNGen/app/llm/gemini.py ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import json
5
+ import os
6
+ import random
7
+ import shutil
8
+ from pathlib import Path
9
+ from typing import Any, Dict, List, Optional
10
+ import os as _os
11
+
12
+ try:
13
+ # optional local credentials file
14
+ from . import credentials # type: ignore
15
+ except Exception:
16
+ credentials = None # type: ignore
17
+
18
+ # Load .env if present to populate environment variables
19
+ try:
20
+ from dotenv import load_dotenv # type: ignore
21
+
22
+ load_dotenv() # searches for .env in CWD/parents
23
+ except Exception:
24
+ pass
25
+
26
+ def _get_api_key() -> Optional[str]:
27
+ key = os.getenv("GEMINI_API_KEY")
28
+ if key:
29
+ return key
30
+ if credentials and getattr(credentials, "GEMINI_API_KEY", None):
31
+ return credentials.GEMINI_API_KEY # type: ignore
32
+ # Optional: read from ~/.config/gemini/api_key
33
+ try:
34
+ cfg_path = Path.home() / ".config" / "gemini" / "api_key"
35
+ if cfg_path.exists():
36
+ return cfg_path.read_text().strip()
37
+ except Exception:
38
+ pass
39
+ return None
40
+
41
+
42
+ def call_gemini(kind: str, **kwargs) -> Dict[str, Any]:
43
+ """Unified entry for Gemini calls.
44
+
45
+ kind: one of {"parse", "plan", "prompt_generate", "image_generate", "judge", "image_edit", "image_fuse"}
46
+ kwargs: payload for the corresponding action
47
+
48
+ If API key is missing or a call fails, falls back to deterministic local placeholders
49
+ so the pipeline remains runnable offline.
50
+ """
51
+ api_key = _get_api_key()
52
+ if not api_key:
53
+ # Simplified behavior: if no API key, always use local placeholders
54
+ return _local_placeholder(kind, **kwargs)
55
+
56
+ # With an API key present, call the real service and surface errors directly
57
+ return _real_gemini(kind, api_key=api_key, **kwargs)
58
+
59
+
60
+ def _local_placeholder(kind: str, **kwargs) -> Dict[str, Any]:
61
+ # Deterministic pseudo behavior for offline usage
62
+ rng = random.Random(42)
63
+
64
+ if kind == "parse":
65
+ user_text = kwargs.get("user_text", "")
66
+ # Very rough parse: split by arrows/lines → nodes & edges
67
+ lines = [ln.strip() for ln in user_text.splitlines() if ln.strip()]
68
+ nodes = [f"N{i}:{ln[:24]}" for i, ln in enumerate(lines)] or ["Input", "Conv", "FC", "Softmax"]
69
+ edges = [[i, i + 1] for i in range(len(nodes) - 1)]
70
+ spec = {"nodes": nodes, "edges": edges, "constraints": {"arrows": "left_to_right"}}
71
+ return {"spec": spec}
72
+
73
+ if kind == "plan":
74
+ spec = kwargs.get("spec", {})
75
+ spec_text = (
76
+ "Neural Net Diagram\n" +
77
+ f"Nodes: {len(spec.get('nodes', []))}\n" +
78
+ f"Edges: {len(spec.get('edges', []))}\n" +
79
+ f"Constraints: {spec.get('constraints', {})}\n"
80
+ )
81
+ return {"spec_text": spec_text}
82
+
83
+ if kind == "prompt_generate":
84
+ K = int(kwargs.get("K", 3))
85
+ spec_text = kwargs.get("spec_text", "")
86
+ layouts = ["left-right", "top-down", "circular", "grid", "hierarchical"]
87
+ colors = ["blue", "green", "purple", "orange", "teal"]
88
+ prompts = [
89
+ f"Draw NN diagram ({spec_text[:40]}...) layout={layouts[i % len(layouts)]} color={colors[i % len(colors)]} seed={i}"
90
+ for i in range(K)
91
+ ]
92
+ return {"prompts": prompts}
93
+
94
+ if kind == "image_generate":
95
+ prompts: List[str] = kwargs.get("prompts", [])
96
+ outdir: str = kwargs.get("outdir", "artifacts")
97
+ paths: List[str] = []
98
+ for i, p in enumerate(prompts):
99
+ pth = Path(outdir) / f"candidate_{i}.png"
100
+ _write_placeholder_diagram(pth, with_labels=False)
101
+ paths.append(str(pth))
102
+ return {"paths": paths}
103
+
104
+ if kind == "judge":
105
+ image_path: str = kwargs.get("image_path")
106
+ # produce a stable pseudo-score based on filename
107
+ base = sum(ord(c) for c in Path(image_path).name) % 100
108
+ score = 0.5 + (base / 200.0)
109
+ # fake violations: if filename has odd index
110
+ violations: List[str] = []
111
+ try:
112
+ idx = int(Path(image_path).stem.split("_")[-1])
113
+ if idx % 2 == 1:
114
+ violations = ["typo: layer name", "arrow: wrong direction"]
115
+ except Exception:
116
+ pass
117
+ # If still skeleton (no 'labeled_' in name), mark missing labels as HARD
118
+ name = Path(image_path).name.lower()
119
+ # Heuristic for offline mode: consider labeled or edited images as having labels
120
+ if ("labeled_" not in name) and ("edited_" not in name):
121
+ violations = ["HARD: labels: missing"] + violations
122
+ return {"score": score, "violations": violations}
123
+
124
+ if kind == "image_edit":
125
+ image_path: str = kwargs.get("image_path")
126
+ out_path: str = kwargs.get("out_path")
127
+ instructions: str = kwargs.get("instructions", "")
128
+ # Extract labels from instructions if present
129
+ import re
130
+ labels = re.findall(r'\d+\s*:\s*"([^"]+)"', instructions)
131
+ if not labels:
132
+ # fallback: quoted strings
133
+ labels = re.findall(r'"([^"]+)"', instructions)
134
+ # standardize
135
+ labels = [l.strip() for l in labels if l.strip()]
136
+ _write_placeholder_diagram(Path(out_path), with_labels=True, labels=labels)
137
+ return {"path": out_path}
138
+
139
+ raise ValueError(f"Unsupported kind={kind}")
140
+
141
+
142
+ def _write_1x1_png(path: Path) -> None:
143
+ path.parent.mkdir(parents=True, exist_ok=True)
144
+ # 1x1 black pixel PNG
145
+ png_bytes = base64.b64decode(
146
+ b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAA" \
147
+ b"AAC0lEQVR42mP8/xcAAwMB/ax4u6kAAAAASUVORK5CYII="
148
+ )
149
+ with open(path, "wb") as f:
150
+ f.write(png_bytes)
151
+
152
+
153
+ def _write_placeholder_diagram(path: Path, *, with_labels: bool, labels: Optional[List[str]] = None) -> None:
154
+ """Generate a simple skeleton diagram (and optionally labels).
155
+
156
+ If Pillow is available, draw with anti-aliased vectors and real text;
157
+ otherwise fall back to a pure-stdlib bitmap renderer.
158
+ """
159
+ try:
160
+ from PIL import Image, ImageDraw, ImageFont # type: ignore
161
+ _write_placeholder_diagram_pil(path, with_labels=with_labels, labels=labels)
162
+ return
163
+ except Exception:
164
+ pass
165
+
166
+ # Fallback: stdlib bitmap renderer
167
+ # White background, 3px black strokes, arrows, dashed group
168
+ import zlib, struct, binascii
169
+
170
+ W, H = 1200, 420
171
+ # initialize white canvas
172
+ pixels: List[List[List[int]]] = [[[255, 255, 255] for _ in range(W)] for _ in range(H)]
173
+
174
+ def set_px(x: int, y: int, c: tuple[int, int, int]):
175
+ if 0 <= x < W and 0 <= y < H:
176
+ pixels[y][x][0] = c[0]
177
+ pixels[y][x][1] = c[1]
178
+ pixels[y][x][2] = c[2]
179
+
180
+ def draw_line(x0: int, y0: int, x1: int, y1: int, c=(0, 0, 0), t: int = 3):
181
+ dx = abs(x1 - x0)
182
+ sx = 1 if x0 < x1 else -1
183
+ dy = -abs(y1 - y0)
184
+ sy = 1 if y0 < y1 else -1
185
+ err = dx + dy
186
+ while True:
187
+ for ox in range(-t // 2, t // 2 + 1):
188
+ for oy in range(-t // 2, t // 2 + 1):
189
+ set_px(x0 + ox, y0 + oy, c)
190
+ if x0 == x1 and y0 == y1:
191
+ break
192
+ e2 = 2 * err
193
+ if e2 >= dy:
194
+ err += dy
195
+ x0 += sx
196
+ if e2 <= dx:
197
+ err += dx
198
+ y0 += sy
199
+
200
+ def draw_rect(x: int, y: int, w: int, h: int, c=(0, 0, 0), t: int = 3, dashed: bool = False):
201
+ def dash_points(x0, y0, x1, y1):
202
+ # Bresenham plus on/off dashes
203
+ dx = abs(x1 - x0)
204
+ sx = 1 if x0 < x1 else -1
205
+ dy = -abs(y1 - y0)
206
+ sy = 1 if y0 < y1 else -1
207
+ err = dx + dy
208
+ on = True
209
+ step = 0
210
+ period = 10
211
+ points = []
212
+ while True:
213
+ if on:
214
+ points.append((x0, y0))
215
+ step = (step + 1) % period
216
+ if step == 0:
217
+ on = not on
218
+ if x0 == x1 and y0 == y1:
219
+ break
220
+ e2 = 2 * err
221
+ if e2 >= dy:
222
+ err += dy
223
+ x0 += sx
224
+ if e2 <= dx:
225
+ err += dx
226
+ y0 += sy
227
+ return points
228
+
229
+ if dashed:
230
+ for (x0, y0, x1, y1) in [
231
+ (x, y, x + w, y),
232
+ (x + w, y, x + w, y + h),
233
+ (x + w, y + h, x, y + h),
234
+ (x, y + h, x, y),
235
+ ]:
236
+ for px, py in dash_points(x0, y0, x1, y1):
237
+ for ox in range(-t // 2, t // 2 + 1):
238
+ for oy in range(-t // 2, t // 2 + 1):
239
+ set_px(px + ox, py + oy, c)
240
+ else:
241
+ draw_line(x, y, x + w, y, c, t)
242
+ draw_line(x + w, y, x + w, y + h, c, t)
243
+ draw_line(x + w, y + h, x, y + h, c, t)
244
+ draw_line(x, y + h, x, y, c, t)
245
+
246
+ def draw_arrow(x0: int, y0: int, x1: int, y1: int, c=(0, 0, 0)):
247
+ draw_line(x0, y0, x1, y1, c, 3)
248
+ # simple arrow head
249
+ vx, vy = x1 - x0, y1 - y0
250
+ length = max((vx * vx + vy * vy) ** 0.5, 1.0)
251
+ ux, uy = vx / length, vy / length
252
+ # perpendicular
253
+ px, py = -uy, ux
254
+ ah = 10 # head length
255
+ aw = 6 # head width
256
+ hx, hy = int(x1 - ux * ah), int(y1 - uy * ah)
257
+ lx, ly = int(hx + px * aw), int(hy + py * aw)
258
+ rx, ry = int(hx - px * aw), int(hy - py * aw)
259
+ draw_line(x1, y1, lx, ly, c, 2)
260
+ draw_line(x1, y1, rx, ry, c, 2)
261
+
262
+ # layout
263
+ margin_x, margin_y = 60, 140
264
+ box_w, box_h = 220, 90
265
+ gap = 90
266
+ y = margin_y
267
+ xs = [margin_x + i * (box_w + gap) for i in range(4)]
268
+
269
+ # dashed group around middle two blocks
270
+ group_x = xs[1] - 20
271
+ group_y = y - 20
272
+ group_w = box_w * 2 + gap + 40
273
+ group_h = box_h + 40
274
+ draw_rect(group_x, group_y, group_w, group_h, c=(140, 140, 140), t=2, dashed=True)
275
+
276
+ # blocks
277
+ for idx, x in enumerate(xs):
278
+ draw_rect(x, y, box_w, box_h, c=(0, 0, 0), t=2)
279
+ if with_labels:
280
+ # draw simple 5x7 bitmap text using ASCII-only; non-ASCII removed
281
+ label = None
282
+ if labels and idx < len(labels):
283
+ label = labels[idx]
284
+ _draw_label_text(pixels, x, y, box_w, box_h, label)
285
+
286
+ # arrows between blocks (center-right to center-left)
287
+ cy = y + box_h // 2
288
+ for i in range(3):
289
+ x0 = xs[i] + box_w
290
+ x1 = xs[i + 1]
291
+ draw_arrow(x0 + 4, cy, x1 - 4, cy, c=(0, 0, 0))
292
+
293
+ # write PNG
294
+ def png_chunk(tag: bytes, data: bytes) -> bytes:
295
+ return struct.pack(">I", len(data)) + tag + data + struct.pack(">I", binascii.crc32(tag + data) & 0xFFFFFFFF)
296
+
297
+ raw = bytearray()
298
+ for row in pixels:
299
+ raw.append(0) # filter type 0
300
+ for r, g, b in row:
301
+ raw.extend((r & 255, g & 255, b & 255))
302
+ comp = zlib.compress(bytes(raw), level=9)
303
+ sig = b"\x89PNG\r\n\x1a\n"
304
+ ihdr = struct.pack(">IIBBBBB", W, H, 8, 2, 0, 0, 0) # 8-bit, truecolor RGB
305
+ png = sig + png_chunk(b"IHDR", ihdr) + png_chunk(b"IDAT", comp) + png_chunk(b"IEND", b"")
306
+ path.parent.mkdir(parents=True, exist_ok=True)
307
+ with open(path, "wb") as f:
308
+ f.write(png)
309
+
310
+ # end _write_placeholder_diagram
311
+
312
+
313
+ def _write_placeholder_diagram_pil(path: Path, *, with_labels: bool, labels: Optional[List[str]] = None) -> None:
314
+ from PIL import Image, ImageDraw, ImageFont # type: ignore
315
+
316
+ W, H = 1280, 480
317
+ img = Image.new("RGB", (W, H), (255, 255, 255))
318
+ draw = ImageDraw.Draw(img)
319
+
320
+ margin_x, margin_y = 80, 160
321
+ box_w, box_h = 240, 110
322
+ gap = 110
323
+ y = margin_y
324
+ xs = [margin_x + i * (box_w + gap) for i in range(4)]
325
+
326
+ # dashed group
327
+ group_x = xs[1] - 24
328
+ group_y = y - 24
329
+ group_w = box_w * 2 + gap + 48
330
+ group_h = box_h + 48
331
+ draw.rounded_rectangle([group_x, group_y, group_x + group_w, group_y + group_h], radius=14, outline=(150, 150, 150), width=2)
332
+ # manual dash overlay
333
+ # top and bottom dashed
334
+ def dashed_line(p0, p1, dash=12, gaplen=8, width=2, fill=(150, 150, 150)):
335
+ from math import hypot
336
+ (x0, y0), (x1, y1) = p0, p1
337
+ dx, dy = x1 - x0, y1 - y0
338
+ length = (dx * dx + dy * dy) ** 0.5
339
+ if length == 0:
340
+ return
341
+ ux, uy = dx / length, dy / length
342
+ dist = 0.0
343
+ on = True
344
+ while dist < length:
345
+ l = dash if on else gaplen
346
+ nx0 = x0 + ux * dist
347
+ ny0 = y0 + uy * dist
348
+ nx1 = x0 + ux * min(length, dist + l)
349
+ ny1 = y0 + uy * min(length, dist + l)
350
+ if on:
351
+ draw.line([(nx0, ny0), (nx1, ny1)], fill=fill, width=width)
352
+ dist += l
353
+ on = not on
354
+
355
+ dashed_line((group_x, group_y), (group_x + group_w, group_y), width=2)
356
+ dashed_line((group_x, group_y + group_h), (group_x + group_w, group_y + group_h), width=2)
357
+ dashed_line((group_x, group_y), (group_x, group_y + group_h), width=2)
358
+ dashed_line((group_x + group_w, group_y), (group_x + group_w, group_y + group_h), width=2)
359
+
360
+ # blocks
361
+ for x in xs:
362
+ draw.rounded_rectangle([x, y, x + box_w, y + box_h], radius=12, outline=(0, 0, 0), width=3)
363
+
364
+ # arrows
365
+ cy = y + box_h // 2
366
+ for i in range(3):
367
+ x0 = xs[i] + box_w + 6
368
+ x1 = xs[i + 1] - 6
369
+ draw.line([(x0, cy), (x1, cy)], fill=(0, 0, 0), width=3)
370
+ # head
371
+ ah, aw = 14, 8
372
+ draw.line([(x1, cy), (x1 - ah, cy - aw)], fill=(0, 0, 0), width=3)
373
+ draw.line([(x1, cy), (x1 - ah, cy + aw)], fill=(0, 0, 0), width=3)
374
+
375
+ # labels
376
+ if with_labels:
377
+ # pick font
378
+ try:
379
+ font = ImageFont.truetype("DejaVuSans.ttf", 22)
380
+ except Exception:
381
+ font = ImageFont.load_default()
382
+ fallback = ["PATCH EMBED", "+ CLS + POSENC", "ENCODER xL", "CLASS HEAD"]
383
+ for i, x in enumerate(xs):
384
+ text = None
385
+ if labels and i < len(labels):
386
+ text = labels[i]
387
+ if not text or not isinstance(text, str) or not text.strip():
388
+ text = fallback[i % len(fallback)]
389
+ # center text
390
+ tw, th = draw.textlength(text, font=font), font.size + 6
391
+ tx = x + (box_w - tw) / 2
392
+ ty = y + (box_h - th) / 2
393
+ draw.text((tx, ty), text, fill=(40, 40, 40), font=font, align="center")
394
+
395
+ path.parent.mkdir(parents=True, exist_ok=True)
396
+ img.save(path)
397
+
398
+
399
+ def _draw_label_text(pixels: List[List[List[int]]], x: int, y: int, w: int, h: int, label: Optional[str]) -> None:
400
+ # Render up to two lines of ASCII uppercase text inside a box using a 5x7 bitmap font
401
+ import re
402
+ inner_x = x + 10
403
+ inner_y = y + 10
404
+ inner_w = w - 20
405
+ inner_h = h - 20
406
+ if inner_w <= 0 or inner_h <= 0:
407
+ return
408
+ if not label:
409
+ # fallback: generic placeholders
410
+ label = "BLOCK"
411
+ # Normalize to ASCII upper
412
+ text = label.upper()
413
+ # allow only A-Z, 0-9, space and '-'
414
+ text = re.sub(r"[^A-Z0-9 \-]", " ", text)
415
+ if not text.strip():
416
+ text = "BLOCK"
417
+ # simple wrap by width
418
+ scale = 3 # enlarge 5x7 glyphs for readability
419
+ char_w = 6 * scale # 5px glyph +1px spacing
420
+ max_chars = max(1, inner_w // char_w)
421
+ words = text.split()
422
+ lines: List[str] = []
423
+ line = ""
424
+ for wtok in words:
425
+ token = wtok
426
+ if line:
427
+ candidate = line + " " + token
428
+ else:
429
+ candidate = token
430
+ if len(candidate) <= max_chars:
431
+ line = candidate
432
+ else:
433
+ if line:
434
+ lines.append(line)
435
+ line = token
436
+ else:
437
+ # force cut
438
+ lines.append(token[:max_chars])
439
+ line = token[max_chars:]
440
+ if line:
441
+ lines.append(line)
442
+ # limit to 2 lines for clarity
443
+ lines = lines[:2]
444
+ # vertical centering
445
+ total_h = len(lines) * ((7 * scale) + scale) - scale
446
+ start_y = inner_y + (inner_h - total_h) // 2
447
+
448
+ for i, ln in enumerate(lines):
449
+ # center each line horizontally
450
+ line_w = len(ln) * (6 * scale)
451
+ start_x = inner_x + max(0, (inner_w - line_w) // 2)
452
+ draw_text_5x7(pixels, start_x, start_y + i * ((7 * scale) + scale), ln, color=(40, 40, 40), scale=scale)
453
+
454
+
455
+ _FONT_5x7: Dict[str, List[str]] = {
456
+ 'A': ["01110","10001","10001","11111","10001","10001","10001"],
457
+ 'B': ["11110","10001","11110","10001","10001","10001","11110"],
458
+ 'C': ["01111","10000","10000","10000","10000","10000","01111"],
459
+ 'D': ["11110","10001","10001","10001","10001","10001","11110"],
460
+ 'E': ["11111","10000","11110","10000","10000","10000","11111"],
461
+ 'F': ["11111","10000","11110","10000","10000","10000","10000"],
462
+ 'G': ["01110","10000","10000","10111","10001","10001","01110"],
463
+ 'H': ["10001","10001","11111","10001","10001","10001","10001"],
464
+ 'I': ["11111","00100","00100","00100","00100","00100","11111"],
465
+ 'J': ["00111","00010","00010","00010","10010","10010","01100"],
466
+ 'K': ["10001","10010","10100","11000","10100","10010","10001"],
467
+ 'L': ["10000","10000","10000","10000","10000","10000","11111"],
468
+ 'M': ["10001","11011","10101","10101","10001","10001","10001"],
469
+ 'N': ["10001","11001","10101","10011","10001","10001","10001"],
470
+ 'O': ["01110","10001","10001","10001","10001","10001","01110"],
471
+ 'P': ["11110","10001","10001","11110","10000","10000","10000"],
472
+ 'Q': ["01110","10001","10001","10001","10101","10010","01101"],
473
+ 'R': ["11110","10001","10001","11110","10100","10010","10001"],
474
+ 'S': ["01111","10000","10000","01110","00001","00001","11110"],
475
+ 'T': ["11111","00100","00100","00100","00100","00100","00100"],
476
+ 'U': ["10001","10001","10001","10001","10001","10001","01110"],
477
+ 'V': ["10001","10001","10001","10001","01010","01010","00100"],
478
+ 'W': ["10001","10001","10001","10101","10101","11011","10001"],
479
+ 'X': ["10001","01010","00100","00100","01010","10001","10001"],
480
+ 'Y': ["10001","01010","00100","00100","00100","00100","00100"],
481
+ 'Z': ["11111","00001","00010","00100","01000","10000","11111"],
482
+ '0': ["01110","10001","10011","10101","11001","10001","01110"],
483
+ '1': ["00100","01100","00100","00100","00100","00100","01110"],
484
+ '2': ["01110","10001","00001","00010","00100","01000","11111"],
485
+ '3': ["11110","00001","00001","00110","00001","00001","11110"],
486
+ '4': ["00010","00110","01010","10010","11111","00010","00010"],
487
+ '5': ["11111","10000","11110","00001","00001","10001","01110"],
488
+ '6': ["00110","01000","10000","11110","10001","10001","01110"],
489
+ '7': ["11111","00001","00010","00100","01000","01000","01000"],
490
+ '8': ["01110","10001","10001","01110","10001","10001","01110"],
491
+ '9': ["01110","10001","10001","01111","00001","00010","01100"],
492
+ '-': ["00000","00000","00000","11111","00000","00000","00000"],
493
+ ' ': ["00000","00000","00000","00000","00000","00000","00000"],
494
+ }
495
+
496
+
497
+ def draw_text_5x7(pixels: List[List[List[int]]], x: int, y: int, text: str, color=(60, 60, 60), scale: int = 1) -> None:
498
+ max_y = len(pixels) - 1
499
+ max_x = len(pixels[0]) - 1 if pixels else -1
500
+
501
+ def set_px(px: int, py: int):
502
+ if 0 <= px <= max_x and 0 <= py <= max_y:
503
+ pixels[py][px][0] = color[0]
504
+ pixels[py][px][1] = color[1]
505
+ pixels[py][px][2] = color[2]
506
+
507
+ cx = x
508
+ for ch in text:
509
+ glyph = _FONT_5x7.get(ch, _FONT_5x7[' '])
510
+ for gy, row in enumerate(glyph):
511
+ for gx, bit in enumerate(row):
512
+ if bit == '1':
513
+ # draw scaled pixel
514
+ for sy in range(scale):
515
+ for sx in range(scale):
516
+ set_px(cx + gx * scale + sx, y + gy * scale + sy)
517
+ cx += 6 * scale # 5px width + spacing
518
+
519
+
520
+ # ------------------------- Real Google GenAI calls -------------------------
521
+
522
+ def _real_gemini(kind: str, *, api_key: str, **kwargs) -> Dict[str, Any]:
523
+ import google.generativeai as genai
524
+
525
+ genai.configure(api_key=api_key)
526
+
527
+ # Model names are configurable via env with safe defaults.
528
+ # You can set GEMINI_MODEL to your provisioned model, e.g. "gemini-2.0-flash-exp" or "gemini-1.5-flash".
529
+ text_model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-flash")
530
+ image_model_name = os.getenv("GEMINI_IMAGE_MODEL", "") # e.g. "imagen-3.0-generate"
531
+ image_edit_model_name = os.getenv("GEMINI_IMAGE_EDIT_MODEL", "") # e.g. "imagen-3.0-edit"
532
+
533
+ if kind == "parse":
534
+ from .. import prompts as _p
535
+ prompt = _p.build_parse_prompt()
536
+ user_text = kwargs.get("user_text", "")
537
+ model = genai.GenerativeModel(text_model_name)
538
+ resp = model.generate_content([prompt, user_text])
539
+ content = _first_text(resp)
540
+ data = _robust_json(content)
541
+ if not isinstance(data, dict):
542
+ raise ValueError("parse: model did not return JSON dict")
543
+ return {"spec": data}
544
+
545
+ if kind == "plan":
546
+ spec = kwargs.get("spec", {})
547
+ from .. import prompts as _p
548
+ prompt = _p.build_plan_prompt()
549
+ model = genai.GenerativeModel(text_model_name)
550
+ resp = model.generate_content([prompt, json.dumps(spec, ensure_ascii=False)])
551
+ spec_text = _first_text(resp)
552
+ return {"spec_text": spec_text.strip()}
553
+
554
+ if kind == "prompt_generate":
555
+ K = int(kwargs.get("K", 3))
556
+ spec_text = kwargs.get("spec_text", "")
557
+ from .. import prompts as _p
558
+ prompt = _p.build_promptgen_prompt(K, spec_text)
559
+ model = genai.GenerativeModel(text_model_name)
560
+ resp = model.generate_content(prompt)
561
+ content = _first_text(resp)
562
+ arr = _robust_json(content)
563
+ if not isinstance(arr, list):
564
+ # fallback: split lines
565
+ arr = [ln.strip("- ") for ln in content.splitlines() if ln.strip()][:K]
566
+ return {"prompts": arr[:K]}
567
+
568
+ if kind == "image_generate":
569
+ prompts: List[str] = kwargs.get("prompts", [])
570
+ outdir: str = kwargs.get("outdir", "artifacts")
571
+ if not image_model_name:
572
+ raise ValueError("GEMINI_IMAGE_MODEL is not set. Please set it to a valid image model (e.g., 'gemini-2.5-flash-image' or 'gemini-2.5-flash-image-preview').")
573
+ try:
574
+ from concurrent.futures import ThreadPoolExecutor, as_completed
575
+
576
+ Path(outdir).mkdir(parents=True, exist_ok=True)
577
+ max_workers = max(1, min(len(prompts), int(_os.getenv("NNG_CONCURRENCY", "4"))))
578
+
579
+ def _gen_one(i: int, p: str) -> str:
580
+ # new model per thread to avoid cross-thread state issues
581
+ mdl = genai.GenerativeModel(model_name=image_model_name)
582
+ resp = mdl.generate_content(p, request_options={"timeout": 180})
583
+ try:
584
+ (Path(outdir) / f"candidate_{i}.resp.txt").write_text(str(resp))
585
+ except Exception:
586
+ pass
587
+ img_bytes, mime = _first_image_bytes(resp)
588
+ if not img_bytes:
589
+ raise ValueError("image model did not return image bytes; see *.resp.txt")
590
+ ext = ".png" if mime == "image/png" else ".jpg"
591
+ pth = Path(outdir) / f"candidate_{i}{ext}"
592
+ with open(pth, "wb") as f:
593
+ f.write(img_bytes)
594
+ with open(str(pth) + ".meta.json", "w", encoding="utf-8") as mf:
595
+ mf.write(json.dumps({"source": "gemini", "mime": mime, "bytes": len(img_bytes)}, ensure_ascii=False))
596
+ return str(pth)
597
+
598
+ futures = []
599
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
600
+ for i, p in enumerate(prompts):
601
+ futures.append(ex.submit(_gen_one, i, p))
602
+ # preserve order by index
603
+ results = [None] * len(prompts)
604
+ for fut in as_completed(futures):
605
+ # find index by result path name
606
+ path = fut.result()
607
+ stem = Path(path).stem
608
+ try:
609
+ idx = int(stem.split("_")[-1])
610
+ except Exception:
611
+ idx = 0
612
+ results[idx] = path
613
+ # fill any missing in order fallback
614
+ paths: List[str] = [r or "" for r in results]
615
+ return {"paths": paths}
616
+ except Exception:
617
+ raise
618
+
619
+ if kind == "judge":
620
+ image_path: str = kwargs.get("image_path")
621
+ spec = kwargs.get("spec", {})
622
+ model = genai.GenerativeModel(text_model_name)
623
+ from .. import prompts as _p
624
+ judge_prompt = _p.build_judge_prompt()
625
+ image_part = _image_part_from_path(image_path)
626
+ resp = model.generate_content([
627
+ {"text": judge_prompt},
628
+ {"text": json.dumps(spec, ensure_ascii=False)},
629
+ image_part,
630
+ ])
631
+ content = _first_text(resp)
632
+ data = _robust_json(content)
633
+ if not isinstance(data, dict):
634
+ raise ValueError("judge: non-JSON")
635
+ score = float(max(0.0, min(1.0, data.get("score", 0.0))))
636
+ violations = list(data.get("violations", []))
637
+ return {"score": score, "violations": violations}
638
+
639
+ if kind == "image_edit":
640
+ image_path: str = kwargs.get("image_path")
641
+ out_path: str = kwargs.get("out_path")
642
+ instructions: str = kwargs.get("instructions", "")
643
+ ref_images: List[str] = list(kwargs.get("ref_images", []) or [])
644
+ if not image_edit_model_name:
645
+ raise ValueError("GEMINI_IMAGE_EDIT_MODEL is not set. Please set it to a valid image edit model (e.g., 'gemini-2.5-flash-image' or 'gemini-2.5-flash-image-preview').")
646
+ try:
647
+ model = genai.GenerativeModel(model_name=image_edit_model_name)
648
+ base_img = _image_part_from_path(image_path)
649
+ from .. import prompts as _p
650
+ parts = [{"text": _p.build_image_edit_prompt(instructions)}, base_img]
651
+ for rp in ref_images:
652
+ try:
653
+ parts.append(_image_part_from_path(rp))
654
+ except Exception:
655
+ continue
656
+ resp = model.generate_content(parts, request_options={"timeout": 120})
657
+ try:
658
+ out_p = Path(out_path)
659
+ out_p.parent.mkdir(parents=True, exist_ok=True)
660
+ (out_p.parent / (out_p.stem + ".resp.txt")).write_text(str(resp))
661
+ except Exception:
662
+ pass
663
+ img_bytes, mime = _first_image_bytes(resp)
664
+ if not img_bytes:
665
+ raise ValueError("image edit returned no image; see *.resp.txt for raw response")
666
+ ext = ".png" if mime == "image/png" else ".jpg"
667
+ out_p = Path(out_path)
668
+ out_p.parent.mkdir(parents=True, exist_ok=True)
669
+ with open(out_p, "wb") as f:
670
+ f.write(img_bytes)
671
+ with open(str(out_p) + ".meta.json", "w", encoding="utf-8") as mf:
672
+ mf.write(json.dumps({"source": "gemini", "mime": mime, "bytes": len(img_bytes)}, ensure_ascii=False))
673
+ return {"path": str(out_p)}
674
+ except Exception as e:
675
+ # surface error rather than fallback, per user's requirement to avoid local rendering
676
+ raise
677
+
678
+ if kind == "image_fuse":
679
+ # Create a new image by composing multiple reference images under textual instructions
680
+ out_path: str = kwargs.get("out_path")
681
+ instructions: str = kwargs.get("instructions", "")
682
+ ref_images: List[str] = list(kwargs.get("ref_images", []) or [])
683
+ if not image_model_name:
684
+ raise ValueError("GEMINI_IMAGE_MODEL is not set. Please set it to a valid image model (e.g., 'gemini-2.5-flash-image' or 'gemini-2.5-flash-image-preview').")
685
+ try:
686
+ model = genai.GenerativeModel(model_name=image_model_name)
687
+ from .. import prompts as _p
688
+ parts = [{"text": _p.build_image_fusion_prompt(instructions)}]
689
+ for rp in ref_images:
690
+ try:
691
+ parts.append(_image_part_from_path(rp))
692
+ except Exception:
693
+ continue
694
+ resp = model.generate_content(parts, request_options={"timeout": 120})
695
+ try:
696
+ out_p = Path(out_path)
697
+ out_p.parent.mkdir(parents=True, exist_ok=True)
698
+ (out_p.parent / (out_p.stem + ".resp.txt")).write_text(str(resp))
699
+ except Exception:
700
+ pass
701
+ img_bytes, mime = _first_image_bytes(resp)
702
+ if not img_bytes:
703
+ raise ValueError("image fuse returned no image; see *.resp.txt for raw response")
704
+ out_p = Path(out_path)
705
+ out_p.parent.mkdir(parents=True, exist_ok=True)
706
+ with open(out_p, "wb") as f:
707
+ f.write(img_bytes)
708
+ with open(str(out_p) + ".meta.json", "w", encoding="utf-8") as mf:
709
+ mf.write(json.dumps({"source": "gemini", "mime": mime, "bytes": len(img_bytes)}, ensure_ascii=False))
710
+ return {"path": str(out_p)}
711
+ except Exception:
712
+ raise
713
+
714
+ raise ValueError(f"Unsupported kind={kind}")
715
+
716
+
717
+ def _first_text(resp: Any) -> str:
718
+ try:
719
+ if hasattr(resp, "text"):
720
+ return resp.text
721
+ # Some SDK versions: candidates[0].content.parts[0].text
722
+ cands = getattr(resp, "candidates", [])
723
+ if cands:
724
+ parts = getattr(cands[0], "content", None)
725
+ if parts and getattr(parts, "parts", None):
726
+ for part in parts.parts:
727
+ if getattr(part, "text", None):
728
+ return part.text
729
+ return str(resp)
730
+ except Exception:
731
+ return str(resp)
732
+
733
+
734
+ def _first_image_bytes(resp: Any) -> tuple[bytes | None, str]:
735
+ # Try to walk through content parts and return first inline image bytes
736
+ try:
737
+ # Newer SDK: resp.candidates[].content.parts[].inline_data
738
+ cands = getattr(resp, "candidates", [])
739
+ for c in cands or []:
740
+ content = getattr(c, "content", None)
741
+ parts = getattr(content, "parts", None) if content else None
742
+ for part in parts or []:
743
+ inline = getattr(part, "inline_data", None)
744
+ if inline and getattr(inline, "data", None):
745
+ data = inline.data
746
+ mime = getattr(inline, "mime_type", "image/png")
747
+ if isinstance(data, bytes):
748
+ return data, mime
749
+ # some versions may base64-encode
750
+ try:
751
+ return base64.b64decode(data), mime
752
+ except Exception:
753
+ pass
754
+ return None, ""
755
+ except Exception:
756
+ return None, ""
757
+
758
+
759
+ def _image_part_from_path(path: str) -> Dict[str, Any]:
760
+ # google-generativeai accepts dict with mime_type and data bytes for images
761
+ p = Path(path)
762
+ mime = "image/png" if p.suffix.lower() == ".png" else "image/jpeg"
763
+ data = p.read_bytes()
764
+ return {"mime_type": mime, "data": data}
765
+
766
+
767
+ def _robust_json(text: str) -> Any:
768
+ # Try parse whole, then attempt to extract first {...} or [...] block
769
+ try:
770
+ return json.loads(text)
771
+ except Exception:
772
+ pass
773
+ # crude extraction
774
+ start = text.find("{")
775
+ end = text.rfind("}")
776
+ if start != -1 and end != -1 and end > start:
777
+ try:
778
+ return json.loads(text[start : end + 1])
779
+ except Exception:
780
+ pass
781
+ start = text.find("[")
782
+ end = text.rfind("]")
783
+ if start != -1 and end != -1 and end > start:
784
+ try:
785
+ return json.loads(text[start : end + 1])
786
+ except Exception:
787
+ pass
788
+ return {}
NNGen/app/nodes/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from . import parser, planner, prompt_gen, gen_generate, gen_labels, judge, select, edit, archive
NNGen/app/nodes/archive.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from ..state import AppState
7
+
8
+
9
+ def run(state: AppState) -> AppState:
10
+ outdir = Path(state["outdir"]) # ensured by CLI
11
+ outdir.mkdir(parents=True, exist_ok=True)
12
+
13
+ # dump spec
14
+ if state.get("spec"):
15
+ (outdir / "spec.json").write_text(json.dumps(state["spec"], ensure_ascii=False, indent=2))
16
+ if state.get("spec_text"):
17
+ (outdir / "spec.txt").write_text(state["spec_text"])
18
+
19
+ # dump prompts
20
+ if state.get("prompts"):
21
+ (outdir / "prompts.json").write_text(json.dumps(state["prompts"], ensure_ascii=False, indent=2))
22
+
23
+ # dump scores
24
+ if state.get("scores"):
25
+ (outdir / "scores.json").write_text(json.dumps(state["scores"], ensure_ascii=False, indent=2))
26
+
27
+ # copy/rename final image
28
+ if state.get("best_image"):
29
+ src = Path(state["best_image"].path)
30
+ dst = outdir / "final.png"
31
+ if src.exists():
32
+ dst.write_bytes(src.read_bytes())
33
+
34
+ return state
35
+
NNGen/app/nodes/edit.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import Dict
5
+
6
+ from ..llm.gemini import call_gemini
7
+ from ..state import AppState
8
+
9
+
10
+ def _labels_from_spec(state: AppState) -> list[str]:
11
+ spec = state.get("spec", {}) or {}
12
+ raw_nodes = spec.get("nodes", []) or []
13
+ labels: list[str] = []
14
+ for n in raw_nodes:
15
+ label = None
16
+ if isinstance(n, str):
17
+ # strip index prefixes like "N0: ..."
18
+ parts = n.split(":", 1)
19
+ label = parts[1].strip() if len(parts) == 2 else n.strip()
20
+ elif isinstance(n, dict):
21
+ label = n.get("label") or n.get("name") or n.get("id")
22
+ if label:
23
+ labels.append(str(label))
24
+ # dedupe sequential exact repeats only when later mapping by order still makes sense
25
+ return labels
26
+
27
+
28
+ def _ascii_friendly_labels(state: AppState) -> list[str]:
29
+ labels = _labels_from_spec(state)
30
+ def is_ascii(s: str) -> bool:
31
+ try:
32
+ s.encode('ascii')
33
+ return True
34
+ except Exception:
35
+ return False
36
+ if not labels or sum(1 for l in labels if is_ascii(l)) == 0:
37
+ return [
38
+ "PATCH EMBEDDING",
39
+ "CLS + POSENC",
40
+ "ENCODER xL",
41
+ "CLASS HEAD",
42
+ ]
43
+ return labels
44
+
45
+
46
+ def plan_edits(state: AppState) -> str:
47
+ hard_violations = [str(v) for v in state.get("hard_violations", [])]
48
+ violations = [str(v) for v in state.get("violations", [])]
49
+ # If judge reports missing labels (prefer HARD), provide an add-labels instruction
50
+ hv = hard_violations or violations
51
+ if any(("labels" in v.lower() and "missing" in v.lower()) for v in hv):
52
+ labels = _labels_from_spec(state)
53
+ numbered = "\n".join([f"{i+1}: \"{lbl}\"" for i, lbl in enumerate(labels)]) or "(no labels provided)"
54
+ return (
55
+ "Add text labels INSIDE each rectangular block without changing geometry, arrows, spacing, sizes, or colors. "
56
+ "Map labels in left→right, top→bottom order; reuse identical labels for repeated blocks. "
57
+ "Use a clean sans-serif font in solid black or dark gray, consistent size.\n"
58
+ f"Labels list:\n{numbered}"
59
+ )
60
+
61
+ # Default: targeted fixes based on judge violations, but always provide labels list to preserve text in offline mode
62
+ fixes = "; ".join(violations) if violations else "typos, arrow direction, spacing/legibility, and style compliance"
63
+ labels = _ascii_friendly_labels(state)
64
+ numbered = "\n".join([f"{i+1}: \"{lbl}\"" for i, lbl in enumerate(labels)]) or "(no labels provided)"
65
+ return (
66
+ f"Fix the following issues precisely: {fixes}. "
67
+ "Do not move or reshape elements. Only adjust text (content/position/size), arrow direction styles, and minimal styling to reach paper standards.\n"
68
+ f"Labels list:\n{numbered}"
69
+ )
70
+
71
+
72
+ def apply_edits(state: AppState) -> AppState:
73
+ if not state.get("best_image"):
74
+ return state
75
+ src = state["best_image"].path
76
+ out_path = str(Path(state["outdir"]) / f"edited_round_{state.get('round', 0)}.png")
77
+ _ = call_gemini("image_edit", image_path=src, out_path=out_path, instructions=plan_edits(state))
78
+ # replace best_image with edited one
79
+ state["best_image"].path = out_path # type: ignore
80
+ return state
NNGen/app/nodes/gen_fusion.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ from concurrent.futures import ThreadPoolExecutor, as_completed
5
+ from pathlib import Path
6
+ from typing import Dict, List
7
+
8
+ from ..llm.gemini import call_gemini
9
+ from ..state import AppState, ImageArtifact
10
+
11
+
12
+ def run(state: AppState) -> AppState:
13
+ """Generate K fused candidates by composing reference images under instructions.
14
+
15
+ Expected in state:
16
+ - outdir: str
17
+ - K: int
18
+ - base_image: optional path (treated as first ref image if present)
19
+ - ref_images: optional list[str]
20
+ - instructions: str
21
+ """
22
+ outdir = Path(state["outdir"]) # ensured by graph
23
+ K = int(state.get("K", 3))
24
+ instructions: str = str(state.get("instructions", "")).strip()
25
+
26
+ # prepare reference list
27
+ refs: List[str] = []
28
+ if state.get("base_image"):
29
+ refs.append(str(state["base_image"]))
30
+ for r in state.get("ref_images", []) or []:
31
+ if r and str(r) not in refs:
32
+ refs.append(str(r))
33
+
34
+ if not refs:
35
+ raise ValueError("Fusion mode requires at least one reference image (base or ref_images)")
36
+
37
+ max_workers = max(1, min(K, int(os.getenv("NNG_CONCURRENCY", "4"))))
38
+
39
+ def _fuse_one(i: int) -> str:
40
+ out_path = str(outdir / f"fused_candidate_{i}.png")
41
+ # Use image_edit if base image is provided; otherwise image_fuse
42
+ if state.get("base_image"):
43
+ call_gemini(
44
+ "image_edit",
45
+ image_path=str(state["base_image"]),
46
+ out_path=out_path,
47
+ instructions=f"Variant {i}: {instructions}",
48
+ ref_images=[p for p in refs if p != str(state["base_image"])],
49
+ )
50
+ else:
51
+ call_gemini(
52
+ "image_fuse",
53
+ out_path=out_path,
54
+ instructions=f"Variant {i}: {instructions}",
55
+ ref_images=refs,
56
+ )
57
+ return out_path
58
+
59
+ paths: List[str] = [""] * K
60
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
61
+ futures = [ex.submit(_fuse_one, i) for i in range(K)]
62
+ for fut in as_completed(futures):
63
+ p = fut.result()
64
+ try:
65
+ idx = int(Path(p).stem.split("_")[-1])
66
+ except Exception:
67
+ idx = 0
68
+ paths[idx] = p
69
+
70
+ images = [ImageArtifact(prompt=instructions, path=pth) for pth in paths if pth]
71
+ state["images"] = images
72
+ return state
73
+
NNGen/app/nodes/gen_generate.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+
5
+ from ..llm.gemini import call_gemini
6
+ from ..state import AppState, ImageArtifact
7
+
8
+
9
+ def run(state: AppState) -> AppState:
10
+ prompts: List[str] = state.get("prompts", [])
11
+ res: Dict = call_gemini("image_generate", prompts=prompts, outdir=state["outdir"])
12
+ paths = res.get("paths", [])
13
+ images = [ImageArtifact(prompt=p, path=pth) for p, pth in zip(prompts, paths)]
14
+ state["images"] = images
15
+ return state
16
+
NNGen/app/nodes/gen_labels.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import List
5
+ import os
6
+ from concurrent.futures import ThreadPoolExecutor, as_completed
7
+
8
+ from ..llm.gemini import call_gemini
9
+ from ..state import AppState, ImageArtifact
10
+ from .edit import _labels_from_spec # reuse label extraction
11
+
12
+
13
+ def run(state: AppState) -> AppState:
14
+ images: List[ImageArtifact] = state.get("images", []) or []
15
+ if not images:
16
+ return state
17
+
18
+ labels = _labels_from_spec(state)
19
+ # Fallback to ASCII-friendly defaults if labels are missing or mostly non-ASCII
20
+ def _is_mostly_ascii(s: str) -> bool:
21
+ try:
22
+ s.encode('ascii')
23
+ return True
24
+ except Exception:
25
+ return False
26
+ if not labels or sum(1 for l in labels if _is_mostly_ascii(l)) == 0:
27
+ labels = [
28
+ "PATCH EMBEDDING",
29
+ "CLS + POSENC",
30
+ "ENCODER xL",
31
+ "CLASS HEAD",
32
+ ]
33
+ numbered = "\n".join([f"{i+1}: \"{lbl}\"" for i, lbl in enumerate(labels)]) or "(no labels provided)"
34
+
35
+ instructions = (
36
+ "Add labels INSIDE each rectangular block. Do not move/resize/add/remove shapes or arrows; keep layout, spacing, and colors unchanged. "
37
+ "Map labels in left→right, top→bottom order; reuse identical labels for repeated blocks. Use each label string exactly as given (no translation or paraphrase). "
38
+ "Typography: clean sans-serif, readable size, centered within blocks; at most two short lines; avoid covering arrows; no legends or titles. "
39
+ "If block count ≠ label count, do NOT add/remove shapes; place labels sequentially on existing blocks.\n"
40
+ f"Labels list:\n{numbered}"
41
+ )
42
+
43
+ outdir = Path(state["outdir"]) if state.get("outdir") else Path("artifacts")
44
+ max_workers = max(1, min(len(images), int(os.getenv("NNG_CONCURRENCY", "4"))))
45
+ results: List[ImageArtifact | None] = [None] * len(images)
46
+
47
+ def _label_one(i: int, im: ImageArtifact) -> tuple[int, str]:
48
+ src = im.path
49
+ out_path = str(outdir / f"labeled_candidate_{i}.png")
50
+ _ = call_gemini("image_edit", image_path=src, out_path=out_path, instructions=instructions)
51
+ return i, out_path
52
+
53
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
54
+ futures = [ex.submit(_label_one, i, im) for i, im in enumerate(images)]
55
+ for fut in as_completed(futures):
56
+ i, out_path = fut.result()
57
+ results[i] = ImageArtifact(prompt=images[i].prompt, path=out_path, meta={"stage": "labels"})
58
+
59
+ state["images"] = [im for im in results if im is not None]
60
+ return state
NNGen/app/nodes/judge.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict, List
4
+ import os
5
+ from concurrent.futures import ThreadPoolExecutor, as_completed
6
+
7
+ from ..llm.gemini import call_gemini
8
+ from ..state import AppState, ScoreItem
9
+
10
+
11
+ def run(state: AppState) -> AppState:
12
+ images = list(state.get("images", []))
13
+ if not images:
14
+ state["scores"] = []
15
+ return state
16
+
17
+ max_workers = max(1, min(len(images), int(os.getenv("NNG_CONCURRENCY", "4"))))
18
+ results: List[ScoreItem | None] = [None] * len(images)
19
+
20
+ def _judge_one(i: int) -> tuple[int, Dict]:
21
+ im = images[i]
22
+ res: Dict = call_gemini("judge", image_path=im.path, spec=state.get("spec", {}))
23
+ return i, res
24
+
25
+ with ThreadPoolExecutor(max_workers=max_workers) as ex:
26
+ futures = [ex.submit(_judge_one, i) for i in range(len(images))]
27
+ for fut in as_completed(futures):
28
+ try:
29
+ i, res = fut.result()
30
+ im = images[i]
31
+ results[i] = {
32
+ "image_path": im.path,
33
+ "score": float(res.get("score", 0.0)),
34
+ "violations": list(res.get("violations", [])),
35
+ }
36
+ except Exception as e:
37
+ im = images[futures.index(fut)] if fut in futures else None
38
+ path = im.path if im else ""
39
+ results[i] = {
40
+ "image_path": path,
41
+ "score": 0.0,
42
+ "violations": [f"judge error: {e}"]
43
+ }
44
+
45
+ state["scores"] = [s for s in results if s is not None]
46
+ return state
NNGen/app/nodes/parser.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ from ..llm.gemini import call_gemini
6
+ from ..state import AppState
7
+
8
+
9
+ def run(state: AppState) -> AppState:
10
+ if state.get("spec"):
11
+ return state
12
+ res: Dict = call_gemini("parse", user_text=state.get("user_text", ""))
13
+ state["spec"] = res.get("spec", {})
14
+ return state
15
+
NNGen/app/nodes/planner.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ from ..llm.gemini import call_gemini
6
+ from ..state import AppState
7
+
8
+
9
+ def run(state: AppState) -> AppState:
10
+ if state.get("spec_text"):
11
+ return state
12
+ res: Dict = call_gemini("plan", spec=state.get("spec", {}))
13
+ state["spec_text"] = res.get("spec_text", "")
14
+ return state
15
+
NNGen/app/nodes/prompt_gen.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from typing import Dict
4
+
5
+ from ..llm.gemini import call_gemini
6
+ from ..state import AppState
7
+
8
+
9
+ def run(state: AppState) -> AppState:
10
+ if state.get("prompts"):
11
+ return state
12
+ res: Dict = call_gemini("prompt_generate", spec_text=state.get("spec_text", ""), K=state.get("K", 3))
13
+ state["prompts"] = res.get("prompts", [])
14
+ return state
15
+
NNGen/app/nodes/select.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from pathlib import Path
4
+ from typing import List
5
+
6
+ from ..state import AppState, ImageArtifact
7
+
8
+
9
+ def run(state: AppState) -> AppState:
10
+ scores = sorted(state.get("scores", []), key=lambda s: s["score"], reverse=True)
11
+ if not scores:
12
+ return state
13
+ best = scores[0]
14
+ best_img_path = best["image_path"]
15
+ # find the corresponding ImageArtifact
16
+ best_image: ImageArtifact | None = None
17
+ for im in state.get("images", []):
18
+ if im.path == best_img_path:
19
+ best_image = im
20
+ break
21
+ vios = [str(v) for v in best.get("violations", [])]
22
+ # Identify hard violations: explicit HARD marker or labels missing heuristic
23
+ hard = [v for v in vios if v.strip().lower().startswith("hard:")]
24
+ if not hard:
25
+ hard = [v for v in vios if ("labels" in v.lower() and "missing" in v.lower())]
26
+
27
+ state["best_image"] = best_image
28
+ state["violations"] = vios
29
+ state["hard_violations"] = hard
30
+ return state
NNGen/app/prompts.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+
4
+ def build_parse_prompt() -> str:
5
+ return (
6
+ "You are a strict parser for neural network architecture specs. "
7
+ "Input is natural language. Return ONLY a JSON object with fields: "
8
+ "nodes: string[], edges: [fromIndex, toIndex][], constraints: object. "
9
+ "No prose."
10
+ )
11
+
12
+
13
+ def build_plan_prompt() -> str:
14
+ return (
15
+ "Given a structured NN spec (JSON), produce a concise, fillable template text "
16
+ "that preserves nodes, edges, and key constraints for diagram rendering. "
17
+ "Emphasize left-to-right flow, explicit layer counts, and unambiguous labels."
18
+ )
19
+
20
+
21
+ def build_promptgen_prompt(K: int, spec_text: str) -> str:
22
+ # Stage G1: lighter, cleaner skeleton-only prompts (no hard stylistic numbers)
23
+ return (
24
+ "Create K concise prompts for an image model to draw ONLY the skeleton of a neural network diagram (no text).\n"
25
+ "Aim for a clean paper-figure look: flat 2D, simple shapes, balanced margins, and a calm palette. Use rectangles for modules and clear left→right arrows.\n"
26
+ "If the spec implies repetition (e.g., Encoder × L), you may show a dashed grouping around the repeated blocks. Avoid flashy effects (no 3D or heavy glow).\n"
27
+ "Return ONLY a JSON array of exactly K strings; each item is one full prompt for image generation.\n"
28
+ f"K={K}.\n"
29
+ f"Spec (summary):\n{spec_text}\n"
30
+ "Each prompt must mention: 'skeleton-only, no text'."
31
+ )
32
+
33
+
34
+ def build_judge_prompt() -> str:
35
+ # Judge content & style, optimized for two-stage (skeleton→labels) flow
36
+ return (
37
+ "You are a strict publication-figure QA judge. Given a spec (JSON) and a NN diagram image, "
38
+ "evaluate (A) Content correctness and (B) Paper-style compliance.\n"
39
+ "(A) Content (0.6): required modules present; edges/arrows reflect correct order; arrows left→right; labels exist and are spelled correctly; "
40
+ "layer count L indicated when applicable. If the image has no labels, include violation EXACTLY 'HARD: labels: missing'.\n"
41
+ "(B) Style (0.4): flat 2D; white background; minimal color (black/gray + ≤2 accents); no gradients/3D/glow/shadows/neon; "
42
+ "consistent stroke width; consistent sans-serif font; adequate spacing; dashed boxes for repeated blocks; high print readability.\n"
43
+ "Return ONLY strict JSON: {score: number in [0,1], violations: string[]}. Violations must be concrete and actionable."
44
+ )
45
+
46
+
47
+ def build_image_edit_prompt(instructions: str) -> str:
48
+ # G2 and later edits: add/adjust labels only; keep geometry fixed (light constraints)
49
+ base = (
50
+ "Add or adjust labels INSIDE each block, without changing any shapes, arrows, layout, spacing, or colors. "
51
+ "Keep a clean, readable look: flat 2D, simple sans-serif font, good contrast, and consistent size across blocks. "
52
+ "Center labels within blocks; use at most two short lines; avoid covering arrows; do not add legends or titles. "
53
+ "Use each label string exactly as provided (no translation or paraphrase). "
54
+ )
55
+ return base + f"Instructions: {instructions}"
56
+
57
+
58
+ def build_image_fusion_prompt(instructions: str) -> str:
59
+ # Compose multiple images guided by text while preserving key visual constraints
60
+ return (
61
+ "Compose a new, clean technical diagram by integrating the following reference images. "
62
+ "Preserve the overall paper-style look: flat 2D, white background, minimal color, consistent line width, and sans-serif text. "
63
+ "Follow the instructions precisely; keep geometry aligned and readable; avoid extra decorations. "
64
+ f"Instructions: {instructions}"
65
+ )
NNGen/app/state.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, TypedDict
5
+
6
+
7
+ class ScoreItem(TypedDict):
8
+ image_path: str
9
+ score: float
10
+ violations: List[str]
11
+
12
+
13
+ @dataclass
14
+ class ImageArtifact:
15
+ prompt: str
16
+ path: str
17
+ meta: Dict[str, Any] = field(default_factory=dict)
18
+
19
+
20
+ class AppState(TypedDict, total=False):
21
+ user_text: str
22
+ spec: Dict[str, Any]
23
+ spec_text: str
24
+ K: int
25
+ T: int
26
+ round: int
27
+ prompts: List[str]
28
+ images: List[ImageArtifact]
29
+ scores: List[ScoreItem]
30
+ best_image: Optional[ImageArtifact]
31
+ violations: List[str]
32
+ hard_violations: List[str]
33
+ outdir: str
NNGen/demo.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
NNGen/notebooks/demo.ipynb ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "0ce88361",
6
+ "metadata": {},
7
+ "source": [
8
+ "# NNGen Demo — Gemini 2.5 Flash Image\n",
9
+ "\n",
10
+ "Interactive demo to generate a neural network diagram from a natural language prompt.\n",
11
+ "- Uses the multi-agent pipeline (`parser → planner → prompt-gen → G1 → G2 → judge → select → edit loop → archive`).\n",
12
+ "- All model calls go through `app.llm.gemini.call_gemini`. If `GEMINI_API_KEY` is not set, the pipeline falls back to local placeholders so the demo still runs.\n",
13
+ "- For image generation/editing, set `GEMINI_IMAGE_MODEL`/`GEMINI_IMAGE_EDIT_MODEL` (e.g., `gemini-2.5-flash-image` or `gemini-2.5-flash-image-preview`).\n"
14
+ ]
15
+ },
16
+ {
17
+ "cell_type": "code",
18
+ "execution_count": 1,
19
+ "id": "a0f6490c",
20
+ "metadata": {},
21
+ "outputs": [
22
+ {
23
+ "ename": "ModuleNotFoundError",
24
+ "evalue": "No module named 'app'",
25
+ "output_type": "error",
26
+ "traceback": [
27
+ "\u001b[31m---------------------------------------------------------------------------\u001b[39m",
28
+ "\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
29
+ "\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Imports\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mapp\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mgraph\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m run_pipeline\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mapp\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AppState\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpathlib\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Path\n",
30
+ "\u001b[31mModuleNotFoundError\u001b[39m: No module named 'app'"
31
+ ]
32
+ }
33
+ ],
34
+ "source": [
35
+ "# Imports\n",
36
+ "from app.graph import run_pipeline\n",
37
+ "from app.state import AppState\n",
38
+ "from pathlib import Path\n",
39
+ "from IPython.display import Image, display\n",
40
+ "import os, json\n"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": null,
46
+ "id": "722a09d3",
47
+ "metadata": {},
48
+ "outputs": [],
49
+ "source": [
50
+ "# Optional: configure models here if not set in environment (.env is supported).\n",
51
+ "# os.environ.setdefault(\"GEMINI_MODEL\", \"gemini-2.5-flash\")\n",
52
+ "# os.environ.setdefault(\"GEMINI_IMAGE_MODEL\", \"gemini-2.5-flash-image\")\n",
53
+ "# os.environ.setdefault(\"GEMINI_IMAGE_EDIT_MODEL\", \"gemini-2.5-flash-image\")\n",
54
+ "print(\"GEMINI_MODEL=\", os.getenv(\"GEMINI_MODEL\", \"(default)\"))\n",
55
+ "print(\"GEMINI_IMAGE_MODEL=\", os.getenv(\"GEMINI_IMAGE_MODEL\", \"(default)\"))\n",
56
+ "print(\"GEMINI_IMAGE_EDIT_MODEL=\", os.getenv(\"GEMINI_IMAGE_EDIT_MODEL\", \"(default)\"))\n"
57
+ ]
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "execution_count": null,
62
+ "id": "61aacffe",
63
+ "metadata": {},
64
+ "outputs": [],
65
+ "source": [
66
+ "# Enter a natural language NN spec (leave blank to use the sample in spec/vit.txt).\n",
67
+ "print(\"Enter your NN spec prompt (blank for sample):\")\n",
68
+ "user_text = input().strip()\n",
69
+ "if not user_text:\n",
70
+ " user_text = Path('spec/vit.txt').read_text()\n",
71
+ "\n",
72
+ "# Number of candidates (K) and max edit rounds (T)\n",
73
+ "K = 4\n",
74
+ "T = 1\n",
75
+ "print(f\"Configured: K={K}, T={T}\")\n"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "code",
80
+ "execution_count": null,
81
+ "id": "f6c28bb1",
82
+ "metadata": {},
83
+ "outputs": [],
84
+ "source": [
85
+ "# Run the multi-agent pipeline\n",
86
+ "state: AppState = {\n",
87
+ " 'K': K,\n",
88
+ " 'T': T,\n",
89
+ " 'user_text': user_text,\n",
90
+ " 'outdir': '' # use timestamped default\n",
91
+ "}\n",
92
+ "final_state = run_pipeline(state)\n",
93
+ "print('Artifacts directory:', final_state['outdir'])\n"
94
+ ]
95
+ },
96
+ {
97
+ "cell_type": "code",
98
+ "execution_count": null,
99
+ "id": "a402a108",
100
+ "metadata": {},
101
+ "outputs": [],
102
+ "source": [
103
+ "# Display the final image\n",
104
+ "final_path = Path(final_state['outdir']) / 'final.png'\n",
105
+ "if final_path.exists():\n",
106
+ " display(Image(filename=str(final_path)))\n",
107
+ "else:\n",
108
+ " print('final.png not found at', final_path)\n"
109
+ ]
110
+ },
111
+ {
112
+ "cell_type": "code",
113
+ "execution_count": null,
114
+ "id": "f7518a91",
115
+ "metadata": {},
116
+ "outputs": [],
117
+ "source": [
118
+ "# (Optional) Inspect outputs: spec and scoring\n",
119
+ "spec_txt = Path(final_state['outdir']) / 'spec.txt'\n",
120
+ "scores_json = Path(final_state['outdir']) / 'scores.json'\n",
121
+ "if spec_txt.exists():\n",
122
+ " print('--- spec.txt ---')\n",
123
+ " print(spec_txt.read_text())\n",
124
+ "if scores_json.exists():\n",
125
+ " print('--- scores.json ---')\n",
126
+ " print(scores_json.read_text())\n"
127
+ ]
128
+ }
129
+ ],
130
+ "metadata": {
131
+ "kernelspec": {
132
+ "display_name": "Python 3 (ipykernel)",
133
+ "language": "python",
134
+ "name": "python3"
135
+ },
136
+ "language_info": {
137
+ "codemirror_mode": {
138
+ "name": "ipython",
139
+ "version": 3
140
+ },
141
+ "file_extension": ".py",
142
+ "mimetype": "text/x-python",
143
+ "name": "python",
144
+ "nbconvert_exporter": "python",
145
+ "pygments_lexer": "ipython3",
146
+ "version": "3.13.3"
147
+ }
148
+ },
149
+ "nbformat": 4,
150
+ "nbformat_minor": 5
151
+ }
NNGen/requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ langgraph>=0.2.0
2
+ google-generativeai>=0.7.0
3
+ python-dotenv>=1.0.1
4
+ gradio>=4.32.0
NNGen/runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.10
NNGen/scripts/gradio_app.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import shutil
5
+ from pathlib import Path
6
+ from typing import List, Tuple
7
+
8
+ import gradio as gr
9
+
10
+ from app.graph import run_pipeline, run_fusion_pipeline
11
+ from app.state import AppState
12
+
13
+
14
+ def _zip_outdir(outdir: str) -> str:
15
+ out = Path(outdir)
16
+ if not out.exists():
17
+ return ""
18
+ zip_path = str(out) + ".zip"
19
+ # remove if exists
20
+ try:
21
+ if Path(zip_path).exists():
22
+ Path(zip_path).unlink()
23
+ except Exception:
24
+ pass
25
+ shutil.make_archive(str(out), "zip", root_dir=str(out))
26
+ return zip_path
27
+
28
+
29
+ def run_text_mode(user_text: str, K: int, T: int, make_zip: bool) -> Tuple[str, List[str], str, str]:
30
+ state: AppState = {"K": int(K), "T": int(T), "user_text": user_text or "", "outdir": ""}
31
+ final_state = run_pipeline(state)
32
+ outdir = final_state["outdir"]
33
+ # Collect candidates if present
34
+ candidates = [im.path for im in (final_state.get("images") or [])]
35
+ final_img = str(Path(outdir) / "final.png")
36
+ zip_path = _zip_outdir(outdir) if make_zip else ""
37
+ return final_img, candidates, outdir, zip_path
38
+
39
+
40
+ def run_image_mode(base_image, ref_images, instructions: str, K: int, make_zip: bool) -> Tuple[str, List[str], str, str]:
41
+ state: AppState = {"K": int(K), "T": 0, "outdir": "", "instructions": instructions or ""}
42
+ if base_image is not None:
43
+ state["base_image"] = base_image if isinstance(base_image, str) else base_image.name
44
+ refs: List[str] = []
45
+ for f in (ref_images or []):
46
+ p = f if isinstance(f, str) else getattr(f, "name", None)
47
+ if p:
48
+ refs.append(p)
49
+ state["ref_images"] = refs
50
+
51
+ final_state = run_fusion_pipeline(state)
52
+ outdir = final_state["outdir"]
53
+ candidates = [im.path for im in (final_state.get("images") or [])]
54
+ final_img = str(Path(outdir) / "final.png")
55
+ zip_path = _zip_outdir(outdir) if make_zip else ""
56
+ return final_img, candidates, outdir, zip_path
57
+
58
+
59
+ def app() -> gr.Blocks:
60
+ with gr.Blocks(title="NNGen — Gemini 2.5 Flash Image") as demo:
61
+ gr.Markdown("""
62
+ # NNGen — Gemini 2.5 Flash Image
63
+ - Text mode: enter a natural language spec to generate a diagram (G1/G2/judge/edit).
64
+ - Image mode: edit/fuse images with textual instructions (e.g., replace UNet with Transformer).
65
+ - Offline works with placeholders if no `GEMINI_API_KEY` is set. With an API key, set `GEMINI_IMAGE_MODEL` and `GEMINI_IMAGE_EDIT_MODEL`.
66
+ """)
67
+
68
+ with gr.Tab("Text Mode"):
69
+ user_text = gr.Textbox(label="NN spec (text)", lines=10, placeholder="Describe the architecture... e.g., Transformer encoder-decoder with cross-attention...")
70
+ with gr.Row():
71
+ K = gr.Slider(1, 6, value=4, step=1, label="K candidates")
72
+ T = gr.Slider(0, 3, value=1, step=1, label="Max edit rounds (T)")
73
+ zip_output = gr.Checkbox(value=False, label="Zip outputs")
74
+ run_btn = gr.Button("Generate")
75
+ final_img = gr.Image(label="final.png", type="filepath")
76
+ gallery = gr.Gallery(label="Candidates").style(grid=4)
77
+ outdir = gr.Textbox(label="Artifacts directory", interactive=False)
78
+ zip_file = gr.File(label="Download run.zip", interactive=False)
79
+
80
+ run_btn.click(run_text_mode, inputs=[user_text, K, T, zip_output], outputs=[final_img, gallery, outdir, zip_file])
81
+
82
+ with gr.Tab("Image Mode (Fusion/Edit)"):
83
+ base = gr.Image(label="Base image (optional)", type="filepath")
84
+ refs = gr.Files(label="Reference images (0..N)")
85
+ instr = gr.Textbox(label="Instructions", lines=4, placeholder="Replace the UNet backbone with a Transformer (DiT); keep layout, fonts, colors, arrows, and dashed groups unchanged.")
86
+ with gr.Row():
87
+ K2 = gr.Slider(1, 6, value=4, step=1, label="K candidates")
88
+ zip_output2 = gr.Checkbox(value=False, label="Zip outputs")
89
+ run_btn2 = gr.Button("Compose / Edit")
90
+ final_img2 = gr.Image(label="final.png", type="filepath")
91
+ gallery2 = gr.Gallery(label="Fused Candidates").style(grid=4)
92
+ outdir2 = gr.Textbox(label="Artifacts directory", interactive=False)
93
+ zip_file2 = gr.File(label="Download run.zip", interactive=False)
94
+
95
+ run_btn2.click(run_image_mode, inputs=[base, refs, instr, K2, zip_output2], outputs=[final_img2, gallery2, outdir2, zip_file2])
96
+
97
+ return demo
98
+
99
+
100
+ if __name__ == "__main__":
101
+ port = int(os.getenv("PORT", "7860"))
102
+ app().launch(server_name="0.0.0.0", server_port=port)
103
+
NNGen/spec/transformer.txt ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Title: Transformer Encoder–Decoder (Machine Translation)
2
+
3
+ Instructions:
4
+ - Produce a clean paper-style architecture diagram: flat 2D, white background, minimal color (black/gray + ≤2 accent colors), no gradients/3D/shadows.
5
+ - Layout left→right. Clear arrows indicate data flow.
6
+ - Draw boxes for major modules and show grouping for repeated layers.
7
+ - Text labels should be concise and capitalized (e.g., EMBEDDING, ENCODER xN, DECODER xN).
8
+
9
+ Architecture:
10
+ - Input: tokenized source sentence
11
+ - Source Embedding + Positional Encoding
12
+ - Encoder (N layers):
13
+ - Multi-Head Self-Attention
14
+ - Add & LayerNorm
15
+ - Feed-Forward (MLP)
16
+ - Add & LayerNorm
17
+ - Target: previous target tokens (for training)
18
+ - Target Embedding + Positional Encoding
19
+ - Decoder (N layers):
20
+ - Masked Multi-Head Self-Attention
21
+ - Add & LayerNorm
22
+ - Cross-Attention (attends to Encoder outputs)
23
+ - Add & LayerNorm
24
+ - Feed-Forward (MLP)
25
+ - Add & LayerNorm
26
+ - Output: Linear + Softmax
27
+
28
+ Style details:
29
+ - Use a dashed rounded rectangle to group the N repeated layers on both encoder and decoder.
30
+ - Keep arrows straight. Left→right overall; show a connection from Encoder outputs to the Cross-Attention in Decoder.
31
+ - If space is tight, abbreviate labels (e.g., SELF-ATTN, CROSS-ATTN, FFN).
32
+
NNGen/spec/vit.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ Generate a high-level diagram of a Vision Transformer (ViT):
2
+ - Input: 224×224 RGB image
3
+ - Patch Embedding: split into 16×16 patches and apply a linear projection
4
+ - Add CLS token and positional encoding
5
+ - Transformer Encoder stack: Multi-Head Self-Attention + MLP + residual + LayerNorm (repeat L layers)
6
+ - Classification head: take CLS token for linear classification
7
+ Layout requirements: left-to-right flow; clear arrow directions; correct spelling of all labels; show the number of layers L; keep colors readable.