File size: 2,994 Bytes
b215789
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
from typing import List
from diffusers.modular_pipelines import (
    PipelineState,
    ModularPipelineBlocks,
    InputParam,
    OutputParam,
)
from transformers import AutoModelForCausalLM, AutoTokenizer
import os

SYSTEM_PROMPT = (
    "You are an expert image generation assistant. "
    "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
    "Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
    "Avoid vague terms — be specific about style, perspective, and mood. "
    "Try to keep the output under 512 tokens. "
    "Please don't return any prefix or suffix tokens, just the expanded user description."
)

class QwenPromptExpander(ModularPipelineBlocks):
    def __init__(self, model_id="Qwen/Qwen2.5-3B-Instruct", system_prompt=SYSTEM_PROMPT):
        super().__init__()

        self.system_prompt = system_prompt
        self.model = AutoModelForCausalLM.from_pretrained(
            model_id, torch_dtype="auto"
        ).to("cuda")
        self.tokenizer = AutoTokenizer.from_pretrained(model_id)

    @property
    def expected_components(self):
        return []

    @property
    def inputs(self) -> List[InputParam]:
        return [
            InputParam(
                "prompt",
                type_hint=str,
                required=True,
                description="Prompt to use",
            )
        ]
    
    @property
    def intermediate_inputs(self) -> List[InputParam]:
        return []

    @property
    def intermediate_outputs(self) -> List[OutputParam]:
        return [
            OutputParam(
                "prompt",
                type_hint=str,
                description="Expanded prompt by the LLM",
            ),
            OutputParam(
                "old_prompt",
                type_hint=str,
                description="Old prompt provided by the user",
            )
        ]


    def __call__(self, components, state: PipelineState) -> PipelineState:
        block_state = self.get_block_state(state)

        old_prompt = block_state.prompt
        print(f"Actual prompt: {old_prompt}")
        
        messages = [
            {"role": "system", "content": self.system_prompt},
            {"role": "user", "content": old_prompt}
        ]
        text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
        model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
        generated_ids = self.model.generate(**model_inputs,max_new_tokens=512)
        generated_ids = [
            output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
        ]

        block_state.prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
        block_state.old_prompt = old_prompt
        print(f"{block_state.prompt=}")
        self.set_block_state(state, block_state)

        return components, state