sayakpaul HF Staff commited on
Commit
b215789
·
verified ·
1 Parent(s): a9e1f11

Upload folder using huggingface_hub

Browse files
Files changed (3) hide show
  1. modular_config.json +7 -0
  2. qwen_prompt_expander.py +87 -0
  3. test.py +8 -0
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "QwenPromptExpander",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "qwen_prompt_expander.QwenPromptExpander"
6
+ }
7
+ }
qwen_prompt_expander.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+ from diffusers.modular_pipelines import (
3
+ PipelineState,
4
+ ModularPipelineBlocks,
5
+ InputParam,
6
+ OutputParam,
7
+ )
8
+ from transformers import AutoModelForCausalLM, AutoTokenizer
9
+ import os
10
+
11
+ SYSTEM_PROMPT = (
12
+ "You are an expert image generation assistant. "
13
+ "Take the user's short description and expand it into a vivid, detailed, and clear image generation prompt. "
14
+ "Ensure rich colors, depth, realistic lighting, and an imaginative composition. "
15
+ "Avoid vague terms — be specific about style, perspective, and mood. "
16
+ "Try to keep the output under 512 tokens. "
17
+ "Please don't return any prefix or suffix tokens, just the expanded user description."
18
+ )
19
+
20
+ class QwenPromptExpander(ModularPipelineBlocks):
21
+ def __init__(self, model_id="Qwen/Qwen2.5-3B-Instruct", system_prompt=SYSTEM_PROMPT):
22
+ super().__init__()
23
+
24
+ self.system_prompt = system_prompt
25
+ self.model = AutoModelForCausalLM.from_pretrained(
26
+ model_id, torch_dtype="auto"
27
+ ).to("cuda")
28
+ self.tokenizer = AutoTokenizer.from_pretrained(model_id)
29
+
30
+ @property
31
+ def expected_components(self):
32
+ return []
33
+
34
+ @property
35
+ def inputs(self) -> List[InputParam]:
36
+ return [
37
+ InputParam(
38
+ "prompt",
39
+ type_hint=str,
40
+ required=True,
41
+ description="Prompt to use",
42
+ )
43
+ ]
44
+
45
+ @property
46
+ def intermediate_inputs(self) -> List[InputParam]:
47
+ return []
48
+
49
+ @property
50
+ def intermediate_outputs(self) -> List[OutputParam]:
51
+ return [
52
+ OutputParam(
53
+ "prompt",
54
+ type_hint=str,
55
+ description="Expanded prompt by the LLM",
56
+ ),
57
+ OutputParam(
58
+ "old_prompt",
59
+ type_hint=str,
60
+ description="Old prompt provided by the user",
61
+ )
62
+ ]
63
+
64
+
65
+ def __call__(self, components, state: PipelineState) -> PipelineState:
66
+ block_state = self.get_block_state(state)
67
+
68
+ old_prompt = block_state.prompt
69
+ print(f"Actual prompt: {old_prompt}")
70
+
71
+ messages = [
72
+ {"role": "system", "content": self.system_prompt},
73
+ {"role": "user", "content": old_prompt}
74
+ ]
75
+ text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
76
+ model_inputs = self.tokenizer([text], return_tensors="pt").to(self.model.device)
77
+ generated_ids = self.model.generate(**model_inputs,max_new_tokens=512)
78
+ generated_ids = [
79
+ output_ids[len(input_ids):] for input_ids, output_ids in zip(model_inputs.input_ids, generated_ids)
80
+ ]
81
+
82
+ block_state.prompt = self.tokenizer.batch_decode(generated_ids, skip_special_tokens=True)[0]
83
+ block_state.old_prompt = old_prompt
84
+ print(f"{block_state.prompt=}")
85
+ self.set_block_state(state, block_state)
86
+
87
+ return components, state
test.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from qwen_prompt_expander import QwenPromptExpander
2
+
3
+ # expander = QwenPromptExpander().init_pipeline()
4
+ # output = expander(prompt="a dog sitting by the river, watching the sunset")
5
+ # print(f"{output.values['prompt']=}")
6
+
7
+
8
+ QwenPromptExpander().save_pretrained(".")