Spaces:
Paused
Paused
Commit
·
8f812c4
1
Parent(s):
2f3fed1
Update CUDA device usage in app.py
Browse files
app.py
CHANGED
|
@@ -74,7 +74,7 @@ def encode_cropped_prompt_77tokens(txt: str):
|
|
| 74 |
padding="max_length",
|
| 75 |
max_length=tokenizer.model_max_length,
|
| 76 |
truncation=True,
|
| 77 |
-
return_tensors="pt").input_ids.to(device=
|
| 78 |
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
| 79 |
return text_cond
|
| 80 |
|
|
@@ -117,15 +117,15 @@ def process(input_fg, prompt, input_undo_steps, image_width, image_height, seed,
|
|
| 117 |
rng = torch.Generator(device="cuda").manual_seed(int(seed))
|
| 118 |
|
| 119 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
| 120 |
-
concat_conds = numpy2pytorch([fg]).to(device=
|
| 121 |
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
| 122 |
|
| 123 |
conds = encode_cropped_prompt_77tokens(prompt)
|
| 124 |
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
| 125 |
|
| 126 |
-
fs = torch.tensor(input_undo_steps).to(device=
|
| 127 |
initial_latents = torch.zeros_like(concat_conds)
|
| 128 |
-
concat_conds = concat_conds.to(device=
|
| 129 |
latents = k_sampler(
|
| 130 |
initial_latent=initial_latents,
|
| 131 |
strength=1.0,
|
|
@@ -169,13 +169,13 @@ def process_video_inner(image_1, image_2, prompt, seed=123, steps=25, cfg_scale=
|
|
| 169 |
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
| 170 |
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
| 171 |
|
| 172 |
-
input_frames = input_frames.to(device=
|
| 173 |
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
| 174 |
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
| 175 |
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
| 176 |
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
| 177 |
|
| 178 |
-
input_frames = input_frames.to(device=
|
| 179 |
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
| 180 |
first_frame = input_frame_latents[:, :, 0]
|
| 181 |
last_frame = input_frame_latents[:, :, 1]
|
|
|
|
| 74 |
padding="max_length",
|
| 75 |
max_length=tokenizer.model_max_length,
|
| 76 |
truncation=True,
|
| 77 |
+
return_tensors="pt").input_ids.to(device="cuda")
|
| 78 |
text_cond = text_encoder(cond_ids, attention_mask=None).last_hidden_state
|
| 79 |
return text_cond
|
| 80 |
|
|
|
|
| 117 |
rng = torch.Generator(device="cuda").manual_seed(int(seed))
|
| 118 |
|
| 119 |
fg = resize_and_center_crop(input_fg, image_width, image_height)
|
| 120 |
+
concat_conds = numpy2pytorch([fg]).to(device="cuda", dtype=vae.dtype)
|
| 121 |
concat_conds = vae.encode(concat_conds).latent_dist.mode() * vae.config.scaling_factor
|
| 122 |
|
| 123 |
conds = encode_cropped_prompt_77tokens(prompt)
|
| 124 |
unconds = encode_cropped_prompt_77tokens(n_prompt)
|
| 125 |
|
| 126 |
+
fs = torch.tensor(input_undo_steps).to(device="cuda", dtype=torch.long)
|
| 127 |
initial_latents = torch.zeros_like(concat_conds)
|
| 128 |
+
concat_conds = concat_conds.to(device="cuda", dtype=unet.dtype)
|
| 129 |
latents = k_sampler(
|
| 130 |
initial_latent=initial_latents,
|
| 131 |
strength=1.0,
|
|
|
|
| 169 |
positive_text_cond = video_pipe.encode_cropped_prompt_77tokens(prompt)
|
| 170 |
negative_text_cond = video_pipe.encode_cropped_prompt_77tokens("")
|
| 171 |
|
| 172 |
+
input_frames = input_frames.to(device="cuda", dtype=video_pipe.image_encoder.dtype)
|
| 173 |
positive_image_cond = video_pipe.encode_clip_vision(input_frames)
|
| 174 |
positive_image_cond = video_pipe.image_projection(positive_image_cond)
|
| 175 |
negative_image_cond = video_pipe.encode_clip_vision(torch.zeros_like(input_frames))
|
| 176 |
negative_image_cond = video_pipe.image_projection(negative_image_cond)
|
| 177 |
|
| 178 |
+
input_frames = input_frames.to(device="cuda", dtype=video_pipe.vae.dtype)
|
| 179 |
input_frame_latents, vae_hidden_states = video_pipe.encode_latents(input_frames, return_hidden_states=True)
|
| 180 |
first_frame = input_frame_latents[:, :, 0]
|
| 181 |
last_frame = input_frame_latents[:, :, 1]
|