Spaces:
Runtime error
Runtime error
Update model.py
Browse files
model.py
CHANGED
|
@@ -153,29 +153,29 @@ class Model:
|
|
| 153 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
| 154 |
t1 = time.time()
|
| 155 |
print(caption_embs.device)
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
|
| 166 |
-
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
|
| 178 |
-
samples = condition_img[0:1]
|
| 179 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
| 180 |
samples = 255 * (samples * 0.5 + 0.5)
|
| 181 |
samples = [
|
|
@@ -247,31 +247,31 @@ class Model:
|
|
| 247 |
c_emb_masks = new_emb_masks
|
| 248 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
| 249 |
t1 = time.time()
|
| 250 |
-
|
| 251 |
-
|
| 252 |
-
|
| 253 |
-
|
| 254 |
-
|
| 255 |
-
|
| 256 |
-
|
| 257 |
-
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
| 262 |
-
|
| 263 |
-
|
| 264 |
-
|
| 265 |
|
| 266 |
-
|
| 267 |
-
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
|
| 274 |
-
samples = condition_img[0:1]
|
| 275 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
| 276 |
samples = 255 * (samples * 0.5 + 0.5)
|
| 277 |
samples = [
|
|
|
|
| 153 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
| 154 |
t1 = time.time()
|
| 155 |
print(caption_embs.device)
|
| 156 |
+
index_sample = generate(
|
| 157 |
+
self.gpt_model,
|
| 158 |
+
c_indices,
|
| 159 |
+
(H // 16) * (W // 16),
|
| 160 |
+
c_emb_masks,
|
| 161 |
+
condition=condition_img,
|
| 162 |
+
cfg_scale=cfg_scale,
|
| 163 |
+
temperature=temperature,
|
| 164 |
+
top_k=top_k,
|
| 165 |
+
top_p=top_p,
|
| 166 |
+
sample_logits=True,
|
| 167 |
+
control_strength=control_strength,
|
| 168 |
+
)
|
| 169 |
+
sampling_time = time.time() - t1
|
| 170 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
| 171 |
|
| 172 |
+
t2 = time.time()
|
| 173 |
+
print(index_sample.shape)
|
| 174 |
+
samples = self.vq_model.decode_code(
|
| 175 |
+
index_sample, qzshape) # output value is between [-1, 1]
|
| 176 |
+
decoder_time = time.time() - t2
|
| 177 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
| 178 |
+
# samples = condition_img[0:1]
|
| 179 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
| 180 |
samples = 255 * (samples * 0.5 + 0.5)
|
| 181 |
samples = [
|
|
|
|
| 247 |
c_emb_masks = new_emb_masks
|
| 248 |
qzshape = [len(c_indices), 8, H // 16, W // 16]
|
| 249 |
t1 = time.time()
|
| 250 |
+
index_sample = generate(
|
| 251 |
+
self.gpt_model,
|
| 252 |
+
c_indices,
|
| 253 |
+
(H // 16) * (W // 16),
|
| 254 |
+
c_emb_masks,
|
| 255 |
+
condition=condition_img,
|
| 256 |
+
cfg_scale=cfg_scale,
|
| 257 |
+
temperature=temperature,
|
| 258 |
+
top_k=top_k,
|
| 259 |
+
top_p=top_p,
|
| 260 |
+
sample_logits=True,
|
| 261 |
+
control_strength=control_strength,
|
| 262 |
+
)
|
| 263 |
+
sampling_time = time.time() - t1
|
| 264 |
+
print(f"Full sampling takes about {sampling_time:.2f} seconds.")
|
| 265 |
|
| 266 |
+
t2 = time.time()
|
| 267 |
+
print(index_sample.shape)
|
| 268 |
+
samples = self.vq_model.decode_code(index_sample, qzshape)
|
| 269 |
+
decoder_time = time.time() - t2
|
| 270 |
+
print(f"decoder takes about {decoder_time:.2f} seconds.")
|
| 271 |
+
condition_img = condition_img.cpu()
|
| 272 |
+
samples = samples.cpu()
|
| 273 |
|
| 274 |
+
# samples = condition_img[0:1]
|
| 275 |
samples = torch.cat((condition_img[0:1], samples), dim=0)
|
| 276 |
samples = 255 * (samples * 0.5 + 0.5)
|
| 277 |
samples = [
|