Spaces:
Build error
Build error
black formmatting
Browse files- app_dialogue.py +70 -21
app_dialogue.py
CHANGED
|
@@ -32,7 +32,12 @@ EOS_TOKENS = "</s>;User"
|
|
| 32 |
import logging
|
| 33 |
|
| 34 |
from accelerate.utils import get_max_memory
|
| 35 |
-
from transformers import
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
|
| 38 |
TOKENIZER_FAST = True
|
|
@@ -56,7 +61,9 @@ def load_processor_tokenizer_model(model_name):
|
|
| 56 |
)
|
| 57 |
# tokenizer.padding_side = "left" -> we don't need that, do we?
|
| 58 |
|
| 59 |
-
config = AutoConfig.from_pretrained(
|
|
|
|
|
|
|
| 60 |
max_memory_map = get_max_memory()
|
| 61 |
|
| 62 |
for key in max_memory_map.keys():
|
|
@@ -92,6 +99,7 @@ def split_prompt_into_list(prompt_str):
|
|
| 92 |
prompt_list.append(ps)
|
| 93 |
return prompt_list
|
| 94 |
|
|
|
|
| 95 |
def model_generation(
|
| 96 |
prompt,
|
| 97 |
processor,
|
|
@@ -117,7 +125,8 @@ def model_generation(
|
|
| 117 |
[split_prompt_into_list(prompt)],
|
| 118 |
eval_mode=True,
|
| 119 |
truncation=True,
|
| 120 |
-
max_length=MAX_SEQ_LEN
|
|
|
|
| 121 |
padding=True,
|
| 122 |
)
|
| 123 |
for k, v in input_args.items():
|
|
@@ -145,7 +154,9 @@ def model_generation(
|
|
| 145 |
if len(eos_tokens) > 0:
|
| 146 |
eos_token_ids = []
|
| 147 |
for eos_token in eos_tokens:
|
| 148 |
-
tokenized_eos_token = tokenizer(
|
|
|
|
|
|
|
| 149 |
if len(tokenized_eos_token) > 1:
|
| 150 |
raise ValueError(
|
| 151 |
f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
|
|
@@ -203,13 +214,17 @@ def model_generation(
|
|
| 203 |
|
| 204 |
tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
|
| 205 |
decoded_skip_special_tokens = repr(
|
| 206 |
-
tokenizer.batch_decode(
|
|
|
|
|
|
|
| 207 |
)
|
| 208 |
|
| 209 |
actual_generated_tokens = generated_tokens[:, input_args["input_ids"].shape[-1] :]
|
| 210 |
first_end_token = len(actual_generated_tokens[0])
|
| 211 |
actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
|
| 212 |
-
generated_text = tokenizer.batch_decode(
|
|
|
|
|
|
|
| 213 |
|
| 214 |
logger.info(
|
| 215 |
"Result: \n"
|
|
@@ -252,7 +267,9 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
| 252 |
show_label=False,
|
| 253 |
container=False,
|
| 254 |
)
|
| 255 |
-
processor, tokenizer, model = load_processor_tokenizer_model(
|
|
|
|
|
|
|
| 256 |
|
| 257 |
imagebox = gr.Image(
|
| 258 |
type="pil",
|
|
@@ -394,26 +411,30 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
| 394 |
# inputs = [chatbot]
|
| 395 |
# )
|
| 396 |
|
| 397 |
-
def format_prompt_with_history_and_system_conditioning(
|
|
|
|
|
|
|
| 398 |
resulting_text = SYSTEM_PROMPT
|
| 399 |
for turn in history:
|
| 400 |
user_utterance, assistant_utterance = turn
|
| 401 |
-
resulting_text +=
|
|
|
|
|
|
|
| 402 |
resulting_text += f"\nUser: {current_user_prompt}</s>\nAssistant:"
|
| 403 |
return resulting_text
|
| 404 |
|
| 405 |
def model_inference(
|
| 406 |
user_prompt,
|
| 407 |
chat_history,
|
| 408 |
-
temperature
|
| 409 |
-
no_repeat_ngram_size
|
| 410 |
-
max_new_tokens
|
| 411 |
-
min_length
|
| 412 |
-
repetition_penalty
|
| 413 |
-
length_penalty
|
| 414 |
-
top_k
|
| 415 |
-
top_p
|
| 416 |
-
penalty_alpha
|
| 417 |
):
|
| 418 |
global processor, model, tokenizer
|
| 419 |
# temperature = 1.0
|
|
@@ -462,13 +483,41 @@ with gr.Blocks(title="IDEFICS", theme=gr.themes.Base()) as demo:
|
|
| 462 |
|
| 463 |
textbox.submit(
|
| 464 |
fn=model_inference,
|
| 465 |
-
inputs=[
|
|
|
|
|
|
|
|
|
|
|
|
|
| 466 |
outputs=[textbox, chatbot],
|
| 467 |
)
|
| 468 |
submit_btn.click(
|
| 469 |
fn=model_inference,
|
| 470 |
-
inputs=[
|
| 471 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 472 |
)
|
| 473 |
|
| 474 |
demo.queue()
|
|
|
|
| 32 |
import logging
|
| 33 |
|
| 34 |
from accelerate.utils import get_max_memory
|
| 35 |
+
from transformers import (
|
| 36 |
+
AutoTokenizer,
|
| 37 |
+
AutoProcessor,
|
| 38 |
+
AutoConfig,
|
| 39 |
+
IdeficsForVisionText2Text,
|
| 40 |
+
)
|
| 41 |
|
| 42 |
|
| 43 |
TOKENIZER_FAST = True
|
|
|
|
| 61 |
)
|
| 62 |
# tokenizer.padding_side = "left" -> we don't need that, do we?
|
| 63 |
|
| 64 |
+
config = AutoConfig.from_pretrained(
|
| 65 |
+
model_name, use_auth_token=os.getenv("HF_AUTH_TOKEN", True)
|
| 66 |
+
)
|
| 67 |
max_memory_map = get_max_memory()
|
| 68 |
|
| 69 |
for key in max_memory_map.keys():
|
|
|
|
| 99 |
prompt_list.append(ps)
|
| 100 |
return prompt_list
|
| 101 |
|
| 102 |
+
|
| 103 |
def model_generation(
|
| 104 |
prompt,
|
| 105 |
processor,
|
|
|
|
| 125 |
[split_prompt_into_list(prompt)],
|
| 126 |
eval_mode=True,
|
| 127 |
truncation=True,
|
| 128 |
+
max_length=MAX_SEQ_LEN
|
| 129 |
+
- 512, # TODO: replace the 512 value with `max_new_tokens`
|
| 130 |
padding=True,
|
| 131 |
)
|
| 132 |
for k, v in input_args.items():
|
|
|
|
| 154 |
if len(eos_tokens) > 0:
|
| 155 |
eos_token_ids = []
|
| 156 |
for eos_token in eos_tokens:
|
| 157 |
+
tokenized_eos_token = tokenizer(
|
| 158 |
+
eos_token, add_special_tokens=False
|
| 159 |
+
).input_ids
|
| 160 |
if len(tokenized_eos_token) > 1:
|
| 161 |
raise ValueError(
|
| 162 |
f"eos_tokens should be one token, here {eos_token} is {len(tokenized_eos_token)} tokens:"
|
|
|
|
| 214 |
|
| 215 |
tokens = tokenizer.convert_ids_to_tokens(generated_tokens[0])
|
| 216 |
decoded_skip_special_tokens = repr(
|
| 217 |
+
tokenizer.batch_decode(
|
| 218 |
+
generated_tokens, skip_special_tokens=hide_special_tokens
|
| 219 |
+
)[0]
|
| 220 |
)
|
| 221 |
|
| 222 |
actual_generated_tokens = generated_tokens[:, input_args["input_ids"].shape[-1] :]
|
| 223 |
first_end_token = len(actual_generated_tokens[0])
|
| 224 |
actual_generated_tokens = actual_generated_tokens[:, :first_end_token]
|
| 225 |
+
generated_text = tokenizer.batch_decode(
|
| 226 |
+
actual_generated_tokens, skip_special_tokens=hide_special_tokens
|
| 227 |
+
)[0]
|
| 228 |
|
| 229 |
logger.info(
|
| 230 |
"Result: \n"
|
|
|
|
| 267 |
show_label=False,
|
| 268 |
container=False,
|
| 269 |
)
|
| 270 |
+
processor, tokenizer, model = load_processor_tokenizer_model(
|
| 271 |
+
model_selector.value
|
| 272 |
+
)
|
| 273 |
|
| 274 |
imagebox = gr.Image(
|
| 275 |
type="pil",
|
|
|
|
| 411 |
# inputs = [chatbot]
|
| 412 |
# )
|
| 413 |
|
| 414 |
+
def format_prompt_with_history_and_system_conditioning(
|
| 415 |
+
current_user_prompt, history
|
| 416 |
+
):
|
| 417 |
resulting_text = SYSTEM_PROMPT
|
| 418 |
for turn in history:
|
| 419 |
user_utterance, assistant_utterance = turn
|
| 420 |
+
resulting_text += (
|
| 421 |
+
f"\nUser: {user_utterance}</s>\nAssistant: {assistant_utterance}</s>"
|
| 422 |
+
)
|
| 423 |
resulting_text += f"\nUser: {current_user_prompt}</s>\nAssistant:"
|
| 424 |
return resulting_text
|
| 425 |
|
| 426 |
def model_inference(
|
| 427 |
user_prompt,
|
| 428 |
chat_history,
|
| 429 |
+
temperature=1.0,
|
| 430 |
+
no_repeat_ngram_size=0,
|
| 431 |
+
max_new_tokens=512,
|
| 432 |
+
min_length=16,
|
| 433 |
+
repetition_penalty=1.0,
|
| 434 |
+
length_penalty=1.0,
|
| 435 |
+
top_k=50,
|
| 436 |
+
top_p=0.95,
|
| 437 |
+
penalty_alpha=0.95,
|
| 438 |
):
|
| 439 |
global processor, model, tokenizer
|
| 440 |
# temperature = 1.0
|
|
|
|
| 483 |
|
| 484 |
textbox.submit(
|
| 485 |
fn=model_inference,
|
| 486 |
+
inputs=[
|
| 487 |
+
textbox,
|
| 488 |
+
chatbot,
|
| 489 |
+
temperature,
|
| 490 |
+
],
|
| 491 |
outputs=[textbox, chatbot],
|
| 492 |
)
|
| 493 |
submit_btn.click(
|
| 494 |
fn=model_inference,
|
| 495 |
+
inputs=[
|
| 496 |
+
textbox,
|
| 497 |
+
chatbot,
|
| 498 |
+
temperature,
|
| 499 |
+
no_repeat_ngram_size,
|
| 500 |
+
max_new_tokens,
|
| 501 |
+
min_length,
|
| 502 |
+
repetition_penalty,
|
| 503 |
+
length_penalty,
|
| 504 |
+
top_k,
|
| 505 |
+
top_p,
|
| 506 |
+
penalty_alpha,
|
| 507 |
+
],
|
| 508 |
+
outputs=[
|
| 509 |
+
textbox,
|
| 510 |
+
chatbot,
|
| 511 |
+
temperature,
|
| 512 |
+
no_repeat_ngram_size,
|
| 513 |
+
max_new_tokens,
|
| 514 |
+
min_length,
|
| 515 |
+
repetition_penalty,
|
| 516 |
+
length_penalty,
|
| 517 |
+
top_k,
|
| 518 |
+
top_p,
|
| 519 |
+
penalty_alpha,
|
| 520 |
+
],
|
| 521 |
)
|
| 522 |
|
| 523 |
demo.queue()
|