diff --git a/client_test.py b/client_test.py index d9c45abce0ed25576e19e095fd9c542602b2c28d..2553b0f11cfdc32217f5ac425ebcf0ec9521a78c 100644 --- a/client_test.py +++ b/client_test.py @@ -48,6 +48,8 @@ import markdown # pip install markdown import pytest from bs4 import BeautifulSoup # pip install beautifulsoup4 +from enums import DocumentChoices + debug = False os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' @@ -62,7 +64,10 @@ def get_client(serialize=True): return client -def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_tokens=50, langchain_mode='Disabled'): +def get_args(prompt, prompt_type, chat=False, stream_output=False, + max_new_tokens=50, + top_k_docs=3, + langchain_mode='Disabled'): from collections import OrderedDict kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True iinput='', # only for chat=True @@ -71,6 +76,7 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token # but leave stream_output=False for simple input/output mode stream_output=stream_output, prompt_type=prompt_type, + prompt_dict='', temperature=0.1, top_p=0.75, top_k=40, @@ -86,9 +92,13 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token instruction_nochat=prompt if not chat else '', iinput_nochat='', # only for chat=False langchain_mode=langchain_mode, - top_k_docs=4, - document_choice=['All'], + top_k_docs=top_k_docs, + chunk=True, + chunk_size=512, + document_choice=[DocumentChoices.All_Relevant.name], ) + from generate import eval_func_param_names + assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0 if chat: # add chatbot output on end. Assumes serialize=False kwargs.update(dict(chatbot=[])) @@ -97,8 +107,8 @@ def get_args(prompt, prompt_type, chat=False, stream_output=False, max_new_token @pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_basic(): - return run_client_nochat(prompt='Who are you?', prompt_type='human_bot', max_new_tokens=50) +def test_client_basic(prompt_type='human_bot'): + return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) def run_client_nochat(prompt, prompt_type, max_new_tokens): @@ -112,15 +122,110 @@ def run_client_nochat(prompt, prompt_type, max_new_tokens): ) print("Raw client result: %s" % res, flush=True) res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], + response=md_to_text(res)) + print(res_dict) + return res_dict, client + + +@pytest.mark.skip(reason="For manual use against some server, no server launched") +def test_client_basic_api(prompt_type='human_bot'): + return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) + + +def run_client_nochat_api(prompt, prompt_type, max_new_tokens): + kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens) + + api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing + client = get_client(serialize=True) + res = client.predict( + str(dict(kwargs)), + api_name=api_name, + ) + print("Raw client result: %s" % res, flush=True) + res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], + response=md_to_text(ast.literal_eval(res)['response']), + sources=ast.literal_eval(res)['sources']) + print(res_dict) + return res_dict, client + + +@pytest.mark.skip(reason="For manual use against some server, no server launched") +def test_client_basic_api_lean(prompt_type='human_bot'): + return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) + + +def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens): + kwargs = dict(instruction_nochat=prompt) + + api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing + client = get_client(serialize=True) + res = client.predict( + str(dict(kwargs)), + api_name=api_name, + ) + print("Raw client result: %s" % res, flush=True) + res_dict = dict(prompt=kwargs['instruction_nochat'], + response=md_to_text(ast.literal_eval(res)['response']), + sources=ast.literal_eval(res)['sources']) + print(res_dict) + return res_dict, client + + +@pytest.mark.skip(reason="For manual use against some server, no server launched") +def test_client_basic_api_lean_morestuff(prompt_type='human_bot'): + return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) + + +def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512): + kwargs = dict( + instruction='', + iinput='', + context='', + stream_output=False, + prompt_type=prompt_type, + temperature=0.1, + top_p=0.75, + top_k=40, + num_beams=1, + max_new_tokens=256, + min_new_tokens=0, + early_stopping=False, + max_time=20, + repetition_penalty=1.0, + num_return_sequences=1, + do_sample=True, + chat=False, + instruction_nochat=prompt, + iinput_nochat='', + langchain_mode='Disabled', + top_k_docs=4, + document_choice=['All'], + ) + + api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing + client = get_client(serialize=True) + res = client.predict( + str(dict(kwargs)), + api_name=api_name, + ) + print("Raw client result: %s" % res, flush=True) + res_dict = dict(prompt=kwargs['instruction_nochat'], response=md_to_text(ast.literal_eval(res)['response']), sources=ast.literal_eval(res)['sources']) print(res_dict) - return res_dict + return res_dict, client + + +@pytest.mark.skip(reason="For manual use against some server, no server launched") +def test_client_chat(prompt_type='human_bot'): + return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50, + langchain_mode='Disabled') @pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_chat(): - return run_client_chat(prompt='Who are you?', prompt_type='human_bot', stream_output=False, max_new_tokens=50, +def test_client_chat_stream(prompt_type='human_bot'): + return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, + stream_output=True, max_new_tokens=512, langchain_mode='Disabled') @@ -133,6 +238,7 @@ def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, langchai def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): + assert kwargs['chat'], "Chat mode only" res = client.predict(*tuple(args), api_name='/instruction') args[-1] += [res[-1]] @@ -166,6 +272,46 @@ def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): return res_dict, client +@pytest.mark.skip(reason="For manual use against some server, no server launched") +def test_client_nochat_stream(prompt_type='human_bot'): + return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, + stream_output=True, max_new_tokens=512, + langchain_mode='Disabled') + + +def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, langchain_mode): + client = get_client(serialize=False) + + kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output, + max_new_tokens=max_new_tokens, langchain_mode=langchain_mode) + return run_client_gen(client, prompt, args, kwargs) + + +def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): + res_dict = kwargs + res_dict['prompt'] = prompt + if not kwargs['stream_output']: + res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api') + res_dict['response'] = res[0] + print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) + return res_dict, client + else: + job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api') + while not job.done(): + outputs_list = job.communicator.job.outputs + if outputs_list: + res = job.communicator.job.outputs[-1] + res_dict = ast.literal_eval(res) + print('Stream: %s' % res_dict['response']) + time.sleep(0.1) + res_list = job.outputs() + assert len(res_list) > 0, "No response, check server" + res = res_list[-1] + res_dict = ast.literal_eval(res) + print('Final: %s' % res_dict['response']) + return res_dict, client + + def md_to_text(md, do_md_to_text=True): if not do_md_to_text: return md @@ -175,5 +321,16 @@ def md_to_text(md, do_md_to_text=True): return soup.get_text() +def run_client_many(prompt_type='human_bot'): + ret1, _ = test_client_chat(prompt_type=prompt_type) + ret2, _ = test_client_chat_stream(prompt_type=prompt_type) + ret3, _ = test_client_nochat_stream(prompt_type=prompt_type) + ret4, _ = test_client_basic(prompt_type=prompt_type) + ret5, _ = test_client_basic_api(prompt_type=prompt_type) + ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type) + ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type) + return ret1, ret2, ret3, ret4, ret5, ret6, ret7 + + if __name__ == '__main__': - test_client_basic() + run_client_many() diff --git a/create_data.py b/create_data.py index cb7385c2bf583278bf436cf9cf6ccc6ac6aac984..f16c519dcdd6b07dfd09f824e670401887f6eeaa 100644 --- a/create_data.py +++ b/create_data.py @@ -567,7 +567,7 @@ def test_show_prompts(): from prompter import generate_prompt for data_points in file_points: for data_point in data_points: - print(generate_prompt(data_point, 'plain', False, False)[0]) + print(generate_prompt(data_point, 'plain', '', False, False, False)[0]) def test_get_open_datasets(): @@ -1571,7 +1571,7 @@ def test_check_stats_data(): llama_type = False tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b' - model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False) + model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type) local_files_only = False resume_download = True use_auth_token = False diff --git a/enums.py b/enums.py new file mode 100644 index 0000000000000000000000000000000000000000..f0a7dc3a94c55793ee18ddb42219a84a30e43796 --- /dev/null +++ b/enums.py @@ -0,0 +1,84 @@ +from enum import Enum + + +class PromptType(Enum): + custom = -1 + plain = 0 + instruct = 1 + quality = 2 + human_bot = 3 + dai_faq = 4 + summarize = 5 + simple_instruct = 6 + instruct_vicuna = 7 + instruct_with_end = 8 + human_bot_orig = 9 + prompt_answer = 10 + open_assistant = 11 + wizard_lm = 12 + wizard_mega = 13 + instruct_vicuna2 = 14 + instruct_vicuna3 = 15 + wizard2 = 16 + wizard3 = 17 + instruct_simple = 18 + wizard_vicuna = 19 + openai = 20 + openai_chat = 21 + gptj = 22 + prompt_answer_openllama = 23 + vicuna11 = 24 + + +class DocumentChoices(Enum): + All_Relevant = 0 + All_Relevant_Only_Sources = 1 + Only_All_Sources = 2 + Just_LLM = 3 + + +class LangChainMode(Enum): + """LangChain mode""" + + DISABLED = "Disabled" + CHAT_LLM = "ChatLLM" + LLM = "LLM" + ALL = "All" + WIKI = "wiki" + WIKI_FULL = "wiki_full" + USER_DATA = "UserData" + MY_DATA = "MyData" + GITHUB_H2OGPT = "github h2oGPT" + H2O_DAI_DOCS = "DriverlessAI docs" + + +no_server_str = no_lora_str = no_model_str = '[None/Remove]' + + +# from site-packages/langchain/llms/openai.py, but needed since ChatOpenAI doesn't have this information +model_token_mapping = { + "gpt-4": 8192, + "gpt-4-0314": 8192, + "gpt-4-32k": 32768, + "gpt-4-32k-0314": 32768, + "gpt-3.5-turbo": 4096, + "gpt-3.5-turbo-16k": 16*1024, + "gpt-3.5-turbo-0301": 4096, + "text-ada-001": 2049, + "ada": 2049, + "text-babbage-001": 2040, + "babbage": 2049, + "text-curie-001": 2049, + "curie": 2049, + "davinci": 2049, + "text-davinci-003": 4097, + "text-davinci-002": 4097, + "code-davinci-002": 8001, + "code-davinci-001": 8001, + "code-cushman-002": 2048, + "code-cushman-001": 2048, +} + + +source_prefix = "Sources [Score | Link]:" +source_postfix = "End Sources
" diff --git a/finetune.py b/finetune.py index 2724198bed4896261f352a8407403a18998f3721..88417a2d068fe32a809077ac9f25bccd50fdf880 100644 --- a/finetune.py +++ b/finetune.py @@ -5,8 +5,11 @@ from typing import List, Union import fire import numpy as np +if os.path.dirname(os.path.abspath(__file__)) not in sys.path: + sys.path.append(os.path.dirname(os.path.abspath(__file__))) + from loaders import get_loaders, get_tokenizer -from prompter import generate_prompt, prompt_types +from prompter import generate_prompt, prompt_types, PromptType from utils import get_githash, copy_code import torch @@ -104,7 +107,6 @@ def train( save_total_limit: int = 3, add_eos_token: bool = False, ): - if llama_flash_attn: # Need to call this before importing transformers. from llama_flash_attn_monkey_patch import replace_llama_attn_with_flash_attn @@ -129,10 +131,12 @@ def train( if not output_dir: output_dir = f"{base_model.split('/')[-1]}.{data_path.replace('/', '')}.{num_epochs}_epochs.{get_githash() or 'nogit'}.{run_id}" if os.path.exists(output_dir) and not resume_from_checkpoint: - raise FileExistsError(f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.") + raise FileExistsError( + f"output_dir {output_dir} based on run_id {run_id} already exists. Please pick a different run_id.") else: if os.path.exists(output_dir) and not resume_from_checkpoint: - raise FileExistsError(f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.") + raise FileExistsError( + f"output_dir {output_dir} already exists. Please pick a different output_dir, or specify a run_id instead.") device_map = "auto" if save_code: @@ -181,7 +185,7 @@ def train( log("num_gpus: %d" % gpus) log("max mem: %s" % max_memory) - model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=False) + model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type) model = model_loader.from_pretrained( base_model, @@ -398,7 +402,8 @@ def train( if train_data_mix_in: train_data = concatenate_datasets([train_data, train_data_mix_in]) log("Tokenizing %s training rows" % train_data.num_rows) - train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count()) + train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, + num_proc=os.cpu_count() // torch.cuda.device_count()) if drop_truncations: log("avoid keeping truncated cases to avoid contaminating model with truncation cases. Original size: %s" % train_data.num_rows) prune_long_sequences_func = partial(prune_long_sequences, cutoff_len=cutoff_len) @@ -413,7 +418,8 @@ def train( if valid_data: log("Tokenizing %s validation rows" % valid_data.num_rows) - valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count() // torch.cuda.device_count()) + valid_data = valid_data.shuffle().map(generate_and_tokenize_prompt_fun, + num_proc=os.cpu_count() // torch.cuda.device_count()) val_set_size = len(valid_data) else: val_set_size = 0 @@ -468,7 +474,7 @@ def train( elif save_steps > eval_steps: # save steps must be round multiple of eval_steps save_steps0 = save_steps - save_steps = max(1, (save_steps//eval_steps)) * eval_steps + save_steps = max(1, (save_steps // eval_steps)) * eval_steps if save_steps0 != save_steps: log("Auto converted save_steps from %s to %s" % (save_steps0, save_steps)) @@ -478,21 +484,21 @@ def train( label_ids = eval_preds.label_ids predictions = eval_preds.predictions - #inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id) - #decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True) - #decoded_inputs = [pred.strip() for pred in decoded_inputs] + # inputs = np.where(inputs != -100, inputs, tokenizer.pad_token_id) + # decoded_inputs = tokenizer.batch_decode(inputs, skip_special_tokens=True) + # decoded_inputs = [pred.strip() for pred in decoded_inputs] label_ids = np.where(label_ids != -100, label_ids, tokenizer.pad_token_id) # tokenizer behavior like generate time decoded_labels = tokenizer.batch_decode(label_ids, skip_special_tokens=True, - clean_up_tokenization_spaces=True) + clean_up_tokenization_spaces=True) decoded_labels = [pred.strip() for pred in decoded_labels] predictions = np.argmax(predictions, -1) predictions = np.where(predictions != -100, predictions, tokenizer.pad_token_id) # tokenizer behavior like generate time decoded_predictions = tokenizer.batch_decode(predictions, skip_special_tokens=True, - clean_up_tokenization_spaces=True) + clean_up_tokenization_spaces=True) decoded_predictions = [pred.strip() for pred in decoded_predictions] result = {} @@ -541,8 +547,8 @@ def train( load_best_model_at_end=True if val_set_size > 0 else False, ddp_find_unused_parameters=False if ddp else None, group_by_length=group_by_length, - #fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None, - #fsdp_min_num_params=20000 if gpus > 1 and not ddp else None, + # fsdp="shard_grad_op auto_wrap" if gpus > 1 and not ddp else None, + # fsdp_min_num_params=20000 if gpus > 1 and not ddp else None, report_to='tensorboard' if not neptune_run else 'neptune', ), data_collator=transformers.DataCollatorForSeq2Seq( @@ -553,13 +559,6 @@ def train( ) model.config.use_cache = False - old_state_dict = model.state_dict - from peft import get_peft_model_state_dict - - model.state_dict = ( - lambda self, *_, **__: get_peft_model_state_dict(self, old_state_dict()) - ).__get__(model, type(model)) - if torch.__version__ >= "2" and sys.platform != "win32": model = torch.compile(model) # WIP (not generally replacing layers until pytorch 2.1) @@ -616,10 +615,12 @@ def generate_and_tokenize_prompt(data_point, prompt_type=None, train_on_inputs=F assert prompt_type is not None assert cutoff_len is not None assert tokenizer is not None - full_prompt, _, _, _ = generate_prompt(data_point, prompt_type, False, False) + prompt_dict = '' # only for custom prompt_type + assert prompt_type != PromptType.custom.name, "custom not setup for finetune" + full_prompt, _, _, _, _ = generate_prompt(data_point, prompt_type, prompt_dict, False, False, False) tokenized_full_prompt = tokenize(full_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token) if not train_on_inputs: - user_prompt, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, False, False) + user_prompt, _, _, _, _ = generate_prompt({**data_point, "output": ""}, prompt_type, prompt_dict, False, False, False) tokenized_user_prompt = tokenize(user_prompt, tokenizer, cutoff_len, add_eos_token=add_eos_token) user_prompt_len = len(tokenized_user_prompt["input_ids"]) if add_eos_token: @@ -638,7 +639,7 @@ def test_debug(): fire.Fire(train) -if __name__ == "__main__": +def entrypoint_main(): CONFIG = "NCCL_P2P_LEVEL=LOC WORLD_SIZE=5 torchrun --nnodes=5 --master_addr=10.10.10.2 --master_port=1111 --nproc_per_node=1" CMD = "finetune.py --data_path=config.json --num_epochs=1 --base_model=decapoda-research/llama-13b-hf" log(f""" @@ -665,6 +666,11 @@ NCCL_P2P_LEVEL=LOC WORLD_SIZE=7 CUDA_VISIBLE_DEVICES="0,1" torchrun --node_rank if os.environ.get("LOCAL_RANK") is None: # then not using torchrun, so can't do distributed, ensure CVD set - assert os.environ.get("CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU" + assert os.environ.get( + "CUDA_VISIBLE_DEVICES") is not None, "Run python script using: torchrun finetune.py OR set CUDA_VISIBLE_DEVICES to single GPU" fire.Fire(train) + + +if __name__ == "__main__": + entrypoint_main() diff --git a/generate.py b/generate.py index e69ac534fe413b4f645387d392f0fe153d82dbf6..09e71f61a614d5e9fba2b96b30150a5fb05ad577 100644 --- a/generate.py +++ b/generate.py @@ -1,27 +1,39 @@ import ast +import copy import functools import glob import inspect import queue -import shutil import sys import os import time import traceback +import types import typing import warnings from datetime import datetime import filelock +import requests import psutil +from requests import ConnectTimeout, JSONDecodeError +from urllib3.exceptions import ConnectTimeoutError, MaxRetryError, ConnectionError +from requests.exceptions import ConnectionError as ConnectionError2 +from requests.exceptions import ReadTimeout as ReadTimeout2 + +if os.path.dirname(os.path.abspath(__file__)) not in sys.path: + sys.path.append(os.path.dirname(os.path.abspath(__file__))) os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' os.environ['BITSANDBYTES_NOWELCOME'] = '1' warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') +from enums import DocumentChoices, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \ + source_postfix from loaders import get_loaders from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \ - import_matplotlib, get_device, makedirs, get_kwargs + import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove +start_faulthandler() import_matplotlib() SEED = 1236 @@ -31,17 +43,14 @@ from typing import Union import fire import torch -from peft import PeftModel from transformers import GenerationConfig, AutoModel, TextIteratorStreamer -from accelerate import init_empty_weights, infer_auto_device_map -from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types +from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt from stopping import get_stopping eval_extra_columns = ['prompt', 'response', 'score'] -langchain_modes = ['Disabled', 'ChatLLM', 'LLM', 'All', 'wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', - 'DriverlessAI docs'] +langchain_modes = [x.value for x in list(LangChainMode)] scratch_base_dir = '/tmp/' @@ -56,8 +65,15 @@ def main( lora_weights: str = "", gpu_id: int = 0, compile_model: bool = True, - + use_cache: bool = None, + inference_server: str = "", prompt_type: Union[int, str] = None, + prompt_dict: typing.Dict = None, + + model_lock: typing.List[typing.Dict[str, str]] = None, + model_lock_columns: int = None, + fail_if_cannot_connect: bool = False, + # input to generation temperature: float = None, top_p: float = None, @@ -87,14 +103,13 @@ def main( cli: bool = False, cli_loop: bool = True, gradio: bool = True, - gradio_avoid_processing_markdown: bool = False, gradio_offline_level: int = 0, chat: bool = True, chat_context: bool = False, stream_output: bool = True, show_examples: bool = None, verbose: bool = False, - h2ocolors: bool = False, + h2ocolors: bool = True, height: int = 600, show_lora: bool = True, login_mode_if_model0: bool = False, @@ -103,16 +118,19 @@ def main( api_open: bool = False, allow_api: bool = True, input_lines: int = 1, + gradio_size: str = None, auth: typing.List[typing.Tuple[str, str]] = None, + max_max_time=None, + max_max_new_tokens=None, - sanitize_user_prompt: bool = True, - sanitize_bot_response: bool = True, + sanitize_user_prompt: bool = False, + sanitize_bot_response: bool = False, extra_model_options: typing.List[str] = [], extra_lora_options: typing.List[str] = [], + extra_server_options: typing.List[str] = [], score_model: str = 'OpenAssistant/reward-model-deberta-v3-large-v2', - auto_score: bool = True, eval_filename: str = None, eval_prompts_only_num: int = 0, @@ -120,8 +138,9 @@ def main( eval_as_output: bool = False, langchain_mode: str = 'Disabled', + force_langchain_evaluate: bool = False, visible_langchain_modes: list = ['UserData', 'MyData'], - document_choice: list = ['All'], + document_choice: list = [DocumentChoices.All_Relevant.name], user_path: str = None, detect_user_path_changes_every_query: bool = False, load_db_if_exists: bool = True, @@ -129,7 +148,7 @@ def main( db_type: str = 'chroma', use_openai_embedding: bool = False, use_openai_model: bool = False, - hf_embedding_model: str = "sentence-transformers/all-MiniLM-L6-v2", + hf_embedding_model: str = None, allow_upload_to_user_data: bool = True, allow_upload_to_my_data: bool = True, enable_url_upload: bool = True, @@ -137,7 +156,10 @@ def main( enable_sources_list: bool = True, chunk: bool = True, chunk_size: int = 512, - top_k_docs: int = 3, # FIXME: Can go back to 4 once https://github.com/h2oai/h2ogpt/issues/192 fixed + top_k_docs: int = None, + reverse_docs: bool = True, + auto_reduce_chunks: bool = True, + max_chunks: int = 100, n_jobs: int = -1, enable_captions: bool = True, captions_model: str = "Salesforce/blip-image-captioning-base", @@ -156,7 +178,31 @@ def main( :param lora_weights: LORA weights path/HF link :param gpu_id: if infer_devices, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1 :param compile_model Whether to compile the model + :param use_cache: Whether to use caching in model (some models fail when multiple threads use) + :param inference_server: Consume base_model as type of model at this address + Address can be text-generation-server hosting that base_model + e.g. python generate.py --inference_server="http://192.168.1.46:6112" --base_model=h2oai/h2ogpt-oasst1-512-12b + Or Address can be "openai_chat" or "openai" for OpenAI API + e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo + e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003 :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model + :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True) + :param model_lock: Lock models to specific combinations, for ease of use and extending to many models + Only used if gradio = True + List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict + If all models have same prompt_type, and prompt_dict, can still specify that once in CLI outside model_lock as default for dict + Can specify model_lock instead of those items on CLI + As with CLI itself, base_model can infer prompt_type and prompt_dict if in prompter.py. + Also, tokenizer_base_model and lora_weights are optional. + Also, inference_server is optional if loading model from local system. + All models provided will automatically appear in compare model mode + Model loading-unloading and related choices will be disabled. Model/lora/server adding will be disabled + :param model_lock_columns: How many columns to show if locking models (and so showing all at once) + If None, then defaults to up to 3 + if -1, then all goes into 1 row + Maximum value is 4 due to non-dynamic gradio rendering elements + :param fail_if_cannot_connect: if doing model locking (e.g. with many models), fail if True. Otherwise ignore. + Useful when many endpoints and want to just see what works, but still have to wait for timeout. :param temperature: generation temperature :param top_p: generation top_p :param top_k: generation top_k @@ -182,13 +228,13 @@ def main( :param cli: whether to use CLI (non-gradio) interface. :param cli_loop: whether to loop for CLI (False usually only for testing) :param gradio: whether to enable gradio, or to enable benchmark mode - :param gradio_avoid_processing_markdown: :param gradio_offline_level: > 0, then change fonts so full offline == 1 means backend won't need internet for fonts, but front-end UI might if font not cached == 2 means backend and frontend don't need internet to download any fonts. Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading. This option further disables google fonts for downloading, which is less intrusive than uploading, but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior. + Also set --share=False to avoid sharing a gradio live link. :param chat: whether to enable chat mode with chat history :param chat_context: whether to use extra helpful context if human_bot :param stream_output: whether to stream output from generate @@ -203,20 +249,25 @@ def main( :param api_open: If False, don't let API calls skip gradio queue :param allow_api: whether to allow API calls at all to gradio server :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit) + :param gradio_size: Overall size of text and spaces: "xsmall", "small", "medium", "large". + Small useful for many chatbots in model_lock mode :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...] e.g. --auth=[('jon','password')] with no spaces - :param sanitize_user_prompt: whether to remove profanity from user input - :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output + :param max_max_time: Maximum max_time for gradio slider + :param max_max_new_tokens: Maximum max_new_tokens for gradio slider + :param sanitize_user_prompt: whether to remove profanity from user input (slows down input processing) + :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output (about 2x slower generation for long streaming cases due to better_profanity being slow) :param extra_model_options: extra models to show in list in gradio :param extra_lora_options: extra LORA to show in list in gradio + :param extra_server_options: extra servers to show in list in gradio :param score_model: which model to score responses (None means no scoring) - :param auto_score: whether to automatically score responses :param eval_filename: json file to use for evaluation, if None is sharegpt :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py. WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present. + :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing. :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode. If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes). @@ -234,6 +285,10 @@ def main( :param use_openai_embedding: Whether to use OpenAI embeddings for vector db :param use_openai_model: Whether to use OpenAI model for use with vector db :param hf_embedding_model: Which HF embedding model to use for vector db + Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v1 if no GPUs + Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2" + Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl' + We support automatically changing of embeddings for chroma, with a backup of db made if this is done :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db :param enable_url_upload: Whether to allow upload from URL @@ -242,12 +297,17 @@ def main( :param chunk: Whether to chunk data (True unless know data is already optimally chunked) :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so neesd to be in context length :param top_k_docs: number of chunks to give LLM + :param reverse_docs: whether to reverse docs order so most relevant is closest to question. + Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too. + But smaller 6_9 models fail to use newest context and can get stuck on old information. + :param auto_reduce_chunks: Whether to automatically reduce top_k_docs to fit context given prompt + :param max_chunks: If top_k_docs=-1, maximum number of chunks to allow :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default) :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model :param captions_model: Which model to use for captions. - captions_model: int = "Salesforce/blip-image-captioning-base", # continue capable + captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state - captions_model: int = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state + captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader parallel loading disabled if preload and have images, to prevent deadlocking on cuda context @@ -256,8 +316,34 @@ def main( :param enable_ocr: Whether to support OCR on images :return: """ - is_hf = bool(os.getenv("HUGGINGFACE_SPACES")) - is_gpth2oai = bool(os.getenv("GPT_H2O_AI")) + if base_model is None: + base_model = '' + if tokenizer_base_model is None: + tokenizer_base_model = '' + if lora_weights is None: + lora_weights = '' + if inference_server is None: + inference_server = '' + + # listen to env if set + model_lock = os.getenv('model_lock', str(model_lock)) + model_lock = ast.literal_eval(model_lock) + + if model_lock: + assert gradio, "model_lock only supported for gradio=True" + if len(model_lock) > 1: + assert chat, "model_lock only works for multiple models for chat=True" + assert not cli, "model_lock only supported for cli=False" + assert not (not cli and not gradio), "model_lock only supported for eval (cli=gradio=False)" + assert not base_model, "Don't specify model_lock and base_model" + assert not tokenizer_base_model, "Don't specify model_lock and tokenizer_base_model" + assert not lora_weights, "Don't specify model_lock and lora_weights" + assert not inference_server, "Don't specify model_lock and inference_server" + # assert not prompt_type, "Don't specify model_lock and prompt_type" + # assert not prompt_dict, "Don't specify model_lock and prompt_dict" + + is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0'))) + is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0'))) is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer if memory_restriction_level is None: memory_restriction_level = 2 if is_hf else 0 # 2 assumes run on 24GB consumer GPU @@ -270,9 +356,11 @@ def main( # allow set token directly use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token) - allow_upload_to_user_data = bool(os.environ.get("allow_upload_to_user_data", allow_upload_to_user_data)) - allow_upload_to_my_data = bool(os.environ.get("allow_upload_to_my_data", allow_upload_to_my_data)) - height = os.environ.get("HEIGHT", height) + allow_upload_to_user_data = bool( + int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data))))) + allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data))))) + height = int(os.environ.get("HEIGHT", height)) + h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors))) # allow enabling langchain via ENV # FIRST PLACE where LangChain referenced, but no imports related to it @@ -282,6 +370,12 @@ def main( if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes: visible_langchain_modes += [langchain_mode] + # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler + if LangChainMode.MY_DATA.value not in visible_langchain_modes: + allow_upload_to_my_data = False + if LangChainMode.USER_DATA.value not in visible_langchain_modes: + allow_upload_to_user_data = False + if is_public: allow_upload_to_user_data = False input_lines = 1 # ensure set, for ease of use @@ -290,31 +384,57 @@ def main( top_k = 70 if top_k is None else top_k if is_hf: do_sample = True if do_sample is None else do_sample + top_k_docs = 3 if top_k_docs is None else top_k_docs else: # by default don't sample, too chatty do_sample = False if do_sample is None else do_sample + top_k_docs = 4 if top_k_docs is None else top_k_docs if memory_restriction_level == 2: - if not base_model: + if not base_model and not inference_server: base_model = 'h2oai/h2ogpt-oasst1-512-12b' # don't set load_8bit if passed base_model, doesn't always work so can't just override load_8bit = True load_4bit = False # FIXME - consider using 4-bit instead of 8-bit - else: - base_model = 'h2oai/h2ogpt-oasst1-512-20b' if not base_model else base_model + elif not inference_server: + top_k_docs = 10 if top_k_docs is None else top_k_docs if memory_restriction_level >= 2: load_8bit = True load_4bit = False # FIXME - consider using 4-bit instead of 8-bit + if hf_embedding_model is None: + hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" + top_k_docs = 3 if top_k_docs is None else top_k_docs + if top_k_docs is None: + top_k_docs = 3 + if is_public: + if not max_time: + max_time = 60 * 2 + if not max_max_time: + max_max_time = max_time + if not max_new_tokens: + max_new_tokens = 256 + if not max_max_new_tokens: + max_max_new_tokens = 256 + else: + if not max_max_time: + max_max_time = 60 * 20 + if not max_max_new_tokens: + max_max_new_tokens = 512 if is_hf: # must override share if in spaces share = False + if not max_time: + max_time = 60 * 1 + if not max_max_time: + max_max_time = max_time + # HF accounted for later in get_max_max_new_tokens() save_dir = os.getenv('SAVE_DIR', save_dir) score_model = os.getenv('SCORE_MODEL', score_model) if score_model == 'None' or score_model is None: score_model = '' concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count)) - api_open = bool(int(os.getenv('API_OPEN', api_open))) - allow_api = bool(int(os.getenv('ALLOW_API', allow_api))) + api_open = bool(int(os.getenv('API_OPEN', str(int(api_open))))) + allow_api = bool(int(os.getenv('ALLOW_API', str(int(allow_api))))) n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 if n_gpus == 0: @@ -326,10 +446,17 @@ def main( torch.backends.cudnn.benchmark = True torch.backends.cudnn.enabled = False torch.set_default_dtype(torch.float32) - if psutil.virtual_memory().available < 94 * 1024 ** 3: + if psutil.virtual_memory().available < 94 * 1024 ** 3 and not inference_server: # 12B uses ~94GB # 6.9B uses ~47GB base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model + if hf_embedding_model is None: + # if no GPUs, use simpler embedding model to avoid cost in time + hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" + else: + if hf_embedding_model is None: + # if still None, then set default + hf_embedding_model = 'hkunlp/instructor-large' # get defaults model_lower = base_model.lower() @@ -344,10 +471,13 @@ def main( if offload_folder: makedirs(offload_folder) + if user_path: + makedirs(user_path) placeholder_instruction, placeholder_input, \ stream_output, show_examples, \ - prompt_type, temperature, top_p, top_k, num_beams, \ + prompt_type, prompt_dict, \ + temperature, top_p, top_k, num_beams, \ max_new_tokens, min_new_tokens, early_stopping, max_time, \ repetition_penalty, num_return_sequences, \ do_sample, \ @@ -356,19 +486,23 @@ def main( task_info = \ get_generate_params(model_lower, chat, stream_output, show_examples, - prompt_type, temperature, top_p, top_k, num_beams, + prompt_type, prompt_dict, + temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample, top_k_docs, + chunk, + chunk_size, verbose, ) + git_hash = get_githash() locals_dict = locals() locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()]) if verbose: print(f"Generating model with params:\n{locals_print}", flush=True) - print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), get_githash()), flush=True) + print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), git_hash), flush=True) if langchain_mode != "Disabled": # SECOND PLACE where LangChain referenced, but all imports are kept local so not required @@ -382,18 +516,22 @@ def main( for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)): if os.path.isdir(gpath1): print("Removing old MyData: %s" % gpath1, flush=True) - shutil.rmtree(gpath1) + remove(gpath1) continue if langchain_mode1 in ['All']: # FIXME: All should be avoided until scans over each db, shouldn't be separate db continue persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case - db = prep_langchain(persist_directory1, - load_db_if_exists, - db_type, use_openai_embedding, - langchain_mode1, user_path, - hf_embedding_model, - kwargs_make_db=locals()) + try: + db = prep_langchain(persist_directory1, + load_db_if_exists, + db_type, use_openai_embedding, + langchain_mode1, user_path, + hf_embedding_model, + kwargs_make_db=locals()) + finally: + # in case updated embeddings or created new embeddings + clear_torch_cache() dbs[langchain_mode1] = db # remove None db's so can just rely upon k in dbs for if hav db dbs = {k: v for k, v in dbs.items() if v is not None} @@ -404,6 +542,10 @@ def main( assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" + model_state_none = dict(model=None, tokenizer=None, device=None, + base_model=None, tokenizer_base_model=None, lora_weights=None, + inference_server=None, prompt_type=None, prompt_dict=None) + if cli: from cli import run_cli return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals())) @@ -415,20 +557,68 @@ def main( from gradio_runner import go_gradio # get default model - all_kwargs = locals().copy() - if all_kwargs.get('base_model') and not all_kwargs['login_mode_if_model0']: - model0, tokenizer0, device = get_model(reward_type=False, - **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs)) - else: - # if empty model, then don't load anything, just get gradio up - model0, tokenizer0, device = None, None, None - model_state0 = [model0, tokenizer0, device, all_kwargs['base_model']] + model_states = [] + model_list = [dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights, + inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict)] + model_list0 = copy.deepcopy(model_list) # just strings, safe to deepcopy + model_state0 = model_state_none.copy() + assert len(model_state_none) == len(model_state0) + if model_lock: + model_list = model_lock + for model_dict in reversed(model_list): + # do reverse, so first is default base_model etc., so some logic works in go_gradio() more easily + # handles defaults user didn't have to pass + model_dict['base_model'] = base_model = model_dict.get('base_model', '') + model_dict['tokenizer_base_model'] = tokenizer_base_model = model_dict.get('tokenizer_base_model', '') + model_dict['lora_weights'] = lora_weights = model_dict.get('lora_weights', '') + model_dict['inference_server'] = inference_server = model_dict.get('inference_server', '') + prompt_type = model_dict.get('prompt_type', model_list0[0]['prompt_type']) # don't use mutated value + # try to infer, ignore empty initial state leading to get_generate_params -> 'plain' + if model_dict.get('prompt_type') is None: + model_lower = base_model.lower() + if model_lower in inv_prompt_type_to_model_lower: + prompt_type = inv_prompt_type_to_model_lower[model_lower] + prompt_dict, error0 = get_prompt(prompt_type, '', + chat=False, context='', reduced=False, making_context=False, + return_dict=True) + model_dict['prompt_type'] = prompt_type + model_dict['prompt_dict'] = prompt_dict = model_dict.get('prompt_dict', prompt_dict) + all_kwargs = locals().copy() + if base_model and not login_mode_if_model0: + model0, tokenizer0, device = get_model(reward_type=False, + **get_kwargs(get_model, exclude_names=['reward_type'], + **all_kwargs)) + else: + # if empty model, then don't load anything, just get gradio up + model0, tokenizer0, device = None, None, None + if model0 is None: + if fail_if_cannot_connect: + raise RuntimeError("Could not connect, see logs") + # skip + if isinstance(model_lock, list): + model_lock.remove(model_dict) + continue + model_state_trial = dict(model=model0, tokenizer=tokenizer0, device=device) + model_state_trial.update(model_dict) + assert len(model_state_none) == len(model_state_trial) + print("Model %s" % model_dict, flush=True) + if model_lock: + # last in iteration will be first + model_states.insert(0, model_state_trial) + # fill model_state0 so go_gradio() easier, manage model_states separately + model_state0 = model_state_trial.copy() + else: + model_state0 = model_state_trial.copy() + assert len(model_state_none) == len(model_state0) # get score model + all_kwargs = locals().copy() smodel, stokenizer, sdevice = get_score_model(reward_type=True, **get_kwargs(get_score_model, exclude_names=['reward_type'], **all_kwargs)) - score_model_state0 = [smodel, stokenizer, sdevice, score_model] + score_model_state0 = dict(model=smodel, tokenizer=stokenizer, device=sdevice, + base_model=score_model, tokenizer_base_model='', lora_weights='', + inference_server='', prompt_type='', prompt_dict='') if enable_captions: if pre_load_caption_model: @@ -443,34 +633,33 @@ def main( go_gradio(**locals()) -def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type, - gpu_id=0, - use_auth_token=False, - trust_remote_code=True, - offload_folder=None, - triton_attn=False, - long_sequence=True, - ): - """ - Ensure model gets on correct device - :param base_model: - :param model_loader: - :param load_half: - :param model_kwargs: - :param reward_type: - :param gpu_id: - :param use_auth_token: - :param trust_remote_code: - :param offload_folder: - :param triton_attn: - :param long_sequence: - :return: - """ +def get_config(base_model, + use_auth_token=False, + trust_remote_code=True, + offload_folder=None, + triton_attn=False, + long_sequence=True, + return_model=False, + raise_exception=False, + ): + from accelerate import init_empty_weights with init_empty_weights(): from transformers import AutoConfig - config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder) + try: + config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token, + trust_remote_code=trust_remote_code, + offload_folder=offload_folder) + except OSError as e: + if raise_exception: + raise + if 'not a local folder and is not a valid model identifier listed on' in str( + e) or '404 Client Error' in str(e): + # e.g. llama, gpjt, etc. + # e.g. HF TGI but not model on HF or private etc. + # HF TGI server only should really require prompt_type, not HF model state + return None, None + else: + raise if triton_attn and 'mpt-' in base_model.lower(): config.attn_config['attn_impl'] = 'triton' if long_sequence: @@ -478,18 +667,36 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward config.update({"max_seq_len": 83968}) if 'mosaicml/mpt-7b-chat' in base_model.lower(): config.update({"max_seq_len": 4096}) - if issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())): + if 'mpt-30b' in base_model.lower(): + config.update({"max_seq_len": 2 * 8192}) + if return_model and \ + issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())): model = AutoModel.from_config( config, + trust_remote_code=trust_remote_code, ) else: # can't infer model = None + if 'falcon' in base_model.lower(): + config.use_cache = False + + return config, model + + +def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type, + config, model, + gpu_id=0, + ): + """ + Ensure model gets on correct device + """ if model is not None: # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model # NOTE: Some models require avoiding sharding some layers, # then would pass no_split_module_classes and give list of those layers. + from accelerate import infer_auto_device_map device_map = infer_auto_device_map( model, dtype=torch.float16 if load_half else torch.float32, @@ -541,12 +748,59 @@ def get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward return model +def get_client_from_inference_server(inference_server, raise_connection_exception=False): + inference_server, headers = get_hf_server(inference_server) + # preload client since slow for gradio case especially + from gradio_utils.grclient import GradioClient + gr_client = None + hf_client = None + if headers is None: + try: + print("GR Client Begin: %s" % inference_server, flush=True) + # first do sanity check if alive, else gradio client takes too long by default + requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) + gr_client = GradioClient(inference_server) + print("GR Client End: %s" % inference_server, flush=True) + except (OSError, ValueError) as e: + # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF + gr_client = None + print("GR Client Failed %s: %s" % (inference_server, str(e)), flush=True) + except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2, + JSONDecodeError, ReadTimeout2, KeyError) as e: + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + print("GR Client Failed %s: %s" % (inference_server, str(ex)), flush=True) + if raise_connection_exception: + raise + + if gr_client is None: + res = None + from text_generation import Client as HFClient + print("HF Client Begin: %s" % inference_server) + try: + hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) + # quick check valid TGI endpoint + res = hf_client.generate('What?', max_new_tokens=1) + hf_client = HFClient(inference_server, headers=headers, timeout=300) + except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2, + JSONDecodeError, ReadTimeout2, KeyError) as e: + hf_client = None + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + print("HF Client Failed %s: %s" % (inference_server, str(ex))) + if raise_connection_exception: + raise + print("HF Client End: %s %s" % (inference_server, res)) + return inference_server, gr_client, hf_client + + def get_model( load_8bit: bool = False, load_4bit: bool = False, load_half: bool = True, infer_devices: bool = True, base_model: str = '', + inference_server: str = "", tokenizer_base_model: str = '', lora_weights: str = "", gpu_id: int = 0, @@ -570,6 +824,7 @@ def get_model( For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches So it is not the default :param base_model: name/path of base model + :param inference_server: whether base_model is hosted locally ('') or via http (url) :param tokenizer_base_model: name/path of tokenizer :param lora_weights: name/path :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1) @@ -585,11 +840,120 @@ def get_model( """ if verbose: print("Get %s model" % base_model, flush=True) + + triton_attn = False + long_sequence = True + config_kwargs = dict(use_auth_token=use_auth_token, + trust_remote_code=trust_remote_code, + offload_folder=offload_folder, + triton_attn=triton_attn, + long_sequence=long_sequence) + config, _ = get_config(base_model, **config_kwargs, raise_exception=False) + + if base_model in non_hf_types: + assert config is None, "Expected config None for %s" % base_model + + llama_type_from_config = 'llama' in str(config).lower() + llama_type_from_name = "llama" in base_model.lower() + llama_type = llama_type_from_config or llama_type_from_name + if "xgen" in base_model.lower(): + llama_type = False + if llama_type: + if verbose: + print("Detected as llama type from" + " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True) + + model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type) + + tokenizer_kwargs = dict(local_files_only=local_files_only, + resume_download=resume_download, + use_auth_token=use_auth_token, + trust_remote_code=trust_remote_code, + offload_folder=offload_folder, + padding_side='left', + config=config, + ) + if not tokenizer_base_model: + tokenizer_base_model = base_model + + if config is not None and tokenizer_loader is not None and not isinstance(tokenizer_loader, str): + tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, **tokenizer_kwargs) + # sets raw (no cushion) limit + set_model_max_len(config, tokenizer, verbose=False) + # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get: + # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233 + tokenizer.model_max_length = tokenizer.model_max_length - 50 + else: + tokenizer = FakeTokenizer() + + if isinstance(inference_server, str) and inference_server.startswith("http"): + inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server) + client = gr_client or hf_client + # Don't return None, None for model, tokenizer so triggers + return client, tokenizer, 'http' + if isinstance(inference_server, str) and inference_server.startswith('openai'): + assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY" + # Don't return None, None for model, tokenizer so triggers + # include small token cushion + tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50) + return inference_server, tokenizer, inference_server + assert not inference_server, "Malformed inference_server=%s" % inference_server if base_model in non_hf_types: from gpt4all_llm import get_model_tokenizer_gpt4all model, tokenizer, device = get_model_tokenizer_gpt4all(base_model) return model, tokenizer, device + # get local torch-HF model + return get_hf_model(load_8bit=load_8bit, + load_4bit=load_4bit, + load_half=load_half, + infer_devices=infer_devices, + base_model=base_model, + tokenizer_base_model=tokenizer_base_model, + lora_weights=lora_weights, + gpu_id=gpu_id, + + reward_type=reward_type, + local_files_only=local_files_only, + resume_download=resume_download, + use_auth_token=use_auth_token, + trust_remote_code=trust_remote_code, + offload_folder=offload_folder, + compile_model=compile_model, + + llama_type=llama_type, + config_kwargs=config_kwargs, + tokenizer_kwargs=tokenizer_kwargs, + + verbose=verbose) + + +def get_hf_model(load_8bit: bool = False, + load_4bit: bool = False, + load_half: bool = True, + infer_devices: bool = True, + base_model: str = '', + tokenizer_base_model: str = '', + lora_weights: str = "", + gpu_id: int = 0, + + reward_type: bool = None, + local_files_only: bool = False, + resume_download: bool = True, + use_auth_token: Union[str, bool] = False, + trust_remote_code: bool = True, + offload_folder: str = None, + compile_model: bool = True, + + llama_type: bool = False, + config_kwargs=None, + tokenizer_kwargs=None, + + verbose: bool = False, + ): + assert config_kwargs is not None + assert tokenizer_kwargs is not None + if lora_weights is not None and lora_weights.strip(): if verbose: print("Get %s lora weights" % lora_weights, flush=True) @@ -604,30 +968,13 @@ def get_model( "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)" ) - from transformers import AutoConfig - config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder) - llama_type_from_config = 'llama' in str(config).lower() - llama_type_from_name = "llama" in base_model.lower() - llama_type = llama_type_from_config or llama_type_from_name - if llama_type: - if verbose: - print("Detected as llama type from" - " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True) + model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type) - model_loader, tokenizer_loader = get_loaders(llama_type=llama_type, model_name=base_model, reward_type=reward_type) - if not tokenizer_base_model: - tokenizer_base_model = base_model + config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs) if tokenizer_loader is not None and not isinstance(tokenizer_loader, str): tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - ) + **tokenizer_kwargs) else: tokenizer = tokenizer_loader @@ -651,7 +998,7 @@ def get_model( load_in_4bit=load_4bit, device_map={"": 0} if (load_8bit or load_4bit) and device == 'cuda' else "auto", )) - if 'mpt-' in base_model.lower() and gpu_id >= 0: + if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0: model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu")) if 'OpenAssistant/reward-model'.lower() in base_model.lower(): @@ -662,27 +1009,33 @@ def get_model( if not lora_weights: with torch.device(device): + if infer_devices: + config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs) model = get_non_lora_model(base_model, model_loader, load_half, model_kwargs, reward_type, + config, model, gpu_id=gpu_id, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, ) else: + config, _ = get_config(base_model, **config_kwargs) if load_half and not (load_8bit or load_4bit): model = model_loader.from_pretrained( base_model, + config=config, **model_kwargs).half() else: model = model_loader.from_pretrained( base_model, + config=config, **model_kwargs) elif load_8bit or load_4bit: + config, _ = get_config(base_model, **config_kwargs) model = model_loader.from_pretrained( base_model, + config=config, **model_kwargs ) + from peft import PeftModel # loads cuda, so avoid in global scope model = PeftModel.from_pretrained( model, lora_weights, @@ -696,10 +1049,13 @@ def get_model( ) else: with torch.device(device): + config, _ = get_config(base_model, raise_exception=True, **config_kwargs) model = model_loader.from_pretrained( base_model, + config=config, **model_kwargs ) + from peft import PeftModel # loads cuda, so avoid in global scope model = PeftModel.from_pretrained( model, lora_weights, @@ -730,13 +1086,27 @@ def get_model( if torch.__version__ >= "2" and sys.platform != "win32" and compile_model: model = torch.compile(model) - if hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int): + set_model_max_len(config, tokenizer, verbose=False, reward_type=reward_type) + + return model, tokenizer, device + + +def set_model_max_len(config, tokenizer, verbose=False, reward_type=False): + if reward_type: + # limit deberta, else uses too much memory and not worth response score + tokenizer.model_max_length = 512 + if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int): + tokenizer.model_max_length = config.max_seq_len + elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int): # help automatically limit inputs to generate tokenizer.model_max_length = config.max_position_embeddings else: + if verbose: + print("Could not determine model_max_length, setting to 2048", flush=True) + tokenizer.model_max_length = 2048 + # for bug in HF transformers + if tokenizer.model_max_length > 100000000: tokenizer.model_max_length = 2048 - - return model, tokenizer, device def pop_unused_model_kwargs(model_kwargs): @@ -758,6 +1128,7 @@ def get_score_model(score_model: str = None, load_half: bool = True, infer_devices: bool = True, base_model: str = '', + inference_server: str = '', tokenizer_base_model: str = '', lora_weights: str = "", gpu_id: int = 0, @@ -779,6 +1150,7 @@ def get_score_model(score_model: str = None, base_model = score_model.strip() tokenizer_base_model = '' lora_weights = '' + inference_server = '' llama_type = False compile_model = False smodel, stokenizer, sdevice = get_model(reward_type=True, @@ -788,30 +1160,155 @@ def get_score_model(score_model: str = None, return smodel, stokenizer, sdevice +no_default_param_names = [ + 'instruction', + 'iinput', + 'context', + 'instruction_nochat', + 'iinput_nochat', +] + +gen_hyper = ['temperature', + 'top_p', + 'top_k', + 'num_beams', + 'max_new_tokens', + 'min_new_tokens', + 'early_stopping', + 'max_time', + 'repetition_penalty', + 'num_return_sequences', + 'do_sample', + ] + eval_func_param_names = ['instruction', 'iinput', 'context', 'stream_output', 'prompt_type', - 'temperature', - 'top_p', - 'top_k', - 'num_beams', - 'max_new_tokens', - 'min_new_tokens', - 'early_stopping', - 'max_time', - 'repetition_penalty', - 'num_return_sequences', - 'do_sample', - 'chat', + 'prompt_dict'] + \ + gen_hyper + \ + ['chat', 'instruction_nochat', 'iinput_nochat', 'langchain_mode', 'top_k_docs', + 'chunk', + 'chunk_size', 'document_choice', ] +# form evaluate defaults for submit_nochat_api +eval_func_param_names_defaults = eval_func_param_names.copy() +for k in no_default_param_names: + if k in eval_func_param_names_defaults: + eval_func_param_names_defaults.remove(k) + + +def evaluate_from_str( + model_state, + my_db_state, + # START NOTE: Examples must have same order of parameters + user_kwargs, + # END NOTE: Examples must have same order of parameters + default_kwargs=None, + src_lang=None, + tgt_lang=None, + debug=False, + concurrency_count=None, + save_dir=None, + sanitize_bot_response=False, + model_state0=None, + memory_restriction_level=None, + max_max_new_tokens=None, + is_public=None, + max_max_time=None, + raise_generate_gpu_exceptions=None, + chat_context=None, + lora_weights=None, + load_db_if_exists=True, + dbs=None, + user_path=None, + detect_user_path_changes_every_query=None, + use_openai_embedding=None, + use_openai_model=None, + hf_embedding_model=None, + db_type=None, + n_jobs=None, + first_para=None, + text_limit=None, + verbose=False, + cli=False, + reverse_docs=True, + use_cache=None, + auto_reduce_chunks=None, + max_chunks=None, + model_lock=None, + force_langchain_evaluate=None, + model_state_none=None, +): + if isinstance(user_kwargs, str): + user_kwargs = ast.literal_eval(user_kwargs) + # only used for submit_nochat_api + user_kwargs['chat'] = False + if 'stream_output' not in user_kwargs: + user_kwargs['stream_output'] = False + if 'langchain_mode' not in user_kwargs: + # if user doesn't specify, then assume disabled, not use default + user_kwargs['langchain_mode'] = 'Disabled' + + assert set(list(default_kwargs.keys())) == set(eval_func_param_names) + # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get() + args_list = [user_kwargs[k] if k in user_kwargs else default_kwargs[k] for k in eval_func_param_names] + + ret = evaluate( + model_state, + my_db_state, + # START NOTE: Examples must have same order of parameters + *tuple(args_list), + # END NOTE: Examples must have same order of parameters + src_lang=src_lang, + tgt_lang=tgt_lang, + debug=debug, + concurrency_count=concurrency_count, + save_dir=save_dir, + sanitize_bot_response=sanitize_bot_response, + model_state0=model_state0, + memory_restriction_level=memory_restriction_level, + max_max_new_tokens=max_max_new_tokens, + is_public=is_public, + max_max_time=max_max_time, + raise_generate_gpu_exceptions=raise_generate_gpu_exceptions, + chat_context=chat_context, + lora_weights=lora_weights, + load_db_if_exists=load_db_if_exists, + dbs=dbs, + user_path=user_path, + detect_user_path_changes_every_query=detect_user_path_changes_every_query, + use_openai_embedding=use_openai_embedding, + use_openai_model=use_openai_model, + hf_embedding_model=hf_embedding_model, + db_type=db_type, + n_jobs=n_jobs, + first_para=first_para, + text_limit=text_limit, + verbose=verbose, + cli=cli, + reverse_docs=reverse_docs, + use_cache=use_cache, + auto_reduce_chunks=auto_reduce_chunks, + max_chunks=max_chunks, + model_lock=model_lock, + force_langchain_evaluate=force_langchain_evaluate, + model_state_none=model_state_none, + ) + try: + for ret1 in ret: + yield ret1 + finally: + # clear before return, in finally in case GPU OOM exception + clear_torch_cache() + def evaluate( model_state, @@ -822,6 +1319,7 @@ def evaluate( context, stream_output, prompt_type, + prompt_dict, temperature, top_p, top_k, @@ -838,6 +1336,8 @@ def evaluate( iinput_nochat, langchain_mode, top_k_docs, + chunk, + chunk_size, document_choice, # END NOTE: Examples must have same order of parameters src_lang=None, @@ -845,9 +1345,12 @@ def evaluate( debug=False, concurrency_count=None, save_dir=None, - sanitize_bot_response=True, + sanitize_bot_response=False, model_state0=None, memory_restriction_level=None, + max_max_new_tokens=None, + is_public=None, + max_max_time=None, raise_generate_gpu_exceptions=None, chat_context=None, lora_weights=None, @@ -858,14 +1361,19 @@ def evaluate( use_openai_embedding=None, use_openai_model=None, hf_embedding_model=None, - chunk=None, - chunk_size=None, db_type=None, n_jobs=None, first_para=None, text_limit=None, verbose=False, cli=False, + reverse_docs=True, + use_cache=None, + auto_reduce_chunks=None, + max_chunks=None, + model_lock=None, + force_langchain_evaluate=None, + model_state_none=None, ): # ensure passed these assert concurrency_count is not None @@ -875,10 +1383,10 @@ def evaluate( assert use_openai_embedding is not None assert use_openai_model is not None assert hf_embedding_model is not None - assert chunk is not None - assert chunk_size is not None assert db_type is not None assert top_k_docs is not None and isinstance(top_k_docs, int) + assert chunk is not None and isinstance(chunk, bool) + assert chunk_size is not None and isinstance(chunk_size, int) assert n_jobs is not None assert first_para is not None @@ -886,29 +1394,58 @@ def evaluate( locals_dict = locals().copy() locals_dict.pop('model_state', None) locals_dict.pop('model_state0', None) + locals_dict.pop('model_states', None) print(locals_dict) - no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\nThen start New Conversation" + no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\n" \ + "Then start New Conversation" + if model_state is None: + model_state = model_state_none.copy() if model_state0 is None: # e.g. for no gradio case, set dummy value, else should be set - model_state0 = [None, None, None, None] - - if model_state is not None and len(model_state) == 4 and not isinstance(model_state[0], str): - # try to free-up original model (i.e. list was passed as reference) - if model_state0 is not None and model_state0[0] is not None: - model_state0[0].cpu() - model_state0[0] = None - # try to free-up original tokenizer (i.e. list was passed as reference) - if model_state0 is not None and model_state0[1] is not None: - model_state0[1] = None - clear_torch_cache() - model, tokenizer, device, base_model = model_state - elif model_state0 is not None and len(model_state0) == 4 and model_state0[0] is not None: - assert isinstance(model_state[0], str) - model, tokenizer, device, base_model = model_state0 + model_state0 = model_state_none.copy() + + # model_state['model] is only 'model' if should use model_state0 + # model could also be None + have_model_lock = model_lock is not None + have_fresh_model = model_state['model'] not in [None, 'model', no_model_str] + # for gradio UI control, expect model_state and model_state0 to match, so if have_model_lock=True, then should have_fresh_model=True + # but gradio API control will only use nochat api etc. and won't use fresh model, so can't assert in general + # if have_model_lock: + # assert have_fresh_model, "Expected model_state and model_state0 to match if have_model_lock" + have_cli_model = model_state0['model'] not in [None, 'model', no_model_str] + + if have_fresh_model: + # USE FRESH MODEL + if not have_model_lock: + # model_state0 is just one of model_state if model_lock, so don't nuke + # try to free-up original model (i.e. list was passed as reference) + if model_state0['model'] and hasattr(model_state0['model'], 'cpu'): + model_state0['model'].cpu() + model_state0['model'] = None + # try to free-up original tokenizer (i.e. list was passed as reference) + if model_state0['tokenizer']: + model_state0['tokenizer'] = None + clear_torch_cache() + chosen_model_state = model_state + elif have_cli_model: + # USE MODEL SETUP AT CLI + assert isinstance(model_state['model'], str) # expect no fresh model + chosen_model_state = model_state0 else: raise AssertionError(no_model_msg) + # get variables + model = chosen_model_state['model'] + tokenizer = chosen_model_state['tokenizer'] + device = chosen_model_state['device'] + base_model = chosen_model_state['base_model'] + tokenizer_base_model = chosen_model_state['tokenizer_base_model'] + lora_weights = chosen_model_state['lora_weights'] + inference_server = chosen_model_state['inference_server'] + # prefer use input from API over model state + prompt_type = prompt_type or chosen_model_state['prompt_type'] + prompt_dict = prompt_dict or chosen_model_state['prompt_dict'] if base_model is None: raise AssertionError(no_model_msg) @@ -922,11 +1459,49 @@ def evaluate( instruction = instruction_nochat iinput = iinput_nochat + # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice + model_lower = base_model.lower() + if not prompt_type and model_lower in inv_prompt_type_to_model_lower: + prompt_type = inv_prompt_type_to_model_lower[model_lower] + if verbose: + print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True) + assert prompt_type is not None, "prompt_type was None" + + # Control generation hyperparameters + # adjust for bad inputs, e.g. in case also come from API that doesn't get constrained by gradio sliders + # below is for TGI server, not required for HF transformers + # limits are chosen similar to gradio_runner.py sliders/numbers + top_p = min(max(1e-3, top_p), 1.0 - 1e-3) + top_k = min(max(1, int(top_k)), 100) + temperature = min(max(0.01, temperature), 2.0) + # FIXME: https://github.com/h2oai/h2ogpt/issues/106 + num_beams = 1 if stream_output else num_beams # See max_beams in gradio_runner + max_max_new_tokens = get_max_max_new_tokens(chosen_model_state, + memory_restriction_level=memory_restriction_level, + max_new_tokens=max_new_tokens, + max_max_new_tokens=max_max_new_tokens) + model_max_length = get_model_max_length(chosen_model_state) + max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens) + min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens) + max_time = min(max(0, max_time), max_max_time) + repetition_penalty = min(max(0.01, repetition_penalty), 3.0) + num_return_sequences = 1 if chat else min(max(1, int(num_return_sequences)), 10) + min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public) + top_k_docs = min(max(min_top_k_docs, int(top_k_docs)), max_top_k_docs) + chunk_size = min(max(128, int(chunk_size)), 2048) if not context: # get hidden context if have one context = get_context(chat_context, prompt_type) - prompter = Prompter(prompt_type, debug=debug, chat=chat, stream_output=stream_output) + # restrict instruction, typically what has large input + from h2oai_pipeline import H2OTextGenerationPipeline + instruction, num_prompt_tokens1 = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer) + context, num_prompt_tokens2 = H2OTextGenerationPipeline.limit_prompt(context, tokenizer) + iinput, num_prompt_tokens3 = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer) + num_prompt_tokens = (num_prompt_tokens1 or 0) + (num_prompt_tokens2 or 0) + (num_prompt_tokens3 or 0) + + # get prompt + prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output) data_point = dict(context=context, instruction=instruction, input=iinput) prompt = prompter.generate_prompt(data_point) @@ -938,20 +1513,36 @@ def evaluate( db1 = dbs[langchain_mode] else: db1 = None - if langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and db1 is not None or base_model in non_hf_types: + do_langchain_path = langchain_mode not in [False, 'Disabled', 'ChatLLM', 'LLM'] and \ + db1 is not None or \ + base_model in non_hf_types or \ + force_langchain_evaluate + if do_langchain_path: query = instruction if not iinput else "%s\n%s" % (instruction, iinput) outr = "" # use smaller cut_distanct for wiki_full since so many matches could be obtained, and often irrelevant unless close from gpt_langchain import run_qa_db + gen_hyper_langchain = dict(do_sample=do_sample, + temperature=temperature, + repetition_penalty=repetition_penalty, + top_k=top_k, + top_p=top_p, + num_beams=num_beams, + min_new_tokens=min_new_tokens, + max_new_tokens=max_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + num_return_sequences=num_return_sequences, + ) for r in run_qa_db(query=query, model_name=base_model, model=model, tokenizer=tokenizer, + inference_server=inference_server, stream_output=stream_output, prompter=prompter, load_db_if_exists=load_db_if_exists, db=db1, user_path=user_path, detect_user_path_changes_every_query=detect_user_path_changes_every_query, - max_new_tokens=max_new_tokens, cut_distanct=1.1 if langchain_mode in ['wiki_full'] else 1.64, # FIXME, too arbitrary use_openai_embedding=use_openai_embedding, use_openai_model=use_openai_model, @@ -963,20 +1554,33 @@ def evaluate( langchain_mode=langchain_mode, document_choice=document_choice, db_type=db_type, - k=top_k_docs, - temperature=temperature, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, + top_k_docs=top_k_docs, + + **gen_hyper_langchain, + prompt_type=prompt_type, + prompt_dict=prompt_dict, n_jobs=n_jobs, verbose=verbose, cli=cli, + sanitize_bot_response=sanitize_bot_response, + reverse_docs=reverse_docs, + + lora_weights=lora_weights, + + auto_reduce_chunks=auto_reduce_chunks, + max_chunks=max_chunks, ): outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer yield dict(response=outr, sources=extra) if save_dir: - save_generate_output(output=outr, base_model=base_model, save_dir=save_dir) + extra_dict = gen_hyper_langchain.copy() + extra_dict.update(prompt_type=prompt_type, inference_server=inference_server, + langchain_mode=langchain_mode, document_choice=document_choice, + num_prompt_tokens=num_prompt_tokens) + save_generate_output(prompt=query, output=outr, base_model=base_model, save_dir=save_dir, + where_from='run_qa_db', + extra_dict=extra_dict) if verbose: print( 'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1), @@ -985,8 +1589,270 @@ def evaluate( # if got no response (e.g. not showing sources and got no sources, # so nothing to give to LLM), then slip through and ask LLM # Or if llama/gptj, then just return since they had no response and can't go down below code path + # clear before return, since .then() never done if from API + clear_torch_cache() return + if inference_server.startswith('openai') or inference_server.startswith('http'): + if inference_server.startswith('openai'): + import openai + where_from = "openai_client" + + openai.api_key = os.getenv("OPENAI_API_KEY") + stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse])) + # OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so. + max_new_tokens_openai = min(max_new_tokens, model_max_length - num_prompt_tokens) + gen_server_kwargs = dict(temperature=temperature if do_sample else 0, + max_tokens=max_new_tokens_openai, + top_p=top_p if do_sample else 1, + frequency_penalty=0, + n=num_return_sequences, + presence_penalty=1.07 - repetition_penalty + 0.6, # so good default + ) + if inference_server == 'openai': + response = openai.Completion.create( + model=base_model, + prompt=prompt, + **gen_server_kwargs, + stop=stop_sequences, + stream=stream_output, + ) + if not stream_output: + text = response['choices'][0]['text'] + yield dict(response=prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources='') + else: + collected_events = [] + text = '' + for event in response: + collected_events.append(event) # save the event response + event_text = event['choices'][0]['text'] # extract the text + text += event_text # append the text + yield dict(response=prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources='') + elif inference_server == 'openai_chat': + response = openai.ChatCompletion.create( + model=base_model, + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {'role': 'user', + 'content': prompt, + } + ], + stream=stream_output, + **gen_server_kwargs, + ) + if not stream_output: + text = response["choices"][0]["message"]["content"] + yield dict(response=prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources='') + else: + text = "" + for chunk in response: + delta = chunk["choices"][0]["delta"] + if 'content' in delta: + text += delta['content'] + yield dict(response=prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources='') + else: + raise RuntimeError("No such OpenAI mode: %s" % inference_server) + elif inference_server.startswith('http'): + inference_server, headers = get_hf_server(inference_server) + from gradio_utils.grclient import GradioClient + from text_generation import Client as HFClient + if isinstance(model, GradioClient): + gr_client = model + hf_client = None + elif isinstance(model, HFClient): + gr_client = None + hf_client = model + else: + inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server) + + # quick sanity check to avoid long timeouts, just see if can reach server + requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) + + if gr_client is not None: + # Note: h2oGPT gradio server could handle input token size issues for prompt, + # but best to handle here so send less data to server + + chat_client = False + where_from = "gr_client" + client_langchain_mode = 'Disabled' + gen_server_kwargs = dict(temperature=temperature, + top_p=top_p, + top_k=top_k, + num_beams=num_beams, + max_new_tokens=max_new_tokens, + min_new_tokens=min_new_tokens, + early_stopping=early_stopping, + max_time=max_time, + repetition_penalty=repetition_penalty, + num_return_sequences=num_return_sequences, + do_sample=do_sample, + chat=chat_client, + ) + # account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection + if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value, + str(PromptType.plain.value)]: + # if our prompt is plain, assume either correct or gradio server knows different prompt type, + # so pass empty prompt_Type + gr_prompt_type = '' + gr_prompt_dict = '' + gr_prompt = prompt # already prepared prompt + gr_context = '' + gr_iinput = '' + else: + # if already have prompt_type that is not plain, None, or '', then already applied some prompting + # But assume server can handle prompting, and need to avoid double-up. + # Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle + # So avoid "prompt" and let gradio server reconstruct from prompt_type we passed + # Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed, + # because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter + # since those won't appear + gr_context = context + gr_prompt = instruction + gr_iinput = iinput + gr_prompt_type = prompt_type + gr_prompt_dict = prompt_dict + client_kwargs = dict(instruction=gr_prompt if chat_client else '', # only for chat=True + iinput=gr_iinput, # only for chat=True + context=gr_context, + # streaming output is supported, loops over and outputs each generation in streaming mode + # but leave stream_output=False for simple input/output mode + stream_output=stream_output, + + **gen_server_kwargs, + + prompt_type=gr_prompt_type, + prompt_dict=gr_prompt_dict, + + instruction_nochat=gr_prompt if not chat_client else '', + iinput_nochat=gr_iinput, # only for chat=False + langchain_mode=client_langchain_mode, + top_k_docs=top_k_docs, + chunk=chunk, + chunk_size=chunk_size, + document_choice=[DocumentChoices.All_Relevant.name], + ) + api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing + if not stream_output: + res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) + res_dict = ast.literal_eval(res) + text = res_dict['response'] + sources = res_dict['sources'] + yield dict(response=prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources=sources) + else: + job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) + text = '' + sources = '' + res_dict = dict(response=text, sources=sources) + while not job.done(): + outputs_list = job.communicator.job.outputs + if outputs_list: + res = job.communicator.job.outputs[-1] + res_dict = ast.literal_eval(res) + text = res_dict['response'] + sources = res_dict['sources'] + if gr_prompt_type == 'plain': + # then gradio server passes back full prompt + text + prompt_and_text = text + else: + prompt_and_text = prompt + text + yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources=sources) + time.sleep(0.01) + # ensure get last output to avoid race + res_all = job.outputs() + if len(res_all) > 0: + res = res_all[-1] + res_dict = ast.literal_eval(res) + text = res_dict['response'] + sources = res_dict['sources'] + else: + # go with old text if last call didn't work + e = job.future._exception + if e is not None: + stre = str(e) + strex = ''.join(traceback.format_tb(e.__traceback__)) + else: + stre = '' + strex = '' + + print("Bad final response: %s %s %s %s %s: %s %s" % (base_model, inference_server, + res_all, prompt, text, stre, strex), + flush=True) + if gr_prompt_type == 'plain': + # then gradio server passes back full prompt + text + prompt_and_text = text + else: + prompt_and_text = prompt + text + yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources=sources) + elif hf_client: + # HF inference server needs control over input tokens + where_from = "hf_client" + + # prompt must include all human-bot like tokens, already added by prompt + # https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types + stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse])) + gen_server_kwargs = dict(do_sample=do_sample, + max_new_tokens=max_new_tokens, + # best_of=None, + repetition_penalty=repetition_penalty, + return_full_text=True, + seed=SEED, + stop_sequences=stop_sequences, + temperature=temperature, + top_k=top_k, + top_p=top_p, + # truncate=False, # behaves oddly + # typical_p=top_p, + # watermark=False, + # decoder_input_details=False, + ) + # work-around for timeout at constructor time, will be issue if multi-threading, + # so just do something reasonable or max_time if larger + # lower bound because client is re-used if multi-threading + hf_client.timeout = max(300, max_time) + if not stream_output: + text = hf_client.generate(prompt, **gen_server_kwargs).generated_text + yield dict(response=prompter.get_response(text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources='') + else: + text = "" + for response in hf_client.generate_stream(prompt, **gen_server_kwargs): + if not response.token.special: + # stop_sequences + text_chunk = response.token.text + text += text_chunk + yield dict(response=prompter.get_response(prompt + text, prompt=prompt, + sanitize_bot_response=sanitize_bot_response), + sources='') + else: + raise RuntimeError("Failed to get client: %s" % inference_server) + else: + raise RuntimeError("No such inference_server %s" % inference_server) + + if save_dir and text: + # save prompt + new text + extra_dict = gen_server_kwargs.copy() + extra_dict.update(dict(inference_server=inference_server, num_prompt_tokens=num_prompt_tokens)) + save_generate_output(prompt=prompt, output=text, base_model=base_model, save_dir=save_dir, + where_from=where_from, extra_dict=extra_dict) + return + else: + assert not inference_server, "inferene_server=%s not supported" % inference_server + if isinstance(tokenizer, str): # pipeline if tokenizer == "summarization": @@ -1000,36 +1866,37 @@ def evaluate( assert src_lang is not None tokenizer.src_lang = languages_covered()[src_lang] - if chat: - # override, ignore user change - num_return_sequences = 1 - stopping_criteria = get_stopping(prompt_type, tokenizer, device) - _, _, max_length_tokenize, max_prompt_length = get_cutoffs(memory_restriction_level, model_max_length=tokenizer.model_max_length) - prompt = prompt[-max_prompt_length:] - inputs = tokenizer(prompt, - return_tensors="pt", - truncation=True, - max_length=max_length_tokenize) - if inputs['input_ids'].shape[1] >= max_length_tokenize - 1: - print("Cutting off input: %s %s" % (inputs['input_ids'].shape[1], max_length_tokenize), flush=True) + stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device, + model_max_length=tokenizer.model_max_length) + + inputs = tokenizer(prompt, return_tensors="pt") if debug and len(inputs["input_ids"]) > 0: print('input_ids length', len(inputs["input_ids"][0]), flush=True) input_ids = inputs["input_ids"].to(device) # CRITICAL LIMIT else will fail max_max_tokens = tokenizer.model_max_length - max_input_tokens = max_max_tokens - max_new_tokens + max_input_tokens = max_max_tokens - min_new_tokens + # NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py input_ids = input_ids[:, -max_input_tokens:] - generation_config = GenerationConfig( - temperature=float(temperature), - top_p=float(top_p), - top_k=top_k, - num_beams=num_beams, - do_sample=do_sample, - repetition_penalty=float(repetition_penalty), - num_return_sequences=num_return_sequences, - renormalize_logits=True, - remove_invalid_values=True, - ) + # required for falcon if multiple threads or asyncio accesses to model during generation + if use_cache is None: + use_cache = False if 'falcon' in base_model else True + gen_config_kwargs = dict(temperature=float(temperature), + top_p=float(top_p), + top_k=top_k, + num_beams=num_beams, + do_sample=do_sample, + repetition_penalty=float(repetition_penalty), + num_return_sequences=num_return_sequences, + renormalize_logits=True, + remove_invalid_values=True, + use_cache=use_cache, + ) + token_ids = ['eos_token_id', 'pad_token_id', 'bos_token_id', 'cls_token_id', 'sep_token_id'] + for token_id in token_ids: + if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None: + gen_config_kwargs.update({token_id: getattr(tokenizer, token_id)}) + generation_config = GenerationConfig(**gen_config_kwargs) gen_kwargs = dict(input_ids=input_ids, generation_config=generation_config, @@ -1048,7 +1915,10 @@ def evaluate( tgt_lang = languages_covered()[tgt_lang] gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])) else: - gen_kwargs.update(dict(pad_token_id=tokenizer.eos_token_id)) + token_ids = ['eos_token_id', 'bos_token_id', 'pad_token_id'] + for token_id in token_ids: + if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None: + gen_kwargs.update({token_id: getattr(tokenizer, token_id)}) decoder_kwargs = dict(skip_special_tokens=True, clean_up_tokenization_spaces=True) @@ -1064,7 +1934,8 @@ def evaluate( ) with torch.no_grad(): - context_class_cast = NullContext if device == 'cpu' or lora_weights else torch.autocast + have_lora_weights = lora_weights not in [no_lora_str, '', None] + context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast with context_class_cast(device): # protection for gradio not keeping track of closed users, # else hit bitsandbytes lack of thread safety: @@ -1127,6 +1998,8 @@ def evaluate( raise thread.exc raise finally: + # clear before return, since .then() never done if from API + clear_torch_cache() # in case no exception and didn't join with thread yet, then join if not thread.exc: thread.join() @@ -1135,14 +2008,21 @@ def evaluate( raise thread.exc decoded_output = outputs else: - outputs = model.generate(**gen_kwargs) + try: + outputs = model.generate(**gen_kwargs) + finally: + clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called outputs = [decoder(s) for s in outputs.sequences] yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded, sanitize_bot_response=sanitize_bot_response), sources='') if outputs and len(outputs) >= 1: decoded_output = prompt + outputs[0] if save_dir and decoded_output: - save_generate_output(output=decoded_output, base_model=base_model, save_dir=save_dir) + extra_dict = gen_config_kwargs.copy() + extra_dict.update(dict(num_prompt_tokens=num_prompt_tokens)) + save_generate_output(prompt=prompt, output=decoded_output, base_model=base_model, save_dir=save_dir, + where_from="evaluate_%s" % str(stream_output), + extra_dict=gen_config_kwargs) if verbose: print('Post-Generate: %s decoded_output: %s' % ( str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True) @@ -1161,6 +2041,7 @@ def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=20 if memory_restriction_level > 0: max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 else: + # at least give room for 1 paragraph output max_length_tokenize = model_max_length - 256 cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens output_smallest = 30 * 4 @@ -1251,18 +2132,22 @@ def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_ge def get_generate_params(model_lower, chat, stream_output, show_examples, - prompt_type, temperature, top_p, top_k, num_beams, + prompt_type, prompt_dict, + temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, - do_sample, k, verbose): + do_sample, + top_k_docs, chunk, chunk_size, + verbose): use_defaults = False use_default_examples = True examples = [] - task_info = f"{prompt_type}" + task_info = 'LLM' if model_lower: print(f"Using Model {model_lower}", flush=True) else: - print("No model defined yet", flush=True) + if verbose: + print("No model defined yet", flush=True) min_new_tokens = min_new_tokens if min_new_tokens is not None else 0 early_stopping = early_stopping if early_stopping is not None else False @@ -1288,15 +2173,13 @@ Jeff: and how can I get started? Jeff: where can I find documentation? Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face""" + use_placeholder_instruction_as_example = False if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower: placeholder_instruction = summarize_example1 placeholder_input = "" use_defaults = True use_default_examples = False - examples += [ - [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults, - 1.0, 1, - False]] + use_placeholder_instruction_as_example = True task_info = "Summarization" elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower: placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?" @@ -1309,29 +2192,25 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa placeholder_input = "" use_defaults = True use_default_examples = False - examples += [ - [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults, - 1.0, 1, - False]] + use_placeholder_instruction_as_example = True elif 'gpt2' in model_lower: placeholder_instruction = "The sky is" placeholder_input = "" prompt_type = prompt_type or 'plain' use_default_examples = True # some will be odd "continuations" but can be ok - examples += [ - [placeholder_instruction, "", "", stream_output, 'plain', 1.0, 1.0, 50, 1, 128, 0, False, max_time_defaults, - 1.0, 1, - False]] + use_placeholder_instruction_as_example = True task_info = "Auto-complete phrase, code, etc." use_defaults = True else: if chat: - placeholder_instruction = "Enter a question or imperative." + placeholder_instruction = "" else: placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter." placeholder_input = "" - if model_lower: - # default is plain, because might relly upon trust_remote_code to handle prompting + if model_lower in inv_prompt_type_to_model_lower: + prompt_type = inv_prompt_type_to_model_lower[model_lower] + elif model_lower: + # default is plain, because might rely upon trust_remote_code to handle prompting prompt_type = prompt_type or 'plain' else: prompt_type = '' @@ -1361,18 +2240,22 @@ Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-pa temperature = 0.1 if temperature is None else temperature top_p = 0.75 if top_p is None else top_p top_k = 40 if top_k is None else top_k - if chat: - num_beams = num_beams or 1 - else: - num_beams = num_beams or 4 + num_beams = num_beams or 1 max_new_tokens = max_new_tokens or 256 repetition_penalty = repetition_penalty or 1.07 num_return_sequences = min(num_beams, num_return_sequences or 1) do_sample = False if do_sample is None else do_sample # doesn't include chat, instruction_nochat, iinput_nochat, added later - params_list = ["", stream_output, prompt_type, temperature, top_p, top_k, num_beams, max_new_tokens, min_new_tokens, + params_list = ["", + stream_output, + prompt_type, prompt_dict, + temperature, top_p, top_k, num_beams, + max_new_tokens, min_new_tokens, early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample] + if use_placeholder_instruction_as_example: + examples += [[placeholder_instruction, ''] + params_list] + if use_default_examples: examples += [ ["Translate English to French", "Good morning"] + params_list, @@ -1410,14 +2293,15 @@ y = np.random.randint(0, 1, 100) # fit random forest classifier with 20 estimators""", ''] + params_list, ] # add summary example - examples += [[summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list] + examples += [ + [summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list] src_lang = "English" tgt_lang = "Russian" # move to correct position for example in examples: - example += [chat, '', '', 'Disabled', k, ['All']] + example += [chat, '', '', 'Disabled', top_k_docs, chunk, chunk_size, [DocumentChoices.All_Relevant.name]] # adjust examples if non-chat mode if not chat: example[eval_func_param_names.index('instruction_nochat')] = example[ @@ -1429,9 +2313,19 @@ y = np.random.randint(0, 1, 100) assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % ( len(example), len(eval_func_param_names)) + if prompt_type == PromptType.custom.name and not prompt_dict: + raise ValueError("Unexpected to get non-empty prompt_dict=%s for prompt_type=%s" % (prompt_dict, prompt_type)) + + # get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format + prompt_dict, error0 = get_prompt(prompt_type, prompt_dict, + chat=False, context='', reduced=False, making_context=False, return_dict=True) + if error0: + raise RuntimeError("Prompt wrong: %s" % error0) + return placeholder_instruction, placeholder_input, \ stream_output, show_examples, \ - prompt_type, temperature, top_p, top_k, num_beams, \ + prompt_type, prompt_dict, \ + temperature, top_p, top_k, num_beams, \ max_new_tokens, min_new_tokens, early_stopping, max_time, \ repetition_penalty, num_return_sequences, \ do_sample, \ @@ -1477,7 +2371,8 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l if 'Expected all tensors to be on the same device' in str(e) or \ 'expected scalar type Half but found Float' in str(e) or \ 'probability tensor contains either' in str(e) or \ - 'cublasLt ran into an error!' in str(e): + 'cublasLt ran into an error!' in str(e) or \ + 'device-side assert triggered' in str(e): print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True) traceback.print_exc() @@ -1491,12 +2386,7 @@ def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_l def check_locals(**kwargs): # ensure everything in evaluate is here - can_skip_because_locally_generated = [ # evaluate - 'instruction', - 'iinput', - 'context', - 'instruction_nochat', - 'iinput_nochat', + can_skip_because_locally_generated = no_default_param_names + [ # get_model: 'reward_type' ] @@ -1515,7 +2405,101 @@ def check_locals(**kwargs): assert k in kwargs, "Missing %s" % k -if __name__ == "__main__": +def get_model_max_length(model_state): + if not isinstance(model_state['tokenizer'], (str, types.NoneType)): + return model_state['tokenizer'].model_max_length + else: + return 2048 + + +def get_max_max_new_tokens(model_state, **kwargs): + if not isinstance(model_state['tokenizer'], (str, types.NoneType)): + max_max_new_tokens = model_state['tokenizer'].model_max_length + else: + max_max_new_tokens = None + + if kwargs['max_max_new_tokens'] is not None and max_max_new_tokens is not None: + return min(max_max_new_tokens, kwargs['max_max_new_tokens']) + elif kwargs['max_max_new_tokens'] is not None: + return kwargs['max_max_new_tokens'] + elif kwargs['memory_restriction_level'] == 1: + return 768 + elif kwargs['memory_restriction_level'] == 2: + return 512 + elif kwargs['memory_restriction_level'] >= 3: + return 256 + else: + # FIXME: Need to update after new model loaded, so user can control with slider + return 2048 + + +def get_minmax_top_k_docs(is_public): + if is_public: + min_top_k_docs = 1 + max_top_k_docs = 3 + label_top_k_docs = "Number of document chunks" + else: + min_top_k_docs = -1 + max_top_k_docs = 100 + label_top_k_docs = "Number of document chunks (-1 = auto fill model context)" + return min_top_k_docs, max_top_k_docs, label_top_k_docs + + +def history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1, model_max_length1, + memory_restriction_level1, keep_sources_in_context1): + """ + consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair + :param history: + :param langchain_mode1: + :param prompt_type1: + :param prompt_dict1: + :param chat1: + :param model_max_length1: + :param memory_restriction_level1: + :param keep_sources_in_context1: + :return: + """ + # ensure output will be unique to models + _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1, + for_context=True, model_max_length=model_max_length1) + context1 = '' + if max_prompt_length is not None and langchain_mode1 not in ['LLM']: + context1 = '' + # - 1 below because current instruction already in history from user() + for histi in range(0, len(history) - 1): + data_point = dict(instruction=history[histi][0], input='', output=history[histi][1]) + prompt, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt(data_point, + prompt_type1, + prompt_dict1, + chat1, + reduced=True, + making_context=True) + # md -> back to text, maybe not super important if model trained enough + if not keep_sources_in_context1 and langchain_mode1 != 'Disabled' and prompt.find(source_prefix) >= 0: + # FIXME: This is relatively slow even for small amount of text, like 0.3s each history item + import re + prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt, + flags=re.DOTALL) + if prompt.endswith('\n
'):
+ prompt = prompt[:-4]
+ prompt = prompt.replace('
', chat_turn_sep)
+ if not prompt.endswith(chat_turn_sep):
+ prompt += chat_turn_sep
+ # most recent first, add older if can
+ # only include desired chat history
+ if len(prompt + context1) > max_prompt_length:
+ break
+ context1 += prompt
+
+ _, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt({}, prompt_type1, prompt_dict1,
+ chat1, reduced=True,
+ making_context=True)
+ if context1 and not context1.endswith(chat_turn_sep):
+ context1 += chat_turn_sep # ensure if terminates abruptly, then human continues on next line
+ return context1
+
+
+def entrypoint_main():
"""
Examples:
@@ -1546,3 +2530,7 @@ if __name__ == "__main__":
python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b
"""
fire.Fire(main)
+
+
+if __name__ == "__main__":
+ entrypoint_main()
diff --git a/gpt4all_llm.py b/gpt4all_llm.py
index 64d7f220f68cfcd80a543b9de58c383ff18229e8..7febb9b374a246b09deef0dfc4b4adf3966a37fd 100644
--- a/gpt4all_llm.py
+++ b/gpt4all_llm.py
@@ -1,23 +1,13 @@
import inspect
import os
-import sys
+from functools import partial
from typing import Dict, Any, Optional, List
from langchain.callbacks.manager import CallbackManagerForLLMRun
from pydantic import root_validator
from langchain.llms import gpt4all
from dotenv import dotenv_values
-
-class FakeTokenizer:
-
- def encode(self, x, *args, **kwargs):
- return dict(input_ids=[x])
-
- def decode(self, x, *args, **kwargs):
- return x
-
- def __call__(self, x, *args, **kwargs):
- return self.encode(x, *args, **kwargs)
+from utils import FakeTokenizer
def get_model_tokenizer_gpt4all(base_model, **kwargs):
@@ -73,9 +63,9 @@ class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler):
pass
-def get_model_kwargs(env_kwargs, default_kwargs, cls):
+def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]):
# default from class
- model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items()}
+ model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list}
# from our defaults
model_kwargs.update(default_kwargs)
# from user defaults
@@ -93,10 +83,14 @@ def get_llm_gpt4all(model_name,
repetition_penalty=1.0,
top_k=40,
top_p=0.7,
- verbose=False):
+ streaming=False,
+ callbacks=None,
+ prompter=None,
+ verbose=False,
+ ):
+ assert prompter is not None
env_gpt4all_file = ".env_gpt4all"
env_kwargs = dotenv_values(env_gpt4all_file)
- callbacks = [H2OStreamingStdOutCallbackHandler()]
n_ctx = env_kwargs.pop('n_ctx', 2048 - max_new_tokens)
default_kwargs = dict(context_erase=0.5,
n_batch=1,
@@ -113,21 +107,23 @@ def get_llm_gpt4all(model_name,
if model_name == 'llama':
cls = H2OLlamaCpp
model_path = env_kwargs.pop('model_path_llama') if model is None else model
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
- model_kwargs.update(dict(model_path=model_path, callbacks=callbacks))
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
+ model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, prompter=prompter))
llm = cls(**model_kwargs)
llm.client.verbose = verbose
elif model_name == 'gpt4all_llama':
cls = H2OGPT4All
model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
- model_kwargs.update(dict(model=model_path, backend='llama', callbacks=callbacks))
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
+ model_kwargs.update(
+ dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, prompter=prompter))
llm = cls(**model_kwargs)
elif model_name == 'gptj':
cls = H2OGPT4All
model_path = env_kwargs.pop('model_path_gptj') if model is None else model
- model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls)
- model_kwargs.update(dict(model=model_path, backend='gptj', callbacks=callbacks))
+ model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs'])
+ model_kwargs.update(
+ dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, prompter=prompter))
llm = cls(**model_kwargs)
else:
raise RuntimeError("No such model_name %s" % model_name)
@@ -136,6 +132,7 @@ def get_llm_gpt4all(model_name,
class H2OGPT4All(gpt4all.GPT4All):
model: Any
+ prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
@@ -155,9 +152,16 @@ class H2OGPT4All(gpt4all.GPT4All):
model_type=values["backend"],
allow_download=False,
)
+ if values["n_threads"] is not None:
+ # set n_threads
+ values["client"].model.set_thread_count(values["n_threads"])
else:
values["client"] = values["model"]
- values["backend"] = values["client"].model.model_type
+ try:
+ values["backend"] = values["client"].model_type
+ except AttributeError:
+ # The below is for compatibility with GPT4All Python bindings <= 0.2.3.
+ values["backend"] = values["client"].model.model_type
except ImportError:
raise ValueError(
@@ -171,12 +175,19 @@ class H2OGPT4All(gpt4all.GPT4All):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs,
) -> str:
# Roughly 4 chars per token if natural language
prompt = prompt[-self.n_ctx * 4:]
+
+ # use instruct prompting
+ data_point = dict(context='', instruction=prompt, input='')
+ prompt = self.prompter.generate_prompt(data_point)
+
verbose = False
if verbose:
print("_call prompt: %s" % prompt, flush=True)
+ # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout
return super()._call(prompt, stop=stop, run_manager=run_manager)
@@ -185,6 +196,7 @@ from langchain.llms import LlamaCpp
class H2OLlamaCpp(LlamaCpp):
model_path: Any
+ prompter: Any
"""Path to the pre-trained GPT4All model file."""
@root_validator()
@@ -236,9 +248,12 @@ class H2OLlamaCpp(LlamaCpp):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs,
) -> str:
verbose = False
# tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate
+ # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal
+ prompt = prompt[-self.n_ctx * 4:]
prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8"))
num_prompt_tokens = len(prompt_tokens)
if num_prompt_tokens > self.n_ctx:
@@ -250,6 +265,33 @@ class H2OLlamaCpp(LlamaCpp):
prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8"))
num_prompt_tokens2 = len(prompt_tokens2)
print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True)
+
+ # use instruct prompting
+ data_point = dict(context='', instruction=prompt, input='')
+ prompt = self.prompter.generate_prompt(data_point)
+
if verbose:
print("_call prompt: %s" % prompt, flush=True)
- return super()._call(prompt, stop=stop, run_manager=run_manager)
+
+ if self.streaming:
+ text_callback = None
+ if run_manager:
+ text_callback = partial(
+ run_manager.on_llm_new_token, verbose=self.verbose
+ )
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
+ if text_callback:
+ text_callback(prompt)
+ text = ""
+ for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager):
+ text_chunk = token["choices"][0]["text"]
+ # self.stream already calls text_callback
+ # if text_callback:
+ # text_callback(text_chunk)
+ text += text_chunk
+ return text
+ else:
+ params = self._get_parameters(stop)
+ params = {**params, **kwargs}
+ result = self.client(prompt=prompt, **params)
+ return result["choices"][0]["text"]
diff --git a/gpt_langchain.py b/gpt_langchain.py
index c35ff897b6efc3197a362abc1e3b41835d6e0407..b9fcc947fcc4816eeb9a9eab96c92070128360a3 100644
--- a/gpt_langchain.py
+++ b/gpt_langchain.py
@@ -1,27 +1,34 @@
+import ast
import glob
import inspect
import os
import pathlib
import pickle
-import queue
import shutil
import subprocess
-import sys
import tempfile
+import time
import traceback
+import types
import uuid
import zipfile
from collections import defaultdict
from datetime import datetime
from functools import reduce
from operator import concat
+import filelock
-from joblib import Parallel, delayed
+from joblib import delayed
+from langchain.callbacks import streaming_stdout
+from langchain.embeddings import HuggingFaceInstructEmbeddings
from tqdm import tqdm
-from prompter import non_hf_types
+from enums import DocumentChoices, no_lora_str, model_token_mapping, source_prefix, source_postfix
+from generate import gen_hyper, get_model, SEED
+from prompter import non_hf_types, PromptType, Prompter
from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \
- get_device, ProgressParallel, remove, hash_file
+ get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer
+from utils_langchain import StreamingGradioCallbackHandler
import_matplotlib()
@@ -36,19 +43,22 @@ from langchain.chains.qa_with_sources import load_qa_with_sources_chain
from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \
UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \
EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \
- UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader
-from langchain.text_splitter import RecursiveCharacterTextSplitter
+ UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader
+from langchain.text_splitter import RecursiveCharacterTextSplitter, Language
from langchain.chains.question_answering import load_qa_chain
from langchain.docstore.document import Document
-from langchain import PromptTemplate
+from langchain import PromptTemplate, HuggingFaceTextGenInference
from langchain.vectorstores import Chroma
-def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directory="db_dir", langchain_mode='notset',
+def get_db(sources, use_openai_embedding=False, db_type='faiss',
+ persist_directory="db_dir", load_db_if_exists=True,
+ langchain_mode='notset',
collection_name=None,
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"):
if not sources:
return None
+
# get embedding model
embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
assert collection_name is not None or langchain_mode != 'notset'
@@ -59,29 +69,41 @@ def get_db(sources, use_openai_embedding=False, db_type='faiss', persist_directo
if db_type == 'faiss':
from langchain.vectorstores import FAISS
db = FAISS.from_documents(sources, embedding)
-
elif db_type == 'weaviate':
import weaviate
from weaviate.embedded import EmbeddedOptions
from langchain.vectorstores import Weaviate
- # TODO: add support for connecting via docker compose
- client = weaviate.Client(
- embedded_options=EmbeddedOptions()
- )
+ if os.getenv('WEAVIATE_URL', None):
+ client = _create_local_weaviate_client()
+ else:
+ client = weaviate.Client(
+ embedded_options=EmbeddedOptions()
+ )
index_name = collection_name.capitalize()
db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False,
index_name=index_name)
-
elif db_type == 'chroma':
assert persist_directory is not None
os.makedirs(persist_directory, exist_ok=True)
- db = Chroma.from_documents(documents=sources,
- embedding=embedding,
- persist_directory=persist_directory,
- collection_name=collection_name,
- anonymized_telemetry=False)
- db.persist()
+
+ # see if already actually have persistent db, and deal with possible changes in embedding
+ db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
+ hf_embedding_model, verbose=False)
+ if db is None:
+ db = Chroma.from_documents(documents=sources,
+ embedding=embedding,
+ persist_directory=persist_directory,
+ collection_name=collection_name,
+ anonymized_telemetry=False)
+ db.persist()
+ clear_embedding(db)
+ save_embed(db, use_openai_embedding, hf_embedding_model)
+ else:
+ # then just add
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
+ use_openai_embedding=use_openai_embedding,
+ hf_embedding_model=hf_embedding_model)
else:
raise RuntimeError("No such db_type=%s" % db_type)
@@ -104,7 +126,10 @@ def _get_unique_sources_in_weaviate(db):
def add_to_db(db, sources, db_type='faiss',
avoid_dup_by_file=False,
- avoid_dup_by_content=True):
+ avoid_dup_by_content=True,
+ use_openai_embedding=False,
+ hf_embedding_model=None):
+ assert hf_embedding_model is not None
num_new_sources = len(sources)
if not sources:
return db, num_new_sources, []
@@ -120,7 +145,7 @@ def add_to_db(db, sources, db_type='faiss',
return db, num_new_sources, []
db.add_documents(documents=sources)
elif db_type == 'chroma':
- collection = db.get()
+ collection = get_documents(db)
# files we already have:
metadata_files = set([x['source'] for x in collection['metadatas']])
if avoid_dup_by_file:
@@ -135,11 +160,15 @@ def add_to_db(db, sources, db_type='faiss',
[x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]])
# avoid sources with same hash
sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids]
+ num_nohash = len([x for x in sources if not x.metadata.get('hashid')])
+ print("Found %s new sources (%d have no hash in original source,"
+ " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True)
# get new file names that match existing file names. delete existing files we are overridding
dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files])
print("Removing %s duplicate files from db because ingesting those as new documents" % len(
dup_metadata_files), flush=True)
- client_collection = db._client.get_collection(name=db._collection.name)
+ client_collection = db._client.get_collection(name=db._collection.name,
+ embedding_function=db._collection._embedding_function)
for dup_file in dup_metadata_files:
dup_file_meta = dict(source=dup_file)
try:
@@ -151,6 +180,8 @@ def add_to_db(db, sources, db_type='faiss',
return db, num_new_sources, []
db.add_documents(documents=sources)
db.persist()
+ clear_embedding(db)
+ save_embed(db, use_openai_embedding, hf_embedding_model)
else:
raise RuntimeError("No such db_type=%s" % db_type)
@@ -165,10 +196,13 @@ def create_or_update_db(db_type, persist_directory, collection_name,
import weaviate
from weaviate.embedded import EmbeddedOptions
- # TODO: add support for connecting via docker compose
- client = weaviate.Client(
- embedded_options=EmbeddedOptions()
- )
+ if os.getenv('WEAVIATE_URL', None):
+ client = _create_local_weaviate_client()
+ else:
+ client = weaviate.Client(
+ embedded_options=EmbeddedOptions()
+ )
+
index_name = collection_name.replace(' ', '_').capitalize()
if client.schema.exists(index_name) and not add_if_exists:
client.schema.delete_class(index_name)
@@ -205,14 +239,20 @@ def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformer
if use_openai_embedding:
assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY"
from langchain.embeddings import OpenAIEmbeddings
- embedding = OpenAIEmbeddings()
+ embedding = OpenAIEmbeddings(disallowed_special=())
else:
# to ensure can fork without deadlock
from langchain.embeddings import HuggingFaceEmbeddings
device, torch_dtype, context_class = get_device_dtype()
model_kwargs = dict(device=device)
- embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
+ if 'instructor' in hf_embedding_model:
+ encode_kwargs = {'normalize_embeddings': True}
+ embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model,
+ model_kwargs=model_kwargs,
+ encode_kwargs=encode_kwargs)
+ else:
+ embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs)
return embedding
@@ -226,63 +266,481 @@ def get_answer_from_sources(chain, sources, question):
)["output_text"]
-def get_llm(use_openai_model=False, model_name=None, model=None,
- tokenizer=None, stream_output=False,
- max_new_tokens=256,
+"""Wrapper around Huggingface text generation inference API."""
+from functools import partial
+from typing import Any, Dict, List, Optional, Set
+
+from pydantic import Extra, Field, root_validator
+
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+
+"""Wrapper around Huggingface text generation inference API."""
+from functools import partial
+from typing import Any, Dict, List, Optional
+
+from pydantic import Extra, Field, root_validator
+
+from langchain.callbacks.manager import CallbackManagerForLLMRun
+from langchain.llms.base import LLM
+
+
+class GradioInference(LLM):
+ """
+ Gradio generation inference API.
+ """
+ inference_server_url: str = ""
+
+ temperature: float = 0.8
+ top_p: Optional[float] = 0.95
+ top_k: Optional[int] = None
+ num_beams: Optional[int] = 1
+ max_new_tokens: int = 512
+ min_new_tokens: int = 1
+ early_stopping: bool = False
+ max_time: int = 180
+ repetition_penalty: Optional[float] = None
+ num_return_sequences: Optional[int] = 1
+ do_sample: bool = False
+ chat_client: bool = False
+
+ return_full_text: bool = True
+ stream: bool = False
+ sanitize_bot_response: bool = False
+
+ prompter: Any = None
+ client: Any = None
+
+ class Config:
+ """Configuration for this pydantic object."""
+
+ extra = Extra.forbid
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ if values['client'] is None:
+ import gradio_client
+ values["client"] = gradio_client.Client(
+ values["inference_server_url"]
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import gradio_client python package. "
+ "Please install it with `pip install gradio_client`."
+ )
+ return values
+
+ @property
+ def _llm_type(self) -> str:
+ """Return type of llm."""
+ return "gradio_inference"
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection,
+ # so server should get prompt_type or '', not plain
+ # This is good, so gradio server can also handle stopping.py conditions
+ # this is different than TGI server that uses prompter to inject prompt_type prompting
+ stream_output = self.stream
+ gr_client = self.client
+ client_langchain_mode = 'Disabled'
+ top_k_docs = 1
+ chunk = True
+ chunk_size = 512
+ client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True
+ iinput='', # only for chat=True
+ context='',
+ # streaming output is supported, loops over and outputs each generation in streaming mode
+ # but leave stream_output=False for simple input/output mode
+ stream_output=stream_output,
+ prompt_type=self.prompter.prompt_type,
+ prompt_dict='',
+
+ temperature=self.temperature,
+ top_p=self.top_p,
+ top_k=self.top_k,
+ num_beams=self.num_beams,
+ max_new_tokens=self.max_new_tokens,
+ min_new_tokens=self.min_new_tokens,
+ early_stopping=self.early_stopping,
+ max_time=self.max_time,
+ repetition_penalty=self.repetition_penalty,
+ num_return_sequences=self.num_return_sequences,
+ do_sample=self.do_sample,
+ chat=self.chat_client,
+
+ instruction_nochat=prompt if not self.chat_client else '',
+ iinput_nochat='', # only for chat=False
+ langchain_mode=client_langchain_mode,
+ top_k_docs=top_k_docs,
+ chunk=chunk,
+ chunk_size=chunk_size,
+ document_choice=[DocumentChoices.All_Relevant.name],
+ )
+ api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing
+ if not stream_output:
+ res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name)
+ res_dict = ast.literal_eval(res)
+ text = res_dict['response']
+ return self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ else:
+ text_callback = None
+ if run_manager:
+ text_callback = partial(
+ run_manager.on_llm_new_token, verbose=self.verbose
+ )
+
+ job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name)
+ text0 = ''
+ while not job.done():
+ outputs_list = job.communicator.job.outputs
+ if outputs_list:
+ res = job.communicator.job.outputs[-1]
+ res_dict = ast.literal_eval(res)
+ text = res_dict['response']
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ # FIXME: derive chunk from full for now
+ text_chunk = text[len(text0):]
+ # save old
+ text0 = text
+
+ if text_callback:
+ text_callback(text_chunk)
+
+ time.sleep(0.01)
+
+ # ensure get last output to avoid race
+ res_all = job.outputs()
+ if len(res_all) > 0:
+ res = res_all[-1]
+ res_dict = ast.literal_eval(res)
+ text = res_dict['response']
+ # FIXME: derive chunk from full for now
+ else:
+ # go with old if failure
+ text = text0
+ text_chunk = text[len(text0):]
+ if text_callback:
+ text_callback(text_chunk)
+ return self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+
+
+class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference):
+ max_new_tokens: int = 512
+ do_sample: bool = False
+ top_k: Optional[int] = None
+ top_p: Optional[float] = 0.95
+ typical_p: Optional[float] = 0.95
+ temperature: float = 0.8
+ repetition_penalty: Optional[float] = None
+ return_full_text: bool = False
+ stop_sequences: List[str] = Field(default_factory=list)
+ seed: Optional[int] = None
+ inference_server_url: str = ""
+ timeout: int = 300
+ headers: dict = None
+ stream: bool = False
+ sanitize_bot_response: bool = False
+ prompter: Any = None
+ tokenizer: Any = None
+ client: Any = None
+
+ @root_validator()
+ def validate_environment(cls, values: Dict) -> Dict:
+ """Validate that python package exists in environment."""
+
+ try:
+ if values['client'] is None:
+ import text_generation
+
+ values["client"] = text_generation.Client(
+ values["inference_server_url"],
+ timeout=values["timeout"],
+ headers=values["headers"],
+ )
+ except ImportError:
+ raise ImportError(
+ "Could not import text_generation python package. "
+ "Please install it with `pip install text_generation`."
+ )
+ return values
+
+ def _call(
+ self,
+ prompt: str,
+ stop: Optional[List[str]] = None,
+ run_manager: Optional[CallbackManagerForLLMRun] = None,
+ **kwargs: Any,
+ ) -> str:
+ if stop is None:
+ stop = self.stop_sequences
+ else:
+ stop += self.stop_sequences
+
+ # HF inference server needs control over input tokens
+ assert self.tokenizer is not None
+ from h2oai_pipeline import H2OTextGenerationPipeline
+ prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer)
+
+ # NOTE: TGI server does not add prompting, so must do here
+ data_point = dict(context='', instruction=prompt, input='')
+ prompt = self.prompter.generate_prompt(data_point)
+
+ gen_server_kwargs = dict(do_sample=self.do_sample,
+ stop_sequences=stop,
+ max_new_tokens=self.max_new_tokens,
+ top_k=self.top_k,
+ top_p=self.top_p,
+ typical_p=self.typical_p,
+ temperature=self.temperature,
+ repetition_penalty=self.repetition_penalty,
+ return_full_text=self.return_full_text,
+ seed=self.seed,
+ )
+ gen_server_kwargs.update(kwargs)
+
+ # lower bound because client is re-used if multi-threading
+ self.client.timeout = max(300, self.timeout)
+
+ if not self.stream:
+ res = self.client.generate(
+ prompt,
+ **gen_server_kwargs,
+ )
+ if self.return_full_text:
+ gen_text = res.generated_text[len(prompt):]
+ else:
+ gen_text = res.generated_text
+ # remove stop sequences from the end of the generated text
+ for stop_seq in stop:
+ if stop_seq in gen_text:
+ gen_text = gen_text[:gen_text.index(stop_seq)]
+ text = prompt + gen_text
+ text = self.prompter.get_response(text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ else:
+ text_callback = None
+ if run_manager:
+ text_callback = partial(
+ run_manager.on_llm_new_token, verbose=self.verbose
+ )
+ # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter
+ if text_callback:
+ text_callback(prompt)
+ text = ""
+ # Note: Streaming ignores return_full_text=True
+ for response in self.client.generate_stream(prompt, **gen_server_kwargs):
+ text_chunk = response.token.text
+ text += text_chunk
+ text = self.prompter.get_response(prompt + text, prompt=prompt,
+ sanitize_bot_response=self.sanitize_bot_response)
+ # stream part
+ is_stop = False
+ for stop_seq in stop:
+ if stop_seq in response.token.text:
+ is_stop = True
+ break
+ if is_stop:
+ break
+ if not response.token.special:
+ if text_callback:
+ text_callback(response.token.text)
+ return text
+
+
+from langchain.chat_models import ChatOpenAI
+
+
+class H2OChatOpenAI(ChatOpenAI):
+ @classmethod
+ def all_required_field_names(cls) -> Set:
+ all_required_field_names = super(ChatOpenAI, cls).all_required_field_names()
+ all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty'})
+ return all_required_field_names
+
+
+def get_llm(use_openai_model=False,
+ model_name=None,
+ model=None,
+ tokenizer=None,
+ inference_server=None,
+ stream_output=False,
+ do_sample=False,
temperature=0.1,
- repetition_penalty=1.0,
top_k=40,
top_p=0.7,
+ num_beams=1,
+ max_new_tokens=256,
+ min_new_tokens=1,
+ early_stopping=False,
+ max_time=180,
+ repetition_penalty=1.0,
+ num_return_sequences=1,
prompt_type=None,
+ prompt_dict=None,
prompter=None,
+ sanitize_bot_response=False,
verbose=False,
):
- if use_openai_model:
- from langchain.llms import OpenAI
- llm = OpenAI(temperature=0)
- model_name = 'openai'
- streamer = None
- prompt_type = 'plain'
+ if use_openai_model or inference_server in ['openai', 'openai_chat']:
+ if use_openai_model and model_name is None:
+ model_name = "gpt-3.5-turbo"
+ if inference_server == 'openai':
+ from langchain.llms import OpenAI
+ cls = OpenAI
+ else:
+ cls = H2OChatOpenAI
+ callbacks = [StreamingGradioCallbackHandler()]
+ llm = cls(model_name=model_name,
+ temperature=temperature if do_sample else 0,
+ # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py
+ max_tokens=max_new_tokens,
+ top_p=top_p if do_sample else 1,
+ frequency_penalty=0,
+ presence_penalty=1.07 - repetition_penalty + 0.6, # so good default
+ callbacks=callbacks if stream_output else None,
+ )
+ streamer = callbacks[0] if stream_output else None
+ if inference_server in ['openai', 'openai_chat']:
+ prompt_type = inference_server
+ else:
+ prompt_type = prompt_type or 'plain'
+ elif inference_server:
+ assert inference_server.startswith(
+ 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server
+
+ from gradio_utils.grclient import GradioClient
+ from text_generation import Client as HFClient
+ if isinstance(model, GradioClient):
+ gr_client = model
+ hf_client = None
+ else:
+ gr_client = None
+ hf_client = model
+ assert isinstance(hf_client, HFClient)
+
+ inference_server, headers = get_hf_server(inference_server)
+
+ # quick sanity check to avoid long timeouts, just see if can reach server
+ requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10')))
+
+ callbacks = [StreamingGradioCallbackHandler()]
+ assert prompter is not None
+ stop_sequences = list(set(prompter.terminate_response + [prompter.PreResponse]))
+
+ if gr_client:
+ chat_client = False
+ llm = GradioInference(
+ inference_server_url=inference_server,
+ return_full_text=True,
+
+ temperature=temperature,
+ top_p=top_p,
+ top_k=top_k,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ early_stopping=early_stopping,
+ max_time=max_time,
+ repetition_penalty=repetition_penalty,
+ num_return_sequences=num_return_sequences,
+ do_sample=do_sample,
+ chat_client=chat_client,
+
+ callbacks=callbacks if stream_output else None,
+ stream=stream_output,
+ prompter=prompter,
+ client=gr_client,
+ sanitize_bot_response=sanitize_bot_response,
+ )
+ elif hf_client:
+ llm = H2OHuggingFaceTextGenInference(
+ inference_server_url=inference_server,
+ do_sample=do_sample,
+ max_new_tokens=max_new_tokens,
+ repetition_penalty=repetition_penalty,
+ return_full_text=True,
+ seed=SEED,
+
+ stop_sequences=stop_sequences,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ # typical_p=top_p,
+ callbacks=callbacks if stream_output else None,
+ stream=stream_output,
+ prompter=prompter,
+ tokenizer=tokenizer,
+ client=hf_client,
+ timeout=max_time,
+ sanitize_bot_response=sanitize_bot_response,
+ )
+ else:
+ raise RuntimeError("No defined client")
+ streamer = callbacks[0] if stream_output else None
elif model_name in non_hf_types:
+ if model_name == 'llama':
+ callbacks = [StreamingGradioCallbackHandler()]
+ streamer = callbacks[0] if stream_output else None
+ else:
+ # stream_output = False
+ # doesn't stream properly as generator, but at least
+ callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()]
+ streamer = None
+ if prompter:
+ prompt_type = prompter.prompt_type
+ else:
+ prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output)
+ pass # assume inputted prompt_type is correct
from gpt4all_llm import get_llm_gpt4all
llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens,
temperature=temperature,
repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
+ callbacks=callbacks,
verbose=verbose,
+ streaming=stream_output,
+ prompter=prompter,
)
- streamer = None
- prompt_type = 'plain'
else:
- from transformers import AutoTokenizer, AutoModelForCausalLM
-
if model is None:
# only used if didn't pass model in
- assert model_name is None
assert tokenizer is None
prompt_type = 'human_bot'
- model_name = 'h2oai/h2ogpt-oasst1-512-12b'
- # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
- # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
- tokenizer = AutoTokenizer.from_pretrained(model_name)
- device, torch_dtype, context_class = get_device_dtype()
-
- with context_class(device):
- load_8bit = True
- # FIXME: for now not to spread across hetero GPUs
- # device_map={"": 0} if load_8bit and device == 'cuda' else "auto"
- device_map = {"": 0} if device == 'cuda' else "auto"
- model = AutoModelForCausalLM.from_pretrained(model_name,
- device_map=device_map,
- torch_dtype=torch_dtype,
- load_in_8bit=load_8bit)
+ if model_name is None:
+ model_name = 'h2oai/h2ogpt-oasst1-512-12b'
+ # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b'
+ # model_name = 'h2oai/h2ogpt-oasst1-512-20b'
+ inference_server = ''
+ model, tokenizer, device = get_model(load_8bit=True, base_model=model_name,
+ inference_server=inference_server, gpu_id=0)
max_max_tokens = tokenizer.model_max_length
- gen_kwargs = dict(max_new_tokens=max_new_tokens,
+ gen_kwargs = dict(do_sample=do_sample,
+ temperature=temperature,
+ top_k=top_k,
+ top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ early_stopping=early_stopping,
+ max_time=max_time,
+ repetition_penalty=repetition_penalty,
+ num_return_sequences=num_return_sequences,
return_full_text=True,
- early_stopping=False,
- handle_long_generation='hole')
+ handle_long_generation=None)
+ assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0
if stream_output:
skip_prompt = False
@@ -297,10 +755,12 @@ def get_llm(use_openai_model=False, model_name=None, model=None,
pipe = H2OTextGenerationPipeline(model=model, use_prompter=True,
prompter=prompter,
prompt_type=prompt_type,
- sanitize_bot_response=True,
+ prompt_dict=prompt_dict,
+ sanitize_bot_response=sanitize_bot_response,
chat=False, stream_output=stream_output,
tokenizer=tokenizer,
- max_input_tokens=max_max_tokens - max_new_tokens,
+ # leave some room for 1 paragraph, even if min_new_tokens=0
+ max_input_tokens=max_max_tokens - max(min_new_tokens, 256),
**gen_kwargs)
# pipe.task = "text-generation"
# below makes it listen only to our prompt removal,
@@ -345,7 +805,7 @@ def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True):
data = json.load(open(filename, "rt"))
page_content = list(data["query"]["pages"].values())[0]["extract"]
if take_head is not None and text_limit is not None:
- page_content = page_content[:text_limit] if take_head else page_content[:-text_limit]
+ page_content = page_content[:text_limit] if take_head else page_content[-text_limit:]
title_url = str(title).replace(' ', '_')
return Document(
page_content=page_content,
@@ -467,6 +927,21 @@ try:
except (pkg_resources.DistributionNotFound, AssertionError):
have_pymupdf = False
+try:
+ assert pkg_resources.get_distribution('selenium') is not None
+ have_selenium = True
+except (pkg_resources.DistributionNotFound, AssertionError):
+ have_selenium = False
+
+try:
+ assert pkg_resources.get_distribution('playwright') is not None
+ have_playwright = True
+except (pkg_resources.DistributionNotFound, AssertionError):
+ have_playwright = False
+
+# disable, hangs too often
+have_playwright = False
+
image_types = ["png", "jpg", "jpeg"]
non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf",
"md", "html",
@@ -484,12 +959,13 @@ file_types = non_image_types + image_types
def add_meta(docs1, file):
file_extension = pathlib.Path(file).suffix
hashid = hash_file(file)
- if not isinstance(docs1, list):
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
docs1 = [docs1]
[x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1]
-def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, chunk=True, chunk_size=512,
+def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False,
+ chunk=True, chunk_size=512,
is_url=False, is_txt=False,
enable_captions=True,
captions_model=None,
@@ -525,9 +1001,25 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
else:
docs1 = []
else:
+ if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")):
+ file = 'http://' + file
docs1 = UnstructuredURLLoader(urls=[file]).load()
+ if len(docs1) == 0 and have_playwright:
+ # then something went wrong, try another loader:
+ from langchain.document_loaders import PlaywrightURLLoader
+ docs1 = PlaywrightURLLoader(urls=[file]).load()
+ if len(docs1) == 0 and have_selenium:
+ # then something went wrong, try another loader:
+ # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary
+ from langchain.document_loaders import SeleniumURLLoader
+ from selenium.common.exceptions import WebDriverException
+ try:
+ docs1 = SeleniumURLLoader(urls=[file]).load()
+ except WebDriverException as e:
+ print("No web driver: %s" % str(e), flush=True)
[x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1]
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif is_txt:
base_path = "user_paste"
source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10])
@@ -536,44 +1028,49 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
f.write(file)
metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt')
doc1 = Document(page_content=file, metadata=metadata)
+ doc1 = clean_doc(doc1)
elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'):
docs1 = UnstructuredHTMLLoader(file_path=file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML)
elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice:
docs1 = UnstructuredWordDocumentLoader(file_path=file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.odt'):
docs1 = UnstructuredODTLoader(file_path=file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('pptx') or file.lower().endswith('ppt'):
docs1 = UnstructuredPowerPointLoader(file_path=file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.txt'):
# use UnstructuredFileLoader ?
docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load()
# makes just one, but big one
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
+ doc1 = clean_doc(doc1)
add_meta(doc1, file)
elif file.lower().endswith('.rtf'):
docs1 = UnstructuredRTFLoader(file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.md'):
docs1 = UnstructuredMarkdownLoader(file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ docs1 = clean_doc(docs1)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN)
elif file.lower().endswith('.enex'):
docs1 = EverNoteLoader(file).load()
add_meta(doc1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.epub'):
docs1 = UnstructuredEPubLoader(file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'):
docs1 = []
if have_tesseract and enable_ocr:
@@ -603,7 +1100,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
doci.metadata['source'] = doci.metadata['image_path']
doci.metadata['hash'] = hash_file(doci.metadata['source'])
if docs1:
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.msg'):
raise RuntimeError("Not supported, GPL3 license")
# docs1 = OutlookMessageLoader(file).load()
@@ -612,14 +1109,14 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
try:
docs1 = UnstructuredEmailLoader(file).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
except ValueError as e:
if 'text/html content not found in email' in str(e):
# e.g. plain/text dict key exists, but not
# doc1 = TextLoader(file, encoding="utf8").load()
docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
else:
raise
# elif file.lower().endswith('.gcsdir'):
@@ -630,6 +1127,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
with open(file, "r") as f:
doc1 = Document(page_content=f.read(), metadata={"source": file})
add_meta(doc1, file)
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST)
elif file.lower().endswith('.pdf'):
env_gpt4all_file = ".env_gpt4all"
from dotenv import dotenv_values
@@ -638,11 +1136,19 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
if have_pymupdf and pdf_class_name == 'PyMuPDFParser':
# GPL, only use if installed
from langchain.document_loaders import PyMuPDFLoader
- doc1 = PyMuPDFLoader(file).load_and_split()
+ # load() still chunks by pages, but every page has title at start to help
+ doc1 = PyMuPDFLoader(file).load()
+ doc1 = clean_doc(doc1)
+ elif pdf_class_name == 'UnstructuredPDFLoader':
+ doc1 = UnstructuredPDFLoader(file).load()
+ # seems to not need cleaning in most cases
else:
# open-source fallback
- doc1 = PyPDFLoader(file).load_and_split()
+ # load() still chunks by pages, but every page has title at start to help
+ doc1 = PyPDFLoader(file).load()
+ doc1 = clean_doc(doc1)
# Some PDFs return nothing or junk from PDFMinerLoader
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
add_meta(doc1, file)
elif file.lower().endswith('.csv'):
doc1 = CSVLoader(file).load()
@@ -650,6 +1156,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
elif file.lower().endswith('.py'):
doc1 = PythonLoader(file).load()
add_meta(doc1, file)
+ doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON)
elif file.lower().endswith('.toml'):
doc1 = TomlLoader(file).load()
add_meta(doc1, file)
@@ -657,7 +1164,7 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
with open(file, "r") as f:
docs1 = UnstructuredURLLoader(urls=f.readlines()).load()
add_meta(docs1, file)
- doc1 = chunk_sources(docs1, chunk_size=chunk_size)
+ doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size)
elif file.lower().endswith('.zip'):
with zipfile.ZipFile(file, 'r') as zip_ref:
# don't put into temporary path, since want to keep references to docs inside zip
@@ -672,12 +1179,12 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
# if list of length one, don't trust and chunk it
if not isinstance(doc1, list):
if chunk:
- docs = chunk_sources([doc1], chunk_size=chunk_size)
+ docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size)
else:
docs = [doc1]
elif isinstance(doc1, list) and len(doc1) == 1:
if chunk:
- docs = chunk_sources(doc1, chunk_size=chunk_size)
+ docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size)
else:
docs = doc1
else:
@@ -687,7 +1194,8 @@ def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, c
return docs
-def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, chunk=True, chunk_size=512,
+def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True,
+ chunk=True, chunk_size=512,
is_url=False, is_txt=False,
enable_captions=True,
captions_model=None,
@@ -739,15 +1247,16 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
existing_files=[],
existing_hash_ids={},
):
+ # path_or_paths could be str, list, tuple, generator
globs_image_types = []
globs_non_image_types = []
if not path_or_paths and not url and not text:
return []
elif url:
- globs_non_image_types = [url]
+ globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url]
elif text:
- globs_non_image_types = [text]
- elif isinstance(path_or_paths, str):
+ globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text]
+ elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths):
# single path, only consume allowed files
path = path_or_paths
# Below globs should match patterns in file_to_doc()
@@ -756,8 +1265,11 @@ def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=
[globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True))
for ftype in non_image_types]
else:
+ if isinstance(path_or_paths, str) and (os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths)):
+ path_or_paths = [path_or_paths]
# list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows)
- assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths)
+ assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), "Wrong type for path_or_paths: %s" % type(
+ path_or_paths)
# reform out of allowed types
globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types]))
# could do below:
@@ -861,12 +1373,12 @@ def prep_langchain(persist_directory,
if db_dir_exists and user_path is None:
print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True)
- db = get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
+ db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
hf_embedding_model)
else:
if db_dir_exists and user_path is not None:
print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % (
- persist_directory, user_path), flush=True)
+ persist_directory, user_path), flush=True)
elif not db_dir_exists:
print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True)
db = None
@@ -912,24 +1424,78 @@ class FakeConsumer(object):
posthog.Consumer = FakeConsumer
-def get_existing_db(persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
- hf_embedding_model):
+def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, langchain_mode):
+ changed_db = False
+ if load_embed(db) != (use_openai_embedding, hf_embedding_model):
+ print("Detected new embedding, updating db: %s" % langchain_mode, flush=True)
+ # handle embedding changes
+ db_get = get_documents(db)
+ sources = [Document(page_content=result[0], metadata=result[1] or {})
+ for result in zip(db_get['documents'], db_get['metadatas'])]
+ # delete index, has to be redone
+ persist_directory = db._persist_directory
+ shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak")
+ db_type = 'chroma'
+ load_db_if_exists = False
+ db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type,
+ persist_directory=persist_directory, load_db_if_exists=load_db_if_exists,
+ langchain_mode=langchain_mode,
+ collection_name=None,
+ hf_embedding_model=hf_embedding_model)
+ if False:
+ # below doesn't work if db already in memory, so have to switch to new db as above
+ # upsert does new embedding, but if index already in memory, complains about size mismatch etc.
+ client_collection = db._client.get_collection(name=db._collection.name,
+ embedding_function=db._collection._embedding_function)
+ client_collection.upsert(ids=db_get['ids'], metadatas=db_get['metadatas'], documents=db_get['documents'])
+ changed_db = True
+ print("Done updating db for new embedding: %s" % langchain_mode, flush=True)
+
+ return db, changed_db
+
+
+def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
+ hf_embedding_model, verbose=False, check_embedding=True):
if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
os.path.join(persist_directory, 'index')):
- print("DO Loading db: %s" % langchain_mode, flush=True)
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
- from chromadb.config import Settings
- client_settings = Settings(anonymized_telemetry=False,
- chroma_db_impl="duckdb+parquet",
- persist_directory=persist_directory)
- db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
- collection_name=langchain_mode.replace(' ', '_'),
- client_settings=client_settings)
- print("DONE Loading db: %s" % langchain_mode, flush=True)
+ if db is None:
+ if verbose:
+ print("DO Loading db: %s" % langchain_mode, flush=True)
+ embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
+ from chromadb.config import Settings
+ client_settings = Settings(anonymized_telemetry=False,
+ chroma_db_impl="duckdb+parquet",
+ persist_directory=persist_directory)
+ db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
+ collection_name=langchain_mode.replace(' ', '_'),
+ client_settings=client_settings)
+ if verbose:
+ print("DONE Loading db: %s" % langchain_mode, flush=True)
+ else:
+ if verbose:
+ print("USING already-loaded db: %s" % langchain_mode, flush=True)
+ if check_embedding:
+ db_trial, changed_db = check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model,
+ langchain_mode)
+ if changed_db:
+ db = db_trial
+ # only call persist if really changed db, else takes too long for large db
+ if db is not None:
+ db.persist()
+ clear_embedding(db)
+ save_embed(db, use_openai_embedding, hf_embedding_model)
return db
return None
+def clear_embedding(db):
+ if db is None:
+ return
+ # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed
+ db._embedding_function.client.cpu()
+ clear_torch_cache()
+
+
def make_db(**langchain_kwargs):
func_names = list(inspect.signature(_make_db).parameters)
missing_kwargs = [x for x in func_names if x not in langchain_kwargs]
@@ -945,9 +1511,33 @@ def make_db(**langchain_kwargs):
return _make_db(**langchain_kwargs)
+def save_embed(db, use_openai_embedding, hf_embedding_model):
+ if db is not None:
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
+ with open(embed_info_file, 'wb') as f:
+ pickle.dump((use_openai_embedding, hf_embedding_model), f)
+ return use_openai_embedding, hf_embedding_model
+
+
+def load_embed(db):
+ embed_info_file = os.path.join(db._persist_directory, 'embed_info')
+ if os.path.isfile(embed_info_file):
+ with open(embed_info_file, 'rb') as f:
+ use_openai_embedding, hf_embedding_model = pickle.load(f)
+ else:
+ # migration, assume defaults
+ use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2"
+ return use_openai_embedding, hf_embedding_model
+
+
+def get_persist_directory(langchain_mode):
+ return 'db_dir_%s' % langchain_mode # single place, no special names for each case
+
+
def _make_db(use_openai_embedding=False,
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
- first_para=False, text_limit=None, chunk=False, chunk_size=1024,
+ first_para=False, text_limit=None,
+ chunk=True, chunk_size=512,
langchain_mode=None,
user_path=None,
db_type='faiss',
@@ -955,19 +1545,13 @@ def _make_db(use_openai_embedding=False,
db=None,
n_jobs=-1,
verbose=False):
- persist_directory = 'db_dir_%s' % langchain_mode # single place, no special names for each case
- if not db and load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir(
- os.path.join(persist_directory, 'index')):
- assert langchain_mode not in ['MyData'], "Should not load MyData db this way"
- print("Loading existing db", flush=True)
- embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model)
- from chromadb.config import Settings
- client_settings = Settings(anonymized_telemetry=False,
- chroma_db_impl="duckdb+parquet",
- persist_directory=persist_directory)
- db = Chroma(persist_directory=persist_directory, embedding_function=embedding,
- collection_name=langchain_mode.replace(' ', '_'),
- client_settings=client_settings)
+ persist_directory = get_persist_directory(langchain_mode)
+ # see if can get persistent chroma db
+ db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode,
+ hf_embedding_model, verbose=verbose)
+ if db_trial is not None:
+ db = db_trial
+
sources = []
if not db and langchain_mode not in ['MyData'] or \
user_path is not None and \
@@ -992,24 +1576,24 @@ def _make_db(use_openai_embedding=False,
sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2)
print("Got new wiki", flush=True)
if chunk:
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
print("Chunked new wiki", flush=True)
sources.extend(sources1)
if langchain_mode in ['wiki', 'All', "'All'"]:
sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit)
if chunk:
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
sources.extend(sources1)
if langchain_mode in ['github h2oGPT', 'All', "'All'"]:
# sources = get_github_docs("dagster-io", "dagster")
sources1 = get_github_docs("h2oai", "h2ogpt")
# FIXME: always chunk for now
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
sources.extend(sources1)
if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]:
sources1 = get_dai_docs(from_hf=True)
if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit
- sources1 = chunk_sources(sources1, chunk_size=chunk_size)
+ sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size)
sources.extend(sources1)
if langchain_mode in ['All', 'UserData']:
if user_path:
@@ -1023,6 +1607,8 @@ def _make_db(use_openai_embedding=False,
existing_files = []
existing_hash_ids = []
# chunk internally for speed over multiple docs
+ # FIXME: If first had old Hash=None and switch embeddings,
+ # then re-embed, and then hit here and reload so have hash, and then re-embed.
sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size,
existing_files=existing_files, existing_hash_ids=existing_hash_ids)
new_metadata_sources = set([x.metadata['source'] for x in sources1])
@@ -1066,7 +1652,9 @@ def _make_db(use_openai_embedding=False,
new_sources_metadata = [x.metadata for x in sources]
elif user_path is not None and langchain_mode in ['UserData']:
print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True)
- db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type)
+ db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type,
+ use_openai_embedding=use_openai_embedding,
+ hf_embedding_model=hf_embedding_model)
print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True)
else:
new_sources_metadata = [x.metadata for x in sources]
@@ -1074,63 +1662,140 @@ def _make_db(use_openai_embedding=False,
return db, len(new_sources_metadata), new_sources_metadata
+def get_metadatas(db):
+ from langchain.vectorstores import FAISS
+ if isinstance(db, FAISS):
+ metadatas = [v.metadata for k, v in db.docstore._dict.items()]
+ elif isinstance(db, Chroma):
+ metadatas = get_documents(db)['metadatas']
+ else:
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
+ metadatas = [x.metadata for x in db.similarity_search("", k=10000)]
+ return metadatas
+
+
+def get_documents(db):
+ if hasattr(db, '_persist_directory'):
+ name_path = os.path.basename(db._persist_directory)
+ base_path = 'locks'
+ makedirs(base_path)
+ with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
+ # get segfaults and other errors when multiple threads access this
+ return _get_documents(db)
+ else:
+ return _get_documents(db)
+
+
+def _get_documents(db):
+ from langchain.vectorstores import FAISS
+ if isinstance(db, FAISS):
+ documents = [v for k, v in db.docstore._dict.items()]
+ elif isinstance(db, Chroma):
+ documents = db.get()
+ else:
+ # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947
+ # seems no way to get all metadata, so need to avoid this approach for weaviate
+ documents = [x for x in db.similarity_search("", k=10000)]
+ return documents
+
+
+def get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
+ if hasattr(db, '_persist_directory'):
+ name_path = os.path.basename(db._persist_directory)
+ base_path = 'locks'
+ makedirs(base_path)
+ with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)):
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
+ else:
+ return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
+
+
+def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}):
+ from langchain.vectorstores import FAISS
+ if isinstance(db, Chroma):
+ db_get = db._collection.get(where=filter_kwargs.get('filter'))
+ db_metadatas = db_get['metadatas']
+ db_documents = db_get['documents']
+ elif isinstance(db, FAISS):
+ import itertools
+ db_metadatas = get_metadatas(db)
+ # FIXME: FAISS has no filter
+ # slice dict first
+ db_documents = list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values())
+ else:
+ db_metadatas = get_metadatas(db)
+ db_documents = get_documents(db)
+ return db_documents, db_metadatas
+
+
def get_existing_files(db):
- collection = db.get()
- metadata_sources = set([x['source'] for x in collection['metadatas']])
+ metadatas = get_metadatas(db)
+ metadata_sources = set([x['source'] for x in metadatas])
return metadata_sources
def get_existing_hash_ids(db):
- collection = db.get()
+ metadatas = get_metadatas(db)
# assume consistency, that any prior hashed source was single hashed file at the time among all source chunks
- metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']}
+ metadata_hash_ids = {x['source']: x.get('hashid') for x in metadatas}
return metadata_hash_ids
-source_prefix = "Sources [Score | Link]:"
-source_postfix = "End Sources
"
-
-
def run_qa_db(**kwargs):
func_names = list(inspect.signature(_run_qa_db).parameters)
# hard-coded defaults
kwargs['answer_with_sources'] = True
- kwargs['sanitize_bot_response'] = True
kwargs['show_rank'] = False
missing_kwargs = [x for x in func_names if x not in kwargs]
assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs
# only keep actual used
kwargs = {k: v for k, v in kwargs.items() if k in func_names}
- return _run_qa_db(**kwargs)
+ try:
+ return _run_qa_db(**kwargs)
+ finally:
+ clear_torch_cache()
def _run_qa_db(query=None,
use_openai_model=False, use_openai_embedding=False,
- first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
user_path=None,
detect_user_path_changes_every_query=False,
db_type='faiss',
- model_name=None, model=None, tokenizer=None,
+ model_name=None, model=None, tokenizer=None, inference_server=None,
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
stream_output=False,
prompter=None,
prompt_type=None,
+ prompt_dict=None,
answer_with_sources=True,
cut_distanct=1.1,
- sanitize_bot_response=True,
+ sanitize_bot_response=False,
show_rank=False,
load_db_if_exists=False,
db=None,
- max_new_tokens=256,
+ do_sample=False,
temperature=0.1,
- repetition_penalty=1.0,
top_k=40,
top_p=0.7,
+ num_beams=1,
+ max_new_tokens=256,
+ min_new_tokens=1,
+ early_stopping=False,
+ max_time=180,
+ repetition_penalty=1.0,
+ num_return_sequences=1,
langchain_mode=None,
- document_choice=['All'],
+ document_choice=[DocumentChoices.All_Relevant.name],
n_jobs=-1,
verbose=False,
- cli=False):
+ cli=False,
+ reverse_docs=True,
+ lora_weights='',
+ auto_reduce_chunks=True,
+ max_chunks=100,
+ ):
"""
:param query:
@@ -1149,39 +1814,63 @@ def _run_qa_db(query=None,
:param answer_with_sources
:return:
"""
+ if model is not None:
+ assert model_name is not None # require so can make decisions
assert query is not None
assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate
if prompter is not None:
prompt_type = prompter.prompt_type
+ prompt_dict = prompter.prompt_dict
if model is not None:
assert prompt_type is not None
+ if prompt_type == PromptType.custom.name:
+ assert prompt_dict is not None # should at least be {} or ''
+ else:
+ prompt_dict = ''
+ assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0
llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name,
- model=model, tokenizer=tokenizer,
+ model=model,
+ tokenizer=tokenizer,
+ inference_server=inference_server,
stream_output=stream_output,
- max_new_tokens=max_new_tokens,
+ do_sample=do_sample,
temperature=temperature,
- repetition_penalty=repetition_penalty,
top_k=top_k,
top_p=top_p,
+ num_beams=num_beams,
+ max_new_tokens=max_new_tokens,
+ min_new_tokens=min_new_tokens,
+ early_stopping=early_stopping,
+ max_time=max_time,
+ repetition_penalty=repetition_penalty,
+ num_return_sequences=num_return_sequences,
prompt_type=prompt_type,
+ prompt_dict=prompt_dict,
prompter=prompter,
+ sanitize_bot_response=sanitize_bot_response,
verbose=verbose,
)
- if model_name in non_hf_types:
- # FIXME: for now, streams to stdout/stderr currently
- stream_output = False
-
use_context = False
scores = []
chain = None
+ if isinstance(document_choice, str):
+ # support string as well
+ document_choice = [document_choice]
+ # get first DocumentChoices as command to use, ignore others
+ doc_choices_set = set([x.name for x in list(DocumentChoices)])
+ cmd = [x for x in document_choice if x in doc_choices_set]
+ cmd = None if len(cmd) == 0 else cmd[0]
+ # now have cmd, filter out for only docs
+ document_choice = [x for x in document_choice if x not in doc_choices_set]
+
func_names = list(inspect.signature(get_similarity_chain).parameters)
sim_kwargs = {k: v for k, v in locals().items() if k in func_names}
missing_kwargs = [x for x in func_names if x not in sim_kwargs]
assert not missing_kwargs, "Missing: %s" % missing_kwargs
docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs)
- if len(document_choice) > 0 and document_choice[0] == 'Only':
+ if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs])
yield formatted_doc_chunks, ''
return
@@ -1189,43 +1878,49 @@ def _run_qa_db(query=None,
# can only return if HF type
return
- if stream_output:
- answer = None
- assert streamer is not None
- import queue
- bucket = queue.Queue()
- thread = EThread(target=chain, streamer=streamer, bucket=bucket)
- thread.start()
- outputs = ""
- prompt = None # FIXME
- try:
- for new_text in streamer:
- # print("new_text: %s" % new_text, flush=True)
- if bucket.qsize() > 0 or thread.exc:
- thread.join()
- outputs += new_text
- if prompter: # and False: # FIXME: pipeline can already use prompter
- output1 = prompter.get_response(outputs, prompt=prompt,
- sanitize_bot_response=sanitize_bot_response)
- yield output1, ''
- else:
- yield outputs, ''
- except BaseException:
- # if any exception, raise that exception if was from thread, first
- if thread.exc:
- raise thread.exc
- raise
- finally:
- # in case no exception and didn't join with thread yet, then join
- if not thread.exc:
- answer = thread.join()
- # in case raise StopIteration or broke queue loop in streamer, but still have exception
- if thread.exc:
- raise thread.exc
- # FIXME: answer is not string outputs from streamer. How to get actual final output?
- # answer = outputs
- else:
- answer = chain()
+ # context stuff similar to used in evaluate()
+ import torch
+ device, torch_dtype, context_class = get_device_dtype()
+ with torch.no_grad():
+ have_lora_weights = lora_weights not in [no_lora_str, '', None]
+ context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast
+ with context_class_cast(device):
+ if stream_output and streamer:
+ answer = None
+ import queue
+ bucket = queue.Queue()
+ thread = EThread(target=chain, streamer=streamer, bucket=bucket)
+ thread.start()
+ outputs = ""
+ prompt = None # FIXME
+ try:
+ for new_text in streamer:
+ # print("new_text: %s" % new_text, flush=True)
+ if bucket.qsize() > 0 or thread.exc:
+ thread.join()
+ outputs += new_text
+ if prompter: # and False: # FIXME: pipeline can already use prompter
+ output1 = prompter.get_response(outputs, prompt=prompt,
+ sanitize_bot_response=sanitize_bot_response)
+ yield output1, ''
+ else:
+ yield outputs, ''
+ except BaseException:
+ # if any exception, raise that exception if was from thread, first
+ if thread.exc:
+ raise thread.exc
+ raise
+ finally:
+ # in case no exception and didn't join with thread yet, then join
+ if not thread.exc:
+ answer = thread.join()
+ # in case raise StopIteration or broke queue loop in streamer, but still have exception
+ if thread.exc:
+ raise thread.exc
+ # FIXME: answer is not string outputs from streamer. How to get actual final output?
+ # answer = outputs
+ else:
+ answer = chain()
if not use_context:
ret = answer['output_text']
@@ -1239,22 +1934,31 @@ def _run_qa_db(query=None,
def get_similarity_chain(query=None,
use_openai_model=False, use_openai_embedding=False,
- first_para=False, text_limit=None, k=4, chunk=False, chunk_size=1024,
+ first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512,
user_path=None,
detect_user_path_changes_every_query=False,
db_type='faiss',
model_name=None,
+ inference_server='',
hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2",
prompt_type=None,
+ prompt_dict=None,
cut_distanct=1.1,
load_db_if_exists=False,
db=None,
langchain_mode=None,
- document_choice=['All'],
+ document_choice=[DocumentChoices.All_Relevant.name],
n_jobs=-1,
# beyond run_db_query:
llm=None,
+ tokenizer=None,
verbose=False,
+ cmd=None,
+ reverse_docs=True,
+
+ # local
+ auto_reduce_chunks=True,
+ max_chunks=100,
):
# determine whether use of context out of docs is planned
if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
@@ -1266,10 +1970,14 @@ def get_similarity_chain(query=None,
use_context = True
# https://github.com/hwchase17/langchain/issues/1946
- # FIXME: Seems to way to get size of chroma db to limit k to avoid
+ # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid
# Chroma collection MyData contains fewer than 4 elements.
# type logger error
- k_db = 1000 if db_type == 'chroma' else k # k=100 works ok too for
+ if top_k_docs == -1:
+ k_db = 1000 if db_type == 'chroma' else 100
+ else:
+ # top_k_docs=100 works ok too
+ k_db = 1000 if db_type == 'chroma' else top_k_docs
# FIXME: For All just go over all dbs instead of a separate db for All
if not detect_user_path_changes_every_query and db is not None:
@@ -1279,7 +1987,8 @@ def get_similarity_chain(query=None,
user_path = None
db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding,
hf_embedding_model=hf_embedding_model,
- first_para=first_para, text_limit=text_limit, chunk=chunk,
+ first_para=first_para, text_limit=text_limit,
+ chunk=chunk,
chunk_size=chunk_size,
langchain_mode=langchain_mode,
user_path=user_path,
@@ -1289,37 +1998,133 @@ def get_similarity_chain(query=None,
n_jobs=n_jobs,
verbose=verbose)
+ if 'falcon' in model_name:
+ extra = "According to only the information in the document sources provided within the context above, "
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends."
+ elif inference_server in ['openai', 'openai_chat']:
+ extra = "According to (primarily) the information in the document sources provided within context above, "
+ prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents."
+ else:
+ extra = ""
+ prefix = ""
+ if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
+ template_if_no_docs = template = """%s{context}{question}""" % prefix
+ else:
+ template = """%s
+\"\"\"
+{context}
+\"\"\"
+%s{question}""" % (prefix, extra)
+ template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra)
+ if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
+ use_template = True
+ else:
+ use_template = False
+
if db and use_context:
- if isinstance(document_choice, str):
- # support string as well
- document_choice = [document_choice]
- if not isinstance(db, Chroma) or \
- len(document_choice) == 0 or \
- len(document_choice) <= 1 and document_choice[0] == 'All':
- # treat empty list as All for now, not 'None'
- filter_kwargs = {}
- elif len(document_choice) > 0 and document_choice[0] == 'Only':
- # Only means All docs, but only will return sources, not LLM response
+ if not isinstance(db, Chroma):
+ # only chroma supports filtering
filter_kwargs = {}
else:
+ # if here then some cmd + documents selected or just documents selected
if len(document_choice) >= 2:
or_filter = [{"source": {"$eq": x}} for x in document_choice]
filter_kwargs = dict(filter={"$or": or_filter})
- elif len(document_choice) > 0:
+ elif len(document_choice) == 1:
+ # degenerate UX bug in chroma
one_filter = [{"source": {"$eq": x}} for x in document_choice][0]
filter_kwargs = dict(filter=one_filter)
else:
+ # shouldn't reach
filter_kwargs = {}
- if len(document_choice) == 1 and document_choice[0] == 'None':
- k_db = 1
- k = 0
- docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:k]
- # cut off so no high distance docs/sources considered
- docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
- scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
- if len(scores) > 0 and verbose:
- print("Distance: min: %s max: %s mean: %s median: %s" %
- (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
+ if cmd == DocumentChoices.Just_LLM.name:
+ docs = []
+ scores = []
+ elif cmd == DocumentChoices.Only_All_Sources.name:
+ db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs)
+ # similar to langchain's chroma's _results_to_docs_and_scores
+ docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0)
+ for result in zip(db_documents, db_metadatas)][:top_k_docs]
+ docs = [x[0] for x in docs_with_score]
+ scores = [x[1] for x in docs_with_score]
+ else:
+ if top_k_docs == -1 or auto_reduce_chunks:
+ # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
+ top_k_docs_tokenize = 100
+ base_path = 'locks'
+ makedirs(base_path)
+ if hasattr(db, '_persist_directory'):
+ name_path = "sim_%s.lock" % os.path.basename(db._persist_directory)
+ else:
+ name_path = "sim.lock"
+ with filelock.FileLock(os.path.join(base_path, name_path)):
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[
+ :top_k_docs_tokenize]
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'):
+ # more accurate
+ tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score]
+ template_tokens = len(llm.pipeline.tokenizer(template)['input_ids'])
+ elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss',
+ 'weaviate']:
+ # use ticktoken for faiss since embedding called differently
+ tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score]
+ template_tokens = llm.get_num_tokens(template)
+ elif isinstance(tokenizer, FakeTokenizer):
+ tokens = [tokenizer.num_tokens_from_string(x[0].page_content) for x in docs_with_score]
+ template_tokens = tokenizer.num_tokens_from_string(template)
+ else:
+ # in case model is not our pipeline with HF tokenizer
+ tokens = [db._embedding_function.client.tokenize([x[0].page_content])['input_ids'].shape[1] for x in
+ docs_with_score]
+ template_tokens = db._embedding_function.client.tokenize([template])['input_ids'].shape[1]
+ tokens_cumsum = np.cumsum(tokens)
+ if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'max_input_tokens'):
+ max_input_tokens = llm.pipeline.max_input_tokens
+ elif inference_server in ['openai']:
+ max_tokens = llm.modelname_to_contextsize(model_name)
+ # leave some room for 1 paragraph, even if min_new_tokens=0
+ max_input_tokens = max_tokens - 256
+ elif inference_server in ['openai_chat']:
+ max_tokens = model_token_mapping[model_name]
+ # leave some room for 1 paragraph, even if min_new_tokens=0
+ max_input_tokens = max_tokens - 256
+ elif isinstance(tokenizer, FakeTokenizer):
+ max_input_tokens = tokenizer.model_max_length - 256
+ else:
+ # leave some room for 1 paragraph, even if min_new_tokens=0
+ max_input_tokens = 2048 - 256
+ max_input_tokens -= template_tokens
+ # FIXME: Doesn't account for query, == context, or new lines between contexts
+ where_res = np.where(tokens_cumsum < max_input_tokens)[0]
+ if where_res.shape[0] == 0:
+ # then no chunk can fit, still do first one
+ top_k_docs_trial = 1
+ else:
+ top_k_docs_trial = 1 + where_res[-1]
+ if 0 < top_k_docs_trial < max_chunks:
+ # avoid craziness
+ if top_k_docs == -1:
+ top_k_docs = top_k_docs_trial
+ else:
+ top_k_docs = min(top_k_docs, top_k_docs_trial)
+ if top_k_docs == -1:
+ # if here, means 0 and just do best with 1 doc
+ print("Unexpected large chunks and can't add to context, will add 1 anyways", flush=True)
+ top_k_docs = 1
+ docs_with_score = docs_with_score[:top_k_docs]
+ else:
+ docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs]
+ # put most relevant chunks closest to question,
+ # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated
+ # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest
+ if reverse_docs:
+ docs_with_score.reverse()
+ # cut off so no high distance docs/sources considered
+ docs = [x[0] for x in docs_with_score if x[1] < cut_distanct]
+ scores = [x[1] for x in docs_with_score if x[1] < cut_distanct]
+ if len(scores) > 0 and verbose:
+ print("Distance: min: %s max: %s mean: %s median: %s" %
+ (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True)
else:
docs = []
scores = []
@@ -1328,7 +2133,7 @@ def get_similarity_chain(query=None,
# if HF type and have no docs, can bail out
return docs, None, [], False
- if len(document_choice) > 0 and document_choice[0] == 'Only':
+ if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]:
# no LLM use
return docs, None, [], False
@@ -1348,19 +2153,11 @@ def get_similarity_chain(query=None,
if len(docs) == 0:
# avoid context == in prompt then
use_context = False
+ template = template_if_no_docs
- if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types:
+ if use_template:
# instruct-like, rather than few-shot prompt_type='plain' as default
# but then sources confuse the model with how inserted among rest of text, so avoid
- prefix = ""
- if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context:
- template = """%s{context}{question}""" % prefix
- else:
- template = """%s
-==
-{context}
-==
-{question}""" % prefix
prompt = PromptTemplate(
# input_variables=["summaries", "question"],
input_variables=["context", "question"],
@@ -1420,15 +2217,32 @@ def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, ve
return ret, extra
-def chunk_sources(sources, chunk_size=1024):
- source_chunks = []
- # Below for known separator
- # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0)
- splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0)
- for source in sources:
- # print(source.metadata['source'], flush=True)
- for chunky in splitter.split_text(source.page_content):
- source_chunks.append(Document(page_content=chunky, metadata=source.metadata))
+def clean_doc(docs1):
+ if not isinstance(docs1, (list, tuple, types.GeneratorType)):
+ docs1 = [docs1]
+ for doci, doc in enumerate(docs1):
+ docs1[doci].page_content = '\n'.join([x.strip() for x in doc.page_content.split("\n") if x.strip()])
+ return docs1
+
+
+def chunk_sources(sources, chunk=True, chunk_size=512, language=None):
+ if not chunk:
+ return sources
+ if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources):
+ # if just one document
+ sources = [sources]
+ if language and False:
+ # Bug in langchain, keep separator=True not working
+ # https://github.com/hwchase17/langchain/issues/2836
+ # so avoid this for now
+ keep_separator = True
+ separators = RecursiveCharacterTextSplitter.get_separators_for_language(language)
+ else:
+ separators = ["\n\n", "\n", " ", ""]
+ keep_separator = False
+ splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator,
+ separators=separators)
+ source_chunks = splitter.split_documents(sources)
return source_chunks
@@ -1439,6 +2253,8 @@ def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'):
path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset')
import zipfile
with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref:
+ persist_directory = os.path.dirname(zip_ref.namelist()[0])
+ remove(persist_directory)
zip_ref.extractall(dest)
return path_to_zip_file
@@ -1467,5 +2283,28 @@ def get_some_dbs_from_hf(dest='.', db_zips=None):
assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected
+def _create_local_weaviate_client():
+ WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080")
+ WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME')
+ WEAVIATE_PASSWORD = os.getenv('WEAVIATE_PASSWORD')
+ WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access")
+
+ resource_owner_config = None
+ try:
+ import weaviate
+ if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None:
+ resource_owner_config = weaviate.AuthClientPassword(
+ username=WEAVIATE_USERNAME,
+ password=WEAVIATE_PASSWORD,
+ scope=WEAVIATE_SCOPE
+ )
+
+ client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config)
+ return client
+ except Exception as e:
+ print(f"Failed to create Weaviate client: {e}")
+ return None
+
+
if __name__ == '__main__':
pass
diff --git a/gradio_runner.py b/gradio_runner.py
index 4348721acd1a0b50d0d3e7c8ff56d58ed9513d78..866d3d9862f0270d47efc207940edd5bafd63766 100644
--- a/gradio_runner.py
+++ b/gradio_runner.py
@@ -1,16 +1,25 @@
import copy
import functools
import inspect
+import itertools
import json
import os
+import pprint
import random
+import shutil
import sys
+import time
import traceback
+import typing
import uuid
import filelock
import pandas as pd
import requests
import tabulate
+from iterators import TimeoutIterator
+
+from gradio_utils.css import get_css
+from gradio_utils.prompt_form import make_prompt_form, make_chatbots
# This is a hack to prevent Gradio from phoning home when it gets imported
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
@@ -28,17 +37,55 @@ import gradio as gr
requests.get = original_get
-from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js
-from prompter import Prompter, \
- prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, generate_prompt, non_hf_types
+
+def fix_pydantic_duplicate_validators_error():
+ try:
+ from pydantic import class_validators
+
+ class_validators.in_ipython = lambda: True # type: ignore[attr-defined]
+ except ImportError:
+ pass
+
+
+fix_pydantic_duplicate_validators_error()
+
+from enums import DocumentChoices, no_model_str, no_lora_str, no_server_str, LangChainMode
+from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \
+ text_xsm
+from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \
+ get_prompt
from utils import get_githash, flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \
- ping, get_short_name, get_url, makedirs, get_kwargs
+ ping, get_short_name, get_url, makedirs, get_kwargs, remove, system_info, ping_gpu
from generate import get_model, languages_covered, evaluate, eval_func_param_names, score_qa, langchain_modes, \
- inputs_kwargs_list, get_cutoffs, scratch_base_dir
+ inputs_kwargs_list, scratch_base_dir, evaluate_from_str, no_default_param_names, \
+ eval_func_param_names_defaults, get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context
from apscheduler.schedulers.background import BackgroundScheduler
+def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True):
+ if fix_latex_dollars:
+ ts = text.split('```')
+ for parti, part in enumerate(ts):
+ inside = parti % 2 == 1
+ if not inside:
+ ts[parti] = ts[parti].replace('$', '﹩')
+ text = '```'.join(ts)
+
+ if fix_new_lines:
+ # let Gradio handle code, since got improved recently
+ ## FIXME: below conflicts with Gradio, but need to see if can handle multiple \n\n\n etc. properly as is.
+ # ensure good visually, else markdown ignores multiple \n
+ # handle code blocks
+ ts = text.split('```')
+ for parti, part in enumerate(ts):
+ inside = parti % 2 == 1
+ if not inside:
+ ts[parti] = ts[parti].replace('\n', '
')
+ text = '```'.join(ts)
+ return text
+
+
def go_gradio(**kwargs):
allow_api = kwargs['allow_api']
is_public = kwargs['is_public']
@@ -47,6 +94,7 @@ def go_gradio(**kwargs):
n_gpus = kwargs['n_gpus']
admin_pass = kwargs['admin_pass']
model_state0 = kwargs['model_state0']
+ model_states = kwargs['model_states']
score_model_state0 = kwargs['score_model_state0']
dbs = kwargs['dbs']
db_type = kwargs['db_type']
@@ -73,17 +121,9 @@ def go_gradio(**kwargs):
else:
instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \
" use Enter for multiple input lines)"
- if kwargs['input_lines'] > 1:
- instruction_label = "You (Shift-Enter or push Submit to send message, use Enter for multiple input lines)"
- else:
- instruction_label = "You (Enter or push Submit to send message, shift-enter for more lines)"
title = 'h2oGPT'
- if 'h2ogpt-research' in kwargs['base_model']:
- title += " [Research demonstration]"
- more_info = """For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O-LLMStudio](https://github.com/h2oai/h2o-llmstudio)
"""
- if is_public:
- more_info += """"""
+ more_info = """h2oGPT H2O LLM Studio
🤗 Models"""
if kwargs['verbose']:
description = f"""Model {kwargs['base_model']} Instruct dataset.
For more information, visit our GitHub pages: [h2oGPT](https://github.com/h2oai/h2ogpt) and [H2O LLM Studio](https://github.com/h2oai/h2o-llmstudio).
@@ -92,10 +132,10 @@ def go_gradio(**kwargs):
"""
else:
description = more_info
- description += "If this host is busy, try [12B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
"
- description += """
By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)
""" + description_bottom = "If this host is busy, try [LLaMa 65B](https://llama.h2o.ai), [Falcon 40B](https://gpt.h2o.ai), [Falcon 40B](http://falcon.h2o.ai), [HF Spaces1 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot) or [HF Spaces2 12B](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)By using h2oGPT, you accept our [Terms of Service](https://github.com/h2oai/h2ogpt/blob/main/docs/tos.md)
""" if is_hf: - description += '''
@@ -341,12 +342,9 @@ body.dark{#warning {background-color: #555555};}
multiselect=True,
)
with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list):
- get_sources_btn = gr.Button(value="Get Sources",
- ).style(full_width=False, size='sm')
- show_sources_btn = gr.Button(value="Show Sources",
- ).style(full_width=False, size='sm')
- refresh_sources_btn = gr.Button(value="Refresh Sources",
- ).style(full_width=False, size='sm')
+ get_sources_btn = gr.Button(value="Get Sources", scale=0, size='sm')
+ show_sources_btn = gr.Button(value="Show Sources", scale=0, size='sm')
+ refresh_sources_btn = gr.Button(value="Refresh Sources", scale=0, size='sm')
# import control
if kwargs['langchain_mode'] != 'Disabled':
@@ -355,8 +353,8 @@ body.dark{#warning {background-color: #555555};}
have_arxiv = False
file_types = []
- upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload).style(
- equal_height=False)
+ upload_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload,
+ equal_height=False)
with upload_row:
with gr.Column():
file_types_str = '[' + ' '.join(file_types) + ']'
@@ -366,38 +364,50 @@ body.dark{#warning {background-color: #555555};}
elem_id="warning", elem_classes="feedback")
with gr.Row():
add_to_shared_db_btn = gr.Button("Add File(s) to UserData",
- visible=allow_upload_to_user_data, elem_id='small_btn')
+ visible=allow_upload_to_user_data,
+ elem_id='small_btn')
add_to_my_db_btn = gr.Button("Add File(s) to Scratch MyData",
- visible=allow_upload_to_my_data,
+ visible=allow_upload_to_my_data and
+ allow_upload_to_user_data,
elem_id='small_btn' if allow_upload_to_user_data else None,
- ).style(
- size='sm' if not allow_upload_to_user_data else None)
+ size='sm' if not allow_upload_to_user_data else None)
with gr.Column(
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload):
url_label = 'URL (http/https) or ArXiv:' if have_arxiv else 'URL (http/https)'
- url_text = gr.Textbox(label=url_label, interactive=True)
+ url_text = gr.Textbox(label=url_label,
+ placeholder="Click Add to Submit" if
+ allow_upload_to_my_data and
+ allow_upload_to_user_data else
+ "Enter to Submit",
+ max_lines=1,
+ interactive=True)
with gr.Row():
url_user_btn = gr.Button(value='Add URL content to Shared UserData',
- visible=allow_upload_to_user_data, elem_id='small_btn')
+ visible=allow_upload_to_user_data and allow_upload_to_my_data,
+ elem_id='small_btn')
url_my_btn = gr.Button(value='Add URL content to Scratch MyData',
- visible=allow_upload_to_my_data,
+ visible=allow_upload_to_my_data and allow_upload_to_user_data,
elem_id='small_btn' if allow_upload_to_user_data else None,
- ).style(size='sm' if not allow_upload_to_user_data else None)
+ size='sm' if not allow_upload_to_user_data else None)
with gr.Column(
visible=kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload):
- user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]', interactive=True)
+ user_text_text = gr.Textbox(label='Paste Text [Shift-Enter more lines]',
+ placeholder="Click Add to Submit" if
+ allow_upload_to_my_data and
+ allow_upload_to_user_data else
+ "Enter to Submit, Shift-Enter for more lines",
+ interactive=True)
with gr.Row():
user_text_user_btn = gr.Button(value='Add Text to Shared UserData',
- visible=allow_upload_to_user_data,
+ visible=allow_upload_to_user_data and allow_upload_to_my_data,
elem_id='small_btn')
user_text_my_btn = gr.Button(value='Add Text to Scratch MyData',
- visible=allow_upload_to_my_data,
+ visible=allow_upload_to_my_data and allow_upload_to_user_data,
elem_id='small_btn' if allow_upload_to_user_data else None,
- ).style(
- size='sm' if not allow_upload_to_user_data else None)
+ size='sm' if not allow_upload_to_user_data else None)
with gr.Column(visible=False):
# WIP:
- with gr.Row(visible=False).style(equal_height=False):
+ with gr.Row(visible=False, equal_height=False):
github_textbox = gr.Textbox(label="Github URL")
with gr.Row(visible=True):
github_shared_btn = gr.Button(value="Add Github to Shared UserData",
@@ -405,18 +415,37 @@ body.dark{#warning {background-color: #555555};}
elem_id='small_btn')
github_my_btn = gr.Button(value="Add Github to Scratch MyData",
visible=allow_upload_to_my_data, elem_id='small_btn')
- sources_row3 = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
- equal_height=False)
- with sources_row3:
+ sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list,
+ equal_height=False)
+ with sources_row:
with gr.Column(scale=1):
file_source = gr.File(interactive=False,
label="Download File w/Sources [click get sources to make file]")
with gr.Column(scale=2):
- pass
- sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list).style(
- equal_height=False)
- with sources_row:
- sources_text = gr.HTML(label='Sources Added', interactive=False)
+ sources_text = gr.HTML(label='Sources Added', interactive=False)
+
+ with gr.TabItem("Chat History"):
+ with gr.Row():
+ if 'mbart-' in kwargs['model_lower']:
+ src_lang = gr.Dropdown(list(languages_covered().keys()),
+ value=kwargs['src_lang'],
+ label="Input Language")
+ tgt_lang = gr.Dropdown(list(languages_covered().keys()),
+ value=kwargs['tgt_lang'],
+ label="Output Language")
+ radio_chats = gr.Radio(value=None, label="Saved Chats", visible=True, interactive=True,
+ type='value')
+ with gr.Row():
+ clear_chat_btn = gr.Button(value="Clear Chat", visible=True, size='sm')
+ export_chats_btn = gr.Button(value="Export Chats to Download", size='sm')
+ remove_chat_btn = gr.Button(value="Remove Selected Chat", visible=True, size='sm')
+ add_to_chats_btn = gr.Button("Import Chats from Upload", size='sm')
+ with gr.Row():
+ chats_file = gr.File(interactive=False, label="Download Exported Chats")
+ chatsup_output = gr.File(label="Upload Chat File(s)",
+ file_types=['.json'],
+ file_count='multiple',
+ elem_id="warning", elem_classes="feedback")
with gr.TabItem("Expert"):
with gr.Row():
@@ -425,22 +454,25 @@ body.dark{#warning {background-color: #555555};}
value=kwargs['stream_output'])
prompt_type = gr.Dropdown(prompt_types_strings,
value=kwargs['prompt_type'], label="Prompt Type",
- visible=not is_public)
+ visible=not kwargs['model_lock'],
+ interactive=not is_public,
+ )
prompt_type2 = gr.Dropdown(prompt_types_strings,
value=kwargs['prompt_type'], label="Prompt Type Model 2",
- visible=not is_public and False)
+ visible=False and not kwargs['model_lock'],
+ interactive=not is_public)
do_sample = gr.Checkbox(label="Sample",
info="Enable sampler, required for use of temperature, top_p, top_k",
value=kwargs['do_sample'])
- temperature = gr.Slider(minimum=0.01, maximum=3,
+ temperature = gr.Slider(minimum=0.01, maximum=2,
value=kwargs['temperature'],
label="Temperature",
info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)")
- top_p = gr.Slider(minimum=0, maximum=1,
+ top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3,
value=kwargs['top_p'], label="Top p",
info="Cumulative probability of tokens to sample from")
top_k = gr.Slider(
- minimum=0, maximum=100, step=1,
+ minimum=1, maximum=100, step=1,
value=kwargs['top_k'], label="Top k",
info='Num. tokens to sample from'
)
@@ -452,18 +484,9 @@ body.dark{#warning {background-color: #555555};}
num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1,
value=min(max_beams, kwargs['num_beams']), label="Beams",
info="Number of searches for optimal overall probability. "
- "Uses more GPU memory/compute")
- # FIXME: 2048 should be tokenizer.model_max_length, but may not even have model yet
- if kwargs['max_new_tokens']:
- max_max_new_tokens = kwargs['max_new_tokens']
- elif memory_restriction_level == 1:
- max_max_new_tokens = 768
- elif memory_restriction_level == 2:
- max_max_new_tokens = 512
- elif memory_restriction_level >= 3:
- max_max_new_tokens = 256
- else:
- max_max_new_tokens = 2048
+ "Uses more GPU memory/compute",
+ interactive=False)
+ max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs)
max_new_tokens = gr.Slider(
minimum=1, maximum=max_max_new_tokens, step=1,
value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length",
@@ -472,13 +495,21 @@ body.dark{#warning {background-color: #555555};}
minimum=0, maximum=max_max_new_tokens, step=1,
value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length",
)
+ max_new_tokens2 = gr.Slider(
+ minimum=1, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2",
+ visible=False and not kwargs['model_lock'],
+ )
+ min_new_tokens2 = gr.Slider(
+ minimum=0, maximum=max_max_new_tokens, step=1,
+ value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2",
+ visible=False and not kwargs['model_lock'],
+ )
early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search",
value=kwargs['early_stopping'])
- max_max_time = 60 * 5 if not is_public else 60 * 2
- if is_hf:
- max_max_time = min(max_max_time, 60 * 1)
- max_time = gr.Slider(minimum=0, maximum=max_max_time, step=1,
- value=min(max_max_time, kwargs['max_time']), label="Max. time",
+ max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1,
+ value=min(kwargs['max_max_time'],
+ kwargs['max_time']), label="Max. time",
info="Max. time to search optimal output.")
repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0,
value=kwargs['repetition_penalty'],
@@ -486,90 +517,137 @@ body.dark{#warning {background-color: #555555};}
num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1,
value=kwargs['num_return_sequences'],
label="Number Returns", info="Must be <= num_beams",
- visible=not is_public)
+ interactive=not is_public)
iinput = gr.Textbox(lines=4, label="Input",
placeholder=kwargs['placeholder_input'],
- visible=not is_public)
+ interactive=not is_public)
context = gr.Textbox(lines=3, label="System Pre-Context",
info="Directly pre-appended without prompt processing",
- visible=not is_public)
+ interactive=not is_public)
chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'],
- visible=not is_public)
- count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", visible=not is_public)
+ visible=not kwargs['model_lock'],
+ interactive=not is_public,
+ )
+ count_chat_tokens_btn = gr.Button(value="Count Chat Tokens",
+ visible=not is_public and not kwargs['model_lock'],
+ interactive=not is_public)
chat_token_count = gr.Textbox(label="Chat Token Count", value=None,
- visible=not is_public, interactive=False)
- top_k_docs = gr.Slider(minimum=0, maximum=20, step=1,
+ visible=not is_public and not kwargs['model_lock'],
+ interactive=False)
+ chunk = gr.components.Checkbox(value=kwargs['chunk'],
+ label="Whether to chunk documents",
+ info="For LangChain",
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public)
+ min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public)
+ top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1,
value=kwargs['top_k_docs'],
- label="Number of document chunks",
+ label=label_top_k_docs,
info="For LangChain",
- visible=not is_public)
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public)
+ chunk_size = gr.Number(value=kwargs['chunk_size'],
+ label="Chunk size for document chunking",
+ info="For LangChain (ignored if chunk=False)",
+ minimum=128,
+ maximum=2048,
+ visible=kwargs['langchain_mode'] != 'Disabled',
+ interactive=not is_public,
+ precision=0)
with gr.TabItem("Models"):
+ model_lock_msg = gr.Textbox(lines=1, label="Model Lock Notice",
+ placeholder="Started in model_lock mode, no model changes allowed.",
+ visible=bool(kwargs['model_lock']), interactive=False)
load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO"
load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \
else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2"
+ variant_load_msg = 'primary' if not is_public else 'secondary'
compare_checkbox = gr.components.Checkbox(label="Compare Mode",
- value=False, visible=not is_public)
+ value=kwargs['model_lock'],
+ visible=not is_public and not kwargs['model_lock'])
with gr.Row():
n_gpus_list = [str(x) for x in list(range(-1, n_gpus))]
with gr.Column():
with gr.Row():
- with gr.Column(scale=50):
+ with gr.Column(scale=20, visible=not kwargs['model_lock']):
model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model",
value=kwargs['base_model'])
lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA",
value=kwargs['lora_weights'], visible=kwargs['show_lora'])
- with gr.Column(scale=1):
- load_model_button = gr.Button(load_msg).style(full_width=False, size='sm')
+ server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server",
+ value=kwargs['inference_server'], visible=not is_public)
+ with gr.Column(scale=1, visible=not kwargs['model_lock']):
+ load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0,
+ size='sm', interactive=not is_public)
model_load8bit_checkbox = gr.components.Checkbox(
label="Load 8-bit [requires support]",
- value=kwargs['load_8bit'])
+ value=kwargs['load_8bit'], interactive=not is_public)
model_infer_devices_checkbox = gr.components.Checkbox(
label="Choose Devices [If not Checked, use all GPUs]",
- value=kwargs['infer_devices'])
+ value=kwargs['infer_devices'], interactive=not is_public)
model_gpu = gr.Dropdown(n_gpus_list,
label="GPU ID [-1 = all GPUs, if Choose is enabled]",
- value=kwargs['gpu_id'])
+ value=kwargs['gpu_id'], interactive=not is_public)
model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'],
interactive=False)
lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'],
visible=kwargs['show_lora'], interactive=False)
+ server_used = gr.Textbox(label="Current Server",
+ value=kwargs['inference_server'],
+ visible=bool(kwargs['inference_server']) and not is_public,
+ interactive=False)
+ prompt_dict = gr.Textbox(label="Prompt (or Custom)",
+ value=pprint.pformat(kwargs['prompt_dict'], indent=4),
+ interactive=not is_public, lines=4)
col_model2 = gr.Column(visible=False)
with col_model2:
with gr.Row():
- with gr.Column(scale=50):
+ with gr.Column(scale=20, visible=not kwargs['model_lock']):
model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2",
value=no_model_str)
lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2",
value=no_lora_str,
visible=kwargs['show_lora'])
- with gr.Column(scale=1):
- load_model_button2 = gr.Button(load_msg2).style(full_width=False, size='sm')
+ server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2",
+ value=no_server_str,
+ visible=not is_public)
+ with gr.Column(scale=1, visible=not kwargs['model_lock']):
+ load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0,
+ size='sm', interactive=not is_public)
model_load8bit_checkbox2 = gr.components.Checkbox(
label="Load 8-bit 2 [requires support]",
- value=kwargs['load_8bit'])
+ value=kwargs['load_8bit'], interactive=not is_public)
model_infer_devices_checkbox2 = gr.components.Checkbox(
label="Choose Devices 2 [If not Checked, use all GPUs]",
value=kwargs[
- 'infer_devices'])
+ 'infer_devices'], interactive=not is_public)
model_gpu2 = gr.Dropdown(n_gpus_list,
label="GPU ID 2 [-1 = all GPUs, if choose is enabled]",
- value=kwargs['gpu_id'])
+ value=kwargs['gpu_id'], interactive=not is_public)
# no model/lora loaded ever in model2 by default
- model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str)
+ model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str,
+ interactive=False)
lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str,
- visible=kwargs['show_lora'])
- with gr.Row():
+ visible=kwargs['show_lora'], interactive=False)
+ server_used2 = gr.Textbox(label="Current Server 2", value=no_server_str,
+ interactive=False,
+ visible=not is_public)
+ prompt_dict2 = gr.Textbox(label="Prompt (or Custom) 2",
+ value=pprint.pformat(kwargs['prompt_dict'], indent=4),
+ interactive=not is_public, lines=4)
+ with gr.Row(visible=not kwargs['model_lock']):
with gr.Column(scale=50):
- new_model = gr.Textbox(label="New Model HF name/path")
- with gr.Row():
- add_model_button = gr.Button("Add new model name").style(full_width=False, size='sm')
+ new_model = gr.Textbox(label="New Model name/path", interactive=not is_public)
with gr.Column(scale=50):
- new_lora = gr.Textbox(label="New LORA HF name/path", visible=kwargs['show_lora'])
+ new_lora = gr.Textbox(label="New LORA name/path", visible=kwargs['show_lora'],
+ interactive=not is_public)
+ with gr.Column(scale=50):
+ new_server = gr.Textbox(label="New Server url:port", interactive=not is_public)
with gr.Row():
- add_lora_button = gr.Button("Add new LORA name", visible=kwargs['show_lora']).style(
- full_width=False, size='sm')
+ add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0,
+ size='sm', interactive=not is_public)
with gr.TabItem("System"):
admin_row = gr.Row()
with admin_row:
@@ -580,8 +658,17 @@ body.dark{#warning {background-color: #555555};}
with gr.Column():
with gr.Row():
system_btn = gr.Button(value='Get System Info')
- system_text = gr.Textbox(label='System Info', interactive=False).style(
- show_copy_button=True)
+ system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True)
+ with gr.Row():
+ system_input = gr.Textbox(label='System Info Dict Password', interactive=True,
+ visible=not is_public)
+ system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public)
+ system_text2 = gr.Textbox(label='System Info Dict', interactive=False,
+ visible=not is_public, show_copy_button=True)
+ with gr.Row():
+ system_btn3 = gr.Button(value='Get Hash', visible=not is_public)
+ system_text3 = gr.Textbox(label='Hash', interactive=False,
+ visible=not is_public, show_copy_button=True)
with gr.Row():
zip_btn = gr.Button("Zip")
@@ -601,6 +688,11 @@ body.dark{#warning {background-color: #555555};}
description += """
'):
- prompt = prompt[:-4]
- prompt = prompt.replace('
', chat_sep)
- if not prompt.endswith(chat_sep):
- prompt += chat_sep
- # most recent first, add older if can
- # only include desired chat history
- if len(prompt + context1) > max_prompt_length:
- break
- context1 = prompt + context1
+ return history_list[0]
- _, pre_response, terminate_response, chat_sep = generate_prompt({}, prompt_type1, chat1,
- reduced=True)
- if context1 and not context1.endswith(chat_sep):
- context1 += chat_sep # ensure if terminates abruptly, then human continues on next line
- return context1
+ def get_model_max_length(model_state1):
+ if model_state1 and not isinstance(model_state1["tokenizer"], str):
+ tokenizer = model_state1["tokenizer"]
+ elif model_state0 and not isinstance(model_state0["tokenizer"], str):
+ tokenizer = model_state0["tokenizer"]
+ else:
+ tokenizer = None
+ if tokenizer is not None:
+ return tokenizer.model_max_length
+ else:
+ return 2000
- def bot(*args, retry=False):
+ def prep_bot(*args, retry=False):
"""
- bot that consumes history for user input
- instruction (from input_list) itself is not consumed by bot
+
:param args:
:param retry:
- :return:
+ :return: last element is True if should run bot, False if should just yield history
"""
# don't deepcopy, can contain model itself
args_list = list(args).copy()
model_state1 = args_list[-3]
my_db_state1 = args_list[-2]
history = args_list[-1]
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
- if model_state1[0] is None or model_state1[0] == no_model_str:
- history = []
- yield history, ''
- return
+ if model_state1['model'] is None or model_state1['model'] == no_model_str:
+ return history, None, None, None
args_list = args_list[:-3] # only keep rest needed for evaluate()
- langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
- if retry and history:
- history.pop()
- if not args_list[eval_func_param_names.index('do_sample')]:
- # if was not sampling, no point in retry unless change to sample
- args_list[eval_func_param_names.index('do_sample')] = True
if not history:
print("No history", flush=True)
history = []
- yield history, ''
- return
+ return history, None, None, None
instruction1 = history[-1][0]
- if not instruction1:
- # reject empty query, can sometimes go nuts
- history = []
- yield history, ''
- return
- prompt_type1 = args_list[eval_func_param_names.index('prompt_type')]
+ if retry and history:
+ # if retry, pop history and move onto bot stuff
+ instruction1 = history[-1][0]
+ history[-1][1] = None
+ elif not instruction1:
+ # if not retrying, then reject empty query
+ return history, None, None, None
+ elif len(history) > 0 and history[-1][1] not in [None, '']:
+ # reject submit button if already filled and not retrying
+ # None when not filling with '' to keep client happy
+ return history, None, None, None
+
+ # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it
+ prompt_type1 = kwargs.get('prompt_type', args_list[eval_func_param_names.index('prompt_type')])
+ # prefer model specific prompt type instead of global one, and apply back to args_list for evaluate()
+ args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 = \
+ model_state1.get('prompt_type', prompt_type1)
+
+ prompt_dict1 = kwargs.get('prompt_dict', args_list[eval_func_param_names.index('prompt_dict')])
+ args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 = \
+ model_state1.get('prompt_dict', prompt_dict1)
+
chat1 = args_list[eval_func_param_names.index('chat')]
- context1 = history_to_context(history, langchain_mode1, prompt_type1, chat1)
+ model_max_length1 = get_model_max_length(model_state1)
+ context1 = history_to_context(history, langchain_mode1, prompt_type1, prompt_dict1, chat1,
+ model_max_length1, memory_restriction_level,
+ kwargs['keep_sources_in_context'])
args_list[0] = instruction1 # override original instruction with history from user
args_list[2] = context1
+
fun1 = partial(evaluate,
model_state1,
my_db_state1,
+ *tuple(args_list),
**kwargs_evaluate)
+
+ return history, fun1, langchain_mode1, my_db_state1
+
+ def get_response(fun1, history):
+ """
+ bot that consumes history for user input
+ instruction (from input_list) itself is not consumed by bot
+ :return:
+ """
+ if not fun1:
+ yield history, ''
+ return
try:
- for output_fun in fun1(*tuple(args_list)):
+ for output_fun in fun1():
output = output_fun['response']
extra = output_fun['sources'] # FIXME: can show sources in separate text box etc.
# ensure good visually, else markdown ignores multiple \n
- bot_message = output.replace('\n', '
')
+ bot_message = fix_text_for_gradio(output)
history[-1][1] = bot_message
yield history, ''
except StopIteration:
@@ -1010,8 +1244,98 @@ body.dark{#warning {background-color: #555555};}
history[-1][1] = ''
yield history, ex
raise
+ finally:
+ clear_torch_cache()
return
+ def clear_embeddings(langchain_mode1, my_db):
+ # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache
+ if db_type == 'chroma' and langchain_mode1 not in ['ChatLLM', 'LLM', 'Disabled', None, '']:
+ from gpt_langchain import clear_embedding
+ db = dbs.get('langchain_mode1')
+ if db is not None and not isinstance(db, str):
+ clear_embedding(db)
+ if langchain_mode1 == LangChainMode.MY_DATA.value and my_db is not None:
+ clear_embedding(my_db[0])
+
+ def bot(*args, retry=False):
+ history, fun1, langchain_mode1, my_db_state1 = prep_bot(*args, retry=retry)
+ try:
+ for res in get_response(fun1, history):
+ yield res
+ finally:
+ clear_embeddings(langchain_mode1, my_db_state1)
+
+ def all_bot(*args, retry=False, model_states1=None):
+ args_list = list(args).copy()
+ chatbots = args_list[-len(model_states1):]
+ args_list0 = args_list[:-len(model_states1)] # same for all models
+ exceptions = []
+ stream_output1 = args_list[eval_func_param_names.index('stream_output')]
+ max_time1 = args_list[eval_func_param_names.index('max_time')]
+ langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')]
+ my_db_state1 = None # will be filled below by some bot
+ try:
+ gen_list = []
+ for chatbot1, model_state1 in zip(chatbots, model_states1):
+ args_list1 = args_list0.copy()
+ args_list1.insert(-1, model_state1) # insert at -1 so is at -2
+ # if at start, have None in response still, replace with '' so client etc. acts like normal
+ # assumes other parts of code treat '' and None as if no response yet from bot
+ # can't do this later in bot code as racy with threaded generators
+ if len(chatbot1) > 0 and len(chatbot1[-1]) == 2 and chatbot1[-1][1] is None:
+ chatbot1[-1][1] = ''
+ args_list1.append(chatbot1)
+ # so consistent with prep_bot()
+ # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1
+ # langchain_mode1 and my_db_state1 should be same for every bot
+ history, fun1, langchain_mode1, my_db_state1 = prep_bot(*tuple(args_list1), retry=retry)
+ gen1 = get_response(fun1, history)
+ if stream_output1:
+ gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False)
+ # else timeout will truncate output for non-streaming case
+ gen_list.append(gen1)
+
+ bots_old = chatbots.copy()
+ exceptions_old = [''] * len(bots_old)
+ tgen0 = time.time()
+ for res1 in itertools.zip_longest(*gen_list):
+ if time.time() - tgen0 > max_time1:
+ break
+
+ bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in
+ zip(res1, bots_old)]
+ bots_old = bots.copy()
+
+ def larger_str(x, y):
+ return x if len(x) > len(y) else y
+
+ exceptions = [x[1] if x is not None and not isinstance(x, BaseException) else larger_str(str(x), y)
+ for x, y in zip(res1, exceptions_old)]
+ exceptions_old = exceptions.copy()
+
+ def choose_exc(x):
+ # don't expose ports etc. to exceptions window
+ if is_public:
+ return "Endpoint unavailable or failed"
+ else:
+ return x
+
+ exceptions_str = '\n'.join(
+ ['Model %s: %s' % (iix, choose_exc(x)) for iix, x in enumerate(exceptions) if
+ x not in [None, '', 'None']])
+ if len(bots) > 1:
+ yield tuple(bots + [exceptions_str])
+ else:
+ yield bots[0], exceptions_str
+ if exceptions:
+ exceptions = [x for x in exceptions if x not in ['', None, 'None']]
+ if exceptions:
+ print("Generate exceptions: %s" % exceptions, flush=True)
+ finally:
+ clear_torch_cache()
+ clear_embeddings(langchain_mode1, my_db_state1)
+
# NORMAL MODEL
user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
inputs=inputs_list + [text_output],
@@ -1025,89 +1349,196 @@ body.dark{#warning {background-color: #555555};}
inputs=inputs_list + [model_state, my_db_state] + [text_output],
outputs=[text_output, exception_text],
)
+ retry_user_args = dict(fn=functools.partial(user, retry=True),
+ inputs=inputs_list + [text_output],
+ outputs=text_output,
+ )
undo_user_args = dict(fn=functools.partial(user, undo=True),
inputs=inputs_list + [text_output],
outputs=text_output,
)
# MODEL2
- user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt'], model2=True),
- inputs=inputs_list + [text_output2],
+ user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']),
+ inputs=inputs_list2 + [text_output2],
outputs=text_output2,
)
bot_args2 = dict(fn=bot,
- inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
+ inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
outputs=[text_output2, exception_text],
)
retry_bot_args2 = dict(fn=functools.partial(bot, retry=True),
- inputs=inputs_list + [model_state2, my_db_state] + [text_output2],
+ inputs=inputs_list2 + [model_state2, my_db_state] + [text_output2],
outputs=[text_output2, exception_text],
)
+ retry_user_args2 = dict(fn=functools.partial(user, retry=True),
+ inputs=inputs_list2 + [text_output2],
+ outputs=text_output2,
+ )
undo_user_args2 = dict(fn=functools.partial(user, undo=True),
- inputs=inputs_list + [text_output2],
+ inputs=inputs_list2 + [text_output2],
outputs=text_output2,
)
+ # MODEL N
+ all_user_args = dict(fn=functools.partial(all_user,
+ sanitize_user_prompt=kwargs['sanitize_user_prompt'],
+ num_model_lock=len(text_outputs),
+ ),
+ inputs=inputs_list + text_outputs,
+ outputs=text_outputs,
+ )
+ all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states),
+ inputs=inputs_list + [my_db_state] + text_outputs,
+ outputs=text_outputs + [exception_text],
+ )
+ all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True),
+ inputs=inputs_list + [my_db_state] + text_outputs,
+ outputs=text_outputs + [exception_text],
+ )
+ all_retry_user_args = dict(fn=functools.partial(all_user, retry=True,
+ sanitize_user_prompt=kwargs['sanitize_user_prompt'],
+ num_model_lock=len(text_outputs),
+ ),
+ inputs=inputs_list + text_outputs,
+ outputs=text_outputs,
+ )
+ all_undo_user_args = dict(fn=functools.partial(all_user, undo=True,
+ sanitize_user_prompt=kwargs['sanitize_user_prompt'],
+ num_model_lock=len(text_outputs),
+ ),
+ inputs=inputs_list + text_outputs,
+ outputs=text_outputs,
+ )
+
def clear_instruct():
return gr.Textbox.update(value='')
- if kwargs['auto_score']:
- score_args_submit = score_args
- score_args2_submit = score_args2
+ def deselect_radio_chats():
+ return gr.update(value=None)
+
+ def clear_all():
+ return gr.Textbox.update(value=''), gr.Textbox.update(value=''), gr.update(value=None), \
+ gr.Textbox.update(value=''), gr.Textbox.update(value='')
+
+ if kwargs['model_states']:
+ submits1 = submits2 = submits3 = []
+ submits4 = []
+
+ fun_source = [instruction.submit, submit.click, retry_btn.click]
+ fun_name = ['instruction', 'submit', 'retry']
+ user_args = [all_user_args, all_user_args, all_retry_user_args]
+ bot_args = [all_bot_args, all_bot_args, all_retry_bot_args]
+ for userargs1, botarg1, funn1, funs1 in zip(user_args, bot_args, fun_name, fun_source):
+ submit_event11 = funs1(fn=dummy_fun,
+ inputs=instruction, outputs=instruction, queue=queue)
+ submit_event1a = submit_event11.then(**userargs1, queue=queue,
+ api_name='%s' % funn1 if allow_api else None)
+ # if hit enter on new instruction for submitting new query, no longer the saved chat
+ submit_event1b = submit_event1a.then(clear_all, inputs=None,
+ outputs=[instruction, iinput, radio_chats, score_text,
+ score_text2],
+ queue=queue)
+ submit_event1c = submit_event1b.then(**botarg1,
+ api_name='%s_bot' % funn1 if allow_api else None,
+ queue=queue)
+ submit_event1d = submit_event1c.then(**all_score_args,
+ api_name='%s_bot_score' % funn1 if allow_api else None,
+ queue=queue)
+
+ submits1.extend([submit_event1a, submit_event1b, submit_event1c, submit_event1d])
+
+ # if undo, no longer the saved chat
+ submit_event4 = undo.click(fn=dummy_fun,
+ inputs=instruction, outputs=instruction, queue=queue) \
+ .then(**all_undo_user_args, api_name='undo' if allow_api else None) \
+ .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text,
+ score_text2], queue=queue) \
+ .then(**all_score_args, api_name='undo_score' if allow_api else None)
+ submits4 = [submit_event4]
+
else:
- score_args_submit = dict(fn=lambda: None, inputs=None, outputs=None)
- score_args2_submit = dict(fn=lambda: None, inputs=None, outputs=None)
-
- # in case 2nd model, consume instruction first, so can clear quickly
- # bot doesn't consume instruction itself, just history from user, so why works
- submit_event1a = instruction.submit(**user_args, queue=queue,
- api_name='instruction' if allow_api else None)
- submit_event1b = submit_event1a.then(**user_args2, api_name='instruction2' if allow_api else None)
- submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \
- .then(clear_instruct, None, iinput)
- submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
- queue=queue)
- submit_event1e = submit_event1d.then(**score_args_submit,
- api_name='instruction_bot_score' if allow_api else None,
- queue=queue)
- submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
- queue=queue)
- submit_event1g = submit_event1f.then(**score_args2_submit,
- api_name='instruction_bot_score2' if allow_api else None, queue=queue)
- submit_event1h = submit_event1g.then(clear_torch_cache)
-
- submit_event2a = submit.click(**user_args, api_name='submit' if allow_api else None)
- submit_event2b = submit_event2a.then(**user_args2, api_name='submit2' if allow_api else None)
- submit_event2c = submit_event2b.then(clear_instruct, None, instruction) \
- .then(clear_instruct, None, iinput)
- submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue)
- submit_event2e = submit_event2d.then(**score_args_submit, api_name='submit_bot_score' if allow_api else None,
- queue=queue)
- submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, queue=queue)
- submit_event2g = submit_event2f.then(**score_args2_submit, api_name='submit_bot_score2' if allow_api else None,
- queue=queue)
- submit_event2h = submit_event2g.then(clear_torch_cache)
-
- submit_event3a = retry.click(**user_args, api_name='retry' if allow_api else None)
- submit_event3b = submit_event3a.then(**user_args2, api_name='retry2' if allow_api else None)
- submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \
- .then(clear_instruct, None, iinput)
- submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None,
- queue=queue)
- submit_event3e = submit_event3d.then(**score_args_submit, api_name='retry_bot_score' if allow_api else None,
- queue=queue)
- submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None,
- queue=queue)
- submit_event3g = submit_event3f.then(**score_args2_submit, api_name='retry_bot_score2' if allow_api else None,
- queue=queue)
- submit_event3h = submit_event3g.then(clear_torch_cache)
-
- submit_event4 = undo.click(**undo_user_args, api_name='undo' if allow_api else None) \
- .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
- .then(clear_instruct, None, instruction) \
- .then(clear_instruct, None, iinput) \
- .then(**score_args_submit, api_name='undo_score' if allow_api else None) \
- .then(**score_args2_submit, api_name='undo_score2' if allow_api else None)
+ # in case 2nd model, consume instruction first, so can clear quickly
+ # bot doesn't consume instruction itself, just history from user, so why works
+ submit_event11 = instruction.submit(fn=dummy_fun,
+ inputs=instruction, outputs=instruction, queue=queue)
+ submit_event1a = submit_event11.then(**user_args, queue=queue,
+ api_name='instruction' if allow_api else None)
+ # if hit enter on new instruction for submitting new query, no longer the saved chat
+ submit_event1a2 = submit_event1a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue)
+ submit_event1b = submit_event1a2.then(**user_args2, api_name='instruction2' if allow_api else None)
+ submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \
+ .then(clear_instruct, None, iinput)
+ submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None,
+ queue=queue)
+ submit_event1e = submit_event1d.then(**score_args,
+ api_name='instruction_bot_score' if allow_api else None,
+ queue=queue)
+ submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None,
+ queue=queue)
+ submit_event1g = submit_event1f.then(**score_args2,
+ api_name='instruction_bot_score2' if allow_api else None, queue=queue)
+
+ submits1 = [submit_event1a, submit_event1a2, submit_event1b, submit_event1c, submit_event1d,
+ submit_event1e,
+ submit_event1f, submit_event1g]
+
+ submit_event21 = submit.click(fn=dummy_fun,
+ inputs=instruction, outputs=instruction, queue=queue)
+ submit_event2a = submit_event21.then(**user_args, api_name='submit' if allow_api else None)
+ # if submit new query, no longer the saved chat
+ submit_event2a2 = submit_event2a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue)
+ submit_event2b = submit_event2a2.then(**user_args2, api_name='submit2' if allow_api else None)
+ submit_event2c = submit_event2b.then(clear_all, inputs=None,
+ outputs=[instruction, iinput, radio_chats, score_text, score_text2],
+ queue=queue)
+ submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue)
+ submit_event2e = submit_event2d.then(**score_args,
+ api_name='submit_bot_score' if allow_api else None,
+ queue=queue)
+ submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None,
+ queue=queue)
+ submit_event2g = submit_event2f.then(**score_args2,
+ api_name='submit_bot_score2' if allow_api else None,
+ queue=queue)
+
+ submits2 = [submit_event2a, submit_event2a2, submit_event2b, submit_event2c, submit_event2d,
+ submit_event2e,
+ submit_event2f, submit_event2g]
+
+ submit_event31 = retry_btn.click(fn=dummy_fun,
+ inputs=instruction, outputs=instruction, queue=queue)
+ submit_event3a = submit_event31.then(**user_args, api_name='retry' if allow_api else None)
+ # if retry, no longer the saved chat
+ submit_event3a2 = submit_event3a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue)
+ submit_event3b = submit_event3a2.then(**user_args2, api_name='retry2' if allow_api else None)
+ submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \
+ .then(clear_instruct, None, iinput)
+ submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None,
+ queue=queue)
+ submit_event3e = submit_event3d.then(**score_args,
+ api_name='retry_bot_score' if allow_api else None,
+ queue=queue)
+ submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None,
+ queue=queue)
+ submit_event3g = submit_event3f.then(**score_args2,
+ api_name='retry_bot_score2' if allow_api else None,
+ queue=queue)
+
+ submits3 = [submit_event3a, submit_event3a2, submit_event3b, submit_event3c, submit_event3d,
+ submit_event3e,
+ submit_event3f, submit_event3g]
+
+ # if undo, no longer the saved chat
+ submit_event4 = undo.click(fn=dummy_fun,
+ inputs=instruction, outputs=instruction, queue=queue) \
+ .then(**undo_user_args, api_name='undo' if allow_api else None) \
+ .then(**undo_user_args2, api_name='undo2' if allow_api else None) \
+ .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text,
+ score_text2], queue=queue) \
+ .then(**score_args, api_name='undo_score' if allow_api else None) \
+ .then(**score_args2, api_name='undo_score2' if allow_api else None)
+ submits4 = [submit_event4]
# MANAGE CHATS
def dedup(short_chat, short_chats):
@@ -1133,50 +1564,80 @@ body.dark{#warning {background-color: #555555};}
#
etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation is_same = True # length of conversation has to be same + if len(x) != len(y): + return False if len(x) != len(y): return False for stepx, stepy in zip(x, y): if len(stepx) != len(stepy): # something off with a conversation return False - if len(stepx) != 2: - # something off - return False - if len(stepy) != 2: - # something off - return False - questionx = stepx[0].replace('
', '').replace('
', '') if stepx[0] is not None else None - answerx = stepx[1].replace('', '').replace('
', '') if stepx[1] is not None else None - - questiony = stepy[0].replace('', '').replace('
', '') if stepy[0] is not None else None - answery = stepy[1].replace('', '').replace('
', '') if stepy[1] is not None else None - - if questionx != questiony or answerx != answery: - return False + for stepxx, stepyy in zip(stepx, stepy): + if len(stepxx) != len(stepyy): + # something off with a conversation + return False + if len(stepxx) != 2: + # something off + return False + if len(stepyy) != 2: + # something off + return False + questionx = stepxx[0].replace('', '').replace('
', '') if stepxx[0] is not None else None + answerx = stepxx[1].replace('', '').replace('
', '') if stepxx[1] is not None else None + + questiony = stepyy[0].replace('', '').replace('
', '') if stepyy[0] is not None else None + answery = stepyy[1].replace('', '').replace('
', '') if stepyy[1] is not None else None + + if questionx != questiony or answerx != answery: + return False return is_same - def save_chat(chat1, chat2, chat_state1): + def save_chat(*args): + args_list = list(args) + chat_list = args_list[:-1] # list of chatbot histories + # remove None histories + chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] + chat_state1 = args_list[ + -1] # dict with keys of short chat names, values of list of list of chatbot histories short_chats = list(chat_state1.keys()) - for chati in [chat1, chat2]: - if chati and len(chati) > 0 and len(chati[0]) == 2 and chati[0][1] is not None: - short_chat = get_short_chat(chati, short_chats) - if short_chat: - already_exists = any([is_chat_same(chati, x) for x in chat_state1.values()]) - if not already_exists: - chat_state1[short_chat] = chati - return chat_state1 + if len(chat_list_not_none) > 0: + # make short_chat key from only first history, based upon question that is same anyways + chat_first = chat_list_not_none[0] + short_chat = get_short_chat(chat_first, short_chats) + if short_chat: + old_chat_lists = list(chat_state1.values()) + already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) + if not already_exists: + chat_state1[short_chat] = chat_list.copy() + # clear chat_list so saved and then new conversation starts + chat_list = [[]] * len(chat_list) + ret_list = chat_list + [chat_state1] + return tuple(ret_list) def update_radio_chats(chat_state1): return gr.update(choices=list(chat_state1.keys()), value=None) - def deselect_radio_chats(): - return gr.update(value=None) - - def switch_chat(chat_key, chat_state1): + def switch_chat(chat_key, chat_state1, num_model_lock=0): chosen_chat = chat_state1[chat_key] - return chosen_chat, chosen_chat - - radio_chats.input(switch_chat, inputs=[radio_chats, chat_state], outputs=[text_output, text_output2]) + # deal with possible different size of chat list vs. current list + ret_chat = [None] * (2 + num_model_lock) + for chati in range(0, 2 + num_model_lock): + ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] + return tuple(ret_chat) + + def clear_texts(*args): + return tuple([gr.Textbox.update(value='')] * len(args)) + + def clear_scores(): + return gr.Textbox.update(value=res_value), \ + gr.Textbox.update(value='Response Score: NA'), \ + gr.Textbox.update(value='Response Score: NA') + + switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) + radio_chats.input(switch_chat_fun, + inputs=[radio_chats, chat_state], + outputs=[text_output, text_output2] + text_outputs) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) def remove_chat(chat_key, chat_state1): chat_state1.pop(chat_key, None) @@ -1213,9 +1674,11 @@ body.dark{#warning {background-color: #555555};} new_chats = json.loads(f.read()) for chat1_k, chat1_v in new_chats.items(): # ignore chat1_k, regenerate and de-dup to avoid loss - chat_state1 = save_chat(chat1_v, None, chat_state1) + _, chat_state1 = save_chat(chat1_v, chat_state1) except BaseException as e: - print("Add chats exception: %s" % str(e), flush=True) + t, v, tb = sys.exc_info() + ex = ''.join(traceback.format_exception(t, v, tb)) + print("Add chats exception: %s" % str(ex), flush=True) return chat_state1, add_btn # note for update_user_db_func output is ignored for db @@ -1226,51 +1689,73 @@ body.dark{#warning {background-color: #555555};} .then(clear_file_list, outputs=chatsup_output, queue=False) \ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, queue=False) - clear_chat_btn.click(lambda: None, None, text_output, queue=False, api_name='clear' if allow_api else None) \ - .then(lambda: None, None, text_output2, queue=False, api_name='clear2' if allow_api else None) \ - .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) + clear_chat_btn.click(fn=clear_texts, + inputs=[text_output, text_output2] + text_outputs, + outputs=[text_output, text_output2] + text_outputs, + queue=False, api_name='clear' if allow_api else None) \ + .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) # does both models - clear.click(save_chat, inputs=[text_output, text_output2, chat_state], outputs=chat_state, + clear.click(save_chat, + inputs=[text_output, text_output2] + text_outputs + [chat_state], + outputs=[text_output, text_output2] + text_outputs + [chat_state], api_name='save_chat' if allow_api else None) \ .then(update_radio_chats, inputs=chat_state, outputs=radio_chats, api_name='update_chats' if allow_api else None) \ - .then(lambda: None, None, text_output, queue=False, api_name='clearB' if allow_api else None) \ - .then(lambda: None, None, text_output2, queue=False, api_name='clearB2' if allow_api else None) + .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) + # NOTE: clear of instruction/iinput for nochat has to come after score, # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() - submit_event_nochat = submit_nochat.click(fun, - inputs=[model_state, my_db_state] + inputs_list, - outputs=text_output_nochat, - queue=queue, - api_name='submit_nochat' if allow_api else None) \ + no_chat_args = dict(fn=fun, + inputs=[model_state, my_db_state] + inputs_list, + outputs=text_output_nochat, + queue=queue, + ) + submit_event_nochat = submit_nochat.click(**no_chat_args, api_name='submit_nochat' if allow_api else None) \ + .then(clear_torch_cache) \ .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \ .then(clear_instruct, None, instruction_nochat) \ .then(clear_instruct, None, iinput_nochat) \ .then(clear_torch_cache) + # copy of above with text box submission + submit_event_nochat2 = instruction_nochat.submit(**no_chat_args) \ + .then(clear_torch_cache) \ + .then(**score_args_nochat, queue=queue) \ + .then(clear_instruct, None, instruction_nochat) \ + .then(clear_instruct, None, iinput_nochat) \ + .then(clear_torch_cache) - def load_model(model_name, lora_weights, model_state_old, prompt_type_old, load_8bit, infer_devices, gpu_id): + submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str, + inputs=[model_state, my_db_state, inputs_dict_str], + outputs=text_output_nochat_api, + queue=True, # required for generator + api_name='submit_nochat_api' if allow_api else None) \ + .then(clear_torch_cache) + + def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit, + infer_devices, gpu_id): # ensure old model removed from GPU memory if kwargs['debug']: print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) - model0 = model_state0[0] - if isinstance(model_state_old[0], str) and model0 is not None: + model0 = model_state0['model'] + if isinstance(model_state_old['model'], str) and model0 is not None: # best can do, move model loaded at first to CPU model0.cpu() - if model_state_old[0] is not None and not isinstance(model_state_old[0], str): + if model_state_old['model'] is not None and not isinstance(model_state_old['model'], str): try: - model_state_old[0].cpu() + model_state_old['model'].cpu() except Exception as e: # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! print("Unable to put model on CPU: %s" % str(e), flush=True) - del model_state_old[0] - model_state_old[0] = None + del model_state_old['model'] + model_state_old['model'] = None - if model_state_old[1] is not None and not isinstance(model_state_old[1], str): - del model_state_old[1] - model_state_old[1] = None + if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): + del model_state_old['tokenizer'] + model_state_old['tokenizer'] = None clear_torch_cache() if kwargs['debug']: @@ -1280,7 +1765,11 @@ body.dark{#warning {background-color: #555555};} # no-op if no model, just free memory # no detranscribe needed for model, never go into evaluate lora_weights = no_lora_str - return [None, None, None, model_name], model_name, lora_weights, prompt_type_old + server_name = no_server_str + return [None, None, None, model_name, server_name], \ + model_name, lora_weights, server_name, prompt_type_old, \ + gr.Slider.update(maximum=256), \ + gr.Slider.update(maximum=256) # don't deepcopy, can contain model itself all_kwargs1 = all_kwargs.copy() @@ -1297,16 +1786,50 @@ body.dark{#warning {background-color: #555555};} # detranscribe if lora_weights == no_lora_str: lora_weights = '' - all_kwargs1['lora_weights'] = lora_weights.strip() + if server_name == no_server_str: + server_name = '' + all_kwargs1['inference_server'] = server_name.strip() + model1, tokenizer1, device1 = get_model(reward_type=False, **get_kwargs(get_model, exclude_names=['reward_type'], **all_kwargs1)) clear_torch_cache() + tokenizer_base_model = model_name + prompt_dict1, error0 = get_prompt(prompt_type1, '', + chat=False, context='', reduced=False, making_context=False, + return_dict=True) + model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, + base_model=model_name, tokenizer_base_model=tokenizer_base_model, + lora_weights=lora_weights, inference_server=server_name, + prompt_type=prompt_type1, prompt_dict=prompt_dict1, + ) + + max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) + if kwargs['debug']: print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) - return [model1, tokenizer1, device1, model_name], model_name, lora_weights, prompt_type1 + return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ + gr.Slider.update(maximum=max_max_new_tokens1), \ + gr.Slider.update(maximum=max_max_new_tokens1) + + def get_prompt_str(prompt_type1, prompt_dict1, which=0): + if prompt_type1 in ['', None]: + print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) + return str({}) + prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', + reduced=False, making_context=False, return_dict=True) + if prompt_dict_error: + return str(prompt_dict_error) + else: + # return so user can manipulate if want and use as custom + return str(prompt_dict1) + + get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) + get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) + prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict) + prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2) def dropdown_prompt_type_list(x): return gr.Dropdown.update(value=x) @@ -1315,9 +1838,12 @@ body.dark{#warning {background-color: #555555};} return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') load_model_args = dict(fn=load_model, - inputs=[model_choice, lora_choice, model_state, prompt_type, + inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, model_load8bit_checkbox, model_infer_devices_checkbox, model_gpu], - outputs=[model_state, model_used, lora_used, prompt_type]) + outputs=[model_state, model_used, lora_used, server_used, + # if prompt_type changes, prompt_dict will change via change rule + prompt_type, max_new_tokens, min_new_tokens, + ]) prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat) @@ -1329,9 +1855,12 @@ body.dark{#warning {background-color: #555555};} .then(clear_torch_cache) load_model_args2 = dict(fn=load_model, - inputs=[model_choice2, lora_choice2, model_state2, prompt_type2, + inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, model_load8bit_checkbox2, model_infer_devices_checkbox2, model_gpu2], - outputs=[model_state2, model_used2, lora_used2, prompt_type2]) + outputs=[model_state2, model_used2, lora_used2, server_used2, + # if prompt_type2 changes, prompt_dict2 will change via change rule + prompt_type2, max_new_tokens2, min_new_tokens2 + ]) prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2) chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2) if not is_public: @@ -1341,32 +1870,51 @@ body.dark{#warning {background-color: #555555};} .then(**chatbot_update_args2) \ .then(clear_torch_cache) - def dropdown_model_list(list0, x): - new_state = [list0[0] + [x]] - new_options = [*new_state[0]] - return gr.Dropdown.update(value=x, choices=new_options), \ - gr.Dropdown.update(value=x, choices=new_options), \ - '', new_state - - add_model_event = add_model_button.click(fn=dropdown_model_list, - inputs=[model_options_state, new_model], - outputs=[model_choice, model_choice2, new_model, model_options_state], - queue=False) - - def dropdown_lora_list(list0, x, model_used1, lora_used1, model_used2, lora_used2): - new_state = [list0[0] + [x]] - new_options = [*new_state[0]] + def dropdown_model_lora_server_list(model_list0, model_x, + lora_list0, lora_x, + server_list0, server_x, + model_used1, lora_used1, server_used1, + model_used2, lora_used2, server_used2, + ): + model_new_state = [model_list0[0] + [model_x]] + model_new_options = [*model_new_state[0]] + x1 = model_x if model_used1 == no_model_str else model_used1 + x2 = model_x if model_used2 == no_model_str else model_used2 + ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), + gr.Dropdown.update(value=x2, choices=model_new_options), + '', model_new_state] + + lora_new_state = [lora_list0[0] + [lora_x]] + lora_new_options = [*lora_new_state[0]] # don't switch drop-down to added lora if already have model loaded - x1 = x if model_used1 == no_model_str else lora_used1 - x2 = x if model_used2 == no_model_str else lora_used2 - return gr.Dropdown.update(value=x1, choices=new_options), \ - gr.Dropdown.update(value=x2, choices=new_options), \ - '', new_state - - add_lora_event = add_lora_button.click(fn=dropdown_lora_list, - inputs=[lora_options_state, new_lora, model_used, lora_used, model_used2, - lora_used2], - outputs=[lora_choice, lora_choice2, new_lora, lora_options_state], + x1 = lora_x if model_used1 == no_model_str else lora_used1 + x2 = lora_x if model_used2 == no_model_str else lora_used2 + ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), + gr.Dropdown.update(value=x2, choices=lora_new_options), + '', lora_new_state] + + server_new_state = [server_list0[0] + [server_x]] + server_new_options = [*server_new_state[0]] + # don't switch drop-down to added server if already have model loaded + x1 = server_x if model_used1 == no_model_str else server_used1 + x2 = server_x if model_used2 == no_model_str else server_used2 + ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), + gr.Dropdown.update(value=x2, choices=server_new_options), + '', server_new_state] + + return tuple(ret1 + ret2 + ret3) + + add_model_lora_server_event = \ + add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, + inputs=[model_options_state, new_model] + + [lora_options_state, new_lora] + + [server_options_state, new_server] + + [model_used, lora_used, server_used] + + [model_used2, lora_used2, server_used2], + outputs=[model_choice, model_choice2, new_model, model_options_state] + + [lora_choice, lora_choice2, new_lora, lora_options_state] + + [server_choice, server_choice2, new_server, + server_options_state], queue=False) go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, queue=False) \ @@ -1382,16 +1930,22 @@ body.dark{#warning {background-color: #555555};} def compare_prompt_fun(x): return gr.Dropdown.update(visible=x) + def slider_fun(x): + return gr.Slider.update(visible=x) + compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, api_name="compare_checkbox" if allow_api else None) \ .then(compare_column_fun, compare_checkbox, col_model2) \ .then(compare_prompt_fun, compare_checkbox, prompt_type2) \ - .then(compare_textbox_fun, compare_checkbox, score_text2) + .then(compare_textbox_fun, compare_checkbox, score_text2) \ + .then(slider_fun, compare_checkbox, max_new_tokens2) \ + .then(slider_fun, compare_checkbox, min_new_tokens2) # FIXME: add score_res2 in condition, but do better # callback for logging flagged input/output - callback.setup(inputs_list + [text_output, text_output2], "flagged_data_points") - flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2], None, + callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") + flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, + None, preprocess=False, api_name='flag' if allow_api else None, queue=False) flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, @@ -1399,25 +1953,64 @@ body.dark{#warning {background-color: #555555};} api_name='flag_nochat' if allow_api else None, queue=False) def get_system_info(): + if is_public: + time.sleep(10) # delay to avoid spam since queue=False return gr.Textbox.update(value=system_info_print()) system_event = system_btn.click(get_system_info, outputs=system_text, api_name='system_info' if allow_api else None, queue=False) + def get_system_info_dict(system_input1, **kwargs1): + if system_input1 != os.getenv("ADMIN_PASS", ""): + return json.dumps({}) + exclude_list = ['admin_pass', 'examples'] + sys_dict = {k: v for k, v in kwargs1.items() if + isinstance(v, (str, int, bool, float)) and k not in exclude_list} + try: + sys_dict.update(system_info()) + except Exception as e: + # protection + print("Exception: %s" % str(e), flush=True) + return json.dumps(sys_dict) + + get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) + + system_dict_event = system_btn2.click(get_system_info_dict_func, + inputs=system_input, + outputs=system_text2, + api_name='system_info_dict' if allow_api else None, + queue=False, # queue to avoid spam + ) + + def get_hash(): + return kwargs['git_hash'] + + system_btn3.click(get_hash, + outputs=system_text3, + api_name='system_hash' if allow_api else None, + queue=False, + ) + # don't pass text_output, don't want to clear output, just stop it # cancel only stops outer generation, not inner generation or non-generation stop_btn.click(lambda: None, None, None, - cancels=[submit_event1d, submit_event1f, - submit_event2d, submit_event2f, - submit_event3d, submit_event3f, - submit_event_nochat], + cancels=submits1 + submits2 + submits3 + + submits4 + + [submit_event_nochat, submit_event_nochat2] + + [eventdb1, eventdb2, eventdb3, + eventdb4, eventdb5, eventdb6] + + [eventdb7, eventdb8, eventdb9] + , queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) - def count_chat_tokens(model_state1, chat1, prompt_type1): - if model_state1 and not isinstance(model_state1[1], str): - tokenizer = model_state1[1] - elif model_state0 and not isinstance(model_state0[1], str): - tokenizer = model_state0[1] + def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, + memory_restriction_level1=0, + keep_sources_in_context1=False, + ): + if model_state1 and not isinstance(model_state1['tokenizer'], str): + tokenizer = model_state1['tokenizer'] + elif model_state0 and not isinstance(model_state0['tokenizer'], str): + tokenizer = model_state0['tokenizer'] else: tokenizer = None if tokenizer is not None: @@ -1425,18 +2018,28 @@ body.dark{#warning {background-color: #555555};} # fake user message to mimic bot() chat1 = copy.deepcopy(chat1) chat1 = chat1 + [['user_message1', None]] - context1 = history_to_context(chat1, langchain_mode1, prompt_type1, chat1) + model_max_length1 = tokenizer.model_max_length + context1 = history_to_context(chat1, langchain_mode1, prompt_type1, prompt_dict1, chat1, + model_max_length1, + memory_restriction_level1, keep_sources_in_context1) return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1]) else: return "N/A" - count_chat_tokens_btn.click(fn=count_chat_tokens, inputs=[model_state, text_output, prompt_type], + count_chat_tokens_func = functools.partial(count_chat_tokens, + memory_restriction_level1=memory_restriction_level, + keep_sources_in_context1=kwargs['keep_sources_in_context']) + count_chat_tokens_btn.click(fn=count_chat_tokens, + inputs=[model_state, text_output, prompt_type, prompt_dict], outputs=chat_token_count, api_name='count_tokens' if allow_api else None) - demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] else None) + demo.load(None, None, None, _js=get_dark_js() if kwargs['h2ocolors'] and False else None) # light best demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) favicon_path = "h2o-logo.svg" + if not os.path.isfile(favicon_path): + print("favicon_path=%s not found" % favicon_path, flush=True) + favicon_path = None scheduler = BackgroundScheduler() scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) @@ -1445,6 +2048,7 @@ body.dark{#warning {background-color: #555555};} # FIXME: disable for gptj, langchain or gpt4all modify print itself # FIXME: and any multi-threaded/async print will enter model output! scheduler.add_job(func=ping, trigger="interval", seconds=60) + scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) scheduler.start() # import control @@ -1466,15 +2070,17 @@ body.dark{#warning {background-color: #555555};} input_args_list = ['model_state', 'my_db_state'] -def get_inputs_list(inputs_dict, model_lower): +def get_inputs_list(inputs_dict, model_lower, model_id=1): """ map gradio objects in locals() to inputs for evaluate(). :param inputs_dict: :param model_lower: + :param model_id: Which model (1 or 2) of 2 :return: """ inputs_list_names = list(inspect.signature(evaluate).parameters) inputs_list = [] + inputs_dict_out = {} for k in inputs_list_names: if k == 'kwargs': continue @@ -1483,8 +2089,18 @@ def get_inputs_list(inputs_dict, model_lower): continue if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']: continue + if model_id == 2: + if k == 'prompt_type': + k = 'prompt_type2' + if k == 'prompt_used': + k = 'prompt_used2' + if k == 'max_new_tokens': + k = 'max_new_tokens2' + if k == 'min_new_tokens': + k = 'min_new_tokens2' inputs_list.append(inputs_dict[k]) - return inputs_list + inputs_dict_out[k] = inputs_dict[k] + return inputs_list, inputs_dict_out def get_sources(db1, langchain_mode, dbs=None, docs_state0=None): @@ -1496,18 +2112,22 @@ def get_sources(db1, langchain_mode, dbs=None, docs_state0=None): " Ask jon.mckinney@h2o.ai for file if required." source_list = [] elif langchain_mode == 'MyData' and len(db1) > 0 and db1[0] is not None: - db_get = db1[0].get() - source_list = sorted(set([x['source'] for x in db_get['metadatas']])) + from gpt_langchain import get_metadatas + metadatas = get_metadatas(db1[0]) + source_list = sorted(set([x['source'] for x in metadatas])) source_files_added = '\n'.join(source_list) elif langchain_mode in dbs and dbs[langchain_mode] is not None: + from gpt_langchain import get_metadatas db1 = dbs[langchain_mode] - db_get = db1.get() - source_list = sorted(set([x['source'] for x in db_get['metadatas']])) + metadatas = get_metadatas(db1) + source_list = sorted(set([x['source'] for x in metadatas])) source_files_added = '\n'.join(source_list) else: source_list = [] source_files_added = "None" - sources_file = 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4())) + sources_dir = "sources_dir" + makedirs(sources_dir) + sources_file = os.path.join(sources_dir, 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))) with open(sources_file, "wt") as f: f.write(source_files_added) source_list = docs_state0 + source_list @@ -1534,21 +2154,35 @@ def update_user_db(file, db1, x, y, *args, dbs=None, langchain_mode='UserData',