Spaces:
Running
on
Zero
Running
on
Zero
add image prompt option
Browse files
app.py
CHANGED
|
@@ -40,6 +40,9 @@ def main():
|
|
| 40 |
label='Only First Stage',
|
| 41 |
value=only_first_stage,
|
| 42 |
visible=not only_first_stage)
|
|
|
|
|
|
|
|
|
|
| 43 |
run_button = gr.Button('Run')
|
| 44 |
|
| 45 |
with gr.Column():
|
|
@@ -50,10 +53,10 @@ def main():
|
|
| 50 |
result_video = gr.Video(show_label=False)
|
| 51 |
|
| 52 |
examples = gr.Examples(
|
| 53 |
-
examples=[['骑滑板的皮卡丘', False, 1234, True],
|
| 54 |
-
['a cat playing chess', True, 1253, True]],
|
| 55 |
fn=model.run_with_translation,
|
| 56 |
-
inputs=[text, translate, seed, only_first_stage],
|
| 57 |
outputs=[translated_text, result_video],
|
| 58 |
cache_examples=True)
|
| 59 |
|
|
@@ -66,6 +69,7 @@ def main():
|
|
| 66 |
translate,
|
| 67 |
seed,
|
| 68 |
only_first_stage,
|
|
|
|
| 69 |
],
|
| 70 |
outputs=[translated_text, result_video])
|
| 71 |
|
|
|
|
| 40 |
label='Only First Stage',
|
| 41 |
value=only_first_stage,
|
| 42 |
visible=not only_first_stage)
|
| 43 |
+
image_prompt = gr.Image(type="filepath"
|
| 44 |
+
label="Image Prompt",
|
| 45 |
+
value=None)
|
| 46 |
run_button = gr.Button('Run')
|
| 47 |
|
| 48 |
with gr.Column():
|
|
|
|
| 53 |
result_video = gr.Video(show_label=False)
|
| 54 |
|
| 55 |
examples = gr.Examples(
|
| 56 |
+
examples=[['骑滑板的皮卡丘', False, 1234, True,None],
|
| 57 |
+
['a cat playing chess', True, 1253, True,None]],
|
| 58 |
fn=model.run_with_translation,
|
| 59 |
+
inputs=[text, translate, seed, only_first_stage,image_prompt],
|
| 60 |
outputs=[translated_text, result_video],
|
| 61 |
cache_examples=True)
|
| 62 |
|
|
|
|
| 69 |
translate,
|
| 70 |
seed,
|
| 71 |
only_first_stage,
|
| 72 |
+
image_prompt
|
| 73 |
],
|
| 74 |
outputs=[translated_text, result_video])
|
| 75 |
|
model.py
CHANGED
|
@@ -796,7 +796,8 @@ class Model:
|
|
| 796 |
video_raw_text=None,
|
| 797 |
video_guidance_text='视频',
|
| 798 |
image_text_suffix='',
|
| 799 |
-
batch_size=1
|
|
|
|
| 800 |
process_start_time = time.perf_counter()
|
| 801 |
|
| 802 |
generate_frame_num = self.args.generate_frame_num
|
|
@@ -828,33 +829,36 @@ class Model:
|
|
| 828 |
|
| 829 |
seq_1st = torch.tensor(seq_1st, dtype=torch.long,
|
| 830 |
device=self.device).unsqueeze(0)
|
| 831 |
-
|
| 832 |
-
|
| 833 |
-
|
| 834 |
-
|
| 835 |
-
|
| 836 |
-
|
| 837 |
-
|
| 838 |
-
|
| 839 |
-
|
| 840 |
-
|
| 841 |
-
|
| 842 |
-
|
| 843 |
-
|
| 844 |
-
|
| 845 |
-
|
| 846 |
-
|
| 847 |
-
|
| 848 |
-
|
| 849 |
-
|
| 850 |
-
|
| 851 |
-
|
| 852 |
-
|
| 853 |
-
|
| 854 |
-
|
| 855 |
-
|
| 856 |
-
|
| 857 |
-
|
|
|
|
|
|
|
|
|
|
| 858 |
|
| 859 |
# generate subsequent frames:
|
| 860 |
total_frames = generate_frame_num
|
|
@@ -1167,7 +1171,7 @@ class Model:
|
|
| 1167 |
1, 2, 0).to(torch.uint8).numpy()
|
| 1168 |
|
| 1169 |
def run(self, text: str, seed: int,
|
| 1170 |
-
only_first_stage: bool) -> list[np.ndarray]:
|
| 1171 |
logger.info('==================== run ====================')
|
| 1172 |
start = time.perf_counter()
|
| 1173 |
|
|
@@ -1188,7 +1192,8 @@ class Model:
|
|
| 1188 |
video_raw_text=text,
|
| 1189 |
video_guidance_text='视频',
|
| 1190 |
image_text_suffix=' 高清摄影',
|
| 1191 |
-
batch_size=self.args.batch_size
|
|
|
|
| 1192 |
if not only_first_stage:
|
| 1193 |
_, res = self.process_stage2(
|
| 1194 |
self.model_stage2,
|
|
@@ -1226,12 +1231,13 @@ class AppModel(Model):
|
|
| 1226 |
|
| 1227 |
def run_with_translation(
|
| 1228 |
self, text: str, translate: bool, seed: int,
|
| 1229 |
-
only_first_stage: bool) -> tuple[str | None, str | None]
|
| 1230 |
-
|
|
|
|
| 1231 |
if translate:
|
| 1232 |
text = translated_text = self.translator(text)
|
| 1233 |
else:
|
| 1234 |
translated_text = None
|
| 1235 |
-
frames = self.run(text, seed, only_first_stage)
|
| 1236 |
video_path = self.to_video(frames)
|
| 1237 |
return translated_text, video_path
|
|
|
|
| 796 |
video_raw_text=None,
|
| 797 |
video_guidance_text='视频',
|
| 798 |
image_text_suffix='',
|
| 799 |
+
batch_size=1,
|
| 800 |
+
image_prompt):
|
| 801 |
process_start_time = time.perf_counter()
|
| 802 |
|
| 803 |
generate_frame_num = self.args.generate_frame_num
|
|
|
|
| 829 |
|
| 830 |
seq_1st = torch.tensor(seq_1st, dtype=torch.long,
|
| 831 |
device=self.device).unsqueeze(0)
|
| 832 |
+
if self.image_prompt is None:
|
| 833 |
+
output_list_1st = []
|
| 834 |
+
for tim in range(max(batch_size // mbz, 1)):
|
| 835 |
+
start_time = time.perf_counter()
|
| 836 |
+
output_list_1st.append(
|
| 837 |
+
my_filling_sequence(
|
| 838 |
+
model,
|
| 839 |
+
tokenizer,
|
| 840 |
+
self.args,
|
| 841 |
+
seq_1st.clone(),
|
| 842 |
+
batch_size=min(batch_size, mbz),
|
| 843 |
+
get_masks_and_position_ids=
|
| 844 |
+
get_masks_and_position_ids_stage1,
|
| 845 |
+
text_len=text_len_1st,
|
| 846 |
+
frame_len=frame_len,
|
| 847 |
+
strategy=self.strategy_cogview2,
|
| 848 |
+
strategy2=self.strategy_cogvideo,
|
| 849 |
+
log_text_attention_weights=1.4,
|
| 850 |
+
enforce_no_swin=True,
|
| 851 |
+
mode_stage1=True,
|
| 852 |
+
)[0])
|
| 853 |
+
elapsed = time.perf_counter() - start_time
|
| 854 |
+
logger.info(f'[First Frame] Elapsed: {elapsed:.2f}')
|
| 855 |
+
output_tokens_1st = torch.cat(output_list_1st, dim=0)
|
| 856 |
+
given_tokens = output_tokens_1st[:, text_len_1st + 1:text_len_1st +
|
| 857 |
+
401].unsqueeze(
|
| 858 |
+
1
|
| 859 |
+
) # given_tokens.shape: [bs, frame_num, 400]
|
| 860 |
+
else:
|
| 861 |
+
given_tokens = tokenizer.encode(image_path=self.image_prompt, image_size=160).repeat(batch_size, 1).unsqueeze(1)
|
| 862 |
|
| 863 |
# generate subsequent frames:
|
| 864 |
total_frames = generate_frame_num
|
|
|
|
| 1171 |
1, 2, 0).to(torch.uint8).numpy()
|
| 1172 |
|
| 1173 |
def run(self, text: str, seed: int,
|
| 1174 |
+
only_first_stage: bool,image_prompt: None) -> list[np.ndarray]:
|
| 1175 |
logger.info('==================== run ====================')
|
| 1176 |
start = time.perf_counter()
|
| 1177 |
|
|
|
|
| 1192 |
video_raw_text=text,
|
| 1193 |
video_guidance_text='视频',
|
| 1194 |
image_text_suffix=' 高清摄影',
|
| 1195 |
+
batch_size=self.args.batch_size
|
| 1196 |
+
image_prompt=image_prompt)
|
| 1197 |
if not only_first_stage:
|
| 1198 |
_, res = self.process_stage2(
|
| 1199 |
self.model_stage2,
|
|
|
|
| 1231 |
|
| 1232 |
def run_with_translation(
|
| 1233 |
self, text: str, translate: bool, seed: int,
|
| 1234 |
+
only_first_stage: bool,image_prompt: None) -> tuple[str | None, str | None],
|
| 1235 |
+
|
| 1236 |
+
logger.info(f'{text=}, {translate=}, {seed=}, {only_first_stage=},{image_prompt=}')
|
| 1237 |
if translate:
|
| 1238 |
text = translated_text = self.translator(text)
|
| 1239 |
else:
|
| 1240 |
translated_text = None
|
| 1241 |
+
frames = self.run(text, seed, only_first_stage,image_prompt)
|
| 1242 |
video_path = self.to_video(frames)
|
| 1243 |
return translated_text, video_path
|