from typing import List from diffusers.modular_pipelines import ( PipelineState, ModularPipelineBlocks, InputParam, OutputParam, ) from transformers import AutoModelForCausalLM, AutoTokenizer import os SYSTEM_PROMPT = ( "You are an expert image generation assistant. " "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. " "Ensure rich colors, depth, realistic lighting, and an imaginative composition. " "Avoid vague terms — be specific about style, perspective, and mood. " "Try to keep the output under 512 tokens. " "Please don't return any prefix or suffix tokens, just the expanded user description." ) class QwenPromptExpander(ModularPipelineBlocks): def __init__(self, model_id="Qwen/Qwen2.5-3B-Instruct", system_prompt=SYSTEM_PROMPT): super().__init__() self.system_prompt = system_prompt self.model = AutoModelForCausalLM.from_pretrained( model_id, torch_dtype="auto" ).to("cuda") self.tokenizer = AutoTokenizer.from_pretrained(model_id) @property def expected_components(self): return [] @property def inputs(self) -> List[InputParam]: return [ InputParam( "prompt", type_hint=str, required=True, description="Prompt to use", ) ] @property def intermediate_inputs(self) -> List[InputParam]: return [] @property def intermediate_outputs(self) -> List[OutputParam]: return [ OutputParam( "prompt", type_hint=str, description="Expanded prompt by the LLM", ), OutputParam( "old_prompt", type_hint=str, description="Old prompt provided by the user", ) ] def __call__(self, components, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) old_prompt = block_state.prompt print(f"Actual prompt: {old_prompt}") messages = [ {"role": "system", "content": self.system_prompt}, {"role": "user", "content": old_prompt} ] text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device) generated_ids = self.model.generate(**model_inputs,max_new_tokens=512) generated_ids = [ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids) ] block_state.prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0] block_state.old_prompt = old_prompt print(f"{block_state.prompt=}") self.set_block_state(state, block_state) return components, state