qwen-prompt-expander / qwen_prompt_expander.py
sayakpaul's picture
sayakpaul HF Staff
Upload folder using huggingface_hub
b215789 verified
from typing import List
from diffusers.modular_pipelines import (
PipelineState,
ModularPipelineBlocks,
InputParam,
OutputParam,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
import os
SYSTEM_PROMPT = (
"You are an expert image generation assistant. "
"Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
"Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
"Avoid vague terms — be specific about style, perspective, and mood. "
"Try to keep the output under 512 tokens. "
"Please don't return any prefix or suffix tokens, just the expanded user description."
)
class QwenPromptExpander(ModularPipelineBlocks):
def __init__(self, model_id="Qwen/Qwen2.5-3B-Instruct", system_prompt=SYSTEM_PROMPT):
super().__init__()
self.system_prompt = system_prompt
self.model = AutoModelForCausalLM.from_pretrained(
model_id, torch_dtype="auto"
).to("cuda")
self.tokenizer = AutoTokenizer.from_pretrained(model_id)
@property
def expected_components(self):
return []
@property
def inputs(self) -> List[InputParam]:
return [
InputParam(
"prompt",
type_hint=str,
required=True,
description="Prompt to use",
)
]
@property
def intermediate_inputs(self) -> List[InputParam]:
return []
@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(
"prompt",
type_hint=str,
description="Expanded prompt by the LLM",
),
OutputParam(
"old_prompt",
type_hint=str,
description="Old prompt provided by the user",
)
]
def __call__(self, components, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)
old_prompt = block_state.prompt
print(f"Actual prompt: {old_prompt}")
messages = [
{"role": "system", "content": self.system_prompt},
{"role": "user", "content": old_prompt}
]
text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
generated_ids = self.model.generate(**model_inputs,max_new_tokens=512)
generated_ids = [
output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
]
block_state.prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
block_state.old_prompt = old_prompt
print(f"{block_state.prompt=}")
self.set_block_state(state, block_state)
return components, state