sayakpaul HF Staff commited on
Commit
1a0cfc6
·
verified ·
1 Parent(s): 65b9422

Upload decoders.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. decoders.py +96 -0
decoders.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2025 Wan and Hugging Face Teams. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from diffusers.configuration_utils import FrozenDict
17
+ from diffusers.modular_pipelines import (
18
+ ModularPipelineBlocks, ComponentSpec, InputParam, OutputParam, PipelineState
19
+ )
20
+ from typing import List, Union
21
+ from diffusers import AutoencoderKLWan
22
+ from diffusers.video_processor import VideoProcessor
23
+ import torch
24
+ import PIL
25
+ import numpy as np
26
+
27
+
28
+ class ChronoEditDecodeStep(ModularPipelineBlocks):
29
+ model_name = "chronoedit"
30
+
31
+ @property
32
+ def expected_components(self) -> List[ComponentSpec]:
33
+ return [
34
+ ComponentSpec("vae", AutoencoderKLWan),
35
+ ComponentSpec(
36
+ "video_processor",
37
+ VideoProcessor,
38
+ config=FrozenDict({"vae_scale_factor": 8}),
39
+ default_creation_method="from_config",
40
+ ),
41
+ ]
42
+
43
+ @property
44
+ def description(self) -> str:
45
+ return "Step that decodes the denoised latents into images"
46
+
47
+ @property
48
+ def inputs(self) -> List[InputParam]:
49
+ return [
50
+ InputParam(
51
+ "latents",
52
+ required=True,
53
+ type_hint=torch.Tensor,
54
+ description="The denoised latents from the denoising step",
55
+ ),
56
+ InputParam("output_type", default="pil"),
57
+ ]
58
+
59
+ @property
60
+ def intermediate_outputs(self) -> List[str]:
61
+ return [
62
+ OutputParam(
63
+ "videos",
64
+ type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
65
+ description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
66
+ )
67
+ ]
68
+
69
+ @torch.no_grad()
70
+ def __call__(self, components, state: PipelineState) -> PipelineState:
71
+ block_state = self.get_block_state(state)
72
+ vae_dtype = components.vae.dtype
73
+
74
+ if not block_state.output_type == "latent":
75
+ latents = block_state.latents
76
+ latents_mean = (
77
+ torch.tensor(components.vae.config.latents_mean)
78
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
79
+ .to(latents.device, latents.dtype)
80
+ )
81
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
82
+ 1, components.vae.config.z_dim, 1, 1, 1
83
+ ).to(latents.device, latents.dtype)
84
+ latents = latents / latents_std + latents_mean
85
+ latents = latents.to(vae_dtype)
86
+ block_state.videos = components.vae.decode(latents, return_dict=False)[0]
87
+ else:
88
+ block_state.videos = block_state.latents
89
+
90
+ block_state.videos = components.video_processor.postprocess_video(
91
+ block_state.videos, output_type=block_state.output_type
92
+ )
93
+
94
+ self.set_block_state(state, block_state)
95
+
96
+ return components, state