Spaces:
Runtime error
Runtime error
Commit
·
e137273
1
Parent(s):
01b5c05
make token streaming work
Browse files
app.py
CHANGED
|
@@ -139,9 +139,6 @@ def model_inference(
|
|
| 139 |
|
| 140 |
streamer = TextIteratorStreamer(
|
| 141 |
PROCESSOR.tokenizer,
|
| 142 |
-
decode_kwargs=dict(
|
| 143 |
-
skip_special_tokens=True
|
| 144 |
-
),
|
| 145 |
skip_prompt=True,
|
| 146 |
)
|
| 147 |
generation_kwargs = dict(
|
|
@@ -150,6 +147,16 @@ def model_inference(
|
|
| 150 |
max_length=4096,
|
| 151 |
streamer=streamer,
|
| 152 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 153 |
thread = Thread(
|
| 154 |
target=MODEL.generate,
|
| 155 |
kwargs=generation_kwargs,
|
|
@@ -157,15 +164,14 @@ def model_inference(
|
|
| 157 |
thread.start()
|
| 158 |
generated_text = ""
|
| 159 |
for new_text in streamer:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 160 |
generated_text += new_text
|
| 161 |
-
|
| 162 |
-
# yield generated_text, image
|
| 163 |
-
print("after yield")
|
| 164 |
|
| 165 |
-
# Sanity hack
|
| 166 |
-
generated_text = generated_text.replace("</s>", "")
|
| 167 |
-
rendered_page = render_webpage(generated_text)
|
| 168 |
-
return generated_text, rendered_page
|
| 169 |
|
| 170 |
generated_html = gr.Code(
|
| 171 |
label="Extracted HTML",
|
|
@@ -234,26 +240,22 @@ with gr.Blocks(title="Screenshot to HTML", theme=gr.themes.Base(), css=css) as d
|
|
| 234 |
fn=model_inference,
|
| 235 |
inputs=[imagebox],
|
| 236 |
outputs=[generated_html, rendered_html],
|
| 237 |
-
queue=False,
|
| 238 |
)
|
| 239 |
regenerate_btn.click(
|
| 240 |
fn=model_inference,
|
| 241 |
inputs=[imagebox],
|
| 242 |
outputs=[generated_html, rendered_html],
|
| 243 |
-
queue=False,
|
| 244 |
)
|
| 245 |
template_gallery.select(
|
| 246 |
fn=add_file_gallery,
|
| 247 |
inputs=[template_gallery],
|
| 248 |
outputs=[imagebox],
|
| 249 |
-
queue=False,
|
| 250 |
).success(
|
| 251 |
fn=model_inference,
|
| 252 |
inputs=[imagebox],
|
| 253 |
outputs=[generated_html, rendered_html],
|
| 254 |
-
queue=False,
|
| 255 |
)
|
| 256 |
-
demo.load(
|
| 257 |
|
| 258 |
demo.queue(max_size=40, api_open=False)
|
| 259 |
demo.launch(max_threads=400)
|
|
|
|
| 139 |
|
| 140 |
streamer = TextIteratorStreamer(
|
| 141 |
PROCESSOR.tokenizer,
|
|
|
|
|
|
|
|
|
|
| 142 |
skip_prompt=True,
|
| 143 |
)
|
| 144 |
generation_kwargs = dict(
|
|
|
|
| 147 |
max_length=4096,
|
| 148 |
streamer=streamer,
|
| 149 |
)
|
| 150 |
+
# Regular generation version
|
| 151 |
+
# generation_kwargs.pop("streamer")
|
| 152 |
+
# generated_ids = MODEL.generate(**generation_kwargs)
|
| 153 |
+
# generated_text = PROCESSOR.batch_decode(
|
| 154 |
+
# generated_ids,
|
| 155 |
+
# skip_special_tokens=True
|
| 156 |
+
# )[0]
|
| 157 |
+
# rendered_page = render_webpage(generated_text)
|
| 158 |
+
# return generated_text, rendered_page
|
| 159 |
+
# Token streaming version
|
| 160 |
thread = Thread(
|
| 161 |
target=MODEL.generate,
|
| 162 |
kwargs=generation_kwargs,
|
|
|
|
| 164 |
thread.start()
|
| 165 |
generated_text = ""
|
| 166 |
for new_text in streamer:
|
| 167 |
+
if "</s>" in new_text:
|
| 168 |
+
new_text = new_text.replace("</s>", "")
|
| 169 |
+
rendered_image = render_webpage(generated_text)
|
| 170 |
+
else:
|
| 171 |
+
rendered_image = None
|
| 172 |
generated_text += new_text
|
| 173 |
+
yield generated_text, rendered_image
|
|
|
|
|
|
|
| 174 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
|
| 176 |
generated_html = gr.Code(
|
| 177 |
label="Extracted HTML",
|
|
|
|
| 240 |
fn=model_inference,
|
| 241 |
inputs=[imagebox],
|
| 242 |
outputs=[generated_html, rendered_html],
|
|
|
|
| 243 |
)
|
| 244 |
regenerate_btn.click(
|
| 245 |
fn=model_inference,
|
| 246 |
inputs=[imagebox],
|
| 247 |
outputs=[generated_html, rendered_html],
|
|
|
|
| 248 |
)
|
| 249 |
template_gallery.select(
|
| 250 |
fn=add_file_gallery,
|
| 251 |
inputs=[template_gallery],
|
| 252 |
outputs=[imagebox],
|
|
|
|
| 253 |
).success(
|
| 254 |
fn=model_inference,
|
| 255 |
inputs=[imagebox],
|
| 256 |
outputs=[generated_html, rendered_html],
|
|
|
|
| 257 |
)
|
| 258 |
+
demo.load()
|
| 259 |
|
| 260 |
demo.queue(max_size=40, api_open=False)
|
| 261 |
demo.launch(max_threads=400)
|