enabled streaming
Browse files
app.py
CHANGED
|
@@ -165,8 +165,63 @@ def chatbot_response(message, history):
|
|
| 165 |
return response + metrics
|
| 166 |
|
| 167 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 168 |
demo = gr.ChatInterface(
|
| 169 |
-
fn=
|
| 170 |
examples=["Hello", "How are you?", "Tell me a joke"],
|
| 171 |
title="Chat with xMAD's: 1-bit-Llama-3-8B-Instruct Model",
|
| 172 |
description="Contact support@xmad.ai to set up a demo",
|
|
|
|
| 165 |
return response + metrics
|
| 166 |
|
| 167 |
|
| 168 |
+
def process_dialog_streaming(message, history):
|
| 169 |
+
terminators = [
|
| 170 |
+
tokenizer.eos_token_id,
|
| 171 |
+
tokenizer.convert_tokens_to_ids("<|eot_id|>"),
|
| 172 |
+
]
|
| 173 |
+
|
| 174 |
+
dialog = [
|
| 175 |
+
{"role": "user" if i % 2 == 0 else "assistant", "content": msg}
|
| 176 |
+
for i, (msg, _) in enumerate(history)
|
| 177 |
+
]
|
| 178 |
+
dialog.append({"role": "user", "content": message})
|
| 179 |
+
|
| 180 |
+
prompt = tokenizer.apply_chat_template(
|
| 181 |
+
dialog, tokenize=False, add_generation_prompt=True
|
| 182 |
+
)
|
| 183 |
+
tokenized_input_prompt_ids = tokenizer(
|
| 184 |
+
prompt, return_tensors="pt"
|
| 185 |
+
).input_ids.to(model.device)
|
| 186 |
+
|
| 187 |
+
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
| 188 |
+
|
| 189 |
+
generation_kwargs = dict(
|
| 190 |
+
inputs=tokenized_input_prompt_ids,
|
| 191 |
+
streamer=streamer,
|
| 192 |
+
max_new_tokens=512,
|
| 193 |
+
temperature=0.4,
|
| 194 |
+
do_sample=True,
|
| 195 |
+
eos_token_id=terminators,
|
| 196 |
+
pad_token_id=tokenizer.pad_token_id,
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
start_time = time.time()
|
| 200 |
+
total_tokens = 0
|
| 201 |
+
|
| 202 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
| 203 |
+
thread.start()
|
| 204 |
+
|
| 205 |
+
generated_text = ""
|
| 206 |
+
for new_text in streamer:
|
| 207 |
+
generated_text += new_text
|
| 208 |
+
total_tokens += 1
|
| 209 |
+
current_time = time.time()
|
| 210 |
+
elapsed_time = current_time - start_time
|
| 211 |
+
tokens_per_second = total_tokens / elapsed_time if elapsed_time > 0 else 0
|
| 212 |
+
print(f"Tokens per second: {tokens_per_second:.2f}", end="\r")
|
| 213 |
+
yield generated_text, elapsed_time, tokens_per_second
|
| 214 |
+
|
| 215 |
+
thread.join()
|
| 216 |
+
|
| 217 |
+
def chatbot_response_streaming(message, history):
|
| 218 |
+
for response, generation_time, tokens_per_second in process_dialog_streaming(message, history):
|
| 219 |
+
metrics = f"\n\n---\n\n **Metrics**\t*Answer Generation Time:* `{generation_time:.2f} sec`\t*Tokens per Second:* `{tokens_per_second:.2f}`\n\n"
|
| 220 |
+
yield response + metrics
|
| 221 |
+
|
| 222 |
+
|
| 223 |
demo = gr.ChatInterface(
|
| 224 |
+
fn=chatbot_response_streaming,
|
| 225 |
examples=["Hello", "How are you?", "Tell me a joke"],
|
| 226 |
title="Chat with xMAD's: 1-bit-Llama-3-8B-Instruct Model",
|
| 227 |
description="Contact support@xmad.ai to set up a demo",
|