# app.py — Story Generator with Elegant UI import os, json, re, pathlib, base64, time, uuid from huggingface_hub import InferenceClient from PIL import Image import gradio as gr # ---------- Config ---------- HF_TOKEN = os.environ.get("HF_TOKEN") if not HF_TOKEN: raise RuntimeError("⚠️ Set HF_TOKEN environment variable (use Spaces Secrets).") CHAT_MODEL = "meta-llama/Llama-3.1-8b-instruct" IMAGE_MODEL = "black-forest-labs/FLUX.1-schnell" SAFETY_MODEL = "meta-llama/Meta-Llama-Guard-2-8B" # Safety model for content moderation OUT_DIR = pathlib.Path("/tmp/generated") OUT_DIR.mkdir(parents=True, exist_ok=True) client = InferenceClient(api_key=HF_TOKEN, provider="auto") # ---------- Utility ---------- def try_chat_completion(model_id, messages, max_tokens=3000): r = client.chat.completions.create(model=model_id, messages=messages, max_tokens=max_tokens) return r.choices[0].message.get("content", "") def extract_json_from_text(text): if not text: return None match = re.search(r'```json\s*(\{[\s\S]*\})\s*```', text) if match: json_str = match.group(1) else: match = re.search(r'\{[\s\S]*\}', text) if not match: return None json_str = match.group(0) try: return json.loads(json_str) except json.JSONDecodeError as e: print(f"--- JSON PARSING FAILED ---\nError: {e}\nContent: {json_str}\n--------------------------") return None def hf_text_to_image(model_id, prompt, out_path): img = client.text_to_image(prompt=prompt, model=model_id) if isinstance(img, Image.Image): img.save(out_path) return out_path raise RuntimeError("Invalid image response") def is_content_inappropriate(text_to_check): """Uses Llama Guard to check for inappropriate content.""" try: r = client.chat.completions.create( model=SAFETY_MODEL, messages=[{"role": "user", "content": f"Please evaluate if the following content is safe or unsafe based on typical safety guidelines for an AI assistant. Output 'safe' or 'unsafe'.\n\nContent: \"{text_to_check}\""}], max_tokens=20, temperature=0.1 ) response = r.choices[0].message.get("content", "").lower() return "unsafe" in response except Exception as e: print(f"Error in safety check: {e}") return False # ---------- Story logic ---------- def make_prompt(user_prompt, nscenes=6, nsent=5): # This prompt is excellent and requires no changes. return f""" You are a creative story writer. Your task is to write a compelling story based on a user's prompt. You MUST return the story in a single, valid JSON object. Do not write any text or explanations outside of the JSON structure. Here is an example of the required JSON format: {{ "title": "A descriptive title for the entire story", "scenes": [ {{ "id": 1, "text": "The full story text for this scene. This should be a complete paragraph with around {nsent} sentences, describing the events and setting.", "visual_prompt": "A detailed, vivid description for an image generation model, capturing the key visual elements of this scene." }} ] }} Please use the following details for the story: - Story Prompt: "{user_prompt}" - Total number of scenes: {nscenes} Now, generate the story in the specified JSON format. """ def generate_story_and_images(prompt, nscenes, nsent, img_model): start = time.time() logs = [] logs.append("🎬 Generating story...") raw = try_chat_completion(CHAT_MODEL, [{"role": "user", "content": make_prompt(prompt, nscenes, nsent)}]) story = extract_json_from_text(raw) if not story: story = {"title": "Untitled", "scenes": [{"id": i + 1, "text": f"Scene {i+1}. The AI failed to generate a proper story, or the JSON was malformed.", "visual_prompt": prompt} for i in range(nscenes)]} logs.append("⚠️ Failed to parse story JSON, using fallback.") else: logs.append("✅ Story JSON parsed.") image_paths = [] for s in story["scenes"]: visual_prompt = s.get("visual_prompt", s.get("text", prompt)) if is_content_inappropriate(visual_prompt): gr.Warning(f"Visual prompt for Scene {s['id']} was moderated for safety. Generating a default image.") visual_prompt = "A serene landscape with gentle colors." name = OUT_DIR / f"{uuid.uuid4().hex[:6]}_scene_{s['id']}.png" logs.append(f"🎨 Generating image for Scene {s['id']}...") hf_text_to_image(img_model, visual_prompt, str(name)) image_paths.append(str(name)) total = time.time() - start logs.append(f"✨ Done in {total:.2f}s") return story, image_paths, "\n".join(logs) # ---------- UI ---------- def build_ui(): css = """ .main-container { max-width: 1400px; margin: 0 auto; } .prompt-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 32px; border-radius: 16px; margin-bottom: 24px; } .prompt-box textarea { font-size: 16px !important; } .story-panel { background: rgba(255,255,255,0.05); padding: 24px; border-radius: 12px; backdrop-filter: blur(10px); border: 1px solid rgba(255,255,255,0.1); max-height: 800px; overflow-y: auto; } .story-title { font-size: 32px; font-weight: 700; margin-bottom: 24px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } .story-content { line-height: 1.8; font-size: 16px; color: rgba(255,255,255,0.9); } .story-content p { margin-bottom: 16px; } """ with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Story Generator") as demo: gr.Markdown("# 📚 AI Story Generator", elem_classes="main-title") with gr.Column(elem_classes="main-container"): with gr.Column(elem_classes="prompt-section"): prompt_box = gr.Textbox(label="✨ Enter your story idea", placeholder="e.g. A pirate discovering a hidden island...", lines=3, elem_classes="prompt-box") with gr.Row(): generate_btn = gr.Button("🚀 Generate Story", variant="primary", size="lg", scale=3) with gr.Column(scale=1): with gr.Accordion("⚙️ Settings", open=False): nscenes = gr.Slider(2, 12, value=6, step=1, label="📖 Number of scenes") nsent = gr.Slider(2, 8, value=5, step=1, label="📝 Sentences per scene") img_model = gr.Dropdown(choices=[IMAGE_MODEL], value=IMAGE_MODEL, label="🎨 Image model") log_box = gr.Textbox(label="📋 Generation Logs", lines=6, interactive=False) with gr.Row(): with gr.Column(scale=5, elem_classes="story-panel"): story_html = gr.HTML("
Your story will appear here...

Click 'Generate Story' to begin! ✨
") with gr.Column(scale=7): image_gallery = gr.Gallery( label="📷 Scene Visuals", show_label=False, elem_id="gallery", columns=2, object_fit="cover", height="auto", preview=True ) def on_generate(prompt, nscenes, nsent, img_model): if is_content_inappropriate(prompt): gr.Warning("Your prompt seems to violate the safety policy. Please try again.") return "
Prompt rejected due to safety policy.
", [], "Prompt rejected by safety filter." story, imgs, logs = generate_story_and_images(prompt, int(nscenes), int(nsent), img_model) story_output_html = f"
{story.get('title', 'Untitled')}
\n
\n" for s in story.get('scenes', []): scene_text = s.get('text', '') if is_content_inappropriate(scene_text): scene_text = f"**[Scene {s['id']} was moderated for safety and replaced with a placeholder.]**" gr.Info(f"Scene {s['id']} content was flagged and replaced.") story_output_html += f"

{scene_text}

\n\n" story_output_html += "
" return story_output_html, imgs, logs generate_btn.click( on_generate, inputs=[prompt_box, nscenes, nsent, img_model], outputs=[story_html, image_gallery, log_box] ) return demo app = build_ui() if __name__ == "__main__": app.launch()