Spaces:
Running
Running
File size: 8,825 Bytes
c2793b4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 |
# app.py β Story Generator with Elegant UI
import os, json, re, pathlib, base64, time, uuid
from huggingface_hub import InferenceClient
from PIL import Image
import gradio as gr
# ---------- Config ----------
HF_TOKEN = os.environ.get("HF_TOKEN")
if not HF_TOKEN:
raise RuntimeError("β οΈ Set HF_TOKEN environment variable (use Spaces Secrets).")
CHAT_MODEL = "meta-llama/Llama-3.1-8b-instruct"
IMAGE_MODEL = "black-forest-labs/FLUX.1-schnell"
SAFETY_MODEL = "meta-llama/Meta-Llama-Guard-2-8B" # Safety model for content moderation
OUT_DIR = pathlib.Path("/tmp/generated")
OUT_DIR.mkdir(parents=True, exist_ok=True)
client = InferenceClient(api_key=HF_TOKEN, provider="auto")
# ---------- Utility ----------
def try_chat_completion(model_id, messages, max_tokens=3000):
r = client.chat.completions.create(model=model_id, messages=messages, max_tokens=max_tokens)
return r.choices[0].message.get("content", "")
def extract_json_from_text(text):
if not text:
return None
match = re.search(r'```json\s*(\{[\s\S]*\})\s*```', text)
if match:
json_str = match.group(1)
else:
match = re.search(r'\{[\s\S]*\}', text)
if not match:
return None
json_str = match.group(0)
try:
return json.loads(json_str)
except json.JSONDecodeError as e:
print(f"--- JSON PARSING FAILED ---\nError: {e}\nContent: {json_str}\n--------------------------")
return None
def hf_text_to_image(model_id, prompt, out_path):
img = client.text_to_image(prompt=prompt, model=model_id)
if isinstance(img, Image.Image):
img.save(out_path)
return out_path
raise RuntimeError("Invalid image response")
def is_content_inappropriate(text_to_check):
"""Uses Llama Guard to check for inappropriate content."""
try:
r = client.chat.completions.create(
model=SAFETY_MODEL,
messages=[{"role": "user", "content": f"Please evaluate if the following content is safe or unsafe based on typical safety guidelines for an AI assistant. Output 'safe' or 'unsafe'.\n\nContent: \"{text_to_check}\""}],
max_tokens=20,
temperature=0.1
)
response = r.choices[0].message.get("content", "").lower()
return "unsafe" in response
except Exception as e:
print(f"Error in safety check: {e}")
return False
# ---------- Story logic ----------
def make_prompt(user_prompt, nscenes=6, nsent=5):
# This prompt is excellent and requires no changes.
return f"""
You are a creative story writer. Your task is to write a compelling story based on a user's prompt.
You MUST return the story in a single, valid JSON object. Do not write any text or explanations outside of the JSON structure.
Here is an example of the required JSON format:
{{
"title": "A descriptive title for the entire story",
"scenes": [
{{
"id": 1,
"text": "The full story text for this scene. This should be a complete paragraph with around {nsent} sentences, describing the events and setting.",
"visual_prompt": "A detailed, vivid description for an image generation model, capturing the key visual elements of this scene."
}}
]
}}
Please use the following details for the story:
- Story Prompt: "{user_prompt}"
- Total number of scenes: {nscenes}
Now, generate the story in the specified JSON format.
"""
def generate_story_and_images(prompt, nscenes, nsent, img_model):
start = time.time()
logs = []
logs.append("π¬ Generating story...")
raw = try_chat_completion(CHAT_MODEL, [{"role": "user", "content": make_prompt(prompt, nscenes, nsent)}])
story = extract_json_from_text(raw)
if not story:
story = {"title": "Untitled", "scenes": [{"id": i + 1, "text": f"Scene {i+1}. The AI failed to generate a proper story, or the JSON was malformed.", "visual_prompt": prompt} for i in range(nscenes)]}
logs.append("β οΈ Failed to parse story JSON, using fallback.")
else:
logs.append("β
Story JSON parsed.")
image_paths = []
for s in story["scenes"]:
visual_prompt = s.get("visual_prompt", s.get("text", prompt))
if is_content_inappropriate(visual_prompt):
gr.Warning(f"Visual prompt for Scene {s['id']} was moderated for safety. Generating a default image.")
visual_prompt = "A serene landscape with gentle colors."
name = OUT_DIR / f"{uuid.uuid4().hex[:6]}_scene_{s['id']}.png"
logs.append(f"π¨ Generating image for Scene {s['id']}...")
hf_text_to_image(img_model, visual_prompt, str(name))
image_paths.append(str(name))
total = time.time() - start
logs.append(f"β¨ Done in {total:.2f}s")
return story, image_paths, "\n".join(logs)
# ---------- UI ----------
def build_ui():
css = """
.main-container { max-width: 1400px; margin: 0 auto; }
.prompt-section { background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); padding: 32px; border-radius: 16px; margin-bottom: 24px; }
.prompt-box textarea { font-size: 16px !important; }
.story-panel { background: rgba(255,255,255,0.05); padding: 24px; border-radius: 12px; backdrop-filter: blur(10px); border: 1px solid rgba(255,255,255,0.1); max-height: 800px; overflow-y: auto; }
.story-title { font-size: 32px; font-weight: 700; margin-bottom: 24px; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); -webkit-background-clip: text; -webkit-text-fill-color: transparent; }
.story-content { line-height: 1.8; font-size: 16px; color: rgba(255,255,255,0.9); }
.story-content p { margin-bottom: 16px; }
"""
with gr.Blocks(css=css, theme=gr.themes.Soft(), title="Story Generator") as demo:
gr.Markdown("# π AI Story Generator", elem_classes="main-title")
with gr.Column(elem_classes="main-container"):
with gr.Column(elem_classes="prompt-section"):
prompt_box = gr.Textbox(label="β¨ Enter your story idea", placeholder="e.g. A pirate discovering a hidden island...", lines=3, elem_classes="prompt-box")
with gr.Row():
generate_btn = gr.Button("π Generate Story", variant="primary", size="lg", scale=3)
with gr.Column(scale=1):
with gr.Accordion("βοΈ Settings", open=False):
nscenes = gr.Slider(2, 12, value=6, step=1, label="π Number of scenes")
nsent = gr.Slider(2, 8, value=5, step=1, label="π Sentences per scene")
img_model = gr.Dropdown(choices=[IMAGE_MODEL], value=IMAGE_MODEL, label="π¨ Image model")
log_box = gr.Textbox(label="π Generation Logs", lines=6, interactive=False)
with gr.Row():
with gr.Column(scale=5, elem_classes="story-panel"):
story_html = gr.HTML("<div style='text-align:center;padding:40px;color:#888;'>Your story will appear here...<br><br>Click 'Generate Story' to begin! β¨</div>")
with gr.Column(scale=7):
image_gallery = gr.Gallery(
label="π· Scene Visuals", show_label=False, elem_id="gallery",
columns=2, object_fit="cover", height="auto", preview=True
)
def on_generate(prompt, nscenes, nsent, img_model):
if is_content_inappropriate(prompt):
gr.Warning("Your prompt seems to violate the safety policy. Please try again.")
return "<div style='text-align:center;padding:40px;color:#888;'>Prompt rejected due to safety policy.</div>", [], "Prompt rejected by safety filter."
story, imgs, logs = generate_story_and_images(prompt, int(nscenes), int(nsent), img_model)
story_output_html = f"<div class='story-title'>{story.get('title', 'Untitled')}</div>\n<div class='story-content'>\n"
for s in story.get('scenes', []):
scene_text = s.get('text', '')
if is_content_inappropriate(scene_text):
scene_text = f"**[Scene {s['id']} was moderated for safety and replaced with a placeholder.]**"
gr.Info(f"Scene {s['id']} content was flagged and replaced.")
story_output_html += f"<p>{scene_text}</p>\n\n"
story_output_html += "</div>"
return story_output_html, imgs, logs
generate_btn.click(
on_generate,
inputs=[prompt_box, nscenes, nsent, img_model],
outputs=[story_html, image_gallery, log_box]
)
return demo
app = build_ui()
if __name__ == "__main__":
app.launch() |