Spaces:
Paused
Paused
Add EOS token interruption
Browse files
app.py
CHANGED
|
@@ -44,6 +44,10 @@ def generate_and_visualize(prompt, num_tokens=10):
|
|
| 44 |
|
| 45 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
| 46 |
input_embeds = model.get_input_embeddings()(input_ids)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
|
| 48 |
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
|
| 49 |
|
|
@@ -198,10 +202,10 @@ How can you build tools simply in transformers? Just use the decorator""",
|
|
| 198 |
]
|
| 199 |
|
| 200 |
with gr.Blocks(css=css) as demo:
|
| 201 |
-
gr.Markdown("#
|
| 202 |
|
| 203 |
input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0])
|
| 204 |
-
num_tokens = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Number of tokens to generate")
|
| 205 |
generate_button = gr.Button("Generate")
|
| 206 |
|
| 207 |
output_html = gr.HTML(label="Generated Output")
|
|
|
|
| 44 |
|
| 45 |
input_ids = torch.cat([input_ids, next_token.unsqueeze(0)], dim=1)
|
| 46 |
input_embeds = model.get_input_embeddings()(input_ids)
|
| 47 |
+
|
| 48 |
+
if next_token.item() == tokenizer.eos_token_id:
|
| 49 |
+
print("EOS token generated, stopping generation.")
|
| 50 |
+
break
|
| 51 |
input_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(input_ids[0]))
|
| 52 |
generated_tokens = really_clean_tokens(tokenizer.convert_ids_to_tokens(generated_tokens_ids))
|
| 53 |
|
|
|
|
| 202 |
]
|
| 203 |
|
| 204 |
with gr.Blocks(css=css) as demo:
|
| 205 |
+
gr.Markdown("# Syntax highlighted text generation - for RAG applications")
|
| 206 |
|
| 207 |
input_text = gr.Textbox(label="Enter your prompt:", lines=10, value=examples[0])
|
| 208 |
+
num_tokens = gr.Slider(minimum=1, maximum=100, value=20, step=1, label="Number of tokens to generate (while no EOS token)")
|
| 209 |
generate_button = gr.Button("Generate")
|
| 210 |
|
| 211 |
output_html = gr.HTML(label="Generated Output")
|