Load model after preset
Browse files
app.py
CHANGED
|
@@ -24,7 +24,7 @@ from presets import (
|
|
| 24 |
from style import custom_css
|
| 25 |
from utils import get_formatted_attribute_context_results
|
| 26 |
|
| 27 |
-
from inseq import list_feature_attribution_methods, list_step_functions
|
| 28 |
from inseq.commands.attribute_context.attribute_context import (
|
| 29 |
AttributeContextArgs,
|
| 30 |
attribute_context_with_model,
|
|
@@ -65,7 +65,7 @@ def pecore(
|
|
| 65 |
)
|
| 66 |
if loaded_model is None or model_name_or_path != loaded_model.model_name:
|
| 67 |
gr.Info("Loading model...")
|
| 68 |
-
loaded_model =
|
| 69 |
model_name_or_path,
|
| 70 |
attribution_method,
|
| 71 |
model_kwargs=json.loads(model_kwargs),
|
|
@@ -130,7 +130,7 @@ def preload_model(
|
|
| 130 |
global loaded_model
|
| 131 |
if loaded_model is None or model_name_or_path != loaded_model.model_name:
|
| 132 |
gr.Info("Loading model...")
|
| 133 |
-
loaded_model =
|
| 134 |
model_name_or_path,
|
| 135 |
attribution_method,
|
| 136 |
model_kwargs=json.loads(model_kwargs),
|
|
@@ -192,7 +192,9 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 192 |
outputs=pecore_output_highlights,
|
| 193 |
)
|
| 194 |
with gr.Tab("βοΈ Parameters") as params_tab:
|
| 195 |
-
gr.Markdown(
|
|
|
|
|
|
|
| 196 |
with gr.Row(equal_height=True):
|
| 197 |
with gr.Column():
|
| 198 |
default_preset = gr.Button("Default", variant="secondary")
|
|
@@ -218,7 +220,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 218 |
"Present for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags."
|
| 219 |
)
|
| 220 |
with gr.Column(scale=1):
|
| 221 |
-
chatml_template = gr.Button("ChatML
|
| 222 |
gr.Markdown(
|
| 223 |
"Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
|
| 224 |
)
|
|
@@ -401,6 +403,15 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 401 |
gr.Markdown(how_to_use)
|
| 402 |
gr.Markdown(citation)
|
| 403 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 404 |
attribute_input_button.click(
|
| 405 |
pecore,
|
| 406 |
inputs=[
|
|
@@ -435,7 +446,7 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 435 |
|
| 436 |
load_model_button.click(
|
| 437 |
preload_model,
|
| 438 |
-
inputs=
|
| 439 |
outputs=[],
|
| 440 |
)
|
| 441 |
|
|
@@ -461,11 +472,13 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 461 |
|
| 462 |
# Presets
|
| 463 |
|
| 464 |
-
default_preset.click(**reset_kwargs)
|
|
|
|
| 465 |
cora_preset.click(**reset_kwargs).then(
|
| 466 |
set_cora_preset,
|
| 467 |
outputs=[model_name_or_path, input_template, contextless_input_current_text],
|
| 468 |
-
)
|
|
|
|
| 469 |
zephyr_preset.click(**reset_kwargs).then(
|
| 470 |
set_zephyr_preset,
|
| 471 |
outputs=[
|
|
@@ -474,11 +487,13 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 474 |
contextless_input_current_text,
|
| 475 |
decoder_input_output_separator,
|
| 476 |
],
|
| 477 |
-
)
|
|
|
|
| 478 |
multilingual_mt_template.click(**reset_kwargs).then(
|
| 479 |
set_mmt_preset,
|
| 480 |
outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs],
|
| 481 |
-
)
|
|
|
|
| 482 |
chatml_template.click(**reset_kwargs).then(
|
| 483 |
set_chatml_preset,
|
| 484 |
outputs=[
|
|
@@ -488,7 +503,8 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 488 |
decoder_input_output_separator,
|
| 489 |
special_tokens_to_keep,
|
| 490 |
],
|
| 491 |
-
)
|
|
|
|
| 492 |
towerinstruct_template.click(**reset_kwargs).then(
|
| 493 |
set_towerinstruct_preset,
|
| 494 |
outputs=[
|
|
@@ -497,6 +513,6 @@ with gr.Blocks(css=custom_css) as demo:
|
|
| 497 |
contextless_input_current_text,
|
| 498 |
decoder_input_output_separator,
|
| 499 |
],
|
| 500 |
-
)
|
| 501 |
|
| 502 |
demo.launch(allowed_paths=["outputs/"])
|
|
|
|
| 24 |
from style import custom_css
|
| 25 |
from utils import get_formatted_attribute_context_results
|
| 26 |
|
| 27 |
+
from inseq import list_feature_attribution_methods, list_step_functions
|
| 28 |
from inseq.commands.attribute_context.attribute_context import (
|
| 29 |
AttributeContextArgs,
|
| 30 |
attribute_context_with_model,
|
|
|
|
| 65 |
)
|
| 66 |
if loaded_model is None or model_name_or_path != loaded_model.model_name:
|
| 67 |
gr.Info("Loading model...")
|
| 68 |
+
loaded_model = HuggingfaceModel.load(
|
| 69 |
model_name_or_path,
|
| 70 |
attribution_method,
|
| 71 |
model_kwargs=json.loads(model_kwargs),
|
|
|
|
| 130 |
global loaded_model
|
| 131 |
if loaded_model is None or model_name_or_path != loaded_model.model_name:
|
| 132 |
gr.Info("Loading model...")
|
| 133 |
+
loaded_model = HuggingfaceModel.load(
|
| 134 |
model_name_or_path,
|
| 135 |
attribution_method,
|
| 136 |
model_kwargs=json.loads(model_kwargs),
|
|
|
|
| 192 |
outputs=pecore_output_highlights,
|
| 193 |
)
|
| 194 |
with gr.Tab("βοΈ Parameters") as params_tab:
|
| 195 |
+
gr.Markdown(
|
| 196 |
+
"## β¨ Presets\nSelect a preset to load default parameters into the fields below."
|
| 197 |
+
)
|
| 198 |
with gr.Row(equal_height=True):
|
| 199 |
with gr.Column():
|
| 200 |
default_preset = gr.Button("Default", variant="secondary")
|
|
|
|
| 220 |
"Present for multilingual MT models such as <a href='https://huggingface.co/facebook/nllb-200-distilled-600M' target='_blank'>NLLB</a> and <a href='https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt' target='_blank'>mBART</a> using language tags."
|
| 221 |
)
|
| 222 |
with gr.Column(scale=1):
|
| 223 |
+
chatml_template = gr.Button("Qwen ChatML", variant="secondary")
|
| 224 |
gr.Markdown(
|
| 225 |
"Preset for models using the <a href='https://github.com/MicrosoftDocs/azure-docs/blob/main/articles/ai-services/openai/includes/chat-markup-language.md' target='_blank'>ChatML conversational template</a>.\nUses <code><|im_start|></code>, <code><|im_end|></code> special tokens."
|
| 226 |
)
|
|
|
|
| 403 |
gr.Markdown(how_to_use)
|
| 404 |
gr.Markdown(citation)
|
| 405 |
|
| 406 |
+
# Main logic
|
| 407 |
+
|
| 408 |
+
load_model_args = [
|
| 409 |
+
model_name_or_path,
|
| 410 |
+
attribution_method,
|
| 411 |
+
model_kwargs,
|
| 412 |
+
tokenizer_kwargs,
|
| 413 |
+
]
|
| 414 |
+
|
| 415 |
attribute_input_button.click(
|
| 416 |
pecore,
|
| 417 |
inputs=[
|
|
|
|
| 446 |
|
| 447 |
load_model_button.click(
|
| 448 |
preload_model,
|
| 449 |
+
inputs=load_model_args,
|
| 450 |
outputs=[],
|
| 451 |
)
|
| 452 |
|
|
|
|
| 472 |
|
| 473 |
# Presets
|
| 474 |
|
| 475 |
+
default_preset.click(**reset_kwargs).success(preload_model, inputs=load_model_args)
|
| 476 |
+
|
| 477 |
cora_preset.click(**reset_kwargs).then(
|
| 478 |
set_cora_preset,
|
| 479 |
outputs=[model_name_or_path, input_template, contextless_input_current_text],
|
| 480 |
+
).success(preload_model, inputs=load_model_args)
|
| 481 |
+
|
| 482 |
zephyr_preset.click(**reset_kwargs).then(
|
| 483 |
set_zephyr_preset,
|
| 484 |
outputs=[
|
|
|
|
| 487 |
contextless_input_current_text,
|
| 488 |
decoder_input_output_separator,
|
| 489 |
],
|
| 490 |
+
).success(preload_model, inputs=load_model_args)
|
| 491 |
+
|
| 492 |
multilingual_mt_template.click(**reset_kwargs).then(
|
| 493 |
set_mmt_preset,
|
| 494 |
outputs=[model_name_or_path, input_template, output_template, tokenizer_kwargs],
|
| 495 |
+
).success(preload_model, inputs=load_model_args)
|
| 496 |
+
|
| 497 |
chatml_template.click(**reset_kwargs).then(
|
| 498 |
set_chatml_preset,
|
| 499 |
outputs=[
|
|
|
|
| 503 |
decoder_input_output_separator,
|
| 504 |
special_tokens_to_keep,
|
| 505 |
],
|
| 506 |
+
).success(preload_model, inputs=load_model_args)
|
| 507 |
+
|
| 508 |
towerinstruct_template.click(**reset_kwargs).then(
|
| 509 |
set_towerinstruct_preset,
|
| 510 |
outputs=[
|
|
|
|
| 513 |
contextless_input_current_text,
|
| 514 |
decoder_input_output_separator,
|
| 515 |
],
|
| 516 |
+
).success(preload_model, inputs=load_model_args)
|
| 517 |
|
| 518 |
demo.launch(allowed_paths=["outputs/"])
|