Spaces:
Runtime error
Runtime error
Commit
·
b43c18e
1
Parent(s):
5b1d132
Update with h2oGPT hash 05d3ad444971c24fb021ea80c27f867c7a953699
Browse files- client_test.py +4 -2
- finetune.py +60 -10
- generate.py +98 -83
- gradio_runner.py +26 -10
- prompter.py +6 -5
- requirements.txt +1 -1
- stopping.py +49 -6
client_test.py
CHANGED
|
@@ -53,13 +53,16 @@ def get_client():
|
|
| 53 |
|
| 54 |
|
| 55 |
def test_client_basic():
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
instruction = '' # only for chat=True
|
| 57 |
iinput = '' # only for chat=True
|
| 58 |
context = ''
|
| 59 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
| 60 |
# but leave stream_output=False for simple input/output mode
|
| 61 |
stream_output = False
|
| 62 |
-
prompt_type = 'human_bot'
|
| 63 |
temperature = 0.1
|
| 64 |
top_p = 0.75
|
| 65 |
top_k = 40
|
|
@@ -73,7 +76,6 @@ def test_client_basic():
|
|
| 73 |
do_sample = True
|
| 74 |
# only these 2 below used if pass chat=False
|
| 75 |
chat = False
|
| 76 |
-
instruction_nochat = "Who are you?"
|
| 77 |
iinput_nochat = ''
|
| 78 |
|
| 79 |
args = [instruction,
|
|
|
|
| 53 |
|
| 54 |
|
| 55 |
def test_client_basic():
|
| 56 |
+
return run_client_basic(instruction_nochat='Who are you?', prompt_type='human_bot')
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def run_client_basic(instruction_nochat, prompt_type):
|
| 60 |
instruction = '' # only for chat=True
|
| 61 |
iinput = '' # only for chat=True
|
| 62 |
context = ''
|
| 63 |
# streaming output is supported, loops over and outputs each generation in streaming mode
|
| 64 |
# but leave stream_output=False for simple input/output mode
|
| 65 |
stream_output = False
|
|
|
|
| 66 |
temperature = 0.1
|
| 67 |
top_p = 0.75
|
| 68 |
top_k = 40
|
|
|
|
| 76 |
do_sample = True
|
| 77 |
# only these 2 below used if pass chat=False
|
| 78 |
chat = False
|
|
|
|
| 79 |
iinput_nochat = ''
|
| 80 |
|
| 81 |
args = [instruction,
|
finetune.py
CHANGED
|
@@ -28,6 +28,8 @@ class PromptType(Enum):
|
|
| 28 |
instruct_vicuna = 7
|
| 29 |
instruct_with_end = 8
|
| 30 |
human_bot_orig = 9
|
|
|
|
|
|
|
| 31 |
|
| 32 |
|
| 33 |
prompt_type_to_model_name = {
|
|
@@ -46,6 +48,14 @@ prompt_type_to_model_name = {
|
|
| 46 |
'philschmid/flan-t5-base-samsum',
|
| 47 |
'gpt2',
|
| 48 |
'distilgpt2',
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
],
|
| 50 |
'instruct': [],
|
| 51 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
@@ -61,14 +71,12 @@ prompt_type_to_model_name = {
|
|
| 61 |
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
| 62 |
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
|
| 63 |
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
|
|
|
| 64 |
}
|
| 65 |
|
| 66 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
| 67 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
| 68 |
|
| 69 |
-
human = '<human>:'
|
| 70 |
-
bot = "<bot>:"
|
| 71 |
-
|
| 72 |
prompt_types_strings = []
|
| 73 |
for p in PromptType:
|
| 74 |
prompt_types_strings.extend([p.name])
|
|
@@ -277,8 +285,13 @@ def train(
|
|
| 277 |
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
| 278 |
)
|
| 279 |
|
| 280 |
-
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
lora_mappings['distilgpt2'] = ["c_attn"]
|
| 283 |
|
| 284 |
if lora_weights:
|
|
@@ -730,10 +743,10 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F
|
|
| 730 |
assert prompt_type is not None
|
| 731 |
assert cutoff_len is not None
|
| 732 |
assert tokenizer is not None
|
| 733 |
-
full_prompt, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
| 734 |
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
| 735 |
if not train_on_inputs:
|
| 736 |
-
user_prompt, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
| 737 |
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
| 738 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
| 739 |
if add_eos_token:
|
|
@@ -752,9 +765,11 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
| 752 |
if prompt_type in [-1, "-1", "plain"]:
|
| 753 |
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
| 754 |
terminate_response = []
|
|
|
|
| 755 |
elif prompt_type == 'simple_instruct':
|
| 756 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
| 757 |
terminate_response = []
|
|
|
|
| 758 |
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
| 759 |
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
| 760 |
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
|
@@ -774,6 +789,7 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
| 774 |
terminate_response = ['### End']
|
| 775 |
else:
|
| 776 |
terminate_response = None
|
|
|
|
| 777 |
elif prompt_type in [1, "1", "quality"]:
|
| 778 |
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
|
| 779 |
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
|
|
@@ -790,7 +806,10 @@ def get_prompt(prompt_type, chat, context, reduced):
|
|
| 790 |
### Response:
|
| 791 |
"""
|
| 792 |
terminate_response = None
|
|
|
|
| 793 |
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
|
|
|
|
|
|
| 794 |
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
| 795 |
preprompt = ''
|
| 796 |
else:
|
|
@@ -819,6 +838,7 @@ Current Time: {}
|
|
| 819 |
PreResponse = bot
|
| 820 |
|
| 821 |
terminate_response = [start, PreResponse]
|
|
|
|
| 822 |
elif prompt_type in [3, "3", "dai_faq"]:
|
| 823 |
promptA = ''
|
| 824 |
promptB = 'Answer the following Driverless AI question.\n'
|
|
@@ -833,11 +853,13 @@ Current Time: {}
|
|
| 833 |
### Driverless AI documentation answer:
|
| 834 |
"""
|
| 835 |
terminate_response = ['\n\n']
|
|
|
|
| 836 |
elif prompt_type in [5, "5", "summarize"]:
|
| 837 |
promptA = promptB = PreInput = ''
|
| 838 |
PreInstruct = '## Main Text\n\n'
|
| 839 |
PreResponse = '\n\n## Summary\n\n'
|
| 840 |
terminate_response = None
|
|
|
|
| 841 |
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
| 842 |
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
| 843 |
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
|
|
@@ -852,10 +874,37 @@ Current Time: {}
|
|
| 852 |
### Assistant:
|
| 853 |
"""
|
| 854 |
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 855 |
else:
|
| 856 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
| 857 |
|
| 858 |
-
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response
|
| 859 |
|
| 860 |
|
| 861 |
def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
@@ -867,7 +916,8 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
| 867 |
output = data_point.get('output')
|
| 868 |
prompt_type = data_point.get('prompt_type', prompt_type)
|
| 869 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
| 870 |
-
promptA, promptB, PreInstruct, PreInput, PreResponse,
|
|
|
|
| 871 |
|
| 872 |
prompt = context if not reduced else ''
|
| 873 |
|
|
@@ -919,7 +969,7 @@ def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
| 919 |
if output:
|
| 920 |
prompt += f"""{output}"""
|
| 921 |
|
| 922 |
-
return prompt, pre_response, terminate_response
|
| 923 |
|
| 924 |
|
| 925 |
def inject_newline(prompt_type, prompt):
|
|
|
|
| 28 |
instruct_vicuna = 7
|
| 29 |
instruct_with_end = 8
|
| 30 |
human_bot_orig = 9
|
| 31 |
+
prompt_answer = 10
|
| 32 |
+
open_assistant = 11
|
| 33 |
|
| 34 |
|
| 35 |
prompt_type_to_model_name = {
|
|
|
|
| 48 |
'philschmid/flan-t5-base-samsum',
|
| 49 |
'gpt2',
|
| 50 |
'distilgpt2',
|
| 51 |
+
'mosaicml/mpt-7b-storywriter',
|
| 52 |
+
'mosaicml/mpt-7b-instruct', # internal code handles instruct
|
| 53 |
+
'mosaicml/mpt-7b-chat', # NC, internal code handles instruct
|
| 54 |
+
],
|
| 55 |
+
'prompt_answer': [
|
| 56 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-20b',
|
| 57 |
+
'h2oai/h2ogpt-gm-oasst1-en-1024-12b',
|
| 58 |
+
'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b',
|
| 59 |
],
|
| 60 |
'instruct': [],
|
| 61 |
'instruct_with_end': ['databricks/dolly-v2-12b'],
|
|
|
|
| 71 |
'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'],
|
| 72 |
'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b'],
|
| 73 |
'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'],
|
| 74 |
+
"open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'],
|
| 75 |
}
|
| 76 |
|
| 77 |
inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
| 78 |
inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l}
|
| 79 |
|
|
|
|
|
|
|
|
|
|
| 80 |
prompt_types_strings = []
|
| 81 |
for p in PromptType:
|
| 82 |
prompt_types_strings.extend([p.name])
|
|
|
|
| 285 |
layer_norm_names=["layer_norm", "layernorm"], # keep all layer norms in higher precision
|
| 286 |
)
|
| 287 |
|
| 288 |
+
from peft import LoraConfig, get_peft_model, set_peft_model_state_dict
|
| 289 |
+
try:
|
| 290 |
+
from peft import utils
|
| 291 |
+
lora_mappings = utils.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
| 292 |
+
except AttributeError:
|
| 293 |
+
from peft import mapping
|
| 294 |
+
lora_mappings = mapping.TRANSFORMERS_MODELS_TO_LORA_TARGET_MODULES_MAPPING.copy()
|
| 295 |
lora_mappings['distilgpt2'] = ["c_attn"]
|
| 296 |
|
| 297 |
if lora_weights:
|
|
|
|
| 743 |
assert prompt_type is not None
|
| 744 |
assert cutoff_len is not None
|
| 745 |
assert tokenizer is not None
|
| 746 |
+
full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False)
|
| 747 |
tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
| 748 |
if not train_on_inputs:
|
| 749 |
+
user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False)
|
| 750 |
tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token)
|
| 751 |
user_prompt_len = len(tokenized_user_prompt["input_ids"])
|
| 752 |
if add_eos_token:
|
|
|
|
| 765 |
if prompt_type in [-1, "-1", "plain"]:
|
| 766 |
promptA = promptB = PreInstruct = PreInput = PreResponse = ''
|
| 767 |
terminate_response = []
|
| 768 |
+
chat_sep = ''
|
| 769 |
elif prompt_type == 'simple_instruct':
|
| 770 |
promptA = promptB = PreInstruct = PreInput = PreResponse = None
|
| 771 |
terminate_response = []
|
| 772 |
+
chat_sep = '\n'
|
| 773 |
elif prompt_type in [0, "0", "instruct"] or prompt_type in [7, "7", "instruct_with_end"]:
|
| 774 |
promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
| 775 |
promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not (chat and reduced) else ''
|
|
|
|
| 789 |
terminate_response = ['### End']
|
| 790 |
else:
|
| 791 |
terminate_response = None
|
| 792 |
+
chat_sep = '\n'
|
| 793 |
elif prompt_type in [1, "1", "quality"]:
|
| 794 |
promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not (chat and reduced) else ''
|
| 795 |
promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not (chat and reduced) else ''
|
|
|
|
| 806 |
### Response:
|
| 807 |
"""
|
| 808 |
terminate_response = None
|
| 809 |
+
chat_sep = '\n'
|
| 810 |
elif prompt_type in [2, "2", "human_bot", 9, "9", "human_bot_orig"]:
|
| 811 |
+
human = '<human>:'
|
| 812 |
+
bot = "<bot>:"
|
| 813 |
if reduced or context or prompt_type in [2, "2", "human_bot"]:
|
| 814 |
preprompt = ''
|
| 815 |
else:
|
|
|
|
| 838 |
PreResponse = bot
|
| 839 |
|
| 840 |
terminate_response = [start, PreResponse]
|
| 841 |
+
chat_sep = '\n'
|
| 842 |
elif prompt_type in [3, "3", "dai_faq"]:
|
| 843 |
promptA = ''
|
| 844 |
promptB = 'Answer the following Driverless AI question.\n'
|
|
|
|
| 853 |
### Driverless AI documentation answer:
|
| 854 |
"""
|
| 855 |
terminate_response = ['\n\n']
|
| 856 |
+
chat_sep = terminate_response
|
| 857 |
elif prompt_type in [5, "5", "summarize"]:
|
| 858 |
promptA = promptB = PreInput = ''
|
| 859 |
PreInstruct = '## Main Text\n\n'
|
| 860 |
PreResponse = '\n\n## Summary\n\n'
|
| 861 |
terminate_response = None
|
| 862 |
+
chat_sep = '\n'
|
| 863 |
elif prompt_type in [6, "6", "instruct_vicuna"]:
|
| 864 |
promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \
|
| 865 |
"The assistant gives helpful, detailed, and polite answers to the human's questions." if not (chat and reduced) else ''
|
|
|
|
| 874 |
### Assistant:
|
| 875 |
"""
|
| 876 |
terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate
|
| 877 |
+
chat_sep = '\n'
|
| 878 |
+
elif prompt_type in [10, "10", "prompt_answer"]:
|
| 879 |
+
preprompt = ''
|
| 880 |
+
prompt_tokens = "<|prompt|>"
|
| 881 |
+
answer_tokens = "<|answer|>"
|
| 882 |
+
start = prompt_tokens
|
| 883 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 884 |
+
PreInstruct = ""
|
| 885 |
+
PreInput = None
|
| 886 |
+
PreResponse = answer_tokens
|
| 887 |
+
eos = '<|endoftext|>' # neox eos
|
| 888 |
+
terminate_response = [start, PreResponse, eos]
|
| 889 |
+
chat_sep = eos
|
| 890 |
+
elif prompt_type in [11, "11", "open_assistant"]:
|
| 891 |
+
# From added_tokens.json
|
| 892 |
+
preprompt = ''
|
| 893 |
+
prompt_tokens = "<|prompter|>"
|
| 894 |
+
answer_tokens = "<|assistant|>"
|
| 895 |
+
start = prompt_tokens
|
| 896 |
+
promptB = promptA = '%s%s' % (preprompt, start)
|
| 897 |
+
PreInstruct = ""
|
| 898 |
+
PreInput = None
|
| 899 |
+
PreResponse = answer_tokens
|
| 900 |
+
pend = "<|prefix_end|>"
|
| 901 |
+
eos = "</s>"
|
| 902 |
+
terminate_response = [start, PreResponse, pend, eos]
|
| 903 |
+
chat_sep = eos
|
| 904 |
else:
|
| 905 |
raise RuntimeError("No such prompt_type=%s" % prompt_type)
|
| 906 |
|
| 907 |
+
return promptA, promptB, PreInstruct, PreInput, PreResponse, terminate_response, chat_sep
|
| 908 |
|
| 909 |
|
| 910 |
def generate_prompt(data_point, prompt_type, chat, reduced):
|
|
|
|
| 916 |
output = data_point.get('output')
|
| 917 |
prompt_type = data_point.get('prompt_type', prompt_type)
|
| 918 |
assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type
|
| 919 |
+
promptA, promptB, PreInstruct, PreInput, PreResponse, \
|
| 920 |
+
terminate_response, chat_sep = get_prompt(prompt_type, chat, context, reduced)
|
| 921 |
|
| 922 |
prompt = context if not reduced else ''
|
| 923 |
|
|
|
|
| 969 |
if output:
|
| 970 |
prompt += f"""{output}"""
|
| 971 |
|
| 972 |
+
return prompt, pre_response, terminate_response, chat_sep
|
| 973 |
|
| 974 |
|
| 975 |
def inject_newline(prompt_type, prompt):
|
generate.py
CHANGED
|
@@ -9,7 +9,7 @@ from datetime import datetime
|
|
| 9 |
import filelock
|
| 10 |
import psutil
|
| 11 |
|
| 12 |
-
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread
|
| 13 |
|
| 14 |
SEED = 1236
|
| 15 |
set_seed(SEED)
|
|
@@ -22,13 +22,13 @@ import pandas as pd
|
|
| 22 |
import fire
|
| 23 |
import torch
|
| 24 |
from peft import PeftModel
|
| 25 |
-
from transformers import GenerationConfig,
|
| 26 |
from accelerate import init_empty_weights, infer_auto_device_map
|
| 27 |
|
| 28 |
from prompter import Prompter
|
| 29 |
|
| 30 |
-
from finetune import get_loaders, example_data_points, generate_prompt,
|
| 31 |
-
from stopping import
|
| 32 |
|
| 33 |
eval_extra_columns = ['prompt', 'response', 'score']
|
| 34 |
|
|
@@ -62,6 +62,7 @@ def main(
|
|
| 62 |
local_files_only: bool = False,
|
| 63 |
resume_download: bool = True,
|
| 64 |
use_auth_token: Union[str, bool] = False,
|
|
|
|
| 65 |
|
| 66 |
src_lang: str = "English",
|
| 67 |
tgt_lang: str = "Russian",
|
|
@@ -124,6 +125,7 @@ def main(
|
|
| 124 |
:param local_files_only: whether to only use local files instead of doing to HF for models
|
| 125 |
:param resume_download: whether to resume downloads from HF for models
|
| 126 |
:param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
|
|
|
|
| 127 |
:param src_lang: source languages to include if doing translation (None = all)
|
| 128 |
:param tgt_lang: target languages to include if doing translation (None = all)
|
| 129 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
|
@@ -168,15 +170,22 @@ def main(
|
|
| 168 |
|
| 169 |
if is_public:
|
| 170 |
input_lines = 1 # ensure set, for ease of use
|
| 171 |
-
temperature = 0.2
|
| 172 |
-
top_p = 0.85
|
| 173 |
-
top_k = 70
|
| 174 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 175 |
if is_low_mem:
|
| 176 |
-
|
| 177 |
-
|
|
|
|
|
|
|
| 178 |
else:
|
| 179 |
-
base_model = 'h2oai/h2ogpt-oasst1-512-20b'
|
| 180 |
if is_low_mem:
|
| 181 |
load_8bit = True
|
| 182 |
if is_hf:
|
|
@@ -229,6 +238,11 @@ def main(
|
|
| 229 |
do_sample,
|
| 230 |
)
|
| 231 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 232 |
if not gradio:
|
| 233 |
if eval_sharegpt_prompts_only > 0:
|
| 234 |
# override default examples with shareGPT ones for human-level eval purposes only
|
|
@@ -416,7 +430,11 @@ def get_device():
|
|
| 416 |
|
| 417 |
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
| 418 |
gpu_id=0,
|
| 419 |
-
use_auth_token=False
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
"""
|
| 421 |
Ensure model gets on correct device
|
| 422 |
:param base_model:
|
|
@@ -426,29 +444,47 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
| 426 |
:param reward_type:
|
| 427 |
:param gpu_id:
|
| 428 |
:param use_auth_token:
|
|
|
|
|
|
|
|
|
|
| 429 |
:return:
|
| 430 |
"""
|
| 431 |
with init_empty_weights():
|
| 432 |
from transformers import AutoConfig
|
| 433 |
-
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token
|
| 434 |
-
|
| 435 |
-
|
| 436 |
-
|
| 437 |
-
|
| 438 |
-
|
| 439 |
-
|
| 440 |
-
|
| 441 |
-
|
| 442 |
-
|
| 443 |
-
|
| 444 |
-
|
| 445 |
-
|
| 446 |
-
|
| 447 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 448 |
dtype=torch.float16 if load_half else torch.float32,
|
| 449 |
)
|
| 450 |
-
|
| 451 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 452 |
|
| 453 |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
| 454 |
|
|
@@ -472,11 +508,13 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward
|
|
| 472 |
if load_in_8bit or not load_half:
|
| 473 |
model = model_loader.from_pretrained(
|
| 474 |
base_model,
|
|
|
|
| 475 |
**model_kwargs,
|
| 476 |
)
|
| 477 |
else:
|
| 478 |
model = model_loader.from_pretrained(
|
| 479 |
base_model,
|
|
|
|
| 480 |
**model_kwargs,
|
| 481 |
).half()
|
| 482 |
return model
|
|
@@ -495,6 +533,7 @@ def get_model(
|
|
| 495 |
local_files_only: bool = False,
|
| 496 |
resume_download: bool = True,
|
| 497 |
use_auth_token: Union[str, bool] = False,
|
|
|
|
| 498 |
compile: bool = True,
|
| 499 |
**kwargs,
|
| 500 |
):
|
|
@@ -513,6 +552,7 @@ def get_model(
|
|
| 513 |
:param local_files_only: use local files instead of from HF
|
| 514 |
:param resume_download: resume downloads from HF
|
| 515 |
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
|
|
|
| 516 |
:param compile: whether to compile torch model
|
| 517 |
:param kwargs:
|
| 518 |
:return:
|
|
@@ -531,7 +571,8 @@ def get_model(
|
|
| 531 |
)
|
| 532 |
|
| 533 |
from transformers import AutoConfig
|
| 534 |
-
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token
|
|
|
|
| 535 |
llama_type_from_config = 'llama' in str(config).lower()
|
| 536 |
llama_type_from_name = "llama" in base_model.lower()
|
| 537 |
llama_type = llama_type_from_config or llama_type_from_name
|
|
@@ -548,6 +589,7 @@ def get_model(
|
|
| 548 |
local_files_only=local_files_only,
|
| 549 |
resume_download=resume_download,
|
| 550 |
use_auth_token=use_auth_token,
|
|
|
|
| 551 |
)
|
| 552 |
else:
|
| 553 |
tokenizer = tokenizer_loader
|
|
@@ -563,13 +605,18 @@ def get_model(
|
|
| 563 |
model_kwargs = dict(local_files_only=local_files_only,
|
| 564 |
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
| 565 |
resume_download=resume_download,
|
| 566 |
-
use_auth_token=use_auth_token
|
| 567 |
-
|
|
|
|
|
|
|
| 568 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
| 569 |
device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
|
| 570 |
))
|
|
|
|
|
|
|
|
|
|
| 571 |
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
| 572 |
-
# could put on other GPUs
|
| 573 |
model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
|
| 574 |
model_kwargs.pop('torch_dtype', None)
|
| 575 |
|
|
@@ -577,7 +624,10 @@ def get_model(
|
|
| 577 |
with torch.device(device):
|
| 578 |
if infer_devices:
|
| 579 |
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
| 580 |
-
gpu_id=gpu_id,
|
|
|
|
|
|
|
|
|
|
| 581 |
else:
|
| 582 |
if load_half and not load_8bit:
|
| 583 |
model = model_loader.from_pretrained(
|
|
@@ -599,6 +649,7 @@ def get_model(
|
|
| 599 |
local_files_only=local_files_only,
|
| 600 |
resume_download=resume_download,
|
| 601 |
use_auth_token=use_auth_token,
|
|
|
|
| 602 |
device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
|
| 603 |
)
|
| 604 |
else:
|
|
@@ -614,6 +665,7 @@ def get_model(
|
|
| 614 |
local_files_only=local_files_only,
|
| 615 |
resume_download=resume_download,
|
| 616 |
use_auth_token=use_auth_token,
|
|
|
|
| 617 |
device_map="auto",
|
| 618 |
)
|
| 619 |
if load_half:
|
|
@@ -782,49 +834,7 @@ def evaluate(
|
|
| 782 |
if chat:
|
| 783 |
# override, ignore user change
|
| 784 |
num_return_sequences = 1
|
| 785 |
-
|
| 786 |
-
if prompt_type == 'human_bot':
|
| 787 |
-
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
| 788 |
-
# stopping only starts once output is beyond prompt
|
| 789 |
-
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
| 790 |
-
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
| 791 |
-
encounters = [1, 2]
|
| 792 |
-
elif prompt_type == 'instruct_vicuna':
|
| 793 |
-
# even below is not enough, generic strings and many ways to encode
|
| 794 |
-
stop_words = [
|
| 795 |
-
'### Human:',
|
| 796 |
-
"""
|
| 797 |
-
### Human:""",
|
| 798 |
-
"""
|
| 799 |
-
### Human:
|
| 800 |
-
""",
|
| 801 |
-
'### Assistant:',
|
| 802 |
-
"""
|
| 803 |
-
### Assistant:""",
|
| 804 |
-
"""
|
| 805 |
-
### Assistant:
|
| 806 |
-
""",
|
| 807 |
-
]
|
| 808 |
-
encounters = [1, 2]
|
| 809 |
-
else:
|
| 810 |
-
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
| 811 |
-
stop_words = ['### End']
|
| 812 |
-
encounters = [1]
|
| 813 |
-
stop_words_ids = [
|
| 814 |
-
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
| 815 |
-
# handle single token case
|
| 816 |
-
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
| 817 |
-
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
| 818 |
-
# avoid padding in front of tokens
|
| 819 |
-
if tokenizer.pad_token:
|
| 820 |
-
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
| 821 |
-
# handle fake \n added
|
| 822 |
-
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
| 823 |
-
# build stopper
|
| 824 |
-
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
|
| 825 |
-
else:
|
| 826 |
-
stopping_criteria = StoppingCriteriaList()
|
| 827 |
-
|
| 828 |
# help to avoid errors like:
|
| 829 |
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
| 830 |
# RuntimeError: expected scalar type Half but found Float
|
|
@@ -903,7 +913,10 @@ def evaluate(
|
|
| 903 |
prompt = inputs_decoded
|
| 904 |
elif inputs_decoded_raw == prompt:
|
| 905 |
# some models specify special tokens that are part of normal prompt, so can't skip them
|
| 906 |
-
|
|
|
|
|
|
|
|
|
|
| 907 |
decoder = decoder_raw
|
| 908 |
else:
|
| 909 |
print("WARNING: Special characters in prompt", flush=True)
|
|
@@ -1046,6 +1059,7 @@ def get_generate_params(model_lower, chat,
|
|
| 1046 |
|
| 1047 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
| 1048 |
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
|
|
|
| 1049 |
|
| 1050 |
# examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
|
| 1051 |
if show_examples is None:
|
|
@@ -1104,7 +1118,8 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
|
|
| 1104 |
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
| 1105 |
placeholder_input = ""
|
| 1106 |
if model_lower:
|
| 1107 |
-
|
|
|
|
| 1108 |
else:
|
| 1109 |
prompt_type = ''
|
| 1110 |
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
|
@@ -1133,9 +1148,9 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
|
|
| 1133 |
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
| 1134 |
do_sample = False if do_sample is None else do_sample
|
| 1135 |
else:
|
| 1136 |
-
temperature = 0.
|
| 1137 |
-
top_p = 0.
|
| 1138 |
-
top_k =
|
| 1139 |
if chat:
|
| 1140 |
num_beams = num_beams or 1
|
| 1141 |
else:
|
|
@@ -1143,7 +1158,7 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa
|
|
| 1143 |
max_new_tokens = max_new_tokens or 256
|
| 1144 |
repetition_penalty = repetition_penalty or 1.07
|
| 1145 |
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
| 1146 |
-
do_sample =
|
| 1147 |
# doesn't include chat, instruction_nochat, iinput_nochat, added later
|
| 1148 |
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
|
| 1149 |
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
|
|
|
|
| 9 |
import filelock
|
| 10 |
import psutil
|
| 11 |
|
| 12 |
+
from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash
|
| 13 |
|
| 14 |
SEED = 1236
|
| 15 |
set_seed(SEED)
|
|
|
|
| 22 |
import fire
|
| 23 |
import torch
|
| 24 |
from peft import PeftModel
|
| 25 |
+
from transformers import GenerationConfig, AutoModel, TextIteratorStreamer
|
| 26 |
from accelerate import init_empty_weights, infer_auto_device_map
|
| 27 |
|
| 28 |
from prompter import Prompter
|
| 29 |
|
| 30 |
+
from finetune import get_loaders, example_data_points, generate_prompt, inv_prompt_type_to_model_lower
|
| 31 |
+
from stopping import get_stopping
|
| 32 |
|
| 33 |
eval_extra_columns = ['prompt', 'response', 'score']
|
| 34 |
|
|
|
|
| 62 |
local_files_only: bool = False,
|
| 63 |
resume_download: bool = True,
|
| 64 |
use_auth_token: Union[str, bool] = False,
|
| 65 |
+
trust_remote_code: Union[str, bool] = True,
|
| 66 |
|
| 67 |
src_lang: str = "English",
|
| 68 |
tgt_lang: str = "Russian",
|
|
|
|
| 125 |
:param local_files_only: whether to only use local files instead of doing to HF for models
|
| 126 |
:param resume_download: whether to resume downloads from HF for models
|
| 127 |
:param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before)
|
| 128 |
+
:param trust_remote_code: whether to use trust any code needed for HF model
|
| 129 |
:param src_lang: source languages to include if doing translation (None = all)
|
| 130 |
:param tgt_lang: target languages to include if doing translation (None = all)
|
| 131 |
:param gradio: whether to enable gradio, or to enable benchmark mode
|
|
|
|
| 170 |
|
| 171 |
if is_public:
|
| 172 |
input_lines = 1 # ensure set, for ease of use
|
| 173 |
+
temperature = 0.2 if temperature is None else temperature
|
| 174 |
+
top_p = 0.85 if top_p is None else top_p
|
| 175 |
+
top_k = 70 if top_k is None else top_k
|
| 176 |
+
if is_hf:
|
| 177 |
+
do_sample = True if do_sample is None else do_sample
|
| 178 |
+
else:
|
| 179 |
+
# by default don't sample, too chatty
|
| 180 |
+
do_sample = False if do_sample is None else do_sample
|
| 181 |
+
|
| 182 |
if is_low_mem:
|
| 183 |
+
if not base_model:
|
| 184 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-12b'
|
| 185 |
+
# don't set load_8bit if passed base_model, doesn't always work so can't just override
|
| 186 |
+
load_8bit = True
|
| 187 |
else:
|
| 188 |
+
base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model
|
| 189 |
if is_low_mem:
|
| 190 |
load_8bit = True
|
| 191 |
if is_hf:
|
|
|
|
| 238 |
do_sample,
|
| 239 |
)
|
| 240 |
|
| 241 |
+
locals_dict = locals()
|
| 242 |
+
locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()])
|
| 243 |
+
print(f"Generating model with params:\n{locals_print}", flush=True)
|
| 244 |
+
print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True)
|
| 245 |
+
|
| 246 |
if not gradio:
|
| 247 |
if eval_sharegpt_prompts_only > 0:
|
| 248 |
# override default examples with shareGPT ones for human-level eval purposes only
|
|
|
|
| 430 |
|
| 431 |
def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
| 432 |
gpu_id=0,
|
| 433 |
+
use_auth_token=False,
|
| 434 |
+
trust_remote_code=True,
|
| 435 |
+
triton_attn=False,
|
| 436 |
+
long_sequence=True,
|
| 437 |
+
):
|
| 438 |
"""
|
| 439 |
Ensure model gets on correct device
|
| 440 |
:param base_model:
|
|
|
|
| 444 |
:param reward_type:
|
| 445 |
:param gpu_id:
|
| 446 |
:param use_auth_token:
|
| 447 |
+
:param trust_remote_code:
|
| 448 |
+
:param triton_attn:
|
| 449 |
+
:param long_sequence:
|
| 450 |
:return:
|
| 451 |
"""
|
| 452 |
with init_empty_weights():
|
| 453 |
from transformers import AutoConfig
|
| 454 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
| 455 |
+
trust_remote_code=trust_remote_code)
|
| 456 |
+
if triton_attn and 'mpt-' in base_model.lower():
|
| 457 |
+
config.attn_config['attn_impl'] = 'triton'
|
| 458 |
+
if long_sequence:
|
| 459 |
+
if 'mpt-7b-storywriter' in base_model.lower():
|
| 460 |
+
config.update({"max_seq_len": 83968})
|
| 461 |
+
if 'mosaicml/mpt-7b-chat' in base_model.lower():
|
| 462 |
+
config.update({"max_seq_len": 4096})
|
| 463 |
+
if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())):
|
| 464 |
+
model = AutoModel.from_config(
|
| 465 |
+
config,
|
| 466 |
+
)
|
| 467 |
+
else:
|
| 468 |
+
# can't infer
|
| 469 |
+
model = None
|
| 470 |
+
|
| 471 |
+
if model is not None:
|
| 472 |
+
# NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model
|
| 473 |
+
# NOTE: Some models require avoiding sharding some layers,
|
| 474 |
+
# then would pass no_split_module_classes and give list of those layers.
|
| 475 |
+
device_map = infer_auto_device_map(
|
| 476 |
+
model,
|
| 477 |
dtype=torch.float16 if load_half else torch.float32,
|
| 478 |
)
|
| 479 |
+
if hasattr(model, 'model'):
|
| 480 |
+
device_map_model = infer_auto_device_map(
|
| 481 |
+
model.model,
|
| 482 |
+
dtype=torch.float16 if load_half else torch.float32,
|
| 483 |
+
)
|
| 484 |
+
device_map.update(device_map_model)
|
| 485 |
+
print('device_map: %s' % device_map, flush=True)
|
| 486 |
+
else:
|
| 487 |
+
device_map = "auto"
|
| 488 |
|
| 489 |
n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0
|
| 490 |
|
|
|
|
| 508 |
if load_in_8bit or not load_half:
|
| 509 |
model = model_loader.from_pretrained(
|
| 510 |
base_model,
|
| 511 |
+
config=config,
|
| 512 |
**model_kwargs,
|
| 513 |
)
|
| 514 |
else:
|
| 515 |
model = model_loader.from_pretrained(
|
| 516 |
base_model,
|
| 517 |
+
config=config,
|
| 518 |
**model_kwargs,
|
| 519 |
).half()
|
| 520 |
return model
|
|
|
|
| 533 |
local_files_only: bool = False,
|
| 534 |
resume_download: bool = True,
|
| 535 |
use_auth_token: Union[str, bool] = False,
|
| 536 |
+
trust_remote_code: bool = True,
|
| 537 |
compile: bool = True,
|
| 538 |
**kwargs,
|
| 539 |
):
|
|
|
|
| 552 |
:param local_files_only: use local files instead of from HF
|
| 553 |
:param resume_download: resume downloads from HF
|
| 554 |
:param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo
|
| 555 |
+
:param trust_remote_code: trust code needed by model
|
| 556 |
:param compile: whether to compile torch model
|
| 557 |
:param kwargs:
|
| 558 |
:return:
|
|
|
|
| 571 |
)
|
| 572 |
|
| 573 |
from transformers import AutoConfig
|
| 574 |
+
config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token,
|
| 575 |
+
trust_remote_code=trust_remote_code)
|
| 576 |
llama_type_from_config = 'llama' in str(config).lower()
|
| 577 |
llama_type_from_name = "llama" in base_model.lower()
|
| 578 |
llama_type = llama_type_from_config or llama_type_from_name
|
|
|
|
| 589 |
local_files_only=local_files_only,
|
| 590 |
resume_download=resume_download,
|
| 591 |
use_auth_token=use_auth_token,
|
| 592 |
+
trust_remote_code=trust_remote_code,
|
| 593 |
)
|
| 594 |
else:
|
| 595 |
tokenizer = tokenizer_loader
|
|
|
|
| 605 |
model_kwargs = dict(local_files_only=local_files_only,
|
| 606 |
torch_dtype=torch.float16 if device == 'cuda' else torch.float32,
|
| 607 |
resume_download=resume_download,
|
| 608 |
+
use_auth_token=use_auth_token,
|
| 609 |
+
trust_remote_code=trust_remote_code,
|
| 610 |
+
)
|
| 611 |
+
if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower():
|
| 612 |
model_kwargs.update(dict(load_in_8bit=load_8bit,
|
| 613 |
device_map={"": 0} if load_8bit and device == 'cuda' else "auto",
|
| 614 |
))
|
| 615 |
+
if 'mpt-' in base_model.lower() and gpu_id >= 0:
|
| 616 |
+
model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu"))
|
| 617 |
+
|
| 618 |
if 'OpenAssistant/reward-model'.lower() in base_model.lower():
|
| 619 |
+
# FIXME: could put on other GPUs
|
| 620 |
model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'}
|
| 621 |
model_kwargs.pop('torch_dtype', None)
|
| 622 |
|
|
|
|
| 624 |
with torch.device(device):
|
| 625 |
if infer_devices:
|
| 626 |
model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type,
|
| 627 |
+
gpu_id=gpu_id,
|
| 628 |
+
use_auth_token=use_auth_token,
|
| 629 |
+
trust_remote_code=trust_remote_code,
|
| 630 |
+
)
|
| 631 |
else:
|
| 632 |
if load_half and not load_8bit:
|
| 633 |
model = model_loader.from_pretrained(
|
|
|
|
| 649 |
local_files_only=local_files_only,
|
| 650 |
resume_download=resume_download,
|
| 651 |
use_auth_token=use_auth_token,
|
| 652 |
+
trust_remote_code=trust_remote_code,
|
| 653 |
device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required
|
| 654 |
)
|
| 655 |
else:
|
|
|
|
| 665 |
local_files_only=local_files_only,
|
| 666 |
resume_download=resume_download,
|
| 667 |
use_auth_token=use_auth_token,
|
| 668 |
+
trust_remote_code=trust_remote_code,
|
| 669 |
device_map="auto",
|
| 670 |
)
|
| 671 |
if load_half:
|
|
|
|
| 834 |
if chat:
|
| 835 |
# override, ignore user change
|
| 836 |
num_return_sequences = 1
|
| 837 |
+
stopping_criteria = get_stopping(prompt_type, tokenizer, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 838 |
# help to avoid errors like:
|
| 839 |
# RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3
|
| 840 |
# RuntimeError: expected scalar type Half but found Float
|
|
|
|
| 913 |
prompt = inputs_decoded
|
| 914 |
elif inputs_decoded_raw == prompt:
|
| 915 |
# some models specify special tokens that are part of normal prompt, so can't skip them
|
| 916 |
+
inputs_decoded = prompt = inputs_decoded_raw
|
| 917 |
+
decoder = decoder_raw
|
| 918 |
+
elif inputs_decoded_raw.replace("<unk> ", "").replace("<unk>", "").replace('\n', ' ').replace(' ', '') == prompt.replace('\n', ' ').replace(' ', ''):
|
| 919 |
+
inputs_decoded = prompt = inputs_decoded_raw
|
| 920 |
decoder = decoder_raw
|
| 921 |
else:
|
| 922 |
print("WARNING: Special characters in prompt", flush=True)
|
|
|
|
| 1059 |
|
| 1060 |
if not prompt_type and model_lower in inv_prompt_type_to_model_lower:
|
| 1061 |
prompt_type = inv_prompt_type_to_model_lower[model_lower]
|
| 1062 |
+
print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True)
|
| 1063 |
|
| 1064 |
# examples at first don't include chat, instruction_nochat, iinput_nochat, added at end
|
| 1065 |
if show_examples is None:
|
|
|
|
| 1118 |
placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter."
|
| 1119 |
placeholder_input = ""
|
| 1120 |
if model_lower:
|
| 1121 |
+
# default is plain, because might relly upon trust_remote_code to handle prompting
|
| 1122 |
+
prompt_type = prompt_type or 'plain'
|
| 1123 |
else:
|
| 1124 |
prompt_type = ''
|
| 1125 |
examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else '', "",
|
|
|
|
| 1148 |
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
| 1149 |
do_sample = False if do_sample is None else do_sample
|
| 1150 |
else:
|
| 1151 |
+
temperature = 0.1 if temperature is None else temperature
|
| 1152 |
+
top_p = 0.75 if top_p is None else top_p
|
| 1153 |
+
top_k = 40 if top_k is None else top_k
|
| 1154 |
if chat:
|
| 1155 |
num_beams = num_beams or 1
|
| 1156 |
else:
|
|
|
|
| 1158 |
max_new_tokens = max_new_tokens or 256
|
| 1159 |
repetition_penalty = repetition_penalty or 1.07
|
| 1160 |
num_return_sequences = min(num_beams, num_return_sequences or 1)
|
| 1161 |
+
do_sample = False if do_sample is None else do_sample
|
| 1162 |
# doesn't include chat, instruction_nochat, iinput_nochat, added later
|
| 1163 |
params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens,
|
| 1164 |
early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample]
|
gradio_runner.py
CHANGED
|
@@ -5,6 +5,7 @@ import os
|
|
| 5 |
import sys
|
| 6 |
|
| 7 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
|
|
|
| 8 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
| 9 |
ping
|
| 10 |
from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
|
|
@@ -49,6 +50,7 @@ def go_gradio(**kwargs):
|
|
| 49 |
"""
|
| 50 |
else:
|
| 51 |
description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
|
|
|
|
| 52 |
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
| 53 |
|
| 54 |
if kwargs['verbose']:
|
|
@@ -389,6 +391,7 @@ def go_gradio(**kwargs):
|
|
| 389 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
| 390 |
|
| 391 |
# Get inputs to evaluate()
|
|
|
|
| 392 |
all_kwargs = kwargs.copy()
|
| 393 |
all_kwargs.update(locals())
|
| 394 |
inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
|
|
@@ -516,9 +519,12 @@ def go_gradio(**kwargs):
|
|
| 516 |
:return:
|
| 517 |
"""
|
| 518 |
args_list = list(args)
|
| 519 |
-
user_message = args_list[
|
| 520 |
-
input1 = args_list[
|
| 521 |
-
context1 = args_list[
|
|
|
|
|
|
|
|
|
|
| 522 |
if input1 and not user_message.endswith(':'):
|
| 523 |
user_message1 = user_message + ":" + input1
|
| 524 |
elif input1:
|
|
@@ -528,6 +534,8 @@ def go_gradio(**kwargs):
|
|
| 528 |
if sanitize_user_prompt:
|
| 529 |
from better_profanity import profanity
|
| 530 |
user_message1 = profanity.censor(user_message1)
|
|
|
|
|
|
|
| 531 |
if user_message1 in ['']:
|
| 532 |
# e.g. when user just hits enter in textbox,
|
| 533 |
# else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
|
|
@@ -559,7 +567,8 @@ def go_gradio(**kwargs):
|
|
| 559 |
:param retry:
|
| 560 |
:return:
|
| 561 |
"""
|
| 562 |
-
|
|
|
|
| 563 |
history = args_list[-1] # model_state is -2
|
| 564 |
if retry and history:
|
| 565 |
history.pop()
|
|
@@ -580,12 +589,18 @@ def go_gradio(**kwargs):
|
|
| 580 |
context1 = ''
|
| 581 |
for histi in range(len(history) - 1):
|
| 582 |
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
| 583 |
-
|
| 584 |
-
|
| 585 |
-
|
| 586 |
-
|
| 587 |
-
|
| 588 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 589 |
args_list[0] = instruction1 # override original instruction with history from user
|
| 590 |
# only include desired chat history
|
| 591 |
args_list[2] = context1[-kwargs['chat_history']:]
|
|
@@ -767,6 +782,7 @@ def go_gradio(**kwargs):
|
|
| 767 |
lora_weights = no_lora_str
|
| 768 |
return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
|
| 769 |
|
|
|
|
| 770 |
all_kwargs1 = all_kwargs.copy()
|
| 771 |
all_kwargs1['base_model'] = model_name.strip()
|
| 772 |
all_kwargs1['load_8bit'] = load_8bit
|
|
|
|
| 5 |
import sys
|
| 6 |
|
| 7 |
from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
|
| 8 |
+
from prompter import Prompter
|
| 9 |
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
|
| 10 |
ping
|
| 11 |
from finetune import prompt_type_to_model_name, prompt_types_strings, generate_prompt, inv_prompt_type_to_model_lower
|
|
|
|
| 50 |
"""
|
| 51 |
else:
|
| 52 |
description = "For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio)<br>"
|
| 53 |
+
description += "If this host is busy, try [gpt.h2o.ai 20B](https://gpt.h2o.ai) and [30B](http://gpu.hopto.org) and [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) and [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)<br>"
|
| 54 |
description += """<p>By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/tos.md)</p>"""
|
| 55 |
|
| 56 |
if kwargs['verbose']:
|
|
|
|
| 391 |
.then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False)
|
| 392 |
|
| 393 |
# Get inputs to evaluate()
|
| 394 |
+
# don't deepcopy, can contain model itself
|
| 395 |
all_kwargs = kwargs.copy()
|
| 396 |
all_kwargs.update(locals())
|
| 397 |
inputs_list = get_inputs_list(all_kwargs, kwargs['model_lower'])
|
|
|
|
| 519 |
:return:
|
| 520 |
"""
|
| 521 |
args_list = list(args)
|
| 522 |
+
user_message = args_list[eval_func_param_names.index('instruction')] # chat only
|
| 523 |
+
input1 = args_list[eval_func_param_names.index('iinput')] # chat only
|
| 524 |
+
context1 = args_list[eval_func_param_names.index('context')]
|
| 525 |
+
prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
|
| 526 |
+
chat1 = args_list[eval_func_param_names.index('chat')]
|
| 527 |
+
stream_output1 = args_list[eval_func_param_names.index('stream_output')]
|
| 528 |
if input1 and not user_message.endswith(':'):
|
| 529 |
user_message1 = user_message + ":" + input1
|
| 530 |
elif input1:
|
|
|
|
| 534 |
if sanitize_user_prompt:
|
| 535 |
from better_profanity import profanity
|
| 536 |
user_message1 = profanity.censor(user_message1)
|
| 537 |
+
# FIXME: WIP to use desired seperator when user enters nothing
|
| 538 |
+
prompter = Prompter(prompt_type1, debug=kwargs['debug'], chat=chat1, stream_output=stream_output1)
|
| 539 |
if user_message1 in ['']:
|
| 540 |
# e.g. when user just hits enter in textbox,
|
| 541 |
# else will have <human>: <bot>: on single line, which seems to be "ok" for LLM but not usual
|
|
|
|
| 567 |
:param retry:
|
| 568 |
:return:
|
| 569 |
"""
|
| 570 |
+
# don't deepcopy, can contain model itself
|
| 571 |
+
args_list = list(args).copy()
|
| 572 |
history = args_list[-1] # model_state is -2
|
| 573 |
if retry and history:
|
| 574 |
history.pop()
|
|
|
|
| 589 |
context1 = ''
|
| 590 |
for histi in range(len(history) - 1):
|
| 591 |
data_point = dict(instruction=history[histi][0], input='', output=history[histi][1])
|
| 592 |
+
prompt, pre_response, terminate_response, chat_sep = generate_prompt(data_point, prompt_type1,
|
| 593 |
+
chat1, reduced=True)
|
| 594 |
+
# md -> back to text, maybe not super improtant if model trained enough
|
| 595 |
+
prompt = prompt.replace('<br>', chat_sep)
|
| 596 |
+
context1 += prompt
|
| 597 |
+
if not context1.endswith(chat_sep):
|
| 598 |
+
context1 += chat_sep
|
| 599 |
+
|
| 600 |
+
_, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
|
| 601 |
+
reduced=True)
|
| 602 |
+
if context1 and not context1.endswith(chat_sep):
|
| 603 |
+
context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
|
| 604 |
args_list[0] = instruction1 # override original instruction with history from user
|
| 605 |
# only include desired chat history
|
| 606 |
args_list[2] = context1[-kwargs['chat_history']:]
|
|
|
|
| 782 |
lora_weights = no_lora_str
|
| 783 |
return [None, None, None, model_name], model_name, lora_weights, prompt_type_old
|
| 784 |
|
| 785 |
+
# don't deepcopy, can contain model itself
|
| 786 |
all_kwargs1 = all_kwargs.copy()
|
| 787 |
all_kwargs1['base_model'] = model_name.strip()
|
| 788 |
all_kwargs1['load_8bit'] = load_8bit
|
prompter.py
CHANGED
|
@@ -6,7 +6,8 @@ class Prompter(object):
|
|
| 6 |
allowed_repeat_line_length=10):
|
| 7 |
self.prompt_type = prompt_type
|
| 8 |
data_point = dict(instruction='', input='', output='')
|
| 9 |
-
_, self.pre_response, self.terminate_response
|
|
|
|
| 10 |
self.debug = debug
|
| 11 |
self.chat = chat
|
| 12 |
self.stream_output = stream_output
|
|
@@ -15,7 +16,7 @@ class Prompter(object):
|
|
| 15 |
|
| 16 |
def generate_prompt(self, data_point):
|
| 17 |
reduced = False
|
| 18 |
-
prompt, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
| 19 |
if self.debug:
|
| 20 |
print("prompt: ", prompt, flush=True)
|
| 21 |
self.prompt = prompt
|
|
@@ -25,12 +26,12 @@ class Prompter(object):
|
|
| 25 |
if isinstance(outputs, str):
|
| 26 |
outputs = [outputs]
|
| 27 |
if self.debug:
|
| 28 |
-
print("output
|
| 29 |
if prompt is not None:
|
| 30 |
self.prompt = prompt
|
| 31 |
|
| 32 |
def clean_response(response):
|
| 33 |
-
meaningless_words = ['<pad>', '</s>', '<|endoftext|>'
|
| 34 |
for word in meaningless_words:
|
| 35 |
response = response.replace(word, "")
|
| 36 |
if sanitize_bot_response:
|
|
@@ -103,5 +104,5 @@ class Prompter(object):
|
|
| 103 |
# join all outputs, only one extra new line between outputs
|
| 104 |
output = '\n'.join(outputs)
|
| 105 |
if self.debug:
|
| 106 |
-
print("outputclean
|
| 107 |
return output
|
|
|
|
| 6 |
allowed_repeat_line_length=10):
|
| 7 |
self.prompt_type = prompt_type
|
| 8 |
data_point = dict(instruction='', input='', output='')
|
| 9 |
+
_, self.pre_response, self.terminate_response, self.chat_sep = \
|
| 10 |
+
generate_prompt(data_point, prompt_type, chat, False)
|
| 11 |
self.debug = debug
|
| 12 |
self.chat = chat
|
| 13 |
self.stream_output = stream_output
|
|
|
|
| 16 |
|
| 17 |
def generate_prompt(self, data_point):
|
| 18 |
reduced = False
|
| 19 |
+
prompt, _, _, _ = generate_prompt(data_point, self.prompt_type, self.chat, reduced)
|
| 20 |
if self.debug:
|
| 21 |
print("prompt: ", prompt, flush=True)
|
| 22 |
self.prompt = prompt
|
|
|
|
| 26 |
if isinstance(outputs, str):
|
| 27 |
outputs = [outputs]
|
| 28 |
if self.debug:
|
| 29 |
+
print("output:\n", '\n\n'.join(outputs), flush=True)
|
| 30 |
if prompt is not None:
|
| 31 |
self.prompt = prompt
|
| 32 |
|
| 33 |
def clean_response(response):
|
| 34 |
+
meaningless_words = ['<pad>', '</s>', '<|endoftext|>']
|
| 35 |
for word in meaningless_words:
|
| 36 |
response = response.replace(word, "")
|
| 37 |
if sanitize_bot_response:
|
|
|
|
| 104 |
# join all outputs, only one extra new line between outputs
|
| 105 |
output = '\n'.join(outputs)
|
| 106 |
if self.debug:
|
| 107 |
+
print("outputclean:\n", '\n\n'.join(outputs), flush=True)
|
| 108 |
return output
|
requirements.txt
CHANGED
|
@@ -19,7 +19,7 @@ pandas==2.0.0
|
|
| 19 |
matplotlib==3.7.1
|
| 20 |
loralib==0.1.1
|
| 21 |
bitsandbytes==0.38.1
|
| 22 |
-
git+https://github.com/huggingface/peft.git@
|
| 23 |
transformers==4.28.1
|
| 24 |
tokenizers==0.13.3
|
| 25 |
APScheduler==3.10.1
|
|
|
|
| 19 |
matplotlib==3.7.1
|
| 20 |
loralib==0.1.1
|
| 21 |
bitsandbytes==0.38.1
|
| 22 |
+
git+https://github.com/huggingface/peft.git@098962fa6515f2e4fe83a757f5995d3ffbb1c373
|
| 23 |
transformers==4.28.1
|
| 24 |
tokenizers==0.13.3
|
| 25 |
APScheduler==3.10.1
|
stopping.py
CHANGED
|
@@ -1,10 +1,5 @@
|
|
| 1 |
-
import traceback
|
| 2 |
-
from queue import Queue
|
| 3 |
-
from threading import Thread
|
| 4 |
-
import collections.abc
|
| 5 |
-
|
| 6 |
import torch
|
| 7 |
-
from transformers import StoppingCriteria
|
| 8 |
|
| 9 |
|
| 10 |
class StoppingCriteriaSub(StoppingCriteria):
|
|
@@ -21,7 +16,55 @@ class StoppingCriteriaSub(StoppingCriteria):
|
|
| 21 |
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 22 |
self.num_stops[stopi] += 1
|
| 23 |
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
|
|
|
| 24 |
return True
|
| 25 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
| 26 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
| 27 |
return False
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
+
from transformers import StoppingCriteria, StoppingCriteriaList
|
| 3 |
|
| 4 |
|
| 5 |
class StoppingCriteriaSub(StoppingCriteria):
|
|
|
|
| 16 |
if torch.all((stop == input_ids[0][-len(stop):])).item():
|
| 17 |
self.num_stops[stopi] += 1
|
| 18 |
if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]:
|
| 19 |
+
# print("Stopped", flush=True)
|
| 20 |
return True
|
| 21 |
# print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True)
|
| 22 |
# print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True)
|
| 23 |
return False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def get_stopping(prompt_type, tokenizer, device, human='<human>:', bot="<bot>:"):
|
| 27 |
+
if prompt_type in ['human_bot', 'instruct_vicuna', 'instruct_with_end']:
|
| 28 |
+
if prompt_type == 'human_bot':
|
| 29 |
+
# encounters = [prompt.count(human) + 1, prompt.count(bot) + 1]
|
| 30 |
+
# stopping only starts once output is beyond prompt
|
| 31 |
+
# 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added
|
| 32 |
+
stop_words = [human, bot, '\n' + human, '\n' + bot]
|
| 33 |
+
encounters = [1, 2]
|
| 34 |
+
elif prompt_type == 'instruct_vicuna':
|
| 35 |
+
# even below is not enough, generic strings and many ways to encode
|
| 36 |
+
stop_words = [
|
| 37 |
+
'### Human:',
|
| 38 |
+
"""
|
| 39 |
+
### Human:""",
|
| 40 |
+
"""
|
| 41 |
+
### Human:
|
| 42 |
+
""",
|
| 43 |
+
'### Assistant:',
|
| 44 |
+
"""
|
| 45 |
+
### Assistant:""",
|
| 46 |
+
"""
|
| 47 |
+
### Assistant:
|
| 48 |
+
""",
|
| 49 |
+
]
|
| 50 |
+
encounters = [1, 2]
|
| 51 |
+
else:
|
| 52 |
+
# some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise
|
| 53 |
+
stop_words = ['### End']
|
| 54 |
+
encounters = [1]
|
| 55 |
+
stop_words_ids = [
|
| 56 |
+
tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words]
|
| 57 |
+
# handle single token case
|
| 58 |
+
stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids]
|
| 59 |
+
stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0]
|
| 60 |
+
# avoid padding in front of tokens
|
| 61 |
+
if tokenizer.pad_token:
|
| 62 |
+
stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids]
|
| 63 |
+
# handle fake \n added
|
| 64 |
+
stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)]
|
| 65 |
+
# build stopper
|
| 66 |
+
stopping_criteria = StoppingCriteriaList(
|
| 67 |
+
[StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device)])
|
| 68 |
+
else:
|
| 69 |
+
stopping_criteria = StoppingCriteriaList()
|
| 70 |
+
return stopping_criteria
|