Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -92,7 +92,7 @@ class Model:
|
|
| 92 |
preprocessor_name: str,
|
| 93 |
) -> list[PIL.Image.Image]:
|
| 94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
| 95 |
-
self.
|
| 96 |
self.vq_model.to('cuda')
|
| 97 |
if isinstance(image, np.ndarray):
|
| 98 |
image = Image.fromarray(image)
|
|
@@ -147,6 +147,7 @@ class Model:
|
|
| 147 |
top_k=top_k,
|
| 148 |
top_p=top_p,
|
| 149 |
sample_logits=True,
|
|
|
|
| 150 |
)
|
| 151 |
sampling_time = time.time() - t1
|
| 152 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
@@ -183,7 +184,7 @@ class Model:
|
|
| 183 |
control_strength: float,
|
| 184 |
preprocessor_name: str
|
| 185 |
) -> list[PIL.Image.Image]:
|
| 186 |
-
self.
|
| 187 |
self.t5_model.model.to(self.device)
|
| 188 |
self.gpt_model_depth.to(self.device)
|
| 189 |
self.get_control_depth.model.to(self.device)
|
|
@@ -237,6 +238,7 @@ class Model:
|
|
| 237 |
top_k=top_k,
|
| 238 |
top_p=top_p,
|
| 239 |
sample_logits=True,
|
|
|
|
| 240 |
)
|
| 241 |
sampling_time = time.time() - t1
|
| 242 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
|
|
| 92 |
preprocessor_name: str,
|
| 93 |
) -> list[PIL.Image.Image]:
|
| 94 |
self.t5_model.model.to('cuda').to(torch.bfloat16)
|
| 95 |
+
self.gpt_model_edge.to('cuda').to(torch.bfloat16)
|
| 96 |
self.vq_model.to('cuda')
|
| 97 |
if isinstance(image, np.ndarray):
|
| 98 |
image = Image.fromarray(image)
|
|
|
|
| 147 |
top_k=top_k,
|
| 148 |
top_p=top_p,
|
| 149 |
sample_logits=True,
|
| 150 |
+
control_strength=control_strength,
|
| 151 |
)
|
| 152 |
sampling_time = time.time() - t1
|
| 153 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
|
|
|
| 184 |
control_strength: float,
|
| 185 |
preprocessor_name: str
|
| 186 |
) -> list[PIL.Image.Image]:
|
| 187 |
+
self.gpt_model_edge.to('cpu')
|
| 188 |
self.t5_model.model.to(self.device)
|
| 189 |
self.gpt_model_depth.to(self.device)
|
| 190 |
self.get_control_depth.model.to(self.device)
|
|
|
|
| 238 |
top_k=top_k,
|
| 239 |
top_p=top_p,
|
| 240 |
sample_logits=True,
|
| 241 |
+
control_strength=control_strength,
|
| 242 |
)
|
| 243 |
sampling_time = time.time() - t1
|
| 244 |
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|