Spaces:
Sleeping
Sleeping
| # app.py β Story Generator with Elegant UI | |
| import os, json, re, pathlib, base64, time, uuid | |
| from huggingface_hub import InferenceClient | |
| from PIL import Image | |
| import gradio as gr | |
| # ---------- Config ---------- | |
| HF_TOKEN = os.environ.get("HF_TOKEN") | |
| if not HF_TOKEN: | |
| raise RuntimeError("β οΈ Set HF_TOKEN environment variable (use Spaces Secrets).") | |
| CHAT_MODEL = "meta-llama/Llama-3.1-8b-instruct" | |
| IMAGE_MODEL = "black-forest-labs/FLUX.1-schnell" | |
| SAFETY_MODEL = "meta-llama/Meta-Llama-Guard-2-8B" # Safety model for content moderation | |
| OUT_DIR = pathlib.Path("/tmp/generated") | |
| OUT_DIR.mkdir(parents=True, exist_ok=True) | |
| client = InferenceClient(api_key=HF_TOKEN, provider="auto") | |
| # ---------- Utility ---------- | |
| def try_chat_completion(model_id, messages, max_tokens=3000): | |
| r = client.chat.completions.create(model=model_id, messages=messages, max_tokens=max_tokens) | |
| return r.choices[0].message.get("content", "") | |
| def extract_json_from_text(text): | |
| if not text: | |
| return None | |
| match = re.search(r'```json\s*(\{[\s\S]*\})\s*```', text) | |
| if match: | |
| json_str = match.group(1) | |
| else: | |
| match = re.search(r'\{[\s\S]*\}', text) | |
| if not match: | |
| return None | |
| json_str = match.group(0) | |
| try: | |
| return json.loads(json_str) | |
| except json.JSONDecodeError as e: | |
| print(f"--- JSON PARSING FAILED ---\nError: {e}\nContent: {json_str}\n--------------------------") | |
| return None | |
| def hf_text_to_image(model_id, prompt, out_path): | |
| img = client.text_to_image(prompt=prompt, model=model_id) | |
| if isinstance(img, Image.Image): | |
| img.save(out_path) | |
| return out_path | |
| raise RuntimeError("Invalid image response") | |
| def is_content_inappropriate(text_to_check): | |
| """Uses Llama Guard to check for inappropriate content.""" | |
| try: | |
| r = client.chat.completions.create( | |
| model=SAFETY_MODEL, | |
| messages=[{"role": "user", "content": f"Please evaluate if the following content is safe or unsafe based on typical safety guidelines for an AI assistant. Output 'safe' or 'unsafe'.\n\nContent: \"{text_to_check}\""}], | |
| max_tokens=20, | |
| temperature=0.1 | |
| ) | |
| response = r.choices[0].message.get("content", "").lower() | |
| return "unsafe" in response | |
| except Exception as e: | |
| print(f"Error in safety check: {e}") | |
| return False | |
| # ---------- Story logic ---------- | |
| def make_prompt(user_prompt, nscenes=6, nsent=5): | |
| # This prompt is excellent and requires no changes. | |
| return f""" | |
| You are a creative story writer. Your task is to write a compelling story based on a user's prompt. | |
| You MUST return the story in a single, valid JSON object. Do not write any text or explanations outside of the JSON structure. | |
| Here is an example of the required JSON format: | |
| {{ | |
| "title": "A descriptive title for the entire story", | |
| "scenes": [ | |
| {{ | |
| "id": 1, | |
| "text": "The full story text for this scene. This should be a complete paragraph with around {nsent} sentences, describing the events and setting.", | |
| "visual_prompt": "A detailed, vivid description for an image generation model, capturing the key visual elements of this scene." | |
| }} | |
| ] | |
| }} | |
| Please use the following details for the story: | |
| - Story Prompt: "{user_prompt}" | |
| - Total number of scenes: {nscenes} | |
| Now, generate the story in the specified JSON format. | |
| """ | |
| def generate_story_and_images(prompt, nscenes, nsent, img_model): | |
| start = time.time() | |
| logs = [] | |
| logs.append("π¬ Generating story...") | |
| raw = try_chat_completion(CHAT_MODEL, [{"role": "user", "content": make_prompt(prompt, nscenes, nsent)}]) | |
| story = extract_json_from_text(raw) | |
| if not story: | |
| story = {"title": "Untitled", "scenes": [{"id": i + 1, "text": f"Scene {i+1}. The AI failed to generate a proper story, or the JSON was malformed.", "visual_prompt": prompt} for i in range(nscenes)]} | |
| logs.append("β οΈ Failed to parse story JSON, using fallback.") | |
| else: | |
| logs.append("β Story JSON parsed.") | |
| image_paths = [] | |
| for s in story["scenes"]: | |
| visual_prompt = s.get("visual_prompt", s.get("text", prompt)) | |
| if is_content_inappropriate(visual_prompt): | |
| gr.Warning(f"Visual prompt for Scene {s['id']} was moderated for safety. Generating a default image.") | |
| visual_prompt = "A serene landscape with gentle colors." | |
| name = OUT_DIR / f"{uuid.uuid4().hex[:6]}_scene_{s['id']}.png" | |
| logs.append(f"π¨ Generating image for Scene {s['id']}...") | |
| hf_text_to_image(img_model, visual_prompt, str(name)) | |
| image_paths.append(str(name)) | |
| total = time.time() - start | |
| logs.append(f"β¨ Done in {total:.2f}s") | |
| return story, image_paths, "\n".join(logs) | |
| # ---------- UI ---------- | |
| def build_ui(): | |
| css = """ | |
| .main-container { max-width: 1400px; margin: 0 auto; } | |
| .prompt-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 32px; border-radius: 16px; margin-bottom: 24px; } | |
| .prompt-box textarea { font-size: 16px !important; } | |
| .story-panel { background: rgba(255,255,255,0.05); padding: 24px; border-radius: 12px; backdrop-filter: blur(10px); border: 1px solid rgba(255,255,255,0.1); max-height: 800px; overflow-y: auto; } | |
| .story-title { font-size: 32px; font-weight: 700; margin-bottom: 24px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; } | |
| .story-content { line-height: 1.8; font-size: 16px; color: rgba(255,255,255,0.9); } | |
| .story-content p { margin-bottom: 16px; } | |
| """ | |
| with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Story Generator") as demo: | |
| gr.Markdown("# π AI Story Generator", elem_classes="main-title") | |
| with gr.Column(elem_classes="main-container"): | |
| with gr.Column(elem_classes="prompt-section"): | |
| prompt_box = gr.Textbox(label="β¨ Enter your story idea", placeholder="e.g. A pirate discovering a hidden island...", lines=3, elem_classes="prompt-box") | |
| with gr.Row(): | |
| generate_btn = gr.Button("π Generate Story", variant="primary", size="lg", scale=3) | |
| with gr.Column(scale=1): | |
| with gr.Accordion("βοΈ Settings", open=False): | |
| nscenes = gr.Slider(2, 12, value=6, step=1, label="π Number of scenes") | |
| nsent = gr.Slider(2, 8, value=5, step=1, label="π Sentences per scene") | |
| img_model = gr.Dropdown(choices=[IMAGE_MODEL], value=IMAGE_MODEL, label="π¨ Image model") | |
| log_box = gr.Textbox(label="π Generation Logs", lines=6, interactive=False) | |
| with gr.Row(): | |
| with gr.Column(scale=5, elem_classes="story-panel"): | |
| story_html = gr.HTML("<div style='text-align:center;padding:40px;color:#888;'>Your story will appear here...<br><br>Click 'Generate Story' to begin! β¨</div>") | |
| with gr.Column(scale=7): | |
| image_gallery = gr.Gallery( | |
| label="π· Scene Visuals", show_label=False, elem_id="gallery", | |
| columns=2, object_fit="cover", height="auto", preview=True | |
| ) | |
| def on_generate(prompt, nscenes, nsent, img_model): | |
| if is_content_inappropriate(prompt): | |
| gr.Warning("Your prompt seems to violate the safety policy. Please try again.") | |
| return "<div style='text-align:center;padding:40px;color:#888;'>Prompt rejected due to safety policy.</div>", [], "Prompt rejected by safety filter." | |
| story, imgs, logs = generate_story_and_images(prompt, int(nscenes), int(nsent), img_model) | |
| story_output_html = f"<div class='story-title'>{story.get('title', 'Untitled')}</div>\n<div class='story-content'>\n" | |
| for s in story.get('scenes', []): | |
| scene_text = s.get('text', '') | |
| if is_content_inappropriate(scene_text): | |
| scene_text = f"**[Scene {s['id']} was moderated for safety and replaced with a placeholder.]**" | |
| gr.Info(f"Scene {s['id']} content was flagged and replaced.") | |
| story_output_html += f"<p>{scene_text}</p>\n\n" | |
| story_output_html += "</div>" | |
| return story_output_html, imgs, logs | |
| generate_btn.click( | |
| on_generate, | |
| inputs=[prompt_box, nscenes, nsent, img_model], | |
| outputs=[story_html, image_gallery, log_box] | |
| ) | |
| return demo | |
| app = build_ui() | |
| if __name__ == "__main__": | |
| app.launch() |