sayakpaul HF Staff commited on
Commit
1b761f5
·
verified ·
1 Parent(s): 0afe5e3

Upload encoders.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. encoders.py +169 -0
encoders.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional, List
17
+ from diffusers.modular_pipelines import (
18
+ ModularPipelineBlocks,
19
+ ComponentSpec,
20
+ InputParam,
21
+ OutputParam,
22
+ ModularPipeline,
23
+ PipelineState,
24
+ )
25
+ from diffusers.guiders import ClassifierFreeGuidance
26
+ from transformers import UMT5EncoderModel, AutoTokenizer
27
+ from diffusers.image_processor import PipelineImageInput
28
+ import torch
29
+ from diffusers.modular_pipelines.wan.encoders import WanTextEncoderStep
30
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
31
+ from diffusers.video_processor import VideoProcessor
32
+ from diffusers.configuration_utils import FrozenDict
33
+
34
+
35
+ class ChronoEditImageEncoderStep(ModularPipelineBlocks):
36
+ model_name = "chronoedit"
37
+
38
+ @property
39
+ def expected_components(self) -> List[ComponentSpec]:
40
+ return [
41
+ ComponentSpec("image_processor", CLIPImageProcessor),
42
+ ComponentSpec("image_encoder", CLIPVisionModelWithProjection),
43
+ ]
44
+
45
+ @property
46
+ def inputs(self) -> List[InputParam]:
47
+ return [InputParam("image", type_hint=PipelineImageInput)]
48
+
49
+ @property
50
+ def intermediate_outputs(self) -> List[OutputParam]:
51
+ return [
52
+ OutputParam(
53
+ "image_embeds",
54
+ type_hint=torch.Tensor,
55
+ description="Image embeddings to use as conditions during the denoising process.",
56
+ )
57
+ ]
58
+
59
+ @staticmethod
60
+ def encode_image(components, image: PipelineImageInput, device: Optional[torch.device] = None):
61
+ device = device or components.image_encoder.device
62
+ image = components.image_processor(images=image, return_tensors="pt").to(device)
63
+ image_embeds = components.image_encoder(**image, output_hidden_states=True)
64
+ return image_embeds.hidden_states[-2]
65
+
66
+ @torch.no_grad()
67
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
68
+ block_state = self.get_block_state(state)
69
+ block_state.image_embeds = self.encode_image(components, block_state.image, components._execution_device)
70
+ self.set_block_state(state, block_state)
71
+ return components, state
72
+
73
+
74
+ class ChronoEditProcessImageStep(ModularPipelineBlocks):
75
+ model_name = "chronoedit"
76
+
77
+ @property
78
+ def inputs(self) -> List[InputParam]:
79
+ return [
80
+ InputParam("image", type_hint=PipelineImageInput),
81
+ InputParam("image_embeds", type_hint=torch.Tensor, required=False),
82
+ InputParam("batch_size", type_hint=int, required=False),
83
+ InputParam("height", type_hint=int),
84
+ InputParam("width", type_hint=int),
85
+ ]
86
+
87
+ @property
88
+ def intermediate_outputs(self) -> List[OutputParam]:
89
+ return [
90
+ OutputParam("processed_image", type_hint=PipelineImageInput),
91
+ OutputParam("image_embeds", type_hint=torch.Tensor)
92
+ ]
93
+
94
+ @property
95
+ def expected_components(self) -> List[ComponentSpec]:
96
+ return [
97
+ ComponentSpec(
98
+ "video_processor",
99
+ VideoProcessor,
100
+ config=FrozenDict({"vae_scale_factor": 8}),
101
+ default_creation_method="from_config",
102
+ )
103
+ ]
104
+
105
+ @torch.no_grad()
106
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
107
+ block_state = self.get_block_state(state)
108
+ image = block_state.image
109
+ device = components._execution_device
110
+
111
+ block_state.processed_image = components.video_processor.preprocess(
112
+ image, height=block_state.height, width=block_state.width
113
+ ).to(device, dtype=torch.bfloat16)
114
+
115
+ if block_state.image_embeds is not None:
116
+ image_embeds = block_state.image_embeds
117
+ batch_size = block_state.batch_size
118
+ block_state.image_embeds = image_embeds.repeat(batch_size, 1, 1).to(torch.bfloat16)
119
+
120
+ self.set_block_state(state, block_state)
121
+
122
+ return components, state
123
+
124
+
125
+ # Configure CFG with a guidance scale of 1.
126
+ class ChronoEditTextEncoderStep(WanTextEncoderStep):
127
+ model_name = "chronoedit"
128
+
129
+ @property
130
+ def expected_components(self) -> List[ComponentSpec]:
131
+ return [
132
+ ComponentSpec("text_encoder", UMT5EncoderModel),
133
+ ComponentSpec("tokenizer", AutoTokenizer),
134
+ ComponentSpec(
135
+ "guider",
136
+ ClassifierFreeGuidance,
137
+ config=FrozenDict({"guidance_scale": 1.0}),
138
+ default_creation_method="from_config",
139
+ ),
140
+ ]
141
+
142
+ @torch.no_grad()
143
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
144
+ # Get inputs and intermediates
145
+ block_state = self.get_block_state(state)
146
+ self.check_inputs(block_state)
147
+
148
+ block_state.prepare_unconditional_embeds = components.guider.num_conditions > 1
149
+ block_state.device = components._execution_device
150
+
151
+ block_state.negative_prompt_embeds = None
152
+ # Encode input prompt
153
+ (
154
+ block_state.prompt_embeds,
155
+ block_state.negative_prompt_embeds,
156
+ ) = self.encode_prompt(
157
+ components,
158
+ block_state.prompt,
159
+ block_state.device,
160
+ 1,
161
+ block_state.prepare_unconditional_embeds,
162
+ block_state.negative_prompt,
163
+ prompt_embeds=None,
164
+ negative_prompt_embeds=block_state.negative_prompt_embeds,
165
+ )
166
+
167
+ # Add outputs
168
+ self.set_block_state(state, block_state)
169
+ return components, state