Spaces:
Sleeping
Sleeping
feat: add NNGen project under NNGen/ and ignore local secrets
Browse files- .gitignore +4 -0
- NNGen/.gitignore +89 -0
- NNGen/AGENTS.md +38 -0
- NNGen/README.md +80 -0
- NNGen/app.py +9 -0
- NNGen/app/__init__.py +1 -0
- NNGen/app/cli.py +55 -0
- NNGen/app/graph.py +82 -0
- NNGen/app/llm/credentials.example.py +4 -0
- NNGen/app/llm/gemini.py +788 -0
- NNGen/app/nodes/__init__.py +1 -0
- NNGen/app/nodes/archive.py +35 -0
- NNGen/app/nodes/edit.py +80 -0
- NNGen/app/nodes/gen_fusion.py +73 -0
- NNGen/app/nodes/gen_generate.py +16 -0
- NNGen/app/nodes/gen_labels.py +60 -0
- NNGen/app/nodes/judge.py +46 -0
- NNGen/app/nodes/parser.py +15 -0
- NNGen/app/nodes/planner.py +15 -0
- NNGen/app/nodes/prompt_gen.py +15 -0
- NNGen/app/nodes/select.py +30 -0
- NNGen/app/prompts.py +65 -0
- NNGen/app/state.py +33 -0
- NNGen/demo.ipynb +0 -0
- NNGen/notebooks/demo.ipynb +151 -0
- NNGen/requirements.txt +4 -0
- NNGen/runtime.txt +1 -0
- NNGen/scripts/gradio_app.py +103 -0
- NNGen/spec/transformer.txt +32 -0
- NNGen/spec/vit.txt +7 -0
.gitignore
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
NNGen/.env
|
| 2 |
+
NNGen/artifacts/
|
| 3 |
+
NNGen/.venv/
|
| 4 |
+
NNGen/__pycache__/
|
NNGen/.gitignore
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
*.egg-info/
|
| 23 |
+
.installed.cfg
|
| 24 |
+
*.egg
|
| 25 |
+
|
| 26 |
+
# Installer logs
|
| 27 |
+
pip-log.txt
|
| 28 |
+
pip-delete-this-directory.txt
|
| 29 |
+
|
| 30 |
+
# Unit test / coverage reports
|
| 31 |
+
htmlcov/
|
| 32 |
+
.tox/
|
| 33 |
+
.nox/
|
| 34 |
+
.coverage
|
| 35 |
+
.coverage.*
|
| 36 |
+
.cache
|
| 37 |
+
nosetests.xml
|
| 38 |
+
coverage.xml
|
| 39 |
+
*.cover
|
| 40 |
+
*.py,cover
|
| 41 |
+
.hypothesis/
|
| 42 |
+
.pytest_cache/
|
| 43 |
+
|
| 44 |
+
# Jupyter Notebook
|
| 45 |
+
.ipynb_checkpoints
|
| 46 |
+
|
| 47 |
+
# IPython
|
| 48 |
+
profile_default/
|
| 49 |
+
ipython_config.py
|
| 50 |
+
|
| 51 |
+
# pyenv
|
| 52 |
+
.python-version
|
| 53 |
+
|
| 54 |
+
# pipenv
|
| 55 |
+
Pipfile.lock
|
| 56 |
+
|
| 57 |
+
# poetry
|
| 58 |
+
poetry.lock
|
| 59 |
+
|
| 60 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 61 |
+
__pypackages__/
|
| 62 |
+
|
| 63 |
+
# Environments
|
| 64 |
+
.env
|
| 65 |
+
.venv
|
| 66 |
+
venv/
|
| 67 |
+
ENV/
|
| 68 |
+
env/
|
| 69 |
+
|
| 70 |
+
# VS Code
|
| 71 |
+
.vscode/
|
| 72 |
+
|
| 73 |
+
# PyCharm
|
| 74 |
+
.idea/
|
| 75 |
+
|
| 76 |
+
# Artifacts and outputs
|
| 77 |
+
artifacts/
|
| 78 |
+
artifacts/**
|
| 79 |
+
*.png
|
| 80 |
+
*.jpg
|
| 81 |
+
*.jpeg
|
| 82 |
+
|
| 83 |
+
# Secrets
|
| 84 |
+
app/llm/credentials.py
|
| 85 |
+
app/llm/credentials.json
|
| 86 |
+
|
| 87 |
+
# Mac/Windows
|
| 88 |
+
.DS_Store
|
| 89 |
+
Thumbs.db
|
NNGen/AGENTS.md
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Repository Guidelines
|
| 2 |
+
|
| 3 |
+
## Project Structure & Module Organization
|
| 4 |
+
- `app/cli.py` — CLI entry point; orchestrates a full run.
|
| 5 |
+
- `app/graph.py` — lightweight pipeline runner and edit loop.
|
| 6 |
+
- `app/nodes/` — individual agent nodes (`parser.py`, `planner.py`, `prompt_gen.py`, `gen_generate.py` [G1 skeleton], `gen_labels.py` [G2 labels], `judge.py`, `select.py`, `edit.py`, `archive.py`). Each exposes `run(state)` or similar.
|
| 7 |
+
- `app/prompts.py` — centralized prompts for parsing/planning/generation/judging/editing.
|
| 8 |
+
- `app/state.py` — typed `AppState` and artifact helpers.
|
| 9 |
+
- `app/llm/gemini.py` — `call_gemini(kind, **kwargs)` wrapper; uses local placeholders if no API key.
|
| 10 |
+
- `spec/` — example specs (e.g., `spec/vit.txt`).
|
| 11 |
+
- `artifacts/` — run outputs (time-stamped folders with `final.png`).
|
| 12 |
+
|
| 13 |
+
## Setup, Run, and Development Commands
|
| 14 |
+
- Create env: `python -m venv .venv && source .venv/bin/activate` (Windows: `./.venv/Scripts/activate`).
|
| 15 |
+
- Install deps: `pip install -r requirements.txt`.
|
| 16 |
+
- Configure API (choose one):
|
| 17 |
+
- Env var: `export GEMINI_API_KEY=...` (supports `.env`).
|
| 18 |
+
- File: create `app/llm/credentials.py` like `credentials.example.py`.
|
| 19 |
+
- Run sample: `python -m app.cli --spec spec/vit.txt --K 3 --T 1`.
|
| 20 |
+
- Models: optionally set `GEMINI_MODEL`, `GEMINI_IMAGE_MODEL`, `GEMINI_IMAGE_EDIT_MODEL`.
|
| 21 |
+
|
| 22 |
+
## Coding Style & Naming Conventions
|
| 23 |
+
- Python 3.10+, PEP8, 4-space indentation, type hints required in public APIs.
|
| 24 |
+
- Files: snake_case; functions: `snake_case`; classes: `PascalCase`.
|
| 25 |
+
- Nodes are pure where possible: read from `state`, return a new `state`; side effects limited to writing under `artifacts/`.
|
| 26 |
+
- Centralize prompt text in `app/prompts.py`; call models via `call_gemini` only.
|
| 27 |
+
|
| 28 |
+
## Testing Guidelines
|
| 29 |
+
- No formal test suite yet. Prefer pytest with files under `tests/` named `test_*.py`.
|
| 30 |
+
- Minimal integration check: run the CLI and assert a `final.png` exists in the latest `artifacts/run_*` directory.
|
| 31 |
+
|
| 32 |
+
## Commit & Pull Request Guidelines
|
| 33 |
+
- Use Conventional Commits: `feat:`, `fix:`, `docs:`, `refactor:`, `chore:`.
|
| 34 |
+
- PRs must include: purpose, linked issues, how to run, and a sample spec plus path to produced artifact (e.g., `artifacts/run_YYYYmmdd_HHMMSS/final.png`).
|
| 35 |
+
|
| 36 |
+
## Security & Configuration Tips
|
| 37 |
+
- Do not commit secrets (`.env`, `app/llm/credentials.py`). Rotate keys if exposed.
|
| 38 |
+
- Large outputs live in `artifacts/`; avoid committing heavy assets unless necessary.
|
NNGen/README.md
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Multi-Agent Neural Network Diagram Generator (Skeleton) — Gemini 2.5 Flash Image
|
| 2 |
+
|
| 3 |
+
This repository is a minimal, runnable skeleton that turns a textual NN spec into a publication-style diagram via a multi-agent pipeline:
|
| 4 |
+
- Parser → Planner → Prompt-Generator → Image-Generator (G1) → Label-Generator (G2) → Judge → Selector → (Editor loop) → Archivist
|
| 5 |
+
- All model calls flow through `call_gemini(...)`, making it easy to use Gemini 2.5 Flash for text and Gemini 2.5 Flash Image for images.
|
| 6 |
+
|
| 7 |
+
Key additions in this version
|
| 8 |
+
- Two-stage generation: G1 draws the geometry-only skeleton (no text), G2 overlays labels on top of the skeleton.
|
| 9 |
+
- Hard violations: Judge returns actionable violations; missing labels are flagged as HARD to trigger edits reliably.
|
| 10 |
+
- Parallelism: G1, G2, and Judge run in parallel; set `NNG_CONCURRENCY` (default 4).
|
| 11 |
+
- Remote images by default: image generate/edit use Gemini 2.5 Flash Image models. If API is missing, the system can fall back to a local placeholder to stay runnable.
|
| 12 |
+
|
| 13 |
+
## Quick Start
|
| 14 |
+
|
| 15 |
+
1) Python 3.10+
|
| 16 |
+
|
| 17 |
+
2) Install deps
|
| 18 |
+
```
|
| 19 |
+
pip install -r requirements.txt
|
| 20 |
+
```
|
| 21 |
+
|
| 22 |
+
3) Configure Gemini (choose one)
|
| 23 |
+
- Env var: `export GEMINI_API_KEY=YOUR_KEY`
|
| 24 |
+
- File: create `app/llm/credentials.py` with `GEMINI_API_KEY = "YOUR_KEY"`
|
| 25 |
+
|
| 26 |
+
4) Run (K=candidates, T=max edit rounds)
|
| 27 |
+
```
|
| 28 |
+
# Text mode (spec -> image)
|
| 29 |
+
python -m app.cli --mode text --spec spec/vit.txt --K 4 --T 1
|
| 30 |
+
|
| 31 |
+
# Image mode (text + image fusion/edit)
|
| 32 |
+
# Example: edit an existing diagram with a component replacement using a reference image
|
| 33 |
+
python -m app.cli --mode image --base-image path/to/base.png \
|
| 34 |
+
--ref-image path/to/transformer_ref.png \
|
| 35 |
+
--instructions "Replace the UNet backbone with a Transformer (DiT); keep layout, font, and colors consistent."
|
| 36 |
+
```
|
| 37 |
+
Artifacts are saved under `artifacts/run_YYYYmmdd_HHMMSS/` with `final.png` as the chosen result.
|
| 38 |
+
|
| 39 |
+
## Gemini 2.5 Flash Image in This Project
|
| 40 |
+
- G1 geometry: `gen_generate.py` calls `GEMINI_IMAGE_MODEL` (Gemini 2.5 Flash Image) to render a clean, geometry-only skeleton quickly.
|
| 41 |
+
- G2 labels: `gen_labels.py` uses `GEMINI_IMAGE_EDIT_MODEL` to overlay text labels onto the G1 skeleton without redrawing everything.
|
| 42 |
+
- Edit loop: `edit.py` performs targeted corrections via the same image model, enabling fast, iterative refinements instead of full regenerations.
|
| 43 |
+
- Why it matters: the model’s speed and editability make multi-round diagram refinement practical while preserving layout quality.
|
| 44 |
+
- Fallback: if no API key is available, the pipeline remains runnable using local placeholders generated by `app/llm/gemini.py`.
|
| 45 |
+
|
| 46 |
+
## Models
|
| 47 |
+
- `GEMINI_MODEL` (default `gemini-2.5-flash`): parsing, planning, prompt generation, and judging.
|
| 48 |
+
- `GEMINI_IMAGE_MODEL` (recommended `gemini-2.5-flash-image` or `gemini-2.5-flash-image-preview`): image generation (G1).
|
| 49 |
+
- `GEMINI_IMAGE_EDIT_MODEL` (recommended `gemini-2.5-flash-image` or `gemini-2.5-flash-image-preview`): image editing (G2, Editor).
|
| 50 |
+
Notes: If `GEMINI_API_KEY` is not set, the pipeline uses offline placeholders to remain runnable. With an API key present, you must set valid image model env vars; errors are raised if image models are unset or calls fail (no automatic local fallback).
|
| 51 |
+
|
| 52 |
+
## Fusion Mode (Text + Image)
|
| 53 |
+
- Accepts a base diagram (`--base-image`) and optional reference images (`--ref-image` repeatable) plus instructions.
|
| 54 |
+
- Uses Gemini 2.5 Flash Image to compose images under textual guidance – ideal for swapping a module (e.g., UNet → Transformer) while preserving style and layout.
|
| 55 |
+
- Outputs multiple fused candidates (`K`) and archives the first as `final.png`.
|
| 56 |
+
|
| 57 |
+
## Structure
|
| 58 |
+
```
|
| 59 |
+
app/
|
| 60 |
+
cli.py # CLI entry (K/T/outdir)
|
| 61 |
+
graph.py # Orchestrator + edit loop
|
| 62 |
+
state.py # AppState + artifacts
|
| 63 |
+
prompts.py # Centralized prompts (parse/plan/G1/G2/judge/edit)
|
| 64 |
+
nodes/
|
| 65 |
+
parser.py, planner.py, prompt_gen.py
|
| 66 |
+
gen_generate.py # G1 skeleton images (no text)
|
| 67 |
+
gen_labels.py # G2 label overlay edits
|
| 68 |
+
judge.py, select.py, edit.py, archive.py
|
| 69 |
+
llm/
|
| 70 |
+
gemini.py # Unified wrapper (API + offline fallback)
|
| 71 |
+
credentials.example.py
|
| 72 |
+
spec/
|
| 73 |
+
vit.txt # Example ViT spec (English)
|
| 74 |
+
artifacts/ # Outputs per run
|
| 75 |
+
```
|
| 76 |
+
|
| 77 |
+
## Tips
|
| 78 |
+
- Concurrency: `NNG_CONCURRENCY=4 python -m app.cli --spec ...`
|
| 79 |
+
- Tuning: Start with `K=4, T=1`; increase `T` for more correction rounds.
|
| 80 |
+
- Debug: image calls write `*.resp.txt`/`*.meta.json` alongside outputs (can be removed later if undesired).
|
NNGen/app.py
ADDED
|
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
# Hugging Face Spaces entrypoint for Gradio
|
| 4 |
+
# Exposes a global `demo` variable that HF will serve.
|
| 5 |
+
|
| 6 |
+
from scripts.gradio_app import app as create_app
|
| 7 |
+
|
| 8 |
+
demo = create_app()
|
| 9 |
+
|
NNGen/app/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
# empty package marker
|
NNGen/app/cli.py
ADDED
|
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import argparse
|
| 4 |
+
import json
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
|
| 7 |
+
from .graph import run_pipeline, run_fusion_pipeline
|
| 8 |
+
from .state import AppState
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main() -> None:
|
| 12 |
+
parser = argparse.ArgumentParser(description="NN Diagram Multi-Agent Pipeline")
|
| 13 |
+
parser.add_argument("--mode", type=str, choices=["text", "image"], default="text", help="'text' (spec→image) or 'image' (text+image fusion/edit)")
|
| 14 |
+
parser.add_argument("--spec", type=str, required=False, help="Path to .txt user prompt or .json spec (text mode)")
|
| 15 |
+
parser.add_argument("--K", type=int, default=4, help="Number of candidates")
|
| 16 |
+
parser.add_argument("--T", type=int, default=1, help="Max edit rounds")
|
| 17 |
+
parser.add_argument("--outdir", type=str, default="", help="Output directory (optional)")
|
| 18 |
+
parser.add_argument("--base-image", type=str, default="", help="Base image to edit (image mode)")
|
| 19 |
+
parser.add_argument("--ref-image", action="append", default=None, help="Additional reference image(s) (repeatable)")
|
| 20 |
+
parser.add_argument("--instructions", type=str, default="", help="Edit/fusion instructions (image mode)")
|
| 21 |
+
args = parser.parse_args()
|
| 22 |
+
|
| 23 |
+
state: AppState = {"K": int(args.K), "T": int(args.T), "outdir": args.outdir or ""}
|
| 24 |
+
|
| 25 |
+
if args.mode == "text":
|
| 26 |
+
if not args.spec:
|
| 27 |
+
raise SystemExit("--spec is required in text mode")
|
| 28 |
+
spec_path = Path(args.spec)
|
| 29 |
+
if not spec_path.exists():
|
| 30 |
+
raise SystemExit(f"Spec file not found: {spec_path}")
|
| 31 |
+
if spec_path.suffix.lower() == ".json":
|
| 32 |
+
state["spec"] = json.loads(spec_path.read_text())
|
| 33 |
+
else:
|
| 34 |
+
state["user_text"] = spec_path.read_text()
|
| 35 |
+
final_state = run_pipeline(state)
|
| 36 |
+
else:
|
| 37 |
+
# image fusion/edit mode
|
| 38 |
+
base_image = args.base_image.strip()
|
| 39 |
+
ref_images = args.ref_image or []
|
| 40 |
+
if not base_image and not ref_images:
|
| 41 |
+
raise SystemExit("image mode requires --base-image and/or at least one --ref-image")
|
| 42 |
+
if base_image:
|
| 43 |
+
if not Path(base_image).exists():
|
| 44 |
+
raise SystemExit(f"Base image not found: {base_image}")
|
| 45 |
+
state["base_image"] = base_image
|
| 46 |
+
valid_refs = [p for p in ref_images if p and Path(p).exists()]
|
| 47 |
+
state["ref_images"] = valid_refs
|
| 48 |
+
state["instructions"] = args.instructions or "Compose and update the figure to reflect the requested component changes while keeping overall style consistent."
|
| 49 |
+
final_state = run_fusion_pipeline(state)
|
| 50 |
+
|
| 51 |
+
print(f"Artifacts saved under: {final_state['outdir']}")
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
if __name__ == "__main__":
|
| 55 |
+
main()
|
NNGen/app/graph.py
ADDED
|
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from datetime import datetime
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
from .state import AppState
|
| 8 |
+
from .nodes import parser, planner, prompt_gen, gen_generate, gen_labels, judge, select, edit, archive
|
| 9 |
+
from .nodes import gen_fusion
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def run_pipeline(state: AppState) -> AppState:
|
| 13 |
+
# ensure outdir
|
| 14 |
+
outdir = Path(state.get("outdir") or _default_outdir())
|
| 15 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 16 |
+
state["outdir"] = str(outdir)
|
| 17 |
+
state["round"] = int(state.get("round", 0))
|
| 18 |
+
|
| 19 |
+
# 1) parse → 2) plan → 3) prompts → 4) generate (skeleton) → 5) generator_2 (labels) → 6) judge → 7) select
|
| 20 |
+
state = parser.run(state)
|
| 21 |
+
state = planner.run(state)
|
| 22 |
+
state = prompt_gen.run(state)
|
| 23 |
+
state = gen_generate.run(state)
|
| 24 |
+
state = gen_labels.run(state)
|
| 25 |
+
state = judge.run(state)
|
| 26 |
+
state = select.run(state)
|
| 27 |
+
|
| 28 |
+
# 8) edit loop (if hard violations or any violations, and round < T)
|
| 29 |
+
T = int(state.get("T", 0))
|
| 30 |
+
while (state.get("hard_violations") or state.get("violations")) and state.get("round", 0) < T:
|
| 31 |
+
state["round"] = int(state.get("round", 0)) + 1
|
| 32 |
+
state = edit.apply_edits(state)
|
| 33 |
+
# re-judge best image
|
| 34 |
+
state = _judge_best_only(state)
|
| 35 |
+
state = select.run(state)
|
| 36 |
+
|
| 37 |
+
# 9) archive
|
| 38 |
+
state = archive.run(state)
|
| 39 |
+
return state
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def run_fusion_pipeline(state: AppState) -> AppState:
|
| 43 |
+
# ensure outdir
|
| 44 |
+
outdir = Path(state.get("outdir") or _default_outdir())
|
| 45 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 46 |
+
state["outdir"] = str(outdir)
|
| 47 |
+
state["round"] = int(state.get("round", 0))
|
| 48 |
+
|
| 49 |
+
# Generate fused candidates from images + text instructions
|
| 50 |
+
state = gen_fusion.run(state)
|
| 51 |
+
|
| 52 |
+
# If we have candidates, select first as best; optionally judge later
|
| 53 |
+
if state.get("images"):
|
| 54 |
+
state["best_image"] = state["images"][0]
|
| 55 |
+
|
| 56 |
+
# Archive results (final.png etc.)
|
| 57 |
+
state = archive.run(state)
|
| 58 |
+
return state
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def _judge_best_only(state: AppState) -> AppState:
|
| 62 |
+
# Only score the current best image again
|
| 63 |
+
from .llm.gemini import call_gemini
|
| 64 |
+
|
| 65 |
+
if not state.get("best_image"):
|
| 66 |
+
return state
|
| 67 |
+
res = call_gemini("judge", image_path=state["best_image"].path, spec=state.get("spec", {}))
|
| 68 |
+
vios = list(res.get("violations", []))
|
| 69 |
+
hard = [v for v in vios if str(v).strip().lower().startswith("hard:")]
|
| 70 |
+
if not hard:
|
| 71 |
+
hard = [v for v in vios if ("labels" in str(v).lower() and "missing" in str(v).lower())]
|
| 72 |
+
state["scores"] = [{
|
| 73 |
+
"image_path": state["best_image"].path,
|
| 74 |
+
"score": float(res.get("score", 0.0)),
|
| 75 |
+
"violations": vios,
|
| 76 |
+
}]
|
| 77 |
+
state["hard_violations"] = hard
|
| 78 |
+
return state
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def _default_outdir() -> str:
|
| 82 |
+
return f"artifacts/run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
|
NNGen/app/llm/credentials.example.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copy this file to credentials.py and fill in your key.
|
| 2 |
+
|
| 3 |
+
GEMINI_API_KEY = "YOUR_GEMINI_API_KEY"
|
| 4 |
+
|
NNGen/app/llm/gemini.py
ADDED
|
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import base64
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import shutil
|
| 8 |
+
from pathlib import Path
|
| 9 |
+
from typing import Any, Dict, List, Optional
|
| 10 |
+
import os as _os
|
| 11 |
+
|
| 12 |
+
try:
|
| 13 |
+
# optional local credentials file
|
| 14 |
+
from . import credentials # type: ignore
|
| 15 |
+
except Exception:
|
| 16 |
+
credentials = None # type: ignore
|
| 17 |
+
|
| 18 |
+
# Load .env if present to populate environment variables
|
| 19 |
+
try:
|
| 20 |
+
from dotenv import load_dotenv # type: ignore
|
| 21 |
+
|
| 22 |
+
load_dotenv() # searches for .env in CWD/parents
|
| 23 |
+
except Exception:
|
| 24 |
+
pass
|
| 25 |
+
|
| 26 |
+
def _get_api_key() -> Optional[str]:
|
| 27 |
+
key = os.getenv("GEMINI_API_KEY")
|
| 28 |
+
if key:
|
| 29 |
+
return key
|
| 30 |
+
if credentials and getattr(credentials, "GEMINI_API_KEY", None):
|
| 31 |
+
return credentials.GEMINI_API_KEY # type: ignore
|
| 32 |
+
# Optional: read from ~/.config/gemini/api_key
|
| 33 |
+
try:
|
| 34 |
+
cfg_path = Path.home() / ".config" / "gemini" / "api_key"
|
| 35 |
+
if cfg_path.exists():
|
| 36 |
+
return cfg_path.read_text().strip()
|
| 37 |
+
except Exception:
|
| 38 |
+
pass
|
| 39 |
+
return None
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def call_gemini(kind: str, **kwargs) -> Dict[str, Any]:
|
| 43 |
+
"""Unified entry for Gemini calls.
|
| 44 |
+
|
| 45 |
+
kind: one of {"parse", "plan", "prompt_generate", "image_generate", "judge", "image_edit", "image_fuse"}
|
| 46 |
+
kwargs: payload for the corresponding action
|
| 47 |
+
|
| 48 |
+
If API key is missing or a call fails, falls back to deterministic local placeholders
|
| 49 |
+
so the pipeline remains runnable offline.
|
| 50 |
+
"""
|
| 51 |
+
api_key = _get_api_key()
|
| 52 |
+
if not api_key:
|
| 53 |
+
# Simplified behavior: if no API key, always use local placeholders
|
| 54 |
+
return _local_placeholder(kind, **kwargs)
|
| 55 |
+
|
| 56 |
+
# With an API key present, call the real service and surface errors directly
|
| 57 |
+
return _real_gemini(kind, api_key=api_key, **kwargs)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _local_placeholder(kind: str, **kwargs) -> Dict[str, Any]:
|
| 61 |
+
# Deterministic pseudo behavior for offline usage
|
| 62 |
+
rng = random.Random(42)
|
| 63 |
+
|
| 64 |
+
if kind == "parse":
|
| 65 |
+
user_text = kwargs.get("user_text", "")
|
| 66 |
+
# Very rough parse: split by arrows/lines → nodes & edges
|
| 67 |
+
lines = [ln.strip() for ln in user_text.splitlines() if ln.strip()]
|
| 68 |
+
nodes = [f"N{i}:{ln[:24]}" for i, ln in enumerate(lines)] or ["Input", "Conv", "FC", "Softmax"]
|
| 69 |
+
edges = [[i, i + 1] for i in range(len(nodes) - 1)]
|
| 70 |
+
spec = {"nodes": nodes, "edges": edges, "constraints": {"arrows": "left_to_right"}}
|
| 71 |
+
return {"spec": spec}
|
| 72 |
+
|
| 73 |
+
if kind == "plan":
|
| 74 |
+
spec = kwargs.get("spec", {})
|
| 75 |
+
spec_text = (
|
| 76 |
+
"Neural Net Diagram\n" +
|
| 77 |
+
f"Nodes: {len(spec.get('nodes', []))}\n" +
|
| 78 |
+
f"Edges: {len(spec.get('edges', []))}\n" +
|
| 79 |
+
f"Constraints: {spec.get('constraints', {})}\n"
|
| 80 |
+
)
|
| 81 |
+
return {"spec_text": spec_text}
|
| 82 |
+
|
| 83 |
+
if kind == "prompt_generate":
|
| 84 |
+
K = int(kwargs.get("K", 3))
|
| 85 |
+
spec_text = kwargs.get("spec_text", "")
|
| 86 |
+
layouts = ["left-right", "top-down", "circular", "grid", "hierarchical"]
|
| 87 |
+
colors = ["blue", "green", "purple", "orange", "teal"]
|
| 88 |
+
prompts = [
|
| 89 |
+
f"Draw NN diagram ({spec_text[:40]}...) layout={layouts[i % len(layouts)]} color={colors[i % len(colors)]} seed={i}"
|
| 90 |
+
for i in range(K)
|
| 91 |
+
]
|
| 92 |
+
return {"prompts": prompts}
|
| 93 |
+
|
| 94 |
+
if kind == "image_generate":
|
| 95 |
+
prompts: List[str] = kwargs.get("prompts", [])
|
| 96 |
+
outdir: str = kwargs.get("outdir", "artifacts")
|
| 97 |
+
paths: List[str] = []
|
| 98 |
+
for i, p in enumerate(prompts):
|
| 99 |
+
pth = Path(outdir) / f"candidate_{i}.png"
|
| 100 |
+
_write_placeholder_diagram(pth, with_labels=False)
|
| 101 |
+
paths.append(str(pth))
|
| 102 |
+
return {"paths": paths}
|
| 103 |
+
|
| 104 |
+
if kind == "judge":
|
| 105 |
+
image_path: str = kwargs.get("image_path")
|
| 106 |
+
# produce a stable pseudo-score based on filename
|
| 107 |
+
base = sum(ord(c) for c in Path(image_path).name) % 100
|
| 108 |
+
score = 0.5 + (base / 200.0)
|
| 109 |
+
# fake violations: if filename has odd index
|
| 110 |
+
violations: List[str] = []
|
| 111 |
+
try:
|
| 112 |
+
idx = int(Path(image_path).stem.split("_")[-1])
|
| 113 |
+
if idx % 2 == 1:
|
| 114 |
+
violations = ["typo: layer name", "arrow: wrong direction"]
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
| 117 |
+
# If still skeleton (no 'labeled_' in name), mark missing labels as HARD
|
| 118 |
+
name = Path(image_path).name.lower()
|
| 119 |
+
# Heuristic for offline mode: consider labeled or edited images as having labels
|
| 120 |
+
if ("labeled_" not in name) and ("edited_" not in name):
|
| 121 |
+
violations = ["HARD: labels: missing"] + violations
|
| 122 |
+
return {"score": score, "violations": violations}
|
| 123 |
+
|
| 124 |
+
if kind == "image_edit":
|
| 125 |
+
image_path: str = kwargs.get("image_path")
|
| 126 |
+
out_path: str = kwargs.get("out_path")
|
| 127 |
+
instructions: str = kwargs.get("instructions", "")
|
| 128 |
+
# Extract labels from instructions if present
|
| 129 |
+
import re
|
| 130 |
+
labels = re.findall(r'\d+\s*:\s*"([^"]+)"', instructions)
|
| 131 |
+
if not labels:
|
| 132 |
+
# fallback: quoted strings
|
| 133 |
+
labels = re.findall(r'"([^"]+)"', instructions)
|
| 134 |
+
# standardize
|
| 135 |
+
labels = [l.strip() for l in labels if l.strip()]
|
| 136 |
+
_write_placeholder_diagram(Path(out_path), with_labels=True, labels=labels)
|
| 137 |
+
return {"path": out_path}
|
| 138 |
+
|
| 139 |
+
raise ValueError(f"Unsupported kind={kind}")
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def _write_1x1_png(path: Path) -> None:
|
| 143 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 144 |
+
# 1x1 black pixel PNG
|
| 145 |
+
png_bytes = base64.b64decode(
|
| 146 |
+
b"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAQAAAC1HAwCAA" \
|
| 147 |
+
b"AAC0lEQVR42mP8/xcAAwMB/ax4u6kAAAAASUVORK5CYII="
|
| 148 |
+
)
|
| 149 |
+
with open(path, "wb") as f:
|
| 150 |
+
f.write(png_bytes)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
def _write_placeholder_diagram(path: Path, *, with_labels: bool, labels: Optional[List[str]] = None) -> None:
|
| 154 |
+
"""Generate a simple skeleton diagram (and optionally labels).
|
| 155 |
+
|
| 156 |
+
If Pillow is available, draw with anti-aliased vectors and real text;
|
| 157 |
+
otherwise fall back to a pure-stdlib bitmap renderer.
|
| 158 |
+
"""
|
| 159 |
+
try:
|
| 160 |
+
from PIL import Image, ImageDraw, ImageFont # type: ignore
|
| 161 |
+
_write_placeholder_diagram_pil(path, with_labels=with_labels, labels=labels)
|
| 162 |
+
return
|
| 163 |
+
except Exception:
|
| 164 |
+
pass
|
| 165 |
+
|
| 166 |
+
# Fallback: stdlib bitmap renderer
|
| 167 |
+
# White background, 3px black strokes, arrows, dashed group
|
| 168 |
+
import zlib, struct, binascii
|
| 169 |
+
|
| 170 |
+
W, H = 1200, 420
|
| 171 |
+
# initialize white canvas
|
| 172 |
+
pixels: List[List[List[int]]] = [[[255, 255, 255] for _ in range(W)] for _ in range(H)]
|
| 173 |
+
|
| 174 |
+
def set_px(x: int, y: int, c: tuple[int, int, int]):
|
| 175 |
+
if 0 <= x < W and 0 <= y < H:
|
| 176 |
+
pixels[y][x][0] = c[0]
|
| 177 |
+
pixels[y][x][1] = c[1]
|
| 178 |
+
pixels[y][x][2] = c[2]
|
| 179 |
+
|
| 180 |
+
def draw_line(x0: int, y0: int, x1: int, y1: int, c=(0, 0, 0), t: int = 3):
|
| 181 |
+
dx = abs(x1 - x0)
|
| 182 |
+
sx = 1 if x0 < x1 else -1
|
| 183 |
+
dy = -abs(y1 - y0)
|
| 184 |
+
sy = 1 if y0 < y1 else -1
|
| 185 |
+
err = dx + dy
|
| 186 |
+
while True:
|
| 187 |
+
for ox in range(-t // 2, t // 2 + 1):
|
| 188 |
+
for oy in range(-t // 2, t // 2 + 1):
|
| 189 |
+
set_px(x0 + ox, y0 + oy, c)
|
| 190 |
+
if x0 == x1 and y0 == y1:
|
| 191 |
+
break
|
| 192 |
+
e2 = 2 * err
|
| 193 |
+
if e2 >= dy:
|
| 194 |
+
err += dy
|
| 195 |
+
x0 += sx
|
| 196 |
+
if e2 <= dx:
|
| 197 |
+
err += dx
|
| 198 |
+
y0 += sy
|
| 199 |
+
|
| 200 |
+
def draw_rect(x: int, y: int, w: int, h: int, c=(0, 0, 0), t: int = 3, dashed: bool = False):
|
| 201 |
+
def dash_points(x0, y0, x1, y1):
|
| 202 |
+
# Bresenham plus on/off dashes
|
| 203 |
+
dx = abs(x1 - x0)
|
| 204 |
+
sx = 1 if x0 < x1 else -1
|
| 205 |
+
dy = -abs(y1 - y0)
|
| 206 |
+
sy = 1 if y0 < y1 else -1
|
| 207 |
+
err = dx + dy
|
| 208 |
+
on = True
|
| 209 |
+
step = 0
|
| 210 |
+
period = 10
|
| 211 |
+
points = []
|
| 212 |
+
while True:
|
| 213 |
+
if on:
|
| 214 |
+
points.append((x0, y0))
|
| 215 |
+
step = (step + 1) % period
|
| 216 |
+
if step == 0:
|
| 217 |
+
on = not on
|
| 218 |
+
if x0 == x1 and y0 == y1:
|
| 219 |
+
break
|
| 220 |
+
e2 = 2 * err
|
| 221 |
+
if e2 >= dy:
|
| 222 |
+
err += dy
|
| 223 |
+
x0 += sx
|
| 224 |
+
if e2 <= dx:
|
| 225 |
+
err += dx
|
| 226 |
+
y0 += sy
|
| 227 |
+
return points
|
| 228 |
+
|
| 229 |
+
if dashed:
|
| 230 |
+
for (x0, y0, x1, y1) in [
|
| 231 |
+
(x, y, x + w, y),
|
| 232 |
+
(x + w, y, x + w, y + h),
|
| 233 |
+
(x + w, y + h, x, y + h),
|
| 234 |
+
(x, y + h, x, y),
|
| 235 |
+
]:
|
| 236 |
+
for px, py in dash_points(x0, y0, x1, y1):
|
| 237 |
+
for ox in range(-t // 2, t // 2 + 1):
|
| 238 |
+
for oy in range(-t // 2, t // 2 + 1):
|
| 239 |
+
set_px(px + ox, py + oy, c)
|
| 240 |
+
else:
|
| 241 |
+
draw_line(x, y, x + w, y, c, t)
|
| 242 |
+
draw_line(x + w, y, x + w, y + h, c, t)
|
| 243 |
+
draw_line(x + w, y + h, x, y + h, c, t)
|
| 244 |
+
draw_line(x, y + h, x, y, c, t)
|
| 245 |
+
|
| 246 |
+
def draw_arrow(x0: int, y0: int, x1: int, y1: int, c=(0, 0, 0)):
|
| 247 |
+
draw_line(x0, y0, x1, y1, c, 3)
|
| 248 |
+
# simple arrow head
|
| 249 |
+
vx, vy = x1 - x0, y1 - y0
|
| 250 |
+
length = max((vx * vx + vy * vy) ** 0.5, 1.0)
|
| 251 |
+
ux, uy = vx / length, vy / length
|
| 252 |
+
# perpendicular
|
| 253 |
+
px, py = -uy, ux
|
| 254 |
+
ah = 10 # head length
|
| 255 |
+
aw = 6 # head width
|
| 256 |
+
hx, hy = int(x1 - ux * ah), int(y1 - uy * ah)
|
| 257 |
+
lx, ly = int(hx + px * aw), int(hy + py * aw)
|
| 258 |
+
rx, ry = int(hx - px * aw), int(hy - py * aw)
|
| 259 |
+
draw_line(x1, y1, lx, ly, c, 2)
|
| 260 |
+
draw_line(x1, y1, rx, ry, c, 2)
|
| 261 |
+
|
| 262 |
+
# layout
|
| 263 |
+
margin_x, margin_y = 60, 140
|
| 264 |
+
box_w, box_h = 220, 90
|
| 265 |
+
gap = 90
|
| 266 |
+
y = margin_y
|
| 267 |
+
xs = [margin_x + i * (box_w + gap) for i in range(4)]
|
| 268 |
+
|
| 269 |
+
# dashed group around middle two blocks
|
| 270 |
+
group_x = xs[1] - 20
|
| 271 |
+
group_y = y - 20
|
| 272 |
+
group_w = box_w * 2 + gap + 40
|
| 273 |
+
group_h = box_h + 40
|
| 274 |
+
draw_rect(group_x, group_y, group_w, group_h, c=(140, 140, 140), t=2, dashed=True)
|
| 275 |
+
|
| 276 |
+
# blocks
|
| 277 |
+
for idx, x in enumerate(xs):
|
| 278 |
+
draw_rect(x, y, box_w, box_h, c=(0, 0, 0), t=2)
|
| 279 |
+
if with_labels:
|
| 280 |
+
# draw simple 5x7 bitmap text using ASCII-only; non-ASCII removed
|
| 281 |
+
label = None
|
| 282 |
+
if labels and idx < len(labels):
|
| 283 |
+
label = labels[idx]
|
| 284 |
+
_draw_label_text(pixels, x, y, box_w, box_h, label)
|
| 285 |
+
|
| 286 |
+
# arrows between blocks (center-right to center-left)
|
| 287 |
+
cy = y + box_h // 2
|
| 288 |
+
for i in range(3):
|
| 289 |
+
x0 = xs[i] + box_w
|
| 290 |
+
x1 = xs[i + 1]
|
| 291 |
+
draw_arrow(x0 + 4, cy, x1 - 4, cy, c=(0, 0, 0))
|
| 292 |
+
|
| 293 |
+
# write PNG
|
| 294 |
+
def png_chunk(tag: bytes, data: bytes) -> bytes:
|
| 295 |
+
return struct.pack(">I", len(data)) + tag + data + struct.pack(">I", binascii.crc32(tag + data) & 0xFFFFFFFF)
|
| 296 |
+
|
| 297 |
+
raw = bytearray()
|
| 298 |
+
for row in pixels:
|
| 299 |
+
raw.append(0) # filter type 0
|
| 300 |
+
for r, g, b in row:
|
| 301 |
+
raw.extend((r & 255, g & 255, b & 255))
|
| 302 |
+
comp = zlib.compress(bytes(raw), level=9)
|
| 303 |
+
sig = b"\x89PNG\r\n\x1a\n"
|
| 304 |
+
ihdr = struct.pack(">IIBBBBB", W, H, 8, 2, 0, 0, 0) # 8-bit, truecolor RGB
|
| 305 |
+
png = sig + png_chunk(b"IHDR", ihdr) + png_chunk(b"IDAT", comp) + png_chunk(b"IEND", b"")
|
| 306 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 307 |
+
with open(path, "wb") as f:
|
| 308 |
+
f.write(png)
|
| 309 |
+
|
| 310 |
+
# end _write_placeholder_diagram
|
| 311 |
+
|
| 312 |
+
|
| 313 |
+
def _write_placeholder_diagram_pil(path: Path, *, with_labels: bool, labels: Optional[List[str]] = None) -> None:
|
| 314 |
+
from PIL import Image, ImageDraw, ImageFont # type: ignore
|
| 315 |
+
|
| 316 |
+
W, H = 1280, 480
|
| 317 |
+
img = Image.new("RGB", (W, H), (255, 255, 255))
|
| 318 |
+
draw = ImageDraw.Draw(img)
|
| 319 |
+
|
| 320 |
+
margin_x, margin_y = 80, 160
|
| 321 |
+
box_w, box_h = 240, 110
|
| 322 |
+
gap = 110
|
| 323 |
+
y = margin_y
|
| 324 |
+
xs = [margin_x + i * (box_w + gap) for i in range(4)]
|
| 325 |
+
|
| 326 |
+
# dashed group
|
| 327 |
+
group_x = xs[1] - 24
|
| 328 |
+
group_y = y - 24
|
| 329 |
+
group_w = box_w * 2 + gap + 48
|
| 330 |
+
group_h = box_h + 48
|
| 331 |
+
draw.rounded_rectangle([group_x, group_y, group_x + group_w, group_y + group_h], radius=14, outline=(150, 150, 150), width=2)
|
| 332 |
+
# manual dash overlay
|
| 333 |
+
# top and bottom dashed
|
| 334 |
+
def dashed_line(p0, p1, dash=12, gaplen=8, width=2, fill=(150, 150, 150)):
|
| 335 |
+
from math import hypot
|
| 336 |
+
(x0, y0), (x1, y1) = p0, p1
|
| 337 |
+
dx, dy = x1 - x0, y1 - y0
|
| 338 |
+
length = (dx * dx + dy * dy) ** 0.5
|
| 339 |
+
if length == 0:
|
| 340 |
+
return
|
| 341 |
+
ux, uy = dx / length, dy / length
|
| 342 |
+
dist = 0.0
|
| 343 |
+
on = True
|
| 344 |
+
while dist < length:
|
| 345 |
+
l = dash if on else gaplen
|
| 346 |
+
nx0 = x0 + ux * dist
|
| 347 |
+
ny0 = y0 + uy * dist
|
| 348 |
+
nx1 = x0 + ux * min(length, dist + l)
|
| 349 |
+
ny1 = y0 + uy * min(length, dist + l)
|
| 350 |
+
if on:
|
| 351 |
+
draw.line([(nx0, ny0), (nx1, ny1)], fill=fill, width=width)
|
| 352 |
+
dist += l
|
| 353 |
+
on = not on
|
| 354 |
+
|
| 355 |
+
dashed_line((group_x, group_y), (group_x + group_w, group_y), width=2)
|
| 356 |
+
dashed_line((group_x, group_y + group_h), (group_x + group_w, group_y + group_h), width=2)
|
| 357 |
+
dashed_line((group_x, group_y), (group_x, group_y + group_h), width=2)
|
| 358 |
+
dashed_line((group_x + group_w, group_y), (group_x + group_w, group_y + group_h), width=2)
|
| 359 |
+
|
| 360 |
+
# blocks
|
| 361 |
+
for x in xs:
|
| 362 |
+
draw.rounded_rectangle([x, y, x + box_w, y + box_h], radius=12, outline=(0, 0, 0), width=3)
|
| 363 |
+
|
| 364 |
+
# arrows
|
| 365 |
+
cy = y + box_h // 2
|
| 366 |
+
for i in range(3):
|
| 367 |
+
x0 = xs[i] + box_w + 6
|
| 368 |
+
x1 = xs[i + 1] - 6
|
| 369 |
+
draw.line([(x0, cy), (x1, cy)], fill=(0, 0, 0), width=3)
|
| 370 |
+
# head
|
| 371 |
+
ah, aw = 14, 8
|
| 372 |
+
draw.line([(x1, cy), (x1 - ah, cy - aw)], fill=(0, 0, 0), width=3)
|
| 373 |
+
draw.line([(x1, cy), (x1 - ah, cy + aw)], fill=(0, 0, 0), width=3)
|
| 374 |
+
|
| 375 |
+
# labels
|
| 376 |
+
if with_labels:
|
| 377 |
+
# pick font
|
| 378 |
+
try:
|
| 379 |
+
font = ImageFont.truetype("DejaVuSans.ttf", 22)
|
| 380 |
+
except Exception:
|
| 381 |
+
font = ImageFont.load_default()
|
| 382 |
+
fallback = ["PATCH EMBED", "+ CLS + POSENC", "ENCODER xL", "CLASS HEAD"]
|
| 383 |
+
for i, x in enumerate(xs):
|
| 384 |
+
text = None
|
| 385 |
+
if labels and i < len(labels):
|
| 386 |
+
text = labels[i]
|
| 387 |
+
if not text or not isinstance(text, str) or not text.strip():
|
| 388 |
+
text = fallback[i % len(fallback)]
|
| 389 |
+
# center text
|
| 390 |
+
tw, th = draw.textlength(text, font=font), font.size + 6
|
| 391 |
+
tx = x + (box_w - tw) / 2
|
| 392 |
+
ty = y + (box_h - th) / 2
|
| 393 |
+
draw.text((tx, ty), text, fill=(40, 40, 40), font=font, align="center")
|
| 394 |
+
|
| 395 |
+
path.parent.mkdir(parents=True, exist_ok=True)
|
| 396 |
+
img.save(path)
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
def _draw_label_text(pixels: List[List[List[int]]], x: int, y: int, w: int, h: int, label: Optional[str]) -> None:
|
| 400 |
+
# Render up to two lines of ASCII uppercase text inside a box using a 5x7 bitmap font
|
| 401 |
+
import re
|
| 402 |
+
inner_x = x + 10
|
| 403 |
+
inner_y = y + 10
|
| 404 |
+
inner_w = w - 20
|
| 405 |
+
inner_h = h - 20
|
| 406 |
+
if inner_w <= 0 or inner_h <= 0:
|
| 407 |
+
return
|
| 408 |
+
if not label:
|
| 409 |
+
# fallback: generic placeholders
|
| 410 |
+
label = "BLOCK"
|
| 411 |
+
# Normalize to ASCII upper
|
| 412 |
+
text = label.upper()
|
| 413 |
+
# allow only A-Z, 0-9, space and '-'
|
| 414 |
+
text = re.sub(r"[^A-Z0-9 \-]", " ", text)
|
| 415 |
+
if not text.strip():
|
| 416 |
+
text = "BLOCK"
|
| 417 |
+
# simple wrap by width
|
| 418 |
+
scale = 3 # enlarge 5x7 glyphs for readability
|
| 419 |
+
char_w = 6 * scale # 5px glyph +1px spacing
|
| 420 |
+
max_chars = max(1, inner_w // char_w)
|
| 421 |
+
words = text.split()
|
| 422 |
+
lines: List[str] = []
|
| 423 |
+
line = ""
|
| 424 |
+
for wtok in words:
|
| 425 |
+
token = wtok
|
| 426 |
+
if line:
|
| 427 |
+
candidate = line + " " + token
|
| 428 |
+
else:
|
| 429 |
+
candidate = token
|
| 430 |
+
if len(candidate) <= max_chars:
|
| 431 |
+
line = candidate
|
| 432 |
+
else:
|
| 433 |
+
if line:
|
| 434 |
+
lines.append(line)
|
| 435 |
+
line = token
|
| 436 |
+
else:
|
| 437 |
+
# force cut
|
| 438 |
+
lines.append(token[:max_chars])
|
| 439 |
+
line = token[max_chars:]
|
| 440 |
+
if line:
|
| 441 |
+
lines.append(line)
|
| 442 |
+
# limit to 2 lines for clarity
|
| 443 |
+
lines = lines[:2]
|
| 444 |
+
# vertical centering
|
| 445 |
+
total_h = len(lines) * ((7 * scale) + scale) - scale
|
| 446 |
+
start_y = inner_y + (inner_h - total_h) // 2
|
| 447 |
+
|
| 448 |
+
for i, ln in enumerate(lines):
|
| 449 |
+
# center each line horizontally
|
| 450 |
+
line_w = len(ln) * (6 * scale)
|
| 451 |
+
start_x = inner_x + max(0, (inner_w - line_w) // 2)
|
| 452 |
+
draw_text_5x7(pixels, start_x, start_y + i * ((7 * scale) + scale), ln, color=(40, 40, 40), scale=scale)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
_FONT_5x7: Dict[str, List[str]] = {
|
| 456 |
+
'A': ["01110","10001","10001","11111","10001","10001","10001"],
|
| 457 |
+
'B': ["11110","10001","11110","10001","10001","10001","11110"],
|
| 458 |
+
'C': ["01111","10000","10000","10000","10000","10000","01111"],
|
| 459 |
+
'D': ["11110","10001","10001","10001","10001","10001","11110"],
|
| 460 |
+
'E': ["11111","10000","11110","10000","10000","10000","11111"],
|
| 461 |
+
'F': ["11111","10000","11110","10000","10000","10000","10000"],
|
| 462 |
+
'G': ["01110","10000","10000","10111","10001","10001","01110"],
|
| 463 |
+
'H': ["10001","10001","11111","10001","10001","10001","10001"],
|
| 464 |
+
'I': ["11111","00100","00100","00100","00100","00100","11111"],
|
| 465 |
+
'J': ["00111","00010","00010","00010","10010","10010","01100"],
|
| 466 |
+
'K': ["10001","10010","10100","11000","10100","10010","10001"],
|
| 467 |
+
'L': ["10000","10000","10000","10000","10000","10000","11111"],
|
| 468 |
+
'M': ["10001","11011","10101","10101","10001","10001","10001"],
|
| 469 |
+
'N': ["10001","11001","10101","10011","10001","10001","10001"],
|
| 470 |
+
'O': ["01110","10001","10001","10001","10001","10001","01110"],
|
| 471 |
+
'P': ["11110","10001","10001","11110","10000","10000","10000"],
|
| 472 |
+
'Q': ["01110","10001","10001","10001","10101","10010","01101"],
|
| 473 |
+
'R': ["11110","10001","10001","11110","10100","10010","10001"],
|
| 474 |
+
'S': ["01111","10000","10000","01110","00001","00001","11110"],
|
| 475 |
+
'T': ["11111","00100","00100","00100","00100","00100","00100"],
|
| 476 |
+
'U': ["10001","10001","10001","10001","10001","10001","01110"],
|
| 477 |
+
'V': ["10001","10001","10001","10001","01010","01010","00100"],
|
| 478 |
+
'W': ["10001","10001","10001","10101","10101","11011","10001"],
|
| 479 |
+
'X': ["10001","01010","00100","00100","01010","10001","10001"],
|
| 480 |
+
'Y': ["10001","01010","00100","00100","00100","00100","00100"],
|
| 481 |
+
'Z': ["11111","00001","00010","00100","01000","10000","11111"],
|
| 482 |
+
'0': ["01110","10001","10011","10101","11001","10001","01110"],
|
| 483 |
+
'1': ["00100","01100","00100","00100","00100","00100","01110"],
|
| 484 |
+
'2': ["01110","10001","00001","00010","00100","01000","11111"],
|
| 485 |
+
'3': ["11110","00001","00001","00110","00001","00001","11110"],
|
| 486 |
+
'4': ["00010","00110","01010","10010","11111","00010","00010"],
|
| 487 |
+
'5': ["11111","10000","11110","00001","00001","10001","01110"],
|
| 488 |
+
'6': ["00110","01000","10000","11110","10001","10001","01110"],
|
| 489 |
+
'7': ["11111","00001","00010","00100","01000","01000","01000"],
|
| 490 |
+
'8': ["01110","10001","10001","01110","10001","10001","01110"],
|
| 491 |
+
'9': ["01110","10001","10001","01111","00001","00010","01100"],
|
| 492 |
+
'-': ["00000","00000","00000","11111","00000","00000","00000"],
|
| 493 |
+
' ': ["00000","00000","00000","00000","00000","00000","00000"],
|
| 494 |
+
}
|
| 495 |
+
|
| 496 |
+
|
| 497 |
+
def draw_text_5x7(pixels: List[List[List[int]]], x: int, y: int, text: str, color=(60, 60, 60), scale: int = 1) -> None:
|
| 498 |
+
max_y = len(pixels) - 1
|
| 499 |
+
max_x = len(pixels[0]) - 1 if pixels else -1
|
| 500 |
+
|
| 501 |
+
def set_px(px: int, py: int):
|
| 502 |
+
if 0 <= px <= max_x and 0 <= py <= max_y:
|
| 503 |
+
pixels[py][px][0] = color[0]
|
| 504 |
+
pixels[py][px][1] = color[1]
|
| 505 |
+
pixels[py][px][2] = color[2]
|
| 506 |
+
|
| 507 |
+
cx = x
|
| 508 |
+
for ch in text:
|
| 509 |
+
glyph = _FONT_5x7.get(ch, _FONT_5x7[' '])
|
| 510 |
+
for gy, row in enumerate(glyph):
|
| 511 |
+
for gx, bit in enumerate(row):
|
| 512 |
+
if bit == '1':
|
| 513 |
+
# draw scaled pixel
|
| 514 |
+
for sy in range(scale):
|
| 515 |
+
for sx in range(scale):
|
| 516 |
+
set_px(cx + gx * scale + sx, y + gy * scale + sy)
|
| 517 |
+
cx += 6 * scale # 5px width + spacing
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
# ------------------------- Real Google GenAI calls -------------------------
|
| 521 |
+
|
| 522 |
+
def _real_gemini(kind: str, *, api_key: str, **kwargs) -> Dict[str, Any]:
|
| 523 |
+
import google.generativeai as genai
|
| 524 |
+
|
| 525 |
+
genai.configure(api_key=api_key)
|
| 526 |
+
|
| 527 |
+
# Model names are configurable via env with safe defaults.
|
| 528 |
+
# You can set GEMINI_MODEL to your provisioned model, e.g. "gemini-2.0-flash-exp" or "gemini-1.5-flash".
|
| 529 |
+
text_model_name = os.getenv("GEMINI_MODEL", "gemini-1.5-flash")
|
| 530 |
+
image_model_name = os.getenv("GEMINI_IMAGE_MODEL", "") # e.g. "imagen-3.0-generate"
|
| 531 |
+
image_edit_model_name = os.getenv("GEMINI_IMAGE_EDIT_MODEL", "") # e.g. "imagen-3.0-edit"
|
| 532 |
+
|
| 533 |
+
if kind == "parse":
|
| 534 |
+
from .. import prompts as _p
|
| 535 |
+
prompt = _p.build_parse_prompt()
|
| 536 |
+
user_text = kwargs.get("user_text", "")
|
| 537 |
+
model = genai.GenerativeModel(text_model_name)
|
| 538 |
+
resp = model.generate_content([prompt, user_text])
|
| 539 |
+
content = _first_text(resp)
|
| 540 |
+
data = _robust_json(content)
|
| 541 |
+
if not isinstance(data, dict):
|
| 542 |
+
raise ValueError("parse: model did not return JSON dict")
|
| 543 |
+
return {"spec": data}
|
| 544 |
+
|
| 545 |
+
if kind == "plan":
|
| 546 |
+
spec = kwargs.get("spec", {})
|
| 547 |
+
from .. import prompts as _p
|
| 548 |
+
prompt = _p.build_plan_prompt()
|
| 549 |
+
model = genai.GenerativeModel(text_model_name)
|
| 550 |
+
resp = model.generate_content([prompt, json.dumps(spec, ensure_ascii=False)])
|
| 551 |
+
spec_text = _first_text(resp)
|
| 552 |
+
return {"spec_text": spec_text.strip()}
|
| 553 |
+
|
| 554 |
+
if kind == "prompt_generate":
|
| 555 |
+
K = int(kwargs.get("K", 3))
|
| 556 |
+
spec_text = kwargs.get("spec_text", "")
|
| 557 |
+
from .. import prompts as _p
|
| 558 |
+
prompt = _p.build_promptgen_prompt(K, spec_text)
|
| 559 |
+
model = genai.GenerativeModel(text_model_name)
|
| 560 |
+
resp = model.generate_content(prompt)
|
| 561 |
+
content = _first_text(resp)
|
| 562 |
+
arr = _robust_json(content)
|
| 563 |
+
if not isinstance(arr, list):
|
| 564 |
+
# fallback: split lines
|
| 565 |
+
arr = [ln.strip("- ") for ln in content.splitlines() if ln.strip()][:K]
|
| 566 |
+
return {"prompts": arr[:K]}
|
| 567 |
+
|
| 568 |
+
if kind == "image_generate":
|
| 569 |
+
prompts: List[str] = kwargs.get("prompts", [])
|
| 570 |
+
outdir: str = kwargs.get("outdir", "artifacts")
|
| 571 |
+
if not image_model_name:
|
| 572 |
+
raise ValueError("GEMINI_IMAGE_MODEL is not set. Please set it to a valid image model (e.g., 'gemini-2.5-flash-image' or 'gemini-2.5-flash-image-preview').")
|
| 573 |
+
try:
|
| 574 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 575 |
+
|
| 576 |
+
Path(outdir).mkdir(parents=True, exist_ok=True)
|
| 577 |
+
max_workers = max(1, min(len(prompts), int(_os.getenv("NNG_CONCURRENCY", "4"))))
|
| 578 |
+
|
| 579 |
+
def _gen_one(i: int, p: str) -> str:
|
| 580 |
+
# new model per thread to avoid cross-thread state issues
|
| 581 |
+
mdl = genai.GenerativeModel(model_name=image_model_name)
|
| 582 |
+
resp = mdl.generate_content(p, request_options={"timeout": 180})
|
| 583 |
+
try:
|
| 584 |
+
(Path(outdir) / f"candidate_{i}.resp.txt").write_text(str(resp))
|
| 585 |
+
except Exception:
|
| 586 |
+
pass
|
| 587 |
+
img_bytes, mime = _first_image_bytes(resp)
|
| 588 |
+
if not img_bytes:
|
| 589 |
+
raise ValueError("image model did not return image bytes; see *.resp.txt")
|
| 590 |
+
ext = ".png" if mime == "image/png" else ".jpg"
|
| 591 |
+
pth = Path(outdir) / f"candidate_{i}{ext}"
|
| 592 |
+
with open(pth, "wb") as f:
|
| 593 |
+
f.write(img_bytes)
|
| 594 |
+
with open(str(pth) + ".meta.json", "w", encoding="utf-8") as mf:
|
| 595 |
+
mf.write(json.dumps({"source": "gemini", "mime": mime, "bytes": len(img_bytes)}, ensure_ascii=False))
|
| 596 |
+
return str(pth)
|
| 597 |
+
|
| 598 |
+
futures = []
|
| 599 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 600 |
+
for i, p in enumerate(prompts):
|
| 601 |
+
futures.append(ex.submit(_gen_one, i, p))
|
| 602 |
+
# preserve order by index
|
| 603 |
+
results = [None] * len(prompts)
|
| 604 |
+
for fut in as_completed(futures):
|
| 605 |
+
# find index by result path name
|
| 606 |
+
path = fut.result()
|
| 607 |
+
stem = Path(path).stem
|
| 608 |
+
try:
|
| 609 |
+
idx = int(stem.split("_")[-1])
|
| 610 |
+
except Exception:
|
| 611 |
+
idx = 0
|
| 612 |
+
results[idx] = path
|
| 613 |
+
# fill any missing in order fallback
|
| 614 |
+
paths: List[str] = [r or "" for r in results]
|
| 615 |
+
return {"paths": paths}
|
| 616 |
+
except Exception:
|
| 617 |
+
raise
|
| 618 |
+
|
| 619 |
+
if kind == "judge":
|
| 620 |
+
image_path: str = kwargs.get("image_path")
|
| 621 |
+
spec = kwargs.get("spec", {})
|
| 622 |
+
model = genai.GenerativeModel(text_model_name)
|
| 623 |
+
from .. import prompts as _p
|
| 624 |
+
judge_prompt = _p.build_judge_prompt()
|
| 625 |
+
image_part = _image_part_from_path(image_path)
|
| 626 |
+
resp = model.generate_content([
|
| 627 |
+
{"text": judge_prompt},
|
| 628 |
+
{"text": json.dumps(spec, ensure_ascii=False)},
|
| 629 |
+
image_part,
|
| 630 |
+
])
|
| 631 |
+
content = _first_text(resp)
|
| 632 |
+
data = _robust_json(content)
|
| 633 |
+
if not isinstance(data, dict):
|
| 634 |
+
raise ValueError("judge: non-JSON")
|
| 635 |
+
score = float(max(0.0, min(1.0, data.get("score", 0.0))))
|
| 636 |
+
violations = list(data.get("violations", []))
|
| 637 |
+
return {"score": score, "violations": violations}
|
| 638 |
+
|
| 639 |
+
if kind == "image_edit":
|
| 640 |
+
image_path: str = kwargs.get("image_path")
|
| 641 |
+
out_path: str = kwargs.get("out_path")
|
| 642 |
+
instructions: str = kwargs.get("instructions", "")
|
| 643 |
+
ref_images: List[str] = list(kwargs.get("ref_images", []) or [])
|
| 644 |
+
if not image_edit_model_name:
|
| 645 |
+
raise ValueError("GEMINI_IMAGE_EDIT_MODEL is not set. Please set it to a valid image edit model (e.g., 'gemini-2.5-flash-image' or 'gemini-2.5-flash-image-preview').")
|
| 646 |
+
try:
|
| 647 |
+
model = genai.GenerativeModel(model_name=image_edit_model_name)
|
| 648 |
+
base_img = _image_part_from_path(image_path)
|
| 649 |
+
from .. import prompts as _p
|
| 650 |
+
parts = [{"text": _p.build_image_edit_prompt(instructions)}, base_img]
|
| 651 |
+
for rp in ref_images:
|
| 652 |
+
try:
|
| 653 |
+
parts.append(_image_part_from_path(rp))
|
| 654 |
+
except Exception:
|
| 655 |
+
continue
|
| 656 |
+
resp = model.generate_content(parts, request_options={"timeout": 120})
|
| 657 |
+
try:
|
| 658 |
+
out_p = Path(out_path)
|
| 659 |
+
out_p.parent.mkdir(parents=True, exist_ok=True)
|
| 660 |
+
(out_p.parent / (out_p.stem + ".resp.txt")).write_text(str(resp))
|
| 661 |
+
except Exception:
|
| 662 |
+
pass
|
| 663 |
+
img_bytes, mime = _first_image_bytes(resp)
|
| 664 |
+
if not img_bytes:
|
| 665 |
+
raise ValueError("image edit returned no image; see *.resp.txt for raw response")
|
| 666 |
+
ext = ".png" if mime == "image/png" else ".jpg"
|
| 667 |
+
out_p = Path(out_path)
|
| 668 |
+
out_p.parent.mkdir(parents=True, exist_ok=True)
|
| 669 |
+
with open(out_p, "wb") as f:
|
| 670 |
+
f.write(img_bytes)
|
| 671 |
+
with open(str(out_p) + ".meta.json", "w", encoding="utf-8") as mf:
|
| 672 |
+
mf.write(json.dumps({"source": "gemini", "mime": mime, "bytes": len(img_bytes)}, ensure_ascii=False))
|
| 673 |
+
return {"path": str(out_p)}
|
| 674 |
+
except Exception as e:
|
| 675 |
+
# surface error rather than fallback, per user's requirement to avoid local rendering
|
| 676 |
+
raise
|
| 677 |
+
|
| 678 |
+
if kind == "image_fuse":
|
| 679 |
+
# Create a new image by composing multiple reference images under textual instructions
|
| 680 |
+
out_path: str = kwargs.get("out_path")
|
| 681 |
+
instructions: str = kwargs.get("instructions", "")
|
| 682 |
+
ref_images: List[str] = list(kwargs.get("ref_images", []) or [])
|
| 683 |
+
if not image_model_name:
|
| 684 |
+
raise ValueError("GEMINI_IMAGE_MODEL is not set. Please set it to a valid image model (e.g., 'gemini-2.5-flash-image' or 'gemini-2.5-flash-image-preview').")
|
| 685 |
+
try:
|
| 686 |
+
model = genai.GenerativeModel(model_name=image_model_name)
|
| 687 |
+
from .. import prompts as _p
|
| 688 |
+
parts = [{"text": _p.build_image_fusion_prompt(instructions)}]
|
| 689 |
+
for rp in ref_images:
|
| 690 |
+
try:
|
| 691 |
+
parts.append(_image_part_from_path(rp))
|
| 692 |
+
except Exception:
|
| 693 |
+
continue
|
| 694 |
+
resp = model.generate_content(parts, request_options={"timeout": 120})
|
| 695 |
+
try:
|
| 696 |
+
out_p = Path(out_path)
|
| 697 |
+
out_p.parent.mkdir(parents=True, exist_ok=True)
|
| 698 |
+
(out_p.parent / (out_p.stem + ".resp.txt")).write_text(str(resp))
|
| 699 |
+
except Exception:
|
| 700 |
+
pass
|
| 701 |
+
img_bytes, mime = _first_image_bytes(resp)
|
| 702 |
+
if not img_bytes:
|
| 703 |
+
raise ValueError("image fuse returned no image; see *.resp.txt for raw response")
|
| 704 |
+
out_p = Path(out_path)
|
| 705 |
+
out_p.parent.mkdir(parents=True, exist_ok=True)
|
| 706 |
+
with open(out_p, "wb") as f:
|
| 707 |
+
f.write(img_bytes)
|
| 708 |
+
with open(str(out_p) + ".meta.json", "w", encoding="utf-8") as mf:
|
| 709 |
+
mf.write(json.dumps({"source": "gemini", "mime": mime, "bytes": len(img_bytes)}, ensure_ascii=False))
|
| 710 |
+
return {"path": str(out_p)}
|
| 711 |
+
except Exception:
|
| 712 |
+
raise
|
| 713 |
+
|
| 714 |
+
raise ValueError(f"Unsupported kind={kind}")
|
| 715 |
+
|
| 716 |
+
|
| 717 |
+
def _first_text(resp: Any) -> str:
|
| 718 |
+
try:
|
| 719 |
+
if hasattr(resp, "text"):
|
| 720 |
+
return resp.text
|
| 721 |
+
# Some SDK versions: candidates[0].content.parts[0].text
|
| 722 |
+
cands = getattr(resp, "candidates", [])
|
| 723 |
+
if cands:
|
| 724 |
+
parts = getattr(cands[0], "content", None)
|
| 725 |
+
if parts and getattr(parts, "parts", None):
|
| 726 |
+
for part in parts.parts:
|
| 727 |
+
if getattr(part, "text", None):
|
| 728 |
+
return part.text
|
| 729 |
+
return str(resp)
|
| 730 |
+
except Exception:
|
| 731 |
+
return str(resp)
|
| 732 |
+
|
| 733 |
+
|
| 734 |
+
def _first_image_bytes(resp: Any) -> tuple[bytes | None, str]:
|
| 735 |
+
# Try to walk through content parts and return first inline image bytes
|
| 736 |
+
try:
|
| 737 |
+
# Newer SDK: resp.candidates[].content.parts[].inline_data
|
| 738 |
+
cands = getattr(resp, "candidates", [])
|
| 739 |
+
for c in cands or []:
|
| 740 |
+
content = getattr(c, "content", None)
|
| 741 |
+
parts = getattr(content, "parts", None) if content else None
|
| 742 |
+
for part in parts or []:
|
| 743 |
+
inline = getattr(part, "inline_data", None)
|
| 744 |
+
if inline and getattr(inline, "data", None):
|
| 745 |
+
data = inline.data
|
| 746 |
+
mime = getattr(inline, "mime_type", "image/png")
|
| 747 |
+
if isinstance(data, bytes):
|
| 748 |
+
return data, mime
|
| 749 |
+
# some versions may base64-encode
|
| 750 |
+
try:
|
| 751 |
+
return base64.b64decode(data), mime
|
| 752 |
+
except Exception:
|
| 753 |
+
pass
|
| 754 |
+
return None, ""
|
| 755 |
+
except Exception:
|
| 756 |
+
return None, ""
|
| 757 |
+
|
| 758 |
+
|
| 759 |
+
def _image_part_from_path(path: str) -> Dict[str, Any]:
|
| 760 |
+
# google-generativeai accepts dict with mime_type and data bytes for images
|
| 761 |
+
p = Path(path)
|
| 762 |
+
mime = "image/png" if p.suffix.lower() == ".png" else "image/jpeg"
|
| 763 |
+
data = p.read_bytes()
|
| 764 |
+
return {"mime_type": mime, "data": data}
|
| 765 |
+
|
| 766 |
+
|
| 767 |
+
def _robust_json(text: str) -> Any:
|
| 768 |
+
# Try parse whole, then attempt to extract first {...} or [...] block
|
| 769 |
+
try:
|
| 770 |
+
return json.loads(text)
|
| 771 |
+
except Exception:
|
| 772 |
+
pass
|
| 773 |
+
# crude extraction
|
| 774 |
+
start = text.find("{")
|
| 775 |
+
end = text.rfind("}")
|
| 776 |
+
if start != -1 and end != -1 and end > start:
|
| 777 |
+
try:
|
| 778 |
+
return json.loads(text[start : end + 1])
|
| 779 |
+
except Exception:
|
| 780 |
+
pass
|
| 781 |
+
start = text.find("[")
|
| 782 |
+
end = text.rfind("]")
|
| 783 |
+
if start != -1 and end != -1 and end > start:
|
| 784 |
+
try:
|
| 785 |
+
return json.loads(text[start : end + 1])
|
| 786 |
+
except Exception:
|
| 787 |
+
pass
|
| 788 |
+
return {}
|
NNGen/app/nodes/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from . import parser, planner, prompt_gen, gen_generate, gen_labels, judge, select, edit, archive
|
NNGen/app/nodes/archive.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from ..state import AppState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run(state: AppState) -> AppState:
|
| 10 |
+
outdir = Path(state["outdir"]) # ensured by CLI
|
| 11 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 12 |
+
|
| 13 |
+
# dump spec
|
| 14 |
+
if state.get("spec"):
|
| 15 |
+
(outdir / "spec.json").write_text(json.dumps(state["spec"], ensure_ascii=False, indent=2))
|
| 16 |
+
if state.get("spec_text"):
|
| 17 |
+
(outdir / "spec.txt").write_text(state["spec_text"])
|
| 18 |
+
|
| 19 |
+
# dump prompts
|
| 20 |
+
if state.get("prompts"):
|
| 21 |
+
(outdir / "prompts.json").write_text(json.dumps(state["prompts"], ensure_ascii=False, indent=2))
|
| 22 |
+
|
| 23 |
+
# dump scores
|
| 24 |
+
if state.get("scores"):
|
| 25 |
+
(outdir / "scores.json").write_text(json.dumps(state["scores"], ensure_ascii=False, indent=2))
|
| 26 |
+
|
| 27 |
+
# copy/rename final image
|
| 28 |
+
if state.get("best_image"):
|
| 29 |
+
src = Path(state["best_image"].path)
|
| 30 |
+
dst = outdir / "final.png"
|
| 31 |
+
if src.exists():
|
| 32 |
+
dst.write_bytes(src.read_bytes())
|
| 33 |
+
|
| 34 |
+
return state
|
| 35 |
+
|
NNGen/app/nodes/edit.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import Dict
|
| 5 |
+
|
| 6 |
+
from ..llm.gemini import call_gemini
|
| 7 |
+
from ..state import AppState
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def _labels_from_spec(state: AppState) -> list[str]:
|
| 11 |
+
spec = state.get("spec", {}) or {}
|
| 12 |
+
raw_nodes = spec.get("nodes", []) or []
|
| 13 |
+
labels: list[str] = []
|
| 14 |
+
for n in raw_nodes:
|
| 15 |
+
label = None
|
| 16 |
+
if isinstance(n, str):
|
| 17 |
+
# strip index prefixes like "N0: ..."
|
| 18 |
+
parts = n.split(":", 1)
|
| 19 |
+
label = parts[1].strip() if len(parts) == 2 else n.strip()
|
| 20 |
+
elif isinstance(n, dict):
|
| 21 |
+
label = n.get("label") or n.get("name") or n.get("id")
|
| 22 |
+
if label:
|
| 23 |
+
labels.append(str(label))
|
| 24 |
+
# dedupe sequential exact repeats only when later mapping by order still makes sense
|
| 25 |
+
return labels
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def _ascii_friendly_labels(state: AppState) -> list[str]:
|
| 29 |
+
labels = _labels_from_spec(state)
|
| 30 |
+
def is_ascii(s: str) -> bool:
|
| 31 |
+
try:
|
| 32 |
+
s.encode('ascii')
|
| 33 |
+
return True
|
| 34 |
+
except Exception:
|
| 35 |
+
return False
|
| 36 |
+
if not labels or sum(1 for l in labels if is_ascii(l)) == 0:
|
| 37 |
+
return [
|
| 38 |
+
"PATCH EMBEDDING",
|
| 39 |
+
"CLS + POSENC",
|
| 40 |
+
"ENCODER xL",
|
| 41 |
+
"CLASS HEAD",
|
| 42 |
+
]
|
| 43 |
+
return labels
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def plan_edits(state: AppState) -> str:
|
| 47 |
+
hard_violations = [str(v) for v in state.get("hard_violations", [])]
|
| 48 |
+
violations = [str(v) for v in state.get("violations", [])]
|
| 49 |
+
# If judge reports missing labels (prefer HARD), provide an add-labels instruction
|
| 50 |
+
hv = hard_violations or violations
|
| 51 |
+
if any(("labels" in v.lower() and "missing" in v.lower()) for v in hv):
|
| 52 |
+
labels = _labels_from_spec(state)
|
| 53 |
+
numbered = "\n".join([f"{i+1}: \"{lbl}\"" for i, lbl in enumerate(labels)]) or "(no labels provided)"
|
| 54 |
+
return (
|
| 55 |
+
"Add text labels INSIDE each rectangular block without changing geometry, arrows, spacing, sizes, or colors. "
|
| 56 |
+
"Map labels in left→right, top→bottom order; reuse identical labels for repeated blocks. "
|
| 57 |
+
"Use a clean sans-serif font in solid black or dark gray, consistent size.\n"
|
| 58 |
+
f"Labels list:\n{numbered}"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Default: targeted fixes based on judge violations, but always provide labels list to preserve text in offline mode
|
| 62 |
+
fixes = "; ".join(violations) if violations else "typos, arrow direction, spacing/legibility, and style compliance"
|
| 63 |
+
labels = _ascii_friendly_labels(state)
|
| 64 |
+
numbered = "\n".join([f"{i+1}: \"{lbl}\"" for i, lbl in enumerate(labels)]) or "(no labels provided)"
|
| 65 |
+
return (
|
| 66 |
+
f"Fix the following issues precisely: {fixes}. "
|
| 67 |
+
"Do not move or reshape elements. Only adjust text (content/position/size), arrow direction styles, and minimal styling to reach paper standards.\n"
|
| 68 |
+
f"Labels list:\n{numbered}"
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def apply_edits(state: AppState) -> AppState:
|
| 73 |
+
if not state.get("best_image"):
|
| 74 |
+
return state
|
| 75 |
+
src = state["best_image"].path
|
| 76 |
+
out_path = str(Path(state["outdir"]) / f"edited_round_{state.get('round', 0)}.png")
|
| 77 |
+
_ = call_gemini("image_edit", image_path=src, out_path=out_path, instructions=plan_edits(state))
|
| 78 |
+
# replace best_image with edited one
|
| 79 |
+
state["best_image"].path = out_path # type: ignore
|
| 80 |
+
return state
|
NNGen/app/nodes/gen_fusion.py
ADDED
|
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Dict, List
|
| 7 |
+
|
| 8 |
+
from ..llm.gemini import call_gemini
|
| 9 |
+
from ..state import AppState, ImageArtifact
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def run(state: AppState) -> AppState:
|
| 13 |
+
"""Generate K fused candidates by composing reference images under instructions.
|
| 14 |
+
|
| 15 |
+
Expected in state:
|
| 16 |
+
- outdir: str
|
| 17 |
+
- K: int
|
| 18 |
+
- base_image: optional path (treated as first ref image if present)
|
| 19 |
+
- ref_images: optional list[str]
|
| 20 |
+
- instructions: str
|
| 21 |
+
"""
|
| 22 |
+
outdir = Path(state["outdir"]) # ensured by graph
|
| 23 |
+
K = int(state.get("K", 3))
|
| 24 |
+
instructions: str = str(state.get("instructions", "")).strip()
|
| 25 |
+
|
| 26 |
+
# prepare reference list
|
| 27 |
+
refs: List[str] = []
|
| 28 |
+
if state.get("base_image"):
|
| 29 |
+
refs.append(str(state["base_image"]))
|
| 30 |
+
for r in state.get("ref_images", []) or []:
|
| 31 |
+
if r and str(r) not in refs:
|
| 32 |
+
refs.append(str(r))
|
| 33 |
+
|
| 34 |
+
if not refs:
|
| 35 |
+
raise ValueError("Fusion mode requires at least one reference image (base or ref_images)")
|
| 36 |
+
|
| 37 |
+
max_workers = max(1, min(K, int(os.getenv("NNG_CONCURRENCY", "4"))))
|
| 38 |
+
|
| 39 |
+
def _fuse_one(i: int) -> str:
|
| 40 |
+
out_path = str(outdir / f"fused_candidate_{i}.png")
|
| 41 |
+
# Use image_edit if base image is provided; otherwise image_fuse
|
| 42 |
+
if state.get("base_image"):
|
| 43 |
+
call_gemini(
|
| 44 |
+
"image_edit",
|
| 45 |
+
image_path=str(state["base_image"]),
|
| 46 |
+
out_path=out_path,
|
| 47 |
+
instructions=f"Variant {i}: {instructions}",
|
| 48 |
+
ref_images=[p for p in refs if p != str(state["base_image"])],
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
call_gemini(
|
| 52 |
+
"image_fuse",
|
| 53 |
+
out_path=out_path,
|
| 54 |
+
instructions=f"Variant {i}: {instructions}",
|
| 55 |
+
ref_images=refs,
|
| 56 |
+
)
|
| 57 |
+
return out_path
|
| 58 |
+
|
| 59 |
+
paths: List[str] = [""] * K
|
| 60 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 61 |
+
futures = [ex.submit(_fuse_one, i) for i in range(K)]
|
| 62 |
+
for fut in as_completed(futures):
|
| 63 |
+
p = fut.result()
|
| 64 |
+
try:
|
| 65 |
+
idx = int(Path(p).stem.split("_")[-1])
|
| 66 |
+
except Exception:
|
| 67 |
+
idx = 0
|
| 68 |
+
paths[idx] = p
|
| 69 |
+
|
| 70 |
+
images = [ImageArtifact(prompt=instructions, path=pth) for pth in paths if pth]
|
| 71 |
+
state["images"] = images
|
| 72 |
+
return state
|
| 73 |
+
|
NNGen/app/nodes/gen_generate.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
from ..llm.gemini import call_gemini
|
| 6 |
+
from ..state import AppState, ImageArtifact
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run(state: AppState) -> AppState:
|
| 10 |
+
prompts: List[str] = state.get("prompts", [])
|
| 11 |
+
res: Dict = call_gemini("image_generate", prompts=prompts, outdir=state["outdir"])
|
| 12 |
+
paths = res.get("paths", [])
|
| 13 |
+
images = [ImageArtifact(prompt=p, path=pth) for p, pth in zip(prompts, paths)]
|
| 14 |
+
state["images"] = images
|
| 15 |
+
return state
|
| 16 |
+
|
NNGen/app/nodes/gen_labels.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
import os
|
| 6 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 7 |
+
|
| 8 |
+
from ..llm.gemini import call_gemini
|
| 9 |
+
from ..state import AppState, ImageArtifact
|
| 10 |
+
from .edit import _labels_from_spec # reuse label extraction
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def run(state: AppState) -> AppState:
|
| 14 |
+
images: List[ImageArtifact] = state.get("images", []) or []
|
| 15 |
+
if not images:
|
| 16 |
+
return state
|
| 17 |
+
|
| 18 |
+
labels = _labels_from_spec(state)
|
| 19 |
+
# Fallback to ASCII-friendly defaults if labels are missing or mostly non-ASCII
|
| 20 |
+
def _is_mostly_ascii(s: str) -> bool:
|
| 21 |
+
try:
|
| 22 |
+
s.encode('ascii')
|
| 23 |
+
return True
|
| 24 |
+
except Exception:
|
| 25 |
+
return False
|
| 26 |
+
if not labels or sum(1 for l in labels if _is_mostly_ascii(l)) == 0:
|
| 27 |
+
labels = [
|
| 28 |
+
"PATCH EMBEDDING",
|
| 29 |
+
"CLS + POSENC",
|
| 30 |
+
"ENCODER xL",
|
| 31 |
+
"CLASS HEAD",
|
| 32 |
+
]
|
| 33 |
+
numbered = "\n".join([f"{i+1}: \"{lbl}\"" for i, lbl in enumerate(labels)]) or "(no labels provided)"
|
| 34 |
+
|
| 35 |
+
instructions = (
|
| 36 |
+
"Add labels INSIDE each rectangular block. Do not move/resize/add/remove shapes or arrows; keep layout, spacing, and colors unchanged. "
|
| 37 |
+
"Map labels in left→right, top→bottom order; reuse identical labels for repeated blocks. Use each label string exactly as given (no translation or paraphrase). "
|
| 38 |
+
"Typography: clean sans-serif, readable size, centered within blocks; at most two short lines; avoid covering arrows; no legends or titles. "
|
| 39 |
+
"If block count ≠ label count, do NOT add/remove shapes; place labels sequentially on existing blocks.\n"
|
| 40 |
+
f"Labels list:\n{numbered}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
outdir = Path(state["outdir"]) if state.get("outdir") else Path("artifacts")
|
| 44 |
+
max_workers = max(1, min(len(images), int(os.getenv("NNG_CONCURRENCY", "4"))))
|
| 45 |
+
results: List[ImageArtifact | None] = [None] * len(images)
|
| 46 |
+
|
| 47 |
+
def _label_one(i: int, im: ImageArtifact) -> tuple[int, str]:
|
| 48 |
+
src = im.path
|
| 49 |
+
out_path = str(outdir / f"labeled_candidate_{i}.png")
|
| 50 |
+
_ = call_gemini("image_edit", image_path=src, out_path=out_path, instructions=instructions)
|
| 51 |
+
return i, out_path
|
| 52 |
+
|
| 53 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 54 |
+
futures = [ex.submit(_label_one, i, im) for i, im in enumerate(images)]
|
| 55 |
+
for fut in as_completed(futures):
|
| 56 |
+
i, out_path = fut.result()
|
| 57 |
+
results[i] = ImageArtifact(prompt=images[i].prompt, path=out_path, meta={"stage": "labels"})
|
| 58 |
+
|
| 59 |
+
state["images"] = [im for im in results if im is not None]
|
| 60 |
+
return state
|
NNGen/app/nodes/judge.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
import os
|
| 5 |
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
| 6 |
+
|
| 7 |
+
from ..llm.gemini import call_gemini
|
| 8 |
+
from ..state import AppState, ScoreItem
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def run(state: AppState) -> AppState:
|
| 12 |
+
images = list(state.get("images", []))
|
| 13 |
+
if not images:
|
| 14 |
+
state["scores"] = []
|
| 15 |
+
return state
|
| 16 |
+
|
| 17 |
+
max_workers = max(1, min(len(images), int(os.getenv("NNG_CONCURRENCY", "4"))))
|
| 18 |
+
results: List[ScoreItem | None] = [None] * len(images)
|
| 19 |
+
|
| 20 |
+
def _judge_one(i: int) -> tuple[int, Dict]:
|
| 21 |
+
im = images[i]
|
| 22 |
+
res: Dict = call_gemini("judge", image_path=im.path, spec=state.get("spec", {}))
|
| 23 |
+
return i, res
|
| 24 |
+
|
| 25 |
+
with ThreadPoolExecutor(max_workers=max_workers) as ex:
|
| 26 |
+
futures = [ex.submit(_judge_one, i) for i in range(len(images))]
|
| 27 |
+
for fut in as_completed(futures):
|
| 28 |
+
try:
|
| 29 |
+
i, res = fut.result()
|
| 30 |
+
im = images[i]
|
| 31 |
+
results[i] = {
|
| 32 |
+
"image_path": im.path,
|
| 33 |
+
"score": float(res.get("score", 0.0)),
|
| 34 |
+
"violations": list(res.get("violations", [])),
|
| 35 |
+
}
|
| 36 |
+
except Exception as e:
|
| 37 |
+
im = images[futures.index(fut)] if fut in futures else None
|
| 38 |
+
path = im.path if im else ""
|
| 39 |
+
results[i] = {
|
| 40 |
+
"image_path": path,
|
| 41 |
+
"score": 0.0,
|
| 42 |
+
"violations": [f"judge error: {e}"]
|
| 43 |
+
}
|
| 44 |
+
|
| 45 |
+
state["scores"] = [s for s in results if s is not None]
|
| 46 |
+
return state
|
NNGen/app/nodes/parser.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from ..llm.gemini import call_gemini
|
| 6 |
+
from ..state import AppState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run(state: AppState) -> AppState:
|
| 10 |
+
if state.get("spec"):
|
| 11 |
+
return state
|
| 12 |
+
res: Dict = call_gemini("parse", user_text=state.get("user_text", ""))
|
| 13 |
+
state["spec"] = res.get("spec", {})
|
| 14 |
+
return state
|
| 15 |
+
|
NNGen/app/nodes/planner.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from ..llm.gemini import call_gemini
|
| 6 |
+
from ..state import AppState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run(state: AppState) -> AppState:
|
| 10 |
+
if state.get("spec_text"):
|
| 11 |
+
return state
|
| 12 |
+
res: Dict = call_gemini("plan", spec=state.get("spec", {}))
|
| 13 |
+
state["spec_text"] = res.get("spec_text", "")
|
| 14 |
+
return state
|
| 15 |
+
|
NNGen/app/nodes/prompt_gen.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from typing import Dict
|
| 4 |
+
|
| 5 |
+
from ..llm.gemini import call_gemini
|
| 6 |
+
from ..state import AppState
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run(state: AppState) -> AppState:
|
| 10 |
+
if state.get("prompts"):
|
| 11 |
+
return state
|
| 12 |
+
res: Dict = call_gemini("prompt_generate", spec_text=state.get("spec_text", ""), K=state.get("K", 3))
|
| 13 |
+
state["prompts"] = res.get("prompts", [])
|
| 14 |
+
return state
|
| 15 |
+
|
NNGen/app/nodes/select.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from typing import List
|
| 5 |
+
|
| 6 |
+
from ..state import AppState, ImageArtifact
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def run(state: AppState) -> AppState:
|
| 10 |
+
scores = sorted(state.get("scores", []), key=lambda s: s["score"], reverse=True)
|
| 11 |
+
if not scores:
|
| 12 |
+
return state
|
| 13 |
+
best = scores[0]
|
| 14 |
+
best_img_path = best["image_path"]
|
| 15 |
+
# find the corresponding ImageArtifact
|
| 16 |
+
best_image: ImageArtifact | None = None
|
| 17 |
+
for im in state.get("images", []):
|
| 18 |
+
if im.path == best_img_path:
|
| 19 |
+
best_image = im
|
| 20 |
+
break
|
| 21 |
+
vios = [str(v) for v in best.get("violations", [])]
|
| 22 |
+
# Identify hard violations: explicit HARD marker or labels missing heuristic
|
| 23 |
+
hard = [v for v in vios if v.strip().lower().startswith("hard:")]
|
| 24 |
+
if not hard:
|
| 25 |
+
hard = [v for v in vios if ("labels" in v.lower() and "missing" in v.lower())]
|
| 26 |
+
|
| 27 |
+
state["best_image"] = best_image
|
| 28 |
+
state["violations"] = vios
|
| 29 |
+
state["hard_violations"] = hard
|
| 30 |
+
return state
|
NNGen/app/prompts.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def build_parse_prompt() -> str:
|
| 5 |
+
return (
|
| 6 |
+
"You are a strict parser for neural network architecture specs. "
|
| 7 |
+
"Input is natural language. Return ONLY a JSON object with fields: "
|
| 8 |
+
"nodes: string[], edges: [fromIndex, toIndex][], constraints: object. "
|
| 9 |
+
"No prose."
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def build_plan_prompt() -> str:
|
| 14 |
+
return (
|
| 15 |
+
"Given a structured NN spec (JSON), produce a concise, fillable template text "
|
| 16 |
+
"that preserves nodes, edges, and key constraints for diagram rendering. "
|
| 17 |
+
"Emphasize left-to-right flow, explicit layer counts, and unambiguous labels."
|
| 18 |
+
)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def build_promptgen_prompt(K: int, spec_text: str) -> str:
|
| 22 |
+
# Stage G1: lighter, cleaner skeleton-only prompts (no hard stylistic numbers)
|
| 23 |
+
return (
|
| 24 |
+
"Create K concise prompts for an image model to draw ONLY the skeleton of a neural network diagram (no text).\n"
|
| 25 |
+
"Aim for a clean paper-figure look: flat 2D, simple shapes, balanced margins, and a calm palette. Use rectangles for modules and clear left→right arrows.\n"
|
| 26 |
+
"If the spec implies repetition (e.g., Encoder × L), you may show a dashed grouping around the repeated blocks. Avoid flashy effects (no 3D or heavy glow).\n"
|
| 27 |
+
"Return ONLY a JSON array of exactly K strings; each item is one full prompt for image generation.\n"
|
| 28 |
+
f"K={K}.\n"
|
| 29 |
+
f"Spec (summary):\n{spec_text}\n"
|
| 30 |
+
"Each prompt must mention: 'skeleton-only, no text'."
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def build_judge_prompt() -> str:
|
| 35 |
+
# Judge content & style, optimized for two-stage (skeleton→labels) flow
|
| 36 |
+
return (
|
| 37 |
+
"You are a strict publication-figure QA judge. Given a spec (JSON) and a NN diagram image, "
|
| 38 |
+
"evaluate (A) Content correctness and (B) Paper-style compliance.\n"
|
| 39 |
+
"(A) Content (0.6): required modules present; edges/arrows reflect correct order; arrows left→right; labels exist and are spelled correctly; "
|
| 40 |
+
"layer count L indicated when applicable. If the image has no labels, include violation EXACTLY 'HARD: labels: missing'.\n"
|
| 41 |
+
"(B) Style (0.4): flat 2D; white background; minimal color (black/gray + ≤2 accents); no gradients/3D/glow/shadows/neon; "
|
| 42 |
+
"consistent stroke width; consistent sans-serif font; adequate spacing; dashed boxes for repeated blocks; high print readability.\n"
|
| 43 |
+
"Return ONLY strict JSON: {score: number in [0,1], violations: string[]}. Violations must be concrete and actionable."
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def build_image_edit_prompt(instructions: str) -> str:
|
| 48 |
+
# G2 and later edits: add/adjust labels only; keep geometry fixed (light constraints)
|
| 49 |
+
base = (
|
| 50 |
+
"Add or adjust labels INSIDE each block, without changing any shapes, arrows, layout, spacing, or colors. "
|
| 51 |
+
"Keep a clean, readable look: flat 2D, simple sans-serif font, good contrast, and consistent size across blocks. "
|
| 52 |
+
"Center labels within blocks; use at most two short lines; avoid covering arrows; do not add legends or titles. "
|
| 53 |
+
"Use each label string exactly as provided (no translation or paraphrase). "
|
| 54 |
+
)
|
| 55 |
+
return base + f"Instructions: {instructions}"
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def build_image_fusion_prompt(instructions: str) -> str:
|
| 59 |
+
# Compose multiple images guided by text while preserving key visual constraints
|
| 60 |
+
return (
|
| 61 |
+
"Compose a new, clean technical diagram by integrating the following reference images. "
|
| 62 |
+
"Preserve the overall paper-style look: flat 2D, white background, minimal color, consistent line width, and sans-serif text. "
|
| 63 |
+
"Follow the instructions precisely; keep geometry aligned and readable; avoid extra decorations. "
|
| 64 |
+
f"Instructions: {instructions}"
|
| 65 |
+
)
|
NNGen/app/state.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Any, Dict, List, Optional, TypedDict
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ScoreItem(TypedDict):
|
| 8 |
+
image_path: str
|
| 9 |
+
score: float
|
| 10 |
+
violations: List[str]
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
@dataclass
|
| 14 |
+
class ImageArtifact:
|
| 15 |
+
prompt: str
|
| 16 |
+
path: str
|
| 17 |
+
meta: Dict[str, Any] = field(default_factory=dict)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class AppState(TypedDict, total=False):
|
| 21 |
+
user_text: str
|
| 22 |
+
spec: Dict[str, Any]
|
| 23 |
+
spec_text: str
|
| 24 |
+
K: int
|
| 25 |
+
T: int
|
| 26 |
+
round: int
|
| 27 |
+
prompts: List[str]
|
| 28 |
+
images: List[ImageArtifact]
|
| 29 |
+
scores: List[ScoreItem]
|
| 30 |
+
best_image: Optional[ImageArtifact]
|
| 31 |
+
violations: List[str]
|
| 32 |
+
hard_violations: List[str]
|
| 33 |
+
outdir: str
|
NNGen/demo.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
NNGen/notebooks/demo.ipynb
ADDED
|
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"cells": [
|
| 3 |
+
{
|
| 4 |
+
"cell_type": "markdown",
|
| 5 |
+
"id": "0ce88361",
|
| 6 |
+
"metadata": {},
|
| 7 |
+
"source": [
|
| 8 |
+
"# NNGen Demo — Gemini 2.5 Flash Image\n",
|
| 9 |
+
"\n",
|
| 10 |
+
"Interactive demo to generate a neural network diagram from a natural language prompt.\n",
|
| 11 |
+
"- Uses the multi-agent pipeline (`parser → planner → prompt-gen → G1 → G2 → judge → select → edit loop → archive`).\n",
|
| 12 |
+
"- All model calls go through `app.llm.gemini.call_gemini`. If `GEMINI_API_KEY` is not set, the pipeline falls back to local placeholders so the demo still runs.\n",
|
| 13 |
+
"- For image generation/editing, set `GEMINI_IMAGE_MODEL`/`GEMINI_IMAGE_EDIT_MODEL` (e.g., `gemini-2.5-flash-image` or `gemini-2.5-flash-image-preview`).\n"
|
| 14 |
+
]
|
| 15 |
+
},
|
| 16 |
+
{
|
| 17 |
+
"cell_type": "code",
|
| 18 |
+
"execution_count": 1,
|
| 19 |
+
"id": "a0f6490c",
|
| 20 |
+
"metadata": {},
|
| 21 |
+
"outputs": [
|
| 22 |
+
{
|
| 23 |
+
"ename": "ModuleNotFoundError",
|
| 24 |
+
"evalue": "No module named 'app'",
|
| 25 |
+
"output_type": "error",
|
| 26 |
+
"traceback": [
|
| 27 |
+
"\u001b[31m---------------------------------------------------------------------------\u001b[39m",
|
| 28 |
+
"\u001b[31mModuleNotFoundError\u001b[39m Traceback (most recent call last)",
|
| 29 |
+
"\u001b[36mCell\u001b[39m\u001b[36m \u001b[39m\u001b[32mIn[1]\u001b[39m\u001b[32m, line 2\u001b[39m\n\u001b[32m 1\u001b[39m \u001b[38;5;66;03m# Imports\u001b[39;00m\n\u001b[32m----> \u001b[39m\u001b[32m2\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mapp\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mgraph\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m run_pipeline\n\u001b[32m 3\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mapp\u001b[39;00m\u001b[34;01m.\u001b[39;00m\u001b[34;01mstate\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m AppState\n\u001b[32m 4\u001b[39m \u001b[38;5;28;01mfrom\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[34;01mpathlib\u001b[39;00m\u001b[38;5;250m \u001b[39m\u001b[38;5;28;01mimport\u001b[39;00m Path\n",
|
| 30 |
+
"\u001b[31mModuleNotFoundError\u001b[39m: No module named 'app'"
|
| 31 |
+
]
|
| 32 |
+
}
|
| 33 |
+
],
|
| 34 |
+
"source": [
|
| 35 |
+
"# Imports\n",
|
| 36 |
+
"from app.graph import run_pipeline\n",
|
| 37 |
+
"from app.state import AppState\n",
|
| 38 |
+
"from pathlib import Path\n",
|
| 39 |
+
"from IPython.display import Image, display\n",
|
| 40 |
+
"import os, json\n"
|
| 41 |
+
]
|
| 42 |
+
},
|
| 43 |
+
{
|
| 44 |
+
"cell_type": "code",
|
| 45 |
+
"execution_count": null,
|
| 46 |
+
"id": "722a09d3",
|
| 47 |
+
"metadata": {},
|
| 48 |
+
"outputs": [],
|
| 49 |
+
"source": [
|
| 50 |
+
"# Optional: configure models here if not set in environment (.env is supported).\n",
|
| 51 |
+
"# os.environ.setdefault(\"GEMINI_MODEL\", \"gemini-2.5-flash\")\n",
|
| 52 |
+
"# os.environ.setdefault(\"GEMINI_IMAGE_MODEL\", \"gemini-2.5-flash-image\")\n",
|
| 53 |
+
"# os.environ.setdefault(\"GEMINI_IMAGE_EDIT_MODEL\", \"gemini-2.5-flash-image\")\n",
|
| 54 |
+
"print(\"GEMINI_MODEL=\", os.getenv(\"GEMINI_MODEL\", \"(default)\"))\n",
|
| 55 |
+
"print(\"GEMINI_IMAGE_MODEL=\", os.getenv(\"GEMINI_IMAGE_MODEL\", \"(default)\"))\n",
|
| 56 |
+
"print(\"GEMINI_IMAGE_EDIT_MODEL=\", os.getenv(\"GEMINI_IMAGE_EDIT_MODEL\", \"(default)\"))\n"
|
| 57 |
+
]
|
| 58 |
+
},
|
| 59 |
+
{
|
| 60 |
+
"cell_type": "code",
|
| 61 |
+
"execution_count": null,
|
| 62 |
+
"id": "61aacffe",
|
| 63 |
+
"metadata": {},
|
| 64 |
+
"outputs": [],
|
| 65 |
+
"source": [
|
| 66 |
+
"# Enter a natural language NN spec (leave blank to use the sample in spec/vit.txt).\n",
|
| 67 |
+
"print(\"Enter your NN spec prompt (blank for sample):\")\n",
|
| 68 |
+
"user_text = input().strip()\n",
|
| 69 |
+
"if not user_text:\n",
|
| 70 |
+
" user_text = Path('spec/vit.txt').read_text()\n",
|
| 71 |
+
"\n",
|
| 72 |
+
"# Number of candidates (K) and max edit rounds (T)\n",
|
| 73 |
+
"K = 4\n",
|
| 74 |
+
"T = 1\n",
|
| 75 |
+
"print(f\"Configured: K={K}, T={T}\")\n"
|
| 76 |
+
]
|
| 77 |
+
},
|
| 78 |
+
{
|
| 79 |
+
"cell_type": "code",
|
| 80 |
+
"execution_count": null,
|
| 81 |
+
"id": "f6c28bb1",
|
| 82 |
+
"metadata": {},
|
| 83 |
+
"outputs": [],
|
| 84 |
+
"source": [
|
| 85 |
+
"# Run the multi-agent pipeline\n",
|
| 86 |
+
"state: AppState = {\n",
|
| 87 |
+
" 'K': K,\n",
|
| 88 |
+
" 'T': T,\n",
|
| 89 |
+
" 'user_text': user_text,\n",
|
| 90 |
+
" 'outdir': '' # use timestamped default\n",
|
| 91 |
+
"}\n",
|
| 92 |
+
"final_state = run_pipeline(state)\n",
|
| 93 |
+
"print('Artifacts directory:', final_state['outdir'])\n"
|
| 94 |
+
]
|
| 95 |
+
},
|
| 96 |
+
{
|
| 97 |
+
"cell_type": "code",
|
| 98 |
+
"execution_count": null,
|
| 99 |
+
"id": "a402a108",
|
| 100 |
+
"metadata": {},
|
| 101 |
+
"outputs": [],
|
| 102 |
+
"source": [
|
| 103 |
+
"# Display the final image\n",
|
| 104 |
+
"final_path = Path(final_state['outdir']) / 'final.png'\n",
|
| 105 |
+
"if final_path.exists():\n",
|
| 106 |
+
" display(Image(filename=str(final_path)))\n",
|
| 107 |
+
"else:\n",
|
| 108 |
+
" print('final.png not found at', final_path)\n"
|
| 109 |
+
]
|
| 110 |
+
},
|
| 111 |
+
{
|
| 112 |
+
"cell_type": "code",
|
| 113 |
+
"execution_count": null,
|
| 114 |
+
"id": "f7518a91",
|
| 115 |
+
"metadata": {},
|
| 116 |
+
"outputs": [],
|
| 117 |
+
"source": [
|
| 118 |
+
"# (Optional) Inspect outputs: spec and scoring\n",
|
| 119 |
+
"spec_txt = Path(final_state['outdir']) / 'spec.txt'\n",
|
| 120 |
+
"scores_json = Path(final_state['outdir']) / 'scores.json'\n",
|
| 121 |
+
"if spec_txt.exists():\n",
|
| 122 |
+
" print('--- spec.txt ---')\n",
|
| 123 |
+
" print(spec_txt.read_text())\n",
|
| 124 |
+
"if scores_json.exists():\n",
|
| 125 |
+
" print('--- scores.json ---')\n",
|
| 126 |
+
" print(scores_json.read_text())\n"
|
| 127 |
+
]
|
| 128 |
+
}
|
| 129 |
+
],
|
| 130 |
+
"metadata": {
|
| 131 |
+
"kernelspec": {
|
| 132 |
+
"display_name": "Python 3 (ipykernel)",
|
| 133 |
+
"language": "python",
|
| 134 |
+
"name": "python3"
|
| 135 |
+
},
|
| 136 |
+
"language_info": {
|
| 137 |
+
"codemirror_mode": {
|
| 138 |
+
"name": "ipython",
|
| 139 |
+
"version": 3
|
| 140 |
+
},
|
| 141 |
+
"file_extension": ".py",
|
| 142 |
+
"mimetype": "text/x-python",
|
| 143 |
+
"name": "python",
|
| 144 |
+
"nbconvert_exporter": "python",
|
| 145 |
+
"pygments_lexer": "ipython3",
|
| 146 |
+
"version": "3.13.3"
|
| 147 |
+
}
|
| 148 |
+
},
|
| 149 |
+
"nbformat": 4,
|
| 150 |
+
"nbformat_minor": 5
|
| 151 |
+
}
|
NNGen/requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
langgraph>=0.2.0
|
| 2 |
+
google-generativeai>=0.7.0
|
| 3 |
+
python-dotenv>=1.0.1
|
| 4 |
+
gradio>=4.32.0
|
NNGen/runtime.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
python-3.10
|
NNGen/scripts/gradio_app.py
ADDED
|
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import shutil
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import List, Tuple
|
| 7 |
+
|
| 8 |
+
import gradio as gr
|
| 9 |
+
|
| 10 |
+
from app.graph import run_pipeline, run_fusion_pipeline
|
| 11 |
+
from app.state import AppState
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _zip_outdir(outdir: str) -> str:
|
| 15 |
+
out = Path(outdir)
|
| 16 |
+
if not out.exists():
|
| 17 |
+
return ""
|
| 18 |
+
zip_path = str(out) + ".zip"
|
| 19 |
+
# remove if exists
|
| 20 |
+
try:
|
| 21 |
+
if Path(zip_path).exists():
|
| 22 |
+
Path(zip_path).unlink()
|
| 23 |
+
except Exception:
|
| 24 |
+
pass
|
| 25 |
+
shutil.make_archive(str(out), "zip", root_dir=str(out))
|
| 26 |
+
return zip_path
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def run_text_mode(user_text: str, K: int, T: int, make_zip: bool) -> Tuple[str, List[str], str, str]:
|
| 30 |
+
state: AppState = {"K": int(K), "T": int(T), "user_text": user_text or "", "outdir": ""}
|
| 31 |
+
final_state = run_pipeline(state)
|
| 32 |
+
outdir = final_state["outdir"]
|
| 33 |
+
# Collect candidates if present
|
| 34 |
+
candidates = [im.path for im in (final_state.get("images") or [])]
|
| 35 |
+
final_img = str(Path(outdir) / "final.png")
|
| 36 |
+
zip_path = _zip_outdir(outdir) if make_zip else ""
|
| 37 |
+
return final_img, candidates, outdir, zip_path
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def run_image_mode(base_image, ref_images, instructions: str, K: int, make_zip: bool) -> Tuple[str, List[str], str, str]:
|
| 41 |
+
state: AppState = {"K": int(K), "T": 0, "outdir": "", "instructions": instructions or ""}
|
| 42 |
+
if base_image is not None:
|
| 43 |
+
state["base_image"] = base_image if isinstance(base_image, str) else base_image.name
|
| 44 |
+
refs: List[str] = []
|
| 45 |
+
for f in (ref_images or []):
|
| 46 |
+
p = f if isinstance(f, str) else getattr(f, "name", None)
|
| 47 |
+
if p:
|
| 48 |
+
refs.append(p)
|
| 49 |
+
state["ref_images"] = refs
|
| 50 |
+
|
| 51 |
+
final_state = run_fusion_pipeline(state)
|
| 52 |
+
outdir = final_state["outdir"]
|
| 53 |
+
candidates = [im.path for im in (final_state.get("images") or [])]
|
| 54 |
+
final_img = str(Path(outdir) / "final.png")
|
| 55 |
+
zip_path = _zip_outdir(outdir) if make_zip else ""
|
| 56 |
+
return final_img, candidates, outdir, zip_path
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def app() -> gr.Blocks:
|
| 60 |
+
with gr.Blocks(title="NNGen — Gemini 2.5 Flash Image") as demo:
|
| 61 |
+
gr.Markdown("""
|
| 62 |
+
# NNGen — Gemini 2.5 Flash Image
|
| 63 |
+
- Text mode: enter a natural language spec to generate a diagram (G1/G2/judge/edit).
|
| 64 |
+
- Image mode: edit/fuse images with textual instructions (e.g., replace UNet with Transformer).
|
| 65 |
+
- Offline works with placeholders if no `GEMINI_API_KEY` is set. With an API key, set `GEMINI_IMAGE_MODEL` and `GEMINI_IMAGE_EDIT_MODEL`.
|
| 66 |
+
""")
|
| 67 |
+
|
| 68 |
+
with gr.Tab("Text Mode"):
|
| 69 |
+
user_text = gr.Textbox(label="NN spec (text)", lines=10, placeholder="Describe the architecture... e.g., Transformer encoder-decoder with cross-attention...")
|
| 70 |
+
with gr.Row():
|
| 71 |
+
K = gr.Slider(1, 6, value=4, step=1, label="K candidates")
|
| 72 |
+
T = gr.Slider(0, 3, value=1, step=1, label="Max edit rounds (T)")
|
| 73 |
+
zip_output = gr.Checkbox(value=False, label="Zip outputs")
|
| 74 |
+
run_btn = gr.Button("Generate")
|
| 75 |
+
final_img = gr.Image(label="final.png", type="filepath")
|
| 76 |
+
gallery = gr.Gallery(label="Candidates").style(grid=4)
|
| 77 |
+
outdir = gr.Textbox(label="Artifacts directory", interactive=False)
|
| 78 |
+
zip_file = gr.File(label="Download run.zip", interactive=False)
|
| 79 |
+
|
| 80 |
+
run_btn.click(run_text_mode, inputs=[user_text, K, T, zip_output], outputs=[final_img, gallery, outdir, zip_file])
|
| 81 |
+
|
| 82 |
+
with gr.Tab("Image Mode (Fusion/Edit)"):
|
| 83 |
+
base = gr.Image(label="Base image (optional)", type="filepath")
|
| 84 |
+
refs = gr.Files(label="Reference images (0..N)")
|
| 85 |
+
instr = gr.Textbox(label="Instructions", lines=4, placeholder="Replace the UNet backbone with a Transformer (DiT); keep layout, fonts, colors, arrows, and dashed groups unchanged.")
|
| 86 |
+
with gr.Row():
|
| 87 |
+
K2 = gr.Slider(1, 6, value=4, step=1, label="K candidates")
|
| 88 |
+
zip_output2 = gr.Checkbox(value=False, label="Zip outputs")
|
| 89 |
+
run_btn2 = gr.Button("Compose / Edit")
|
| 90 |
+
final_img2 = gr.Image(label="final.png", type="filepath")
|
| 91 |
+
gallery2 = gr.Gallery(label="Fused Candidates").style(grid=4)
|
| 92 |
+
outdir2 = gr.Textbox(label="Artifacts directory", interactive=False)
|
| 93 |
+
zip_file2 = gr.File(label="Download run.zip", interactive=False)
|
| 94 |
+
|
| 95 |
+
run_btn2.click(run_image_mode, inputs=[base, refs, instr, K2, zip_output2], outputs=[final_img2, gallery2, outdir2, zip_file2])
|
| 96 |
+
|
| 97 |
+
return demo
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
if __name__ == "__main__":
|
| 101 |
+
port = int(os.getenv("PORT", "7860"))
|
| 102 |
+
app().launch(server_name="0.0.0.0", server_port=port)
|
| 103 |
+
|
NNGen/spec/transformer.txt
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Title: Transformer Encoder–Decoder (Machine Translation)
|
| 2 |
+
|
| 3 |
+
Instructions:
|
| 4 |
+
- Produce a clean paper-style architecture diagram: flat 2D, white background, minimal color (black/gray + ≤2 accent colors), no gradients/3D/shadows.
|
| 5 |
+
- Layout left→right. Clear arrows indicate data flow.
|
| 6 |
+
- Draw boxes for major modules and show grouping for repeated layers.
|
| 7 |
+
- Text labels should be concise and capitalized (e.g., EMBEDDING, ENCODER xN, DECODER xN).
|
| 8 |
+
|
| 9 |
+
Architecture:
|
| 10 |
+
- Input: tokenized source sentence
|
| 11 |
+
- Source Embedding + Positional Encoding
|
| 12 |
+
- Encoder (N layers):
|
| 13 |
+
- Multi-Head Self-Attention
|
| 14 |
+
- Add & LayerNorm
|
| 15 |
+
- Feed-Forward (MLP)
|
| 16 |
+
- Add & LayerNorm
|
| 17 |
+
- Target: previous target tokens (for training)
|
| 18 |
+
- Target Embedding + Positional Encoding
|
| 19 |
+
- Decoder (N layers):
|
| 20 |
+
- Masked Multi-Head Self-Attention
|
| 21 |
+
- Add & LayerNorm
|
| 22 |
+
- Cross-Attention (attends to Encoder outputs)
|
| 23 |
+
- Add & LayerNorm
|
| 24 |
+
- Feed-Forward (MLP)
|
| 25 |
+
- Add & LayerNorm
|
| 26 |
+
- Output: Linear + Softmax
|
| 27 |
+
|
| 28 |
+
Style details:
|
| 29 |
+
- Use a dashed rounded rectangle to group the N repeated layers on both encoder and decoder.
|
| 30 |
+
- Keep arrows straight. Left→right overall; show a connection from Encoder outputs to the Cross-Attention in Decoder.
|
| 31 |
+
- If space is tight, abbreviate labels (e.g., SELF-ATTN, CROSS-ATTN, FFN).
|
| 32 |
+
|
NNGen/spec/vit.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Generate a high-level diagram of a Vision Transformer (ViT):
|
| 2 |
+
- Input: 224×224 RGB image
|
| 3 |
+
- Patch Embedding: split into 16×16 patches and apply a linear projection
|
| 4 |
+
- Add CLS token and positional encoding
|
| 5 |
+
- Transformer Encoder stack: Multi-Head Self-Attention + MLP + residual + LayerNorm (repeat L layers)
|
| 6 |
+
- Classification head: take CLS token for linear classification
|
| 7 |
+
Layout requirements: left-to-right flow; clear arrow directions; correct spelling of all labels; show the number of layers L; keep colors readable.
|