File size: 8,825 Bytes
c2793b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
# app.py β€” Story Generator with Elegant UI
import os, json, re, pathlib, base64, time, uuid
from huggingface_hub import InferenceClient
from PIL import Image
import gradio as gr

# ---------- Config ----------
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
    raise RuntimeError("⚠️ Set HF_TOKEN environment variable (use Spaces Secrets).")

CHAT_MODEL = "meta-llama/Llama-3.1-8b-instruct"
IMAGE_MODEL = "black-forest-labs/FLUX.1-schnell"
SAFETY_MODEL = "meta-llama/Meta-Llama-Guard-2-8B" # Safety model for content moderation
OUT_DIR = pathlib.Path("/tmp/generated")
OUT_DIR.mkdir(parents=True, exist_ok=True)
client = InferenceClient(api_key=HF_TOKEN, provider="auto")


# ---------- Utility ----------
def try_chat_completion(model_id, messages, max_tokens=3000):
    r = client.chat.completions.create(model=model_id, messages=messages, max_tokens=max_tokens)
    return r.choices[0].message.get("content", "")

def extract_json_from_text(text):
    if not text:
        return None
    match = re.search(r'```json\s*(\{[\s\S]*\})\s*```', text)
    if match:
        json_str = match.group(1)
    else:
        match = re.search(r'\{[\s\S]*\}', text)
        if not match:
            return None
        json_str = match.group(0)
    try:
        return json.loads(json_str)
    except json.JSONDecodeError as e:
        print(f"--- JSON PARSING FAILED ---\nError: {e}\nContent: {json_str}\n--------------------------")
        return None

def hf_text_to_image(model_id, prompt, out_path):
    img = client.text_to_image(prompt=prompt, model=model_id)
    if isinstance(img, Image.Image):
        img.save(out_path)
        return out_path
    raise RuntimeError("Invalid image response")

def is_content_inappropriate(text_to_check):
    """Uses Llama Guard to check for inappropriate content."""
    try:
        r = client.chat.completions.create(
            model=SAFETY_MODEL,
            messages=[{"role": "user", "content": f"Please evaluate if the following content is safe or unsafe based on typical safety guidelines for an AI assistant. Output 'safe' or 'unsafe'.\n\nContent: \"{text_to_check}\""}],
            max_tokens=20,
            temperature=0.1
        )
        response = r.choices[0].message.get("content", "").lower()
        return "unsafe" in response
    except Exception as e:
        print(f"Error in safety check: {e}")
        return False

# ---------- Story logic ----------
def make_prompt(user_prompt, nscenes=6, nsent=5):
    # This prompt is excellent and requires no changes.
    return f"""
You are a creative story writer. Your task is to write a compelling story based on a user's prompt.
You MUST return the story in a single, valid JSON object. Do not write any text or explanations outside of the JSON structure.

Here is an example of the required JSON format:
{{
  "title": "A descriptive title for the entire story",
  "scenes": [
    {{
      "id": 1,
      "text": "The full story text for this scene. This should be a complete paragraph with around {nsent} sentences, describing the events and setting.",
      "visual_prompt": "A detailed, vivid description for an image generation model, capturing the key visual elements of this scene."
    }}
  ]
}}

Please use the following details for the story:
- Story Prompt: "{user_prompt}"
- Total number of scenes: {nscenes}

Now, generate the story in the specified JSON format.
"""

def generate_story_and_images(prompt, nscenes, nsent, img_model):
    start = time.time()
    logs = []

    logs.append("🎬 Generating story...")
    raw = try_chat_completion(CHAT_MODEL, [{"role": "user", "content": make_prompt(prompt, nscenes, nsent)}])
    story = extract_json_from_text(raw)
    
    if not story:
        story = {"title": "Untitled", "scenes": [{"id": i + 1, "text": f"Scene {i+1}. The AI failed to generate a proper story, or the JSON was malformed.", "visual_prompt": prompt} for i in range(nscenes)]}
        logs.append("⚠️ Failed to parse story JSON, using fallback.")
    else:
        logs.append("βœ… Story JSON parsed.")

    image_paths = []
    for s in story["scenes"]:
        visual_prompt = s.get("visual_prompt", s.get("text", prompt))
        if is_content_inappropriate(visual_prompt):
            gr.Warning(f"Visual prompt for Scene {s['id']} was moderated for safety. Generating a default image.")
            visual_prompt = "A serene landscape with gentle colors."
        
        name = OUT_DIR / f"{uuid.uuid4().hex[:6]}_scene_{s['id']}.png"
        logs.append(f"🎨 Generating image for Scene {s['id']}...")
        hf_text_to_image(img_model, visual_prompt, str(name))
        image_paths.append(str(name))
    total = time.time() - start
    logs.append(f"✨ Done in {total:.2f}s")
    return story, image_paths, "\n".join(logs)

