Spaces:
Build error
Build error
| import base64 | |
| import json | |
| import os | |
| import random | |
| import re | |
| import shutil | |
| import sys | |
| import tempfile | |
| import uuid | |
| import requests | |
| from datetime import datetime | |
| from io import BytesIO | |
| from pathlib import Path | |
| import gradio as gr | |
| from PIL import Image | |
| from dotenv import load_dotenv | |
| from graphviz import Digraph | |
| from huggingface_hub import InferenceClient | |
| from together import Together | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # ENV / API | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| load_dotenv() | |
| HF_TOKEN = os.getenv("HF_TOKEN") # <-- add your HF token to .env | |
| TOGETHER_TOKEN = os.getenv("TOGETHER_API_KEY", "") | |
| together_client = Together(api_key=TOGETHER_TOKEN) | |
| image_client = InferenceClient(token=HF_TOKEN) # default model set later | |
| # Optional Graphviz path helper (Windows ONLY (RIP Gotham)) | |
| # if shutil.which("dot") is None: | |
| # gv_path = r"C:\Program Files\Graphviz\bin" | |
| # if os.path.exists(gv_path): | |
| # os.environ["PATH"] = gv_path + os.pathsep + os.environ["PATH"] | |
| # else: | |
| # sys.exit("Graphviz not found. Please install Graphviz or remove the check.") | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # LLM templates | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| LLAMA_JSON_PROMPT = """ | |
| Extract every character and any explicit relationship between them. | |
| Return pure JSON ONLY in this schema: | |
| { | |
| "characters": ["Alice", "Bob"], | |
| "relations": [ | |
| {"from":"Alice","to":"Bob","type":"friend"} | |
| ] | |
| } | |
| TEXT: | |
| \"\"\"%s\"\"\" | |
| """ | |
| IMAGE_PROMPT_TEMPLATE = """ | |
| Based on the following story, write %d distinct vivid scene descriptions, one per line. | |
| Each line should begin with a dash (-) followed by a detailed image-worthy scene. | |
| Include setting, mood, characters, and visual cues. | |
| Return ONLY the list of scenes, each on its own line. | |
| Story: | |
| \"\"\"%s\"\"\" | |
| """ | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Entity extraction | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def extract_entities(text: str): | |
| try: | |
| prompt = LLAMA_JSON_PROMPT % text | |
| resp = together_client.chat.completions.create( | |
| model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
| messages=[{"role": "user", "content": prompt}], | |
| max_tokens=1024, | |
| ) | |
| raw = resp.choices[0].message.content.strip() | |
| m = re.search(r"\{[\s\S]*\}", raw) | |
| if not m: | |
| return None, f"โ ๏ธย No JSON block found.\n\n{raw}" | |
| data = json.loads(m.group(0)) | |
| return data, None | |
| except Exception as e: | |
| return None, f"โ ๏ธย extractor error: {e}" | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Build visual prompt | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def generate_image_prompts(story_text: str, count=1): | |
| try: | |
| prompt_msg = IMAGE_PROMPT_TEMPLATE % (count, story_text) | |
| resp = together_client.chat.completions.create( | |
| model="meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8", | |
| messages=[{"role": "user", "content": prompt_msg}], | |
| max_tokens=200, | |
| ) | |
| raw_output = resp.choices[0].message.content.strip() | |
| prompts = [line.strip("-โข ").strip() for line in raw_output.split("\n") if line.strip()] | |
| return prompts[:count] # just in case LLM gives more than needed | |
| except Exception as e: | |
| print("โ ๏ธ LLM scene prompt generation failed:", e) | |
| return [] | |
| def generate_images_with_together(story, style, quality, count=1): | |
| base_prompt = generate_image_prompts(story) | |
| images = [] | |
| for i in range(count): | |
| full_prompt = f"{style} style, cinematic lighting, quality {quality}, {base_prompt} [Scene {i + 1}]" | |
| seed = random.randint(1, 10_000_000) | |
| try: | |
| resp = together_client.images.generate( | |
| model="black-forest-labs/FLUX.1-schnell-Free", | |
| prompt=full_prompt, | |
| seed=seed, | |
| width=768, | |
| height=512, | |
| steps=4 | |
| ) | |
| except Exception as e: | |
| print("๐ฅ Together image API error:", e) | |
| break | |
| img = None | |
| if resp.data: | |
| choice = resp.data[0] | |
| if getattr(choice, "url", None): | |
| try: | |
| img_bytes = requests.get(choice.url, timeout=30).content | |
| img = Image.open(BytesIO(img_bytes)) | |
| except Exception as e: | |
| print("โ ๏ธย URL fetch failed:", e) | |
| elif getattr(choice, "b64_json", None): | |
| try: | |
| img_bytes = base64.b64decode(choice.b64_json) | |
| img = Image.open(BytesIO(img_bytes)) | |
| except Exception as e: | |
| print("โ ๏ธย base64 decode failed:", e) | |
| if img is not None: | |
| images.append(img) | |
| else: | |
| print(f"โ ๏ธย No image for scene {i+1}") | |
| return images | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Graph โ PNG (Graphviz) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def build_graph_png(data: dict) -> str: | |
| dot = Digraph(format="png") | |
| dot.attr(rankdir="LR", bgcolor="white", fontsize="11") | |
| for c in data["characters"]: | |
| dot.node(c, shape="ellipse", style="filled", fillcolor="#8ecae6") | |
| for r in data["relations"]: | |
| dot.edge(r["from"], r["to"], label=r["type"], fontsize="10") | |
| tmpdir = Path(tempfile.mkdtemp()) | |
| path = tmpdir / f"graph_{uuid.uuid4().hex}.png" | |
| dot.render(path.stem, directory=tmpdir, cleanup=True) | |
| return str(path) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Core generation | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def generate_assets(prompt, style, quality, num_images, state): | |
| data, err = extract_entities(prompt) | |
| if not data: | |
| return [], None, err or "No data.", state | |
| graph_path = build_graph_png(data) | |
| images = [] | |
| if num_images > 0: | |
| try: | |
| images = generate_images_with_together(prompt, style, quality, int(num_images)) | |
| except Exception as e: | |
| status = f"โ ๏ธ Image generation failed: {e}" | |
| return [], graph_path, status, data | |
| status = "โ All assets generated." if images else "โ Graph generated (no images)." | |
| return images, graph_path, status, data | |
| # Helper to rebuild graph after manual edits | |
| def _regen_graph(state): return gr.update(value=build_graph_png(state)) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # Manual tweak callbacks | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| def add_character(name, state): | |
| if not name: | |
| return gr.update(), "Enter a character name.", state | |
| if name in state["characters"]: | |
| return gr.update(), f"{name} already exists.", state | |
| state["characters"].append(name) | |
| return _regen_graph(state), "โ ย Character added.", state | |
| def add_relation(frm, to, typ, state): | |
| if frm not in state["characters"] or to not in state["characters"]: | |
| return gr.update(), "Both characters must exist first.", state | |
| state["relations"].append({"from": frm, "to": to, "type": typ or "relation"}) | |
| return _regen_graph(state), "โ ย Relation added.", state | |
| def delete_character(name, state): | |
| if name not in state["characters"]: | |
| return gr.update(), "Character not found.", state | |
| state["characters"].remove(name) | |
| state["relations"] = [r for r in state["relations"] if r["from"] != name and r["to"] != name] | |
| return _regen_graph(state), f"๐ฎย {name} deleted.", state | |
| # Save / Load | |
| def save_json(state): | |
| fp = Path(tempfile.gettempdir()) / f"story_{datetime.now().isoformat()}.json" | |
| fp.write_text(json.dumps(state, indent=2)) | |
| return str(fp) | |
| def load_json(file_obj, state): | |
| if not file_obj or not Path(file_obj).exists(): | |
| return gr.update(), "No file uploaded.", state | |
| try: | |
| data = json.loads(Path(file_obj).read_text()) | |
| assert "characters" in data and "relations" in data | |
| return _regen_graph(data), "โ ย File loaded.", data | |
| except Exception as e: | |
| return gr.update(), f"Load error: {e}", state | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| # UI (same tabs you designed) | |
| # โโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโโ | |
| with gr.Blocks(theme=gr.themes.Soft(primary_hue="blue", secondary_hue="cyan")) as demo: | |
| gr.Markdown("## โจ EpicFrame โ Narrative Workbench") | |
| state = gr.State({"characters": [], "relations": []}) | |
| # Input tab | |
| with gr.Tab("Input"): | |
| text_input = gr.Textbox(label="Story prompt", lines=6) | |
| style_dropdown = gr.Dropdown(["Realistic", "Anime", "Sketch"], value="Realistic", label="Style") | |
| quality_slider = gr.Slider(1, 10, value=7, step=1, label="Image Quality") | |
| num_images_sl = gr.Slider(0, 4, value=0, step=1, label="Images to generate (0 = skip)") | |
| generate_btn = gr.Button("โถ๏ธ Generate Assets") | |
| status_box = gr.Textbox(label="Status", lines=2) | |
| # Images tab | |
| with gr.Tab("Images"): | |
| gallery = gr.Gallery(label="๐ผ๏ธ Images", columns=4) | |
| # Graph/Edit tab | |
| with gr.Tab("Graph / Edit"): | |
| graph_img = gr.Image(label="๐ Character Map", interactive=False, height=500) | |
| with gr.Row(): | |
| add_char_name = gr.Textbox(label="Add Character โ Name") | |
| add_char_btn = gr.Button("Add") | |
| with gr.Row(): | |
| rel_from = gr.Textbox(label="Relation From") | |
| rel_to = gr.Textbox(label="To") | |
| rel_type = gr.Textbox(label="Type") | |
| add_rel_btn = gr.Button("Add Relation") | |
| with gr.Row(): | |
| del_char_name = gr.Textbox(label="Delete Character โ Name") | |
| del_char_btn = gr.Button("Delete") | |
| tweak_msg = gr.Textbox(label="โฐ Status", max_lines=2) | |
| # Save/Load tab | |
| with gr.Tab("Save / Load"): | |
| save_btn = gr.Button("๐พ Download JSON") | |
| load_file = gr.File(label="Load JSON") | |
| load_btn = gr.Button("โคต๏ธ Load into workspace") | |
| save_msg = gr.Textbox(label="Status", max_lines=2) | |
| # callbacks | |
| generate_btn.click( | |
| generate_assets, | |
| inputs=[text_input, style_dropdown, quality_slider, num_images_sl, state], | |
| outputs=[gallery, graph_img, status_box, state] | |
| ) | |
| add_char_btn.click(add_character, | |
| inputs=[add_char_name, state], | |
| outputs=[graph_img, tweak_msg, state]) | |
| add_rel_btn.click(add_relation, | |
| inputs=[rel_from, rel_to, rel_type, state], | |
| outputs=[graph_img, tweak_msg, state]) | |
| del_char_btn.click(delete_character, | |
| inputs=[del_char_name, state], | |
| outputs=[graph_img, tweak_msg, state]) | |
| save_btn.click(save_json, inputs=state, outputs=save_btn, api_name="download") \ | |
| .then(lambda p: "โ JSON ready.", outputs=save_msg) | |
| load_btn.click(load_json, inputs=[load_file, state], | |
| outputs=[graph_img, save_msg, state]) | |
| demo.launch() | |