Spaces:
Running
on
Zero
Running
on
Zero
wuwenxu.01
commited on
Commit
·
c62efeb
1
Parent(s):
9d31e57
fix: remove unused parameters
Browse files- app.py +19 -17
- assets/examples/3one2one/config.json +0 -8
- assets/examples/3one2one/ref1.png +0 -3
- assets/examples/3one2one/result.png +0 -3
- assets/examples/{5two2one → 3two2one}/config.json +1 -1
- assets/examples/{5two2one → 3two2one}/ref1.png +0 -0
- assets/examples/{5two2one → 3two2one}/ref2.png +0 -0
- assets/examples/{5two2one → 3two2one}/result.png +0 -0
- assets/examples/4two2one/config.json +1 -1
- assets/examples/{6many2one → 5many2one}/config.json +0 -0
- assets/examples/{6many2one → 5many2one}/ref1.png +0 -0
- assets/examples/{6many2one → 5many2one}/ref2.png +0 -0
- assets/examples/{6many2one → 5many2one}/ref3.png +0 -0
- assets/examples/{6many2one → 5many2one}/result.png +0 -0
- assets/examples/{7t2i → 6t2i}/config.json +0 -0
- assets/examples/{7t2i → 6t2i}/result.png +0 -0
- uno/flux/pipeline.py +2 -22
- uno/flux/sampling.py +0 -19
- uno/flux/util.py +9 -3
app.py
CHANGED
|
@@ -44,7 +44,6 @@ def get_examples(examples_dir: str = "assets/examples") -> list:
|
|
| 44 |
example_list.append(None)
|
| 45 |
|
| 46 |
example_list.append(example_dict["seed"])
|
| 47 |
-
example_list.append(example_dict["ref_long_side"])
|
| 48 |
|
| 49 |
ans.append(example_list)
|
| 50 |
return ans
|
|
@@ -58,23 +57,27 @@ def create_demo(
|
|
| 58 |
pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
|
| 59 |
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
with gr.Blocks() as demo:
|
| 62 |
gr.Markdown(f"# UNO by UNO team")
|
|
|
|
| 63 |
with gr.Row():
|
| 64 |
with gr.Column():
|
| 65 |
prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
|
| 66 |
with gr.Row():
|
| 67 |
-
image_prompt1 = gr.Image(label="
|
| 68 |
-
image_prompt2 = gr.Image(label="
|
| 69 |
-
image_prompt3 = gr.Image(label="
|
| 70 |
-
image_prompt4 = gr.Image(label="
|
| 71 |
-
|
| 72 |
-
with gr.Row():
|
| 73 |
-
with gr.Column():
|
| 74 |
-
ref_long_side = gr.Slider(128, 512, 512, step=16, label="Long side of Ref Images")
|
| 75 |
-
with gr.Column():
|
| 76 |
-
gr.Markdown("📌 **The recommended ref scale** is related to the ref img number.\n")
|
| 77 |
-
gr.Markdown(" 1->512 / 2,3,4->320")
|
| 78 |
|
| 79 |
with gr.Row():
|
| 80 |
with gr.Column():
|
|
@@ -87,7 +90,7 @@ def create_demo(
|
|
| 87 |
" and the higher size gives a better visual effect but is less stable"
|
| 88 |
)
|
| 89 |
|
| 90 |
-
with gr.Accordion("
|
| 91 |
with gr.Row():
|
| 92 |
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
|
| 93 |
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
|
|
@@ -102,7 +105,7 @@ def create_demo(
|
|
| 102 |
|
| 103 |
inputs = [
|
| 104 |
prompt, width, height, guidance, num_steps,
|
| 105 |
-
seed,
|
| 106 |
]
|
| 107 |
generate_btn.click(
|
| 108 |
fn=pipeline.gradio_generate,
|
|
@@ -118,11 +121,10 @@ def create_demo(
|
|
| 118 |
inputs=[
|
| 119 |
example_text, prompt,
|
| 120 |
image_prompt1, image_prompt2, image_prompt3, image_prompt4,
|
| 121 |
-
seed,
|
| 122 |
],
|
| 123 |
)
|
| 124 |
|
| 125 |
-
|
| 126 |
return demo
|
| 127 |
|
| 128 |
if __name__ == "__main__":
|
|
@@ -145,4 +147,4 @@ if __name__ == "__main__":
|
|
| 145 |
args = args_tuple[0]
|
| 146 |
|
| 147 |
demo = create_demo(args.name, args.device, args.offload)
|
| 148 |
-
demo.launch(server_port=args.port)
|
|
|
|
| 44 |
example_list.append(None)
|
| 45 |
|
| 46 |
example_list.append(example_dict["seed"])
|
|
|
|
| 47 |
|
| 48 |
ans.append(example_list)
|
| 49 |
return ans
|
|
|
|
| 57 |
pipeline = UNOPipeline(model_type, device, offload, only_lora=True, lora_rank=512)
|
| 58 |
pipeline.gradio_generate = spaces.GPU(duratioin=120)(pipeline.gradio_generate)
|
| 59 |
|
| 60 |
+
|
| 61 |
+
badges_text = r"""
|
| 62 |
+
<div style="text-align: center; display: flex; justify-content: left; gap: 5px;">
|
| 63 |
+
<a href="https://bytedance.github.io/UNO/"><img alt="Build" src="https://img.shields.io/badge/Project%20Page-UNO-yellow"></a>
|
| 64 |
+
<a href="https://arxiv.org/abs/2504.02160"><img alt="Build" src="https://img.shields.io/badge/arXiv%20paper-UNO-b31b1b.svg"></a>
|
| 65 |
+
<a href="https://huggingface.co/bytedance-research/UNO"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=Model&color=orange"></a>
|
| 66 |
+
<a href="https://huggingface.co/spaces/bytedance-research/UNO-FLUX"><img src="https://img.shields.io/static/v1?label=%F0%9F%A4%97%20Hugging%20Face&message=demo&color=orange"></a>
|
| 67 |
+
</div>
|
| 68 |
+
""".strip()
|
| 69 |
+
|
| 70 |
with gr.Blocks() as demo:
|
| 71 |
gr.Markdown(f"# UNO by UNO team")
|
| 72 |
+
gr.Markdown(badges_text)
|
| 73 |
with gr.Row():
|
| 74 |
with gr.Column():
|
| 75 |
prompt = gr.Textbox(label="Prompt", value="handsome woman in the city")
|
| 76 |
with gr.Row():
|
| 77 |
+
image_prompt1 = gr.Image(label="Ref Img1", visible=True, interactive=True, type="pil")
|
| 78 |
+
image_prompt2 = gr.Image(label="Ref Img2", visible=True, interactive=True, type="pil")
|
| 79 |
+
image_prompt3 = gr.Image(label="Ref Img3", visible=True, interactive=True, type="pil")
|
| 80 |
+
image_prompt4 = gr.Image(label="Ref img4", visible=True, interactive=True, type="pil")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 81 |
|
| 82 |
with gr.Row():
|
| 83 |
with gr.Column():
|
|
|
|
| 90 |
" and the higher size gives a better visual effect but is less stable"
|
| 91 |
)
|
| 92 |
|
| 93 |
+
with gr.Accordion("Advanced Options", open=False):
|
| 94 |
with gr.Row():
|
| 95 |
num_steps = gr.Slider(1, 50, 25, step=1, label="Number of steps")
|
| 96 |
guidance = gr.Slider(1.0, 5.0, 4.0, step=0.1, label="Guidance", interactive=True)
|
|
|
|
| 105 |
|
| 106 |
inputs = [
|
| 107 |
prompt, width, height, guidance, num_steps,
|
| 108 |
+
seed, image_prompt1, image_prompt2, image_prompt3, image_prompt4
|
| 109 |
]
|
| 110 |
generate_btn.click(
|
| 111 |
fn=pipeline.gradio_generate,
|
|
|
|
| 121 |
inputs=[
|
| 122 |
example_text, prompt,
|
| 123 |
image_prompt1, image_prompt2, image_prompt3, image_prompt4,
|
| 124 |
+
seed, output_image
|
| 125 |
],
|
| 126 |
)
|
| 127 |
|
|
|
|
| 128 |
return demo
|
| 129 |
|
| 130 |
if __name__ == "__main__":
|
|
|
|
| 147 |
args = args_tuple[0]
|
| 148 |
|
| 149 |
demo = create_demo(args.name, args.device, args.offload)
|
| 150 |
+
demo.launch(server_port=args.port)
|
assets/examples/3one2one/config.json
DELETED
|
@@ -1,8 +0,0 @@
|
|
| 1 |
-
{
|
| 2 |
-
"prompt": "3d cartoon style, a woman.",
|
| 3 |
-
"seed": 2,
|
| 4 |
-
"ref_long_side": 512,
|
| 5 |
-
"useage": "one2one",
|
| 6 |
-
"image_ref1": "./ref1.png",
|
| 7 |
-
"image_result": "./result.png"
|
| 8 |
-
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assets/examples/3one2one/ref1.png
DELETED
Git LFS Details
|
assets/examples/3one2one/result.png
DELETED
Git LFS Details
|
assets/examples/{5two2one → 3two2one}/config.json
RENAMED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"prompt": "The figurine is in the crystal ball",
|
| 3 |
-
"seed":
|
| 4 |
"ref_long_side": 320,
|
| 5 |
"useage": "two2one",
|
| 6 |
"image_ref1": "./ref1.png",
|
|
|
|
| 1 |
{
|
| 2 |
"prompt": "The figurine is in the crystal ball",
|
| 3 |
+
"seed": 0,
|
| 4 |
"ref_long_side": 320,
|
| 5 |
"useage": "two2one",
|
| 6 |
"image_ref1": "./ref1.png",
|
assets/examples/{5two2one → 3two2one}/ref1.png
RENAMED
|
File without changes
|
assets/examples/{5two2one → 3two2one}/ref2.png
RENAMED
|
File without changes
|
assets/examples/{5two2one → 3two2one}/result.png
RENAMED
|
File without changes
|
assets/examples/4two2one/config.json
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
{
|
| 2 |
"prompt": "The logo is printed on the cup",
|
| 3 |
-
"seed":
|
| 4 |
"ref_long_side": 320,
|
| 5 |
"useage": "two2one",
|
| 6 |
"image_ref1": "./ref1.png",
|
|
|
|
| 1 |
{
|
| 2 |
"prompt": "The logo is printed on the cup",
|
| 3 |
+
"seed": 61733557,
|
| 4 |
"ref_long_side": 320,
|
| 5 |
"useage": "two2one",
|
| 6 |
"image_ref1": "./ref1.png",
|
assets/examples/{6many2one → 5many2one}/config.json
RENAMED
|
File without changes
|
assets/examples/{6many2one → 5many2one}/ref1.png
RENAMED
|
File without changes
|
assets/examples/{6many2one → 5many2one}/ref2.png
RENAMED
|
File without changes
|
assets/examples/{6many2one → 5many2one}/ref3.png
RENAMED
|
File without changes
|
assets/examples/{6many2one → 5many2one}/result.png
RENAMED
|
File without changes
|
assets/examples/{7t2i → 6t2i}/config.json
RENAMED
|
File without changes
|
assets/examples/{7t2i → 6t2i}/result.png
RENAMED
|
File without changes
|
uno/flux/pipeline.py
CHANGED
|
@@ -27,7 +27,7 @@ from uno.flux.modules.layers import (
|
|
| 27 |
SingleStreamBlockLoraProcessor,
|
| 28 |
SingleStreamBlockProcessor,
|
| 29 |
)
|
| 30 |
-
from uno.flux.sampling import denoise, get_noise, get_schedule,
|
| 31 |
from uno.flux.util import (
|
| 32 |
get_lora_rank,
|
| 33 |
load_ae,
|
|
@@ -185,10 +185,6 @@ class UNOPipeline:
|
|
| 185 |
guidance: float = 4,
|
| 186 |
num_steps: int = 50,
|
| 187 |
seed: int = 123456789,
|
| 188 |
-
true_gs: float = 3,
|
| 189 |
-
neg_prompt: str = '',
|
| 190 |
-
neg_image_prompt: Image = None,
|
| 191 |
-
timestep_to_start_cfg: int = 0,
|
| 192 |
**kwargs
|
| 193 |
):
|
| 194 |
width = 16 * (width // 16)
|
|
@@ -201,9 +197,6 @@ class UNOPipeline:
|
|
| 201 |
guidance,
|
| 202 |
num_steps,
|
| 203 |
seed,
|
| 204 |
-
timestep_to_start_cfg=timestep_to_start_cfg,
|
| 205 |
-
true_gs=true_gs,
|
| 206 |
-
neg_prompt=neg_prompt,
|
| 207 |
**kwargs
|
| 208 |
)
|
| 209 |
|
|
@@ -216,7 +209,6 @@ class UNOPipeline:
|
|
| 216 |
guidance: float,
|
| 217 |
num_steps: int,
|
| 218 |
seed: int,
|
| 219 |
-
ref_long_side: int,
|
| 220 |
image_prompt1: Image.Image,
|
| 221 |
image_prompt2: Image.Image,
|
| 222 |
image_prompt3: Image.Image,
|
|
@@ -224,6 +216,7 @@ class UNOPipeline:
|
|
| 224 |
):
|
| 225 |
ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
|
| 226 |
ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
|
|
|
|
| 227 |
ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
|
| 228 |
|
| 229 |
seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
|
|
@@ -250,9 +243,6 @@ class UNOPipeline:
|
|
| 250 |
guidance: float,
|
| 251 |
num_steps: int,
|
| 252 |
seed: int,
|
| 253 |
-
timestep_to_start_cfg: int = 1e5, # TODO 没用,删除
|
| 254 |
-
true_gs: float = 3.5,
|
| 255 |
-
neg_prompt: str = "",
|
| 256 |
ref_imgs: list[Image.Image] | None = None,
|
| 257 |
pe: Literal['d', 'h', 'w', 'o'] = 'd',
|
| 258 |
):
|
|
@@ -283,11 +273,6 @@ class UNOPipeline:
|
|
| 283 |
img=x,
|
| 284 |
prompt=prompt, ref_imgs=x_1_refs, pe=pe
|
| 285 |
)
|
| 286 |
-
neg_inp_cond = prepare_multi_ip(
|
| 287 |
-
t5=self.t5, clip=self.clip,
|
| 288 |
-
img=x,
|
| 289 |
-
prompt=neg_prompt, ref_imgs=x_1_refs, pe=pe
|
| 290 |
-
)
|
| 291 |
|
| 292 |
if self.offload:
|
| 293 |
self.offload_model_to_cpu(self.t5, self.clip)
|
|
@@ -298,11 +283,6 @@ class UNOPipeline:
|
|
| 298 |
**inp_cond,
|
| 299 |
timesteps=timesteps,
|
| 300 |
guidance=guidance,
|
| 301 |
-
timestep_to_start_cfg=timestep_to_start_cfg,
|
| 302 |
-
neg_txt=neg_inp_cond['txt'],
|
| 303 |
-
neg_txt_ids=neg_inp_cond['txt_ids'],
|
| 304 |
-
neg_vec=neg_inp_cond['vec'],
|
| 305 |
-
true_gs=true_gs,
|
| 306 |
)
|
| 307 |
|
| 308 |
if self.offload:
|
|
|
|
| 27 |
SingleStreamBlockLoraProcessor,
|
| 28 |
SingleStreamBlockProcessor,
|
| 29 |
)
|
| 30 |
+
from uno.flux.sampling import denoise, get_noise, get_schedule, prepare_multi_ip, unpack
|
| 31 |
from uno.flux.util import (
|
| 32 |
get_lora_rank,
|
| 33 |
load_ae,
|
|
|
|
| 185 |
guidance: float = 4,
|
| 186 |
num_steps: int = 50,
|
| 187 |
seed: int = 123456789,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 188 |
**kwargs
|
| 189 |
):
|
| 190 |
width = 16 * (width // 16)
|
|
|
|
| 197 |
guidance,
|
| 198 |
num_steps,
|
| 199 |
seed,
|
|
|
|
|
|
|
|
|
|
| 200 |
**kwargs
|
| 201 |
)
|
| 202 |
|
|
|
|
| 209 |
guidance: float,
|
| 210 |
num_steps: int,
|
| 211 |
seed: int,
|
|
|
|
| 212 |
image_prompt1: Image.Image,
|
| 213 |
image_prompt2: Image.Image,
|
| 214 |
image_prompt3: Image.Image,
|
|
|
|
| 216 |
):
|
| 217 |
ref_imgs = [image_prompt1, image_prompt2, image_prompt3, image_prompt4]
|
| 218 |
ref_imgs = [img for img in ref_imgs if isinstance(img, Image.Image)]
|
| 219 |
+
ref_long_side = 512 if len(ref_imgs) <= 1 else 320
|
| 220 |
ref_imgs = [preprocess_ref(img, ref_long_side) for img in ref_imgs]
|
| 221 |
|
| 222 |
seed = seed if seed != -1 else torch.randint(0, 10 ** 8, (1,)).item()
|
|
|
|
| 243 |
guidance: float,
|
| 244 |
num_steps: int,
|
| 245 |
seed: int,
|
|
|
|
|
|
|
|
|
|
| 246 |
ref_imgs: list[Image.Image] | None = None,
|
| 247 |
pe: Literal['d', 'h', 'w', 'o'] = 'd',
|
| 248 |
):
|
|
|
|
| 273 |
img=x,
|
| 274 |
prompt=prompt, ref_imgs=x_1_refs, pe=pe
|
| 275 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 276 |
|
| 277 |
if self.offload:
|
| 278 |
self.offload_model_to_cpu(self.t5, self.clip)
|
|
|
|
| 283 |
**inp_cond,
|
| 284 |
timesteps=timesteps,
|
| 285 |
guidance=guidance,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 286 |
)
|
| 287 |
|
| 288 |
if self.offload:
|
uno/flux/sampling.py
CHANGED
|
@@ -215,14 +215,9 @@ def denoise(
|
|
| 215 |
txt: Tensor,
|
| 216 |
txt_ids: Tensor,
|
| 217 |
vec: Tensor,
|
| 218 |
-
neg_txt: Tensor,
|
| 219 |
-
neg_txt_ids: Tensor,
|
| 220 |
-
neg_vec: Tensor,
|
| 221 |
# sampling parameters
|
| 222 |
timesteps: list[float],
|
| 223 |
guidance: float = 4.0,
|
| 224 |
-
true_gs = 1,
|
| 225 |
-
timestep_to_start_cfg=0,
|
| 226 |
ref_img: Tensor=None,
|
| 227 |
ref_img_ids: Tensor=None,
|
| 228 |
):
|
|
@@ -241,20 +236,6 @@ def denoise(
|
|
| 241 |
timesteps=t_vec,
|
| 242 |
guidance=guidance_vec
|
| 243 |
)
|
| 244 |
-
if i >= timestep_to_start_cfg:
|
| 245 |
-
# not test
|
| 246 |
-
neg_pred = model(
|
| 247 |
-
img=img,
|
| 248 |
-
img_ids=img_ids,
|
| 249 |
-
ref_img=ref_img, # TODO: neg img embedding
|
| 250 |
-
ref_img_ids=ref_img_ids,
|
| 251 |
-
txt=neg_txt,
|
| 252 |
-
txt_ids=neg_txt_ids,
|
| 253 |
-
y=neg_vec,
|
| 254 |
-
timesteps=t_vec,
|
| 255 |
-
guidance=guidance_vec,
|
| 256 |
-
)
|
| 257 |
-
pred = neg_pred + true_gs * (pred - neg_pred)
|
| 258 |
img = img + (t_prev - t_curr) * pred
|
| 259 |
i += 1
|
| 260 |
return img
|
|
|
|
| 215 |
txt: Tensor,
|
| 216 |
txt_ids: Tensor,
|
| 217 |
vec: Tensor,
|
|
|
|
|
|
|
|
|
|
| 218 |
# sampling parameters
|
| 219 |
timesteps: list[float],
|
| 220 |
guidance: float = 4.0,
|
|
|
|
|
|
|
| 221 |
ref_img: Tensor=None,
|
| 222 |
ref_img_ids: Tensor=None,
|
| 223 |
):
|
|
|
|
| 236 |
timesteps=t_vec,
|
| 237 |
guidance=guidance_vec
|
| 238 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 239 |
img = img + (t_prev - t_curr) * pred
|
| 240 |
i += 1
|
| 241 |
return img
|
uno/flux/util.py
CHANGED
|
@@ -271,7 +271,11 @@ def load_flow_model_only_lora(
|
|
| 271 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
| 272 |
|
| 273 |
if hf_download:
|
| 274 |
-
lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
|
|
|
|
|
|
|
|
|
|
|
|
|
| 275 |
else:
|
| 276 |
lora_ckpt_path = os.environ.get("LORA", None)
|
| 277 |
|
|
@@ -362,10 +366,12 @@ def load_flow_model_quintized(name: str, device: str | torch.device = "cuda", hf
|
|
| 362 |
|
| 363 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 364 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 365 |
-
|
|
|
|
| 366 |
|
| 367 |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
| 368 |
-
|
|
|
|
| 369 |
|
| 370 |
|
| 371 |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|
|
|
|
| 271 |
ckpt_path = hf_hub_download(configs[name].repo_id, configs[name].repo_flow.replace("sft", "safetensors"))
|
| 272 |
|
| 273 |
if hf_download:
|
| 274 |
+
# lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
|
| 275 |
+
try:
|
| 276 |
+
lora_ckpt_path = hf_hub_download("bytedance-research/UNO", "dit_lora.safetensors")
|
| 277 |
+
except:
|
| 278 |
+
lora_ckpt_path = os.environ.get("LORA", None)
|
| 279 |
else:
|
| 280 |
lora_ckpt_path = os.environ.get("LORA", None)
|
| 281 |
|
|
|
|
| 366 |
|
| 367 |
def load_t5(device: str | torch.device = "cuda", max_length: int = 512) -> HFEmbedder:
|
| 368 |
# max length 64, 128, 256 and 512 should work (if your sequence is short enough)
|
| 369 |
+
version = os.environ.get("T5", "xlabs-ai/xflux_text_encoders")
|
| 370 |
+
return HFEmbedder(version, max_length=max_length, torch_dtype=torch.bfloat16).to(device)
|
| 371 |
|
| 372 |
def load_clip(device: str | torch.device = "cuda") -> HFEmbedder:
|
| 373 |
+
version = os.environ.get("CLIP", "openai/clip-vit-large-patch14")
|
| 374 |
+
return HFEmbedder(version, max_length=77, torch_dtype=torch.bfloat16).to(device)
|
| 375 |
|
| 376 |
|
| 377 |
def load_ae(name: str, device: str | torch.device = "cuda", hf_download: bool = True) -> AutoEncoder:
|