# ---------- UI ----------
def build_ui():
    css = """
    .main-container { max-width: 1400px; margin: 0 auto; }
    .prompt-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 32px; border-radius: 16px; margin-bottom: 24px; }
    .prompt-box textarea { font-size: 16px !important; }
    .story-panel { background: rgba(255,255,255,0.05); padding: 24px; border-radius: 12px; backdrop-filter: blur(10px); border: 1px solid rgba(255,255,255,0.1); max-height: 800px; overflow-y: auto; }
    .story-title { font-size: 32px; font-weight: 700; margin-bottom: 24px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
    .story-content { line-height: 1.8; font-size: 16px; color: rgba(255,255,255,0.9); }
    .story-content p { margin-bottom: 16px; }
    """

    with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Story Generator") as demo:
        gr.Markdown("# πŸ“š AI Story Generator", elem_classes="main-title")
        
        with gr.Column(elem_classes="main-container"):
            with gr.Column(elem_classes="prompt-section"):
                prompt_box = gr.Textbox(label="✨ Enter your story idea", placeholder="e.g. A pirate discovering a hidden island...", lines=3, elem_classes="prompt-box")
                with gr.Row():
                    generate_btn = gr.Button("πŸš€ Generate Story", variant="primary", size="lg", scale=3)
                    with gr.Column(scale=1):
                        with gr.Accordion("βš™οΈ Settings", open=False):
                            nscenes = gr.Slider(2, 12, value=6, step=1, label="πŸ“– Number of scenes")
                            nsent = gr.Slider(2, 8, value=5, step=1, label="πŸ“ Sentences per scene")
                            img_model = gr.Dropdown(choices=[IMAGE_MODEL], value=IMAGE_MODEL, label="🎨 Image model")
                            log_box = gr.Textbox(label="πŸ“‹ Generation Logs", lines=6, interactive=False)

            with gr.Row():
                with gr.Column(scale=5, elem_classes="story-panel"):
                    story_html = gr.HTML("<div style='text-align:center;padding:40px;color:#888;'>Your story will appear here...<br><br>Click 'Generate Story' to begin! ✨</div>")
                
                with gr.Column(scale=7):
                    image_gallery = gr.Gallery(
                        label="πŸ“· Scene Visuals", show_label=False, elem_id="gallery",
                        columns=2, object_fit="cover", height="auto", preview=True
                    )

        def on_generate(prompt, nscenes, nsent, img_model):
            if is_content_inappropriate(prompt):
                gr.Warning("Your prompt seems to violate the safety policy. Please try again.")
                return "<div style='text-align:center;padding:40px;color:#888;'>Prompt rejected due to safety policy.</div>", [], "Prompt rejected by safety filter."
            
            story, imgs, logs = generate_story_and_images(prompt, int(nscenes), int(nsent), img_model)
            
            story_output_html = f"<div class='story-title'>{story.get('title', 'Untitled')}</div>\n<div class='story-content'>\n"
            
            for s in story.get('scenes', []):
                scene_text = s.get('text', '')
                if is_content_inappropriate(scene_text):
                    scene_text = f"**[Scene {s['id']} was moderated for safety and replaced with a placeholder.]**"
                    gr.Info(f"Scene {s['id']} content was flagged and replaced.")
                story_output_html += f"<p>{scene_text}</p>\n\n"
            
            story_output_html += "</div>"
            
            return story_output_html, imgs, logs

        generate_btn.click(
            on_generate, 
            inputs=[prompt_box, nscenes, nsent, img_model], 
            outputs=[story_html, image_gallery, log_box]
        )

    return demo

app = build_ui()

if __name__ == "__main__":
    app.launch()