Spaces:
Runtime error
Runtime error
wondervictor
commited on
Commit
·
b962858
1
Parent(s):
d8de5a4
add requirements
Browse files
model.py
CHANGED
|
@@ -59,7 +59,7 @@ class Model:
|
|
| 59 |
map_location="cpu")
|
| 60 |
vq_model.load_state_dict(checkpoint["model"])
|
| 61 |
del checkpoint
|
| 62 |
-
print(
|
| 63 |
return vq_model
|
| 64 |
|
| 65 |
def load_gpt(self, condition_type='canny'):
|
|
@@ -76,14 +76,14 @@ class Model:
|
|
| 76 |
model_weight = load_file(gpt_ckpt)
|
| 77 |
gpt_model.load_state_dict(model_weight, strict=False)
|
| 78 |
gpt_model.eval()
|
| 79 |
-
print(
|
| 80 |
return gpt_model
|
| 81 |
|
| 82 |
def load_t5(self):
|
| 83 |
precision = torch.bfloat16
|
| 84 |
t5_model = T5Embedder(
|
| 85 |
device=self.device,
|
| 86 |
-
local_cache=
|
| 87 |
cache_dir='checkpoints/flan-t5-xl',
|
| 88 |
dir_or_name='flan-t5-xl',
|
| 89 |
torch_dtype=precision,
|
|
|
|
| 59 |
map_location="cpu")
|
| 60 |
vq_model.load_state_dict(checkpoint["model"])
|
| 61 |
del checkpoint
|
| 62 |
+
print("image tokenizer is loaded")
|
| 63 |
return vq_model
|
| 64 |
|
| 65 |
def load_gpt(self, condition_type='canny'):
|
|
|
|
| 76 |
model_weight = load_file(gpt_ckpt)
|
| 77 |
gpt_model.load_state_dict(model_weight, strict=False)
|
| 78 |
gpt_model.eval()
|
| 79 |
+
print("gpt model is loaded")
|
| 80 |
return gpt_model
|
| 81 |
|
| 82 |
def load_t5(self):
|
| 83 |
precision = torch.bfloat16
|
| 84 |
t5_model = T5Embedder(
|
| 85 |
device=self.device,
|
| 86 |
+
local_cache=False,
|
| 87 |
cache_dir='checkpoints/flan-t5-xl',
|
| 88 |
dir_or_name='flan-t5-xl',
|
| 89 |
torch_dtype=precision,
|