diff --git a/LICENSE b/LICENSE deleted file mode 100644 index 25ae4110625608b553d170b6bb5c439215503afe..0000000000000000000000000000000000000000 --- a/LICENSE +++ /dev/null @@ -1,201 +0,0 @@ - Apache License - Version 2.0, January 2004 - http://www.apache.org/licenses/ - - TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - - 1. Definitions. - - "License" shall mean the terms and conditions for use, reproduction, - and distribution as defined by Sections 1 through 9 of this document. - - "Licensor" shall mean the copyright owner or entity authorized by - the copyright owner that is granting the License. - - "Legal Entity" shall mean the union of the acting entity and all - other entities that control, are controlled by, or are under common - control with that entity. For the purposes of this definition, - "control" means (i) the power, direct or indirect, to cause the - direction or management of such entity, whether by contract or - otherwise, or (ii) ownership of fifty percent (50%) or more of the - outstanding shares, or (iii) beneficial ownership of such entity. - - "You" (or "Your") shall mean an individual or Legal Entity - exercising permissions granted by this License. - - "Source" form shall mean the preferred form for making modifications, - including but not limited to software source code, documentation - source, and configuration files. - - "Object" form shall mean any form resulting from mechanical - transformation or translation of a Source form, including but - not limited to compiled object code, generated documentation, - and conversions to other media types. - - "Work" shall mean the work of authorship, whether in Source or - Object form, made available under the License, as indicated by a - copyright notice that is included in or attached to the work - (an example is provided in the Appendix below). - - "Derivative Works" shall mean any work, whether in Source or Object - form, that is based on (or derived from) the Work and for which the - editorial revisions, annotations, elaborations, or other modifications - represent, as a whole, an original work of authorship. For the purposes - of this License, Derivative Works shall not include works that remain - separable from, or merely link (or bind by name) to the interfaces of, - the Work and Derivative Works thereof. - - "Contribution" shall mean any work of authorship, including - the original version of the Work and any modifications or additions - to that Work or Derivative Works thereof, that is intentionally - submitted to Licensor for inclusion in the Work by the copyright owner - or by an individual or Legal Entity authorized to submit on behalf of - the copyright owner. For the purposes of this definition, "submitted" - means any form of electronic, verbal, or written communication sent - to the Licensor or its representatives, including but not limited to - communication on electronic mailing lists, source code control systems, - and issue tracking systems that are managed by, or on behalf of, the - Licensor for the purpose of discussing and improving the Work, but - excluding communication that is conspicuously marked or otherwise - designated in writing by the copyright owner as "Not a Contribution." - - "Contributor" shall mean Licensor and any individual or Legal Entity - on behalf of whom a Contribution has been received by Licensor and - subsequently incorporated within the Work. - - 2. Grant of Copyright License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - copyright license to reproduce, prepare Derivative Works of, - publicly display, publicly perform, sublicense, and distribute the - Work and such Derivative Works in Source or Object form. - - 3. Grant of Patent License. Subject to the terms and conditions of - this License, each Contributor hereby grants to You a perpetual, - worldwide, non-exclusive, no-charge, royalty-free, irrevocable - (except as stated in this section) patent license to make, have made, - use, offer to sell, sell, import, and otherwise transfer the Work, - where such license applies only to those patent claims licensable - by such Contributor that are necessarily infringed by their - Contribution(s) alone or by combination of their Contribution(s) - with the Work to which such Contribution(s) was submitted. If You - institute patent litigation against any entity (including a - cross-claim or counterclaim in a lawsuit) alleging that the Work - or a Contribution incorporated within the Work constitutes direct - or contributory patent infringement, then any patent licenses - granted to You under this License for that Work shall terminate - as of the date such litigation is filed. - - 4. Redistribution. You may reproduce and distribute copies of the - Work or Derivative Works thereof in any medium, with or without - modifications, and in Source or Object form, provided that You - meet the following conditions: - - (a) You must give any other recipients of the Work or - Derivative Works a copy of this License; and - - (b) You must cause any modified files to carry prominent notices - stating that You changed the files; and - - (c) You must retain, in the Source form of any Derivative Works - that You distribute, all copyright, patent, trademark, and - attribution notices from the Source form of the Work, - excluding those notices that do not pertain to any part of - the Derivative Works; and - - (d) If the Work includes a "NOTICE" text file as part of its - distribution, then any Derivative Works that You distribute must - include a readable copy of the attribution notices contained - within such NOTICE file, excluding those notices that do not - pertain to any part of the Derivative Works, in at least one - of the following places: within a NOTICE text file distributed - as part of the Derivative Works; within the Source form or - documentation, if provided along with the Derivative Works; or, - within a display generated by the Derivative Works, if and - wherever such third-party notices normally appear. The contents - of the NOTICE file are for informational purposes only and - do not modify the License. You may add Your own attribution - notices within Derivative Works that You distribute, alongside - or as an addendum to the NOTICE text from the Work, provided - that such additional attribution notices cannot be construed - as modifying the License. - - You may add Your own copyright statement to Your modifications and - may provide additional or different license terms and conditions - for use, reproduction, or distribution of Your modifications, or - for any such Derivative Works as a whole, provided Your use, - reproduction, and distribution of the Work otherwise complies with - the conditions stated in this License. - - 5. Submission of Contributions. Unless You explicitly state otherwise, - any Contribution intentionally submitted for inclusion in the Work - by You to the Licensor shall be under the terms and conditions of - this License, without any additional terms or conditions. - Notwithstanding the above, nothing herein shall supersede or modify - the terms of any separate license agreement you may have executed - with Licensor regarding such Contributions. - - 6. Trademarks. This License does not grant permission to use the trade - names, trademarks, service marks, or product names of the Licensor, - except as required for reasonable and customary use in describing the - origin of the Work and reproducing the content of the NOTICE file. - - 7. Disclaimer of Warranty. Unless required by applicable law or - agreed to in writing, Licensor provides the Work (and each - Contributor provides its Contributions) on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or - implied, including, without limitation, any warranties or conditions - of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A - PARTICULAR PURPOSE. You are solely responsible for determining the - appropriateness of using or redistributing the Work and assume any - risks associated with Your exercise of permissions under this License. - - 8. Limitation of Liability. In no event and under no legal theory, - whether in tort (including negligence), contract, or otherwise, - unless required by applicable law (such as deliberate and grossly - negligent acts) or agreed to in writing, shall any Contributor be - liable to You for damages, including any direct, indirect, special, - incidental, or consequential damages of any character arising as a - result of this License or out of the use or inability to use the - Work (including but not limited to damages for loss of goodwill, - work stoppage, computer failure or malfunction, or any and all - other commercial damages or losses), even if such Contributor - has been advised of the possibility of such damages. - - 9. Accepting Warranty or Additional Liability. While redistributing - the Work or Derivative Works thereof, You may choose to offer, - and charge a fee for, acceptance of support, warranty, indemnity, - or other liability obligations and/or rights consistent with this - License. However, in accepting such obligations, You may act only - on Your own behalf and on Your sole responsibility, not on behalf - of any other Contributor, and only if You agree to indemnify, - defend, and hold each Contributor harmless for any liability - incurred by, or claims asserted against, such Contributor by reason - of your accepting any such warranty or additional liability. - - END OF TERMS AND CONDITIONS - - APPENDIX: How to apply the Apache License to your work. - - To apply the Apache License to your work, attach the following - boilerplate notice, with the fields enclosed by brackets "[]" - replaced with your own identifying information. (Don't include - the brackets!) The text should be enclosed in the appropriate - comment syntax for the file format. We also recommend that a - file or class name and description of purpose be included on the - same "printed page" as the copyright notice for easier - identification within third-party archives. - - Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li - - Licensed under the Apache License, Version 2.0 (the "License"); - you may not use this file except in compliance with the License. - You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - - Unless required by applicable law or agreed to in writing, software - distributed under the License is distributed on an "AS IS" BASIS, - WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - See the License for the specific language governing permissions and - limitations under the License. diff --git a/client_test.py b/client_test.py deleted file mode 100644 index 2ba7e199b0b1e2fc662a3c5b60bb2c6c7d56cad5..0000000000000000000000000000000000000000 --- a/client_test.py +++ /dev/null @@ -1,362 +0,0 @@ -""" -Client test. - -Run server: - -python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b - -NOTE: For private models, add --use-auth_token=True - -NOTE: --use_gpu_id=True (default) must be used for multi-GPU in case see failures with cuda:x cuda:y mismatches. -Currently, this will force model to be on a single GPU. - -Then run this client as: - -python src/client_test.py - - - -For HF spaces: - -HOST="https://h2oai-h2ogpt-chatbot.hf.space" python src/client_test.py - -Result: - -Loaded as API: https://h2oai-h2ogpt-chatbot.hf.space ✔ -{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a large language model developed by LAION.', 'sources': ''} - - -For demo: - -HOST="https://gpt.h2o.ai" python src/client_test.py - -Result: - -Loaded as API: https://gpt.h2o.ai ✔ -{'instruction_nochat': 'Who are you?', 'iinput_nochat': '', 'response': 'I am h2oGPT, a chatbot created by LAION.', 'sources': ''} - -NOTE: Raw output from API for nochat case is a string of a python dict and will remain so if other entries are added to dict: - -{'response': "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", 'sources': ''} - - -""" -import ast -import time -import os -import markdown # pip install markdown -import pytest -from bs4 import BeautifulSoup # pip install beautifulsoup4 - -from enums import DocumentSubset, LangChainAction - -debug = False - -os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' - - -def get_client(serialize=True): - from gradio_client import Client - - client = Client(os.getenv('HOST', "http://localhost:7860"), serialize=serialize) - if debug: - print(client.view_api(all_endpoints=True)) - return client - - -def get_args(prompt, prompt_type, chat=False, stream_output=False, - max_new_tokens=50, - top_k_docs=3, - langchain_mode='Disabled', - add_chat_history_to_context=True, - langchain_action=LangChainAction.QUERY.value, - langchain_agents=[], - prompt_dict=None): - from collections import OrderedDict - kwargs = OrderedDict(instruction=prompt if chat else '', # only for chat=True - iinput='', # only for chat=True - context='', - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - temperature=0.1, - top_p=0.75, - top_k=40, - num_beams=1, - max_new_tokens=max_new_tokens, - min_new_tokens=0, - early_stopping=False, - max_time=20, - repetition_penalty=1.0, - num_return_sequences=1, - do_sample=True, - chat=chat, - instruction_nochat=prompt if not chat else '', - iinput_nochat='', # only for chat=False - langchain_mode=langchain_mode, - add_chat_history_to_context=add_chat_history_to_context, - langchain_action=langchain_action, - langchain_agents=langchain_agents, - top_k_docs=top_k_docs, - chunk=True, - chunk_size=512, - document_subset=DocumentSubset.Relevant.name, - document_choice=[], - ) - from evaluate_params import eval_func_param_names - assert len(set(eval_func_param_names).difference(set(list(kwargs.keys())))) == 0 - if chat: - # add chatbot output on end. Assumes serialize=False - kwargs.update(dict(chatbot=[])) - - return kwargs, list(kwargs.values()) - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_basic(prompt_type='human_bot'): - return run_client_nochat(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) - - -def run_client_nochat(prompt, prompt_type, max_new_tokens): - kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens) - - api_name = '/submit_nochat' - client = get_client(serialize=True) - res = client.predict( - *tuple(args), - api_name=api_name, - ) - print("Raw client result: %s" % res, flush=True) - res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], - response=md_to_text(res)) - print(res_dict) - return res_dict, client - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_basic_api(prompt_type='human_bot'): - return run_client_nochat_api(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) - - -def run_client_nochat_api(prompt, prompt_type, max_new_tokens): - kwargs, args = get_args(prompt, prompt_type, chat=False, max_new_tokens=max_new_tokens) - - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - client = get_client(serialize=True) - res = client.predict( - str(dict(kwargs)), - api_name=api_name, - ) - print("Raw client result: %s" % res, flush=True) - res_dict = dict(prompt=kwargs['instruction_nochat'], iinput=kwargs['iinput_nochat'], - response=md_to_text(ast.literal_eval(res)['response']), - sources=ast.literal_eval(res)['sources']) - print(res_dict) - return res_dict, client - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_basic_api_lean(prompt_type='human_bot'): - return run_client_nochat_api_lean(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) - - -def run_client_nochat_api_lean(prompt, prompt_type, max_new_tokens): - kwargs = dict(instruction_nochat=prompt) - - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - client = get_client(serialize=True) - res = client.predict( - str(dict(kwargs)), - api_name=api_name, - ) - print("Raw client result: %s" % res, flush=True) - res_dict = dict(prompt=kwargs['instruction_nochat'], - response=md_to_text(ast.literal_eval(res)['response']), - sources=ast.literal_eval(res)['sources']) - print(res_dict) - return res_dict, client - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_basic_api_lean_morestuff(prompt_type='human_bot'): - return run_client_nochat_api_lean_morestuff(prompt='Who are you?', prompt_type=prompt_type, max_new_tokens=50) - - -def run_client_nochat_api_lean_morestuff(prompt, prompt_type='human_bot', max_new_tokens=512): - kwargs = dict( - instruction='', - iinput='', - context='', - stream_output=False, - prompt_type=prompt_type, - temperature=0.1, - top_p=0.75, - top_k=40, - num_beams=1, - max_new_tokens=256, - min_new_tokens=0, - early_stopping=False, - max_time=20, - repetition_penalty=1.0, - num_return_sequences=1, - do_sample=True, - chat=False, - instruction_nochat=prompt, - iinput_nochat='', - langchain_mode='Disabled', - add_chat_history_to_context=True, - langchain_action=LangChainAction.QUERY.value, - langchain_agents=[], - top_k_docs=4, - document_subset=DocumentSubset.Relevant.name, - document_choice=[], - ) - - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - client = get_client(serialize=True) - res = client.predict( - str(dict(kwargs)), - api_name=api_name, - ) - print("Raw client result: %s" % res, flush=True) - res_dict = dict(prompt=kwargs['instruction_nochat'], - response=md_to_text(ast.literal_eval(res)['response']), - sources=ast.literal_eval(res)['sources']) - print(res_dict) - return res_dict, client - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_chat(prompt_type='human_bot'): - return run_client_chat(prompt='Who are you?', prompt_type=prompt_type, stream_output=False, max_new_tokens=50, - langchain_mode='Disabled', - langchain_action=LangChainAction.QUERY.value, - langchain_agents=[]) - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_chat_stream(prompt_type='human_bot'): - return run_client_chat(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, - stream_output=True, max_new_tokens=512, - langchain_mode='Disabled', - langchain_action=LangChainAction.QUERY.value, - langchain_agents=[]) - - -def run_client_chat(prompt, prompt_type, stream_output, max_new_tokens, - langchain_mode, langchain_action, langchain_agents, - prompt_dict=None): - client = get_client(serialize=False) - - kwargs, args = get_args(prompt, prompt_type, chat=True, stream_output=stream_output, - max_new_tokens=max_new_tokens, - langchain_mode=langchain_mode, - langchain_action=langchain_action, - langchain_agents=langchain_agents, - prompt_dict=prompt_dict) - return run_client(client, prompt, args, kwargs) - - -def run_client(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): - assert kwargs['chat'], "Chat mode only" - res = client.predict(*tuple(args), api_name='/instruction') - args[-1] += [res[-1]] - - res_dict = kwargs - res_dict['prompt'] = prompt - if not kwargs['stream_output']: - res = client.predict(*tuple(args), api_name='/instruction_bot') - res_dict['response'] = res[0][-1][1] - print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) - return res_dict, client - else: - job = client.submit(*tuple(args), api_name='/instruction_bot') - res1 = '' - while not job.done(): - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res1 = res[0][-1][-1] - res1 = md_to_text(res1, do_md_to_text=do_md_to_text) - print(res1) - time.sleep(0.1) - full_outputs = job.outputs() - if verbose: - print('job.outputs: %s' % str(full_outputs)) - # ensure get ending to avoid race - # -1 means last response if streaming - # 0 means get text_output, ignore exception_text - # 0 means get list within text_output that looks like [[prompt], [answer]] - # 1 means get bot answer, so will have last bot answer - res_dict['response'] = md_to_text(full_outputs[-1][0][0][1], do_md_to_text=do_md_to_text) - return res_dict, client - - -@pytest.mark.skip(reason="For manual use against some server, no server launched") -def test_client_nochat_stream(prompt_type='human_bot'): - return run_client_nochat_gen(prompt="Tell a very long kid's story about birds.", prompt_type=prompt_type, - stream_output=True, max_new_tokens=512, - langchain_mode='Disabled', - langchain_action=LangChainAction.QUERY.value, - langchain_agents=[]) - - -def run_client_nochat_gen(prompt, prompt_type, stream_output, max_new_tokens, - langchain_mode, langchain_action, langchain_agents): - client = get_client(serialize=False) - - kwargs, args = get_args(prompt, prompt_type, chat=False, stream_output=stream_output, - max_new_tokens=max_new_tokens, langchain_mode=langchain_mode, - langchain_action=langchain_action, langchain_agents=langchain_agents) - return run_client_gen(client, prompt, args, kwargs) - - -def run_client_gen(client, prompt, args, kwargs, do_md_to_text=True, verbose=False): - res_dict = kwargs - res_dict['prompt'] = prompt - if not kwargs['stream_output']: - res = client.predict(str(dict(kwargs)), api_name='/submit_nochat_api') - res_dict['response'] = res[0] - print(md_to_text(res_dict['response'], do_md_to_text=do_md_to_text)) - return res_dict, client - else: - job = client.submit(str(dict(kwargs)), api_name='/submit_nochat_api') - while not job.done(): - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res_dict = ast.literal_eval(res) - print('Stream: %s' % res_dict['response']) - time.sleep(0.1) - res_list = job.outputs() - assert len(res_list) > 0, "No response, check server" - res = res_list[-1] - res_dict = ast.literal_eval(res) - print('Final: %s' % res_dict['response']) - return res_dict, client - - -def md_to_text(md, do_md_to_text=True): - if not do_md_to_text: - return md - assert md is not None, "Markdown is None" - html = markdown.markdown(md) - soup = BeautifulSoup(html, features='html.parser') - return soup.get_text() - - -def run_client_many(prompt_type='human_bot'): - ret1, _ = test_client_chat(prompt_type=prompt_type) - ret2, _ = test_client_chat_stream(prompt_type=prompt_type) - ret3, _ = test_client_nochat_stream(prompt_type=prompt_type) - ret4, _ = test_client_basic(prompt_type=prompt_type) - ret5, _ = test_client_basic_api(prompt_type=prompt_type) - ret6, _ = test_client_basic_api_lean(prompt_type=prompt_type) - ret7, _ = test_client_basic_api_lean_morestuff(prompt_type=prompt_type) - return ret1, ret2, ret3, ret4, ret5, ret6, ret7 - - -if __name__ == '__main__': - run_client_many() diff --git a/create_data.py b/create_data.py deleted file mode 100644 index f16c519dcdd6b07dfd09f824e670401887f6eeaa..0000000000000000000000000000000000000000 --- a/create_data.py +++ /dev/null @@ -1,1809 +0,0 @@ -""" -Dataset creation tools. - -Keep to-level imports clean of non-trivial imports for specific tools, -because this file is imported for various purposes -""" - -import ast -import concurrent.futures -import contextlib -import hashlib -import json -import os -import shutil -import signal -import sys -import traceback -from concurrent.futures import ProcessPoolExecutor - -import psutil -import pytest -import pandas as pd -import numpy as np -from tqdm import tqdm - -from utils import flatten_list, remove - - -def parse_rst_file(filepath): - with open(filepath, 'r') as f: - input_data = f.read() - settings_overrides = {'initial_header_level': 2} - from docutils import core - document = core.publish_doctree( - source=input_data, - source_path=filepath, - settings_overrides=settings_overrides, - ) - qa_pairs = [] - current_section = None - current_question = "" - current_answer = "" - for node in document.traverse(): - if node.__class__.__name__ == 'section': - current_section = "" - elif current_section is not None: - if node.__class__.__name__ == 'Text': - if node.astext()[-1] == "?": - if current_question: - qa_pairs.append((current_question, current_answer)) - current_question = node.astext() - current_answer = "" - else: - current_answer += node.astext() - if current_answer: - qa_pairs.append((current_question, current_answer)) - return {k: v for k, v in qa_pairs} - - -def test_scrape_dai_docs(): - home = os.path.expanduser('~') - file = os.path.join(home, 'h2oai/docs/faq.rst') - qa_pairs = parse_rst_file(file) - prompt_type = 'human_bot' - from prompter import prompt_types - assert prompt_type in prompt_types - save_thing = [{"instruction": k, "output": v, 'prompt_type': prompt_type} for k, v in qa_pairs.items()] - output_file = "dai_faq.json" - with open(output_file, "wt") as f: - f.write(json.dumps(save_thing, indent=2)) - - -def test_scrape_dai_docs_all(): - """ - pytest create_data.py::test_scrape_dai_docs_all - """ - import glob - import nltk - nltk.download('punkt') - dd = {} - np.random.seed(1234) - home = os.path.expanduser('~') - files = list(glob.glob(os.path.join(home, "h2oai/docs/**/*rst"))) - np.random.shuffle(files) - val_count = int(0.05 * len(files)) - train_files = files[val_count:] - valid_files = files[:val_count] - things = [ - ("dai_docs.train.json", train_files), - ("dai_docs.valid.json", valid_files) - ] - for LEN in [100, 200, 500]: - for output_file, ff in things: - if output_file not in dd: - dd[output_file] = [] - for f in ff: - with open(f) as input: - blob = input.read() - blob = blob.replace("~~", "") - blob = blob.replace("==", "") - blob = blob.replace("''", "") - blob = blob.replace("--", "") - blob = blob.replace("**", "") - dd[output_file].extend(get_sentences(blob, length=LEN)) - for output_file, _ in things: - save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in dd[output_file]] - with open(output_file, "wt") as f: - f.write(json.dumps(save_thing, indent=2)) - - -def get_sentences(blob, length): - """ - break-up input text into sentences and then output list of sentences of about length in size - :param blob: - :param length: - :return: - """ - import nltk - nltk.download('punkt') - from nltk.tokenize import sent_tokenize - sentences = sent_tokenize(blob) - my_sentences = [] - my_string = "" - for sentence in sentences: - if len(my_string) + len(sentence) <= length: - if my_string: - my_string += " " + sentence - else: - my_string = sentence - else: - my_sentences.append(my_string) - my_string = "" - return my_sentences or [my_string] - - -def setup_dai_docs(path=None, dst="working_dir_docs", from_hf=False): - """ - Only supported if have access to source code or HF token for HF spaces and from_hf=True - :param path: - :param dst: - :param from_hf: - :return: - """ - - home = os.path.expanduser('~') - - if from_hf: - # assumes - from huggingface_hub import hf_hub_download - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv('HUGGINGFACE_API_TOKEN', True) - path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.zip', token=token, repo_type='dataset') - path = 'h2oai' - import zipfile - with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: - zip_ref.extractall(path) - path = os.path.join(path, 'docs/**/*') - - if path is None: - if os.path.isdir(os.path.join(home, 'h2oai')): - path = os.path.join(home, "h2oai/docs/**/*") - else: - assert os.path.isdir(os.path.join(home, 'h2oai.superclean')), '%s does not exist' % path - path = os.path.join(home, "h2oai.superclean/docs/**/*") - import glob - files = list(glob.glob(path, recursive=True)) - - # pandoc can't find include files - - remove(dst) - os.makedirs(dst) - - # copy full tree, for absolute paths in rst - for fil in files: - if os.path.isfile(fil): - shutil.copy(fil, dst) - - # hack for relative path - scorers_dir = os.path.join(dst, 'scorers') - makedirs(scorers_dir) - for fil in glob.glob(os.path.join(dst, '*.frag')): - shutil.copy(fil, scorers_dir) - - return dst - - -def rst_to_outputs(files, min_len=30, max_len=2048 // 2 - 30): - # account for sequence length (context window) including prompt and input and output - - # os.system('pandoc -f rst -t plain ./expert_settings/nlp_settings.rst') - import pypandoc - basedir = os.path.abspath(os.getcwd()) - - outputs = [] - for fil in files: - os.chdir(basedir) - os.chdir(os.path.dirname(fil)) - fil = os.path.basename(fil) - print("Processing %s" % fil, flush=True) - # out_format can be one of: asciidoc, asciidoctor, beamer, biblatex, bibtex, commonmark, commonmark_x, - # context, csljson, docbook, docbook4, docbook5, docx, dokuwiki, - # dzslides, epub, epub2, epub3, fb2, gfm, haddock, html, html4, html5, icml, - # ipynb, jats, jats_archiving, jats_articleauthoring, jats_publishing, jira, - # json, latex, man, - # markdown, markdown_github, markdown_mmd, markdown_phpextra, markdown_strict, - # mediawiki, ms, muse, native, odt, opendocument, opml, org, pdf, plain, pptx, - # revealjs, rst, rtf, s5, slideous, slidy, tei, texinfo, textile, xwiki, zimwiki - out_format = 'plain' - # avoid extra new lines injected into text - extra_args = ['--wrap=preserve', '--resource path="%s" % dst'] - - plain_list = [] - try: - # valid for expert settings - input_rst = pypandoc.convert_file(fil, 'rst') - input_list = input_rst.split('\n``') - for input_subrst in input_list: - input_plain = pypandoc.convert_text(input_subrst, format='rst', to='plain') - plain_list.append([input_plain, fil]) - except Exception as e: - print("file exception: %s %s" % (fil, str(e)), flush=True) - - if not plain_list: - # if failed to process as pieces of rst, then - output = pypandoc.convert_file(fil, out_format, extra_args=extra_args, format='rst') - outputs1 = get_sentences(output, length=max_len) - for oi, output in enumerate(outputs1): - output = output.replace('\n\n', '\n') - plain_list.append([output, fil]) - outputs.extend(plain_list) - - # report: - # [print(len(x)) for x in outputs] - - # deal with blocks longer than context size (sequence length) of 2048 - new_outputs = [] - num_truncated = 0 - num_orig = len(outputs) - for output, fil in outputs: - if len(output) < max_len: - new_outputs.append([output, fil]) - continue - outputs1 = get_sentences(output, length=max_len) - for oi, output1 in enumerate(outputs1): - output1 = output1.replace('\n\n', '\n') - new_outputs.append([output1, fil]) - num_truncated += 1 - print('num_orig: %s num_truncated: %s' % (num_orig, num_truncated), flush=True) - - new_outputs = [[k.strip(), fil] for k, fil in new_outputs if len(k.strip()) > min_len] - - return new_outputs - - -def test_scrape_dai_docs_all_pandoc(): - """ - pytest -s -v create_data.py::test_scrape_dai_docs_all_pandoc - :return: - """ - - dst = setup_dai_docs() - - import glob - files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True)) - - basedir = os.path.abspath(os.getcwd()) - new_outputs = rst_to_outputs(files) - os.chdir(basedir) - - remove(dst) - save_thing = [{"output": k.strip(), 'prompt_type': 'plain'} for k in new_outputs] - output_file = "dai_docs.train_cleaned.json" - with open(output_file, "wt") as f: - f.write(json.dumps(save_thing, indent=2)) - - -def test_config_to_json(): - """ - Needs to run from Driverless AI source directory. - E.g. (base) jon@gpu:~/h2oai$ pytest -s -v /data/jon/h2ogpt/create_data.py::test_config_to_json ; cp config.json /data/jon/h2ogpt/ - :return: - """ - try: - # Arrange - import json - from h2oaicore.systemutils import config - toml_list = [] - for k, v in config.get_meta_dict().items(): - title = (v.title + ": ") if v.title else '' - comment = v.comment or '' - if not (title or comment): - continue - toml_list.extend( - [ - { - 'prompt_type': 'plain', - 'instruction': f": What does {k} do?\n: {k.replace('_', ' ')} config.toml: {comment or title}\n:".replace( - "\n", ""), - }, - { - 'prompt_type': 'plain', - 'instruction': f": Explain {k}.\n: {k.replace('_', ' ')} config.toml: {comment or title}\n:".replace( - "\n", ""), - }, - { - 'prompt_type': 'plain', - 'instruction': f": How can I do this: {title}.\n: Set the {k.replace('_', ' ')} config.toml\n:".replace( - "\n", ""), - } if title and comment else None, - { - 'prompt_type': 'human_bot', - 'instruction': f'Explain the following expert setting for Driverless AI', - 'input': f"{k}", - 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""), - }, - { - 'prompt_type': 'human_bot', - 'instruction': f'Explain the following expert setting for Driverless AI', - 'input': f"{k}", - 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""), - }, - { - 'prompt_type': 'human_bot', - 'instruction': f'Explain the following expert setting for Driverless AI', - 'input': f"{k.replace('_', ' ')}", - 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""), - }, - { - 'prompt_type': 'human_bot', - 'instruction': f'Explain the following expert setting for Driverless AI', - 'input': f"{title}", - 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""), - }, - { - 'prompt_type': 'human_bot', - 'instruction': f'Provide a short explanation of the expert setting {k}', - 'output': f"{k.replace('_', ' ')} config.toml: {comment or title}".replace("\n", ""), - }, - { - 'prompt_type': 'human_bot', - 'instruction': f'Provide a detailed explanation of the expert setting {k}', - 'output': f"{k.replace('_', ' ')} config.toml: {title}{comment}".replace("\n", ""), - }, - ] - ) - toml_list = [x for x in toml_list if x] - with open("config.json", "wt") as f: - f.write(json.dumps(toml_list, indent=2)) - except Exception as e: - print("Exception: %s" % str(e), flush=True) - - -def copy_tree(src, dst, follow_symlink=False): - makedirs(dst, exist_ok=True) - for (path, dirs, files) in os.walk(src, followlinks=follow_symlink): - new_path = path.replace(src, dst) - makedirs(new_path, exist_ok=True) - for file in files: - filename = os.path.join(path, file) - new_filename = os.path.join(new_path, file) - # print("%s -> %s" % (filename, new_filename)) - try: - atomic_copy(filename, new_filename) - except FileNotFoundError: - pass - - -def atomic_move(src, dst): - try: - shutil.move(src, dst) - except (shutil.Error, FileExistsError): - pass - remove(src) - - -def atomic_copy(src=None, dst=None, with_permissions=True): - if os.path.isfile(dst): - return - import uuid - my_uuid = uuid.uuid4() - dst_tmp = dst + str(my_uuid) - makedirs(os.path.dirname(dst), exist_ok=True) - if with_permissions: - shutil.copy(src, dst_tmp) - else: - shutil.copyfile(src, dst_tmp) - atomic_move(dst_tmp, dst) - remove(dst_tmp) - - -def makedirs(path, exist_ok=True): - """ - Avoid some inefficiency in os.makedirs() - :param path: - :param exist_ok: - :return: - """ - if os.path.isdir(path) and os.path.exists(path): - assert exist_ok, "Path already exists" - return path - os.makedirs(path, exist_ok=exist_ok) - - -## Download from https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_unfiltered_cleaned_split.json -## Turn into simple instruct prompt type. No context/previous conversations. -def test_prep_instruct_vicuna(): - from datasets import load_dataset - filename = 'ShareGPT_unfiltered_cleaned_split.json' - if not os.path.exists(filename): - os.system( - 'wget https://huggingface.co/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/%s' % filename) - data = load_dataset("json", data_files={"train": filename})["train"] - training_rows = [] - for i in range(data.num_rows): - conversations = data[i]['conversations'] - assert isinstance(conversations, list), conversations - convo = "" - for j, conv in enumerate(conversations): - # Get ready for generate.py prompt_type=human_bot - # But train with prompt_type=plain - if conv['from'] == 'human': - FROM = ': ' - elif conv['from'] == 'gpt': - FROM = ': ' - convo += f"{FROM}" + conv['value'] + "\n" - if convo: - training_rows.append(dict(input=convo)) - with open(filename + ".generate_human_bot.train_plain.json", "wt") as f: - f.write(json.dumps(training_rows, indent=2)) - - -POSTFIX = ".generate_human_bot.train_plain.json" - -# https://bair.berkeley.edu/blog/2023/04/03/koala/ -OIG_DATASETS = [ - "unified_chip2.jsonl", - "unified_grade_school_math_instructions.jsonl", - "unified_poetry_2_song.jsonl", - "unified_plot_screenplay_books_dialog.jsonl", -] - -# hub issue: https://huggingface.co/datasets/laion/OIG/discussions/4 -ALL_OIG_DATASETS = ['unified_abstract_infill.jsonl', - 'unified_basic.jsonl', - 'unified_canadian_parliament.jsonl', - 'unified_chip2.jsonl', - 'unified_conv_finqa.jsonl', - 'unified_cuad.jsonl', - 'unified_essays.jsonl', - 'unified_flan.jsonl.gz', - 'unified_grade_school_math_instructions.jsonl', - 'unified_hc3_human.jsonl', - 'unified_image_prompts_instructions.jsonl', - 'unified_joke_explanations.jsonl', - 'unified_mathqa_flanv2_kojma_cot.jsonl', - 'unified_merged_code_xp3.jsonl', - 'unified_multi_news.jsonl', - 'unified_multi_sum.jsonl', - 'unified_ni.jsonl.gz', - 'unified_nq.jsonl', - 'unified_openai_summarize_tldr.jsonl', - 'unified_oscar_en_sample_dialog.jsonl', - 'unified_p3.jsonl.gz', - 'unified_plot_screenplay_books_dialog.jsonl', - 'unified_poetry_2_song.jsonl', - 'unified_poetry_instructions.jsonl', - 'unified_rallio_safety_and_prosocial.jsonl', - 'unified_rallio_soda_upgraded_2048.jsonl', - 'unified_soda_dialog.jsonl', - 'unified_sqlv1.jsonl', - 'unified_sqlv2.jsonl', - 'unified_squad_v2.jsonl', - 'unified_squad_v2_more_neg.jsonl', - 'unified_ul2_plus_oscar_en_sample_dialog.jsonl', - 'unified_unifiedskg_instructions.jsonl', - 'unified_unnatural_instructions.jsonl', - 'unified_xp3_sample.jsonl'] - -useful_oig_files = ['unified_rallio_safety_and_prosocial.jsonl.parquet', - 'unified_chip2.jsonl.parquet', - 'unified_cuad.jsonl.parquet', - 'unified_essays.jsonl.parquet', - 'unified_flan.jsonl.gz.parquet', - 'unified_grade_school_math_instructions.jsonl.parquet', - 'unified_hc3_human.jsonl.parquet', - 'unified_mathqa_flanv2_kojma_cot.jsonl.parquet', - 'unified_merged_code_xp3.jsonl.parquet', - 'unified_multi_news.jsonl.parquet', - # 'unified_multi_sum.jsonl.parquet' - 'unified_ni.jsonl.gz.parquet', - 'unified_openai_summarize_tldr.jsonl.parquet', - # 'unified_oscar_en_sample_dialog.jsonl.parquet', # create text containing these N words, not specific - 'unified_plot_screenplay_books_dialog.jsonl.parquet', - 'unified_soda_dialog.jsonl.parquet', - 'unified_unnatural_instructions.jsonl.parquet', - ] - - -@pytest.mark.parametrize("filename", OIG_DATASETS) -def test_get_small_sample_oig_data(filename): - if not os.path.exists(filename): - os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename) - import json - rows = [] - with open(filename, "r") as f: - for line in f.readlines(): - row = json.loads(line) - rows.append(dict(input=row["text"])) - with open(filename + POSTFIX, "w") as f: - f.write(json.dumps(rows, indent=2)) - - -@pytest.mark.parametrize("filename", ALL_OIG_DATASETS) -def test_download_useful_data_as_parquet(filename): - dest_file = filename + '.parquet' - if dest_file not in useful_oig_files: - pytest.skip('file declared not useful') - if not os.path.exists(filename): - os.system('wget https://huggingface.co/datasets/laion/OIG/resolve/main/%s' % filename) - if not os.path.exists(dest_file): - df = pd.read_json(path_or_buf=filename, lines=True) - df.to_parquet(dest_file, index=False) - - -def test_merge_shuffle_small_sample_oig_data(): - np.random.seed(1234) - rows = [] - for filename in OIG_DATASETS: - with open(filename + POSTFIX, "r") as f: - rows.extend(json.loads(f.read())) - np.random.shuffle(rows) - with open("merged_shuffled_OIG_%s.json" % hashlib.sha256(str(OIG_DATASETS).encode()).hexdigest()[:10], "w") as f: - f.write(json.dumps(rows, indent=2)) - - -def test_join_jsons(): - files = ['config.json'] * 1 + \ - ['dai_docs.train_cleaned.json'] * 2 + \ - ['dai_faq.json'] * 3 - print(files) - lst = [] - [lst.extend(json.load(open(fil, 'rt'))) for fil in files] - print(len(lst)) - json.dump(lst, open("merged.json", "wt"), indent=2) - - -@pytest.mark.parametrize("filename", ['Anthropic/hh-rlhf']) -def test_make_rlhf_good_data(filename): - from datasets import load_dataset - rows = load_dataset(filename)["train"]["chosen"] - new_rows = [] - for row in rows: - if row[:2] == "\n\n": - row = row[2:] - row = row.replace("Human: ", ": ") - row = row.replace("Assistant: ", ": ") - new_rows.append(dict(input=row)) - with open(filename.replace("/", "_") + POSTFIX, "w") as f: - f.write(json.dumps(new_rows, indent=2)) - - -def test_show_prompts(): - files = ['config.json'] * 1 + \ - ['dai_docs.train_cleaned.json'] * 1 + \ - ['dai_faq.json'] * 1 - file_points = [json.load(open(fil, 'rt')) for fil in files] - from prompter import generate_prompt - for data_points in file_points: - for data_point in data_points: - print(generate_prompt(data_point, 'plain', '', False, False, False)[0]) - - -def test_get_open_datasets(): - # HF changed things so don't get raw list of all datasets, so not have to filter, but can't do negative filter - open_tags = ['license:Apache License 2.0', - 'license:mit', - 'license:apache', - 'license:apache2', - 'license:apache-2.0', - 'license:bsd', - 'license:bsd-2-clause', - 'license:bsd-3-clause', - 'license:bsd-3-clause-clear', - 'license:lgpl-2.1', - 'license:lgpl-3.0', - 'license:lgpl-lr', - 'license:lgpl', - 'license:openrail++', - 'license:openrail', - 'license:bigscience-bloom-rail-1.0', - # 'license:agpl-3.0', - 'license:other', - 'license:unknown', - # 'license:mpl-2.0', # ok, but would have to include original copyright, license, source, copies in distribution - # Attribution required: - 'license:odc-by', - 'license:cc-by-4.0', - 'license:cc-by-3.0', - 'license:cc-by-2.0', - 'license:cc-by-2.5', - # 'license:cc-by-sa-4.0', # would require same license - 'license:odbl', - 'license:pddl', - 'license:ms-pl', - 'license:zlib', - ] - # bad license: cc-by-nc-4.0 - - from huggingface_hub import list_datasets - datasets = flatten_list([[x for x in list_datasets(filter=y)] for y in open_tags]) - datasets += [x for x in list_datasets(author='openai')] - # check all: - all_license_tags = set(flatten_list([[y for y in x.tags if 'license' in y] for x in datasets])) - print(len(all_license_tags)) - open_datasets = [x for x in datasets if any([y in x.tags for y in open_tags]) or 'license:' not in str(x.tags)] - print('open_datasets', len(open_datasets)) - all_task_tags = set(flatten_list([[y for y in x.tags if 'task' in y] for x in open_datasets])) - print('all_task_tags', len(all_task_tags)) - excluded_tags = ['image', 'hate', 'tabular', 'table-', 'classification', 'retrieval', - 'translation', 'identification', 'object', 'mask', 'to-text', - 'face-detection', 'audio', 'voice', 'reinforcement', 'depth-est', - 'forecasting', 'parsing', 'visual', 'speech', 'multiple-choice', - 'slot-filling', 'irds/argsme', '-scoring', 'other', 'graph-ml', - 'feature-extraction', 'keyword-spotting', - 'coreference-resolution', 'segmentation', - 'word-sense-disambiguation', - 'lemmatization'] - task_tags = [x.replace('task_categories:', '').replace('task_ids:', '') - for x in all_task_tags if not any([y in x for y in - excluded_tags])] - print('task_tags', len(task_tags)) - # str(x.tags) to catch any pattern match to anything in list - open_tasked_datasets = [x for x in open_datasets if - any([y in str([x for x in x.tags if 'task' in x]) for y in task_tags]) and - not any([y in str([x for x in x.tags if 'task' in x]) for y in excluded_tags]) or - 'task_categories' not in str(x.tags) and 'task_ids' not in str(x.tags)] - open_tasked_datasets = [x for x in open_tasked_datasets if not x.disabled] - open_tasked_datasets = [x for x in open_tasked_datasets if not x.gated] - open_tasked_datasets = [x for x in open_tasked_datasets if not x.private] - print('open_tasked_datasets', len(open_tasked_datasets)) - sizes = list(set(flatten_list([[(y, x.id) for y in x.tags if 'size' in y] for x in open_tasked_datasets]))) - languages = list(set(flatten_list([[(y, x.id) for y in x.tags if 'language:' in y] for x in open_tasked_datasets]))) - open_english_tasked_datasets = [x for x in open_tasked_datasets if - 'language:' not in str(x.tags) or - 'language:en' in str(x.tags)] - small_open_english_tasked_datasets = [x for x in open_english_tasked_datasets if - 'n<1K' in str(x.tags) or - '1K summarization? - # load_dataset(open_tasked_datasets[0].id).data['train'].to_pandas() - ids = [x.id for x in small_open_english_tasked_datasets] - - # sanity checks - # https://bair.berkeley.edu/blog/2023/04/03/koala/ - assert 'alespalla/chatbot_instruction_prompts' in ids - assert 'laion/OIG' in ids - assert 'openai/webgpt_comparisons' in ids - assert 'openai/summarize_from_feedback' in ids - assert 'Anthropic/hh-rlhf' in ids - - # useful but not allowed for commercial purposes: - # https://huggingface.co/datasets/squad - - print('open_english_tasked_datasets: ', ids, flush=True) - - exclude_ids = ['allenai/nllb', # translation only - 'hf-internal-testing/fixtures_image_utils', # testing - 'allenai/c4', # search-url - 'agemagician/uniref50', # unknown - 'huggingface-course/documentation-images', # images - 'smilegate-ai/kor_unsmile', # korean - 'MohamedRashad/ChatGPT-prompts', # ChatGPT/LearnGPT/https://www.emergentmind.com/ - 'humarin/chatgpt-paraphrases', # Paraphrase using ChatGPT - 'Jeska/vaccinchat', # not useful - 'alespalla/chatbot_instruction_prompts', # mixes alpaca - 'allenai/prosocial-dialog', - # already exlucded, but wrongly in other datasets that say more permissive license - 'AlekseyKorshuk/persona-chat', # low quality - 'bavard/personachat_truecased', # low quality - 'adamlin/daily_dialog', # medium quality conversations - 'adamlin/FewShotWoz', # low quality - 'benjaminbeilharz/better_daily_dialog', # low quality - 'benjaminbeilharz/daily_dialog_w_turn_templates', # low - 'benjaminbeilharz/empathetic_dialogues_for_lm', # low - 'GEM-submissions/GEM__bart_base_schema_guided_dialog__1645547915', # NA - 'ia-bentebib/conv_ai_2_fr', # low fr - 'ia-bentebib/daily_dialog_fr', # low fr - 'ia-bentebib/dialog_re_fr', # low fr - 'ia-bentebib/empathetic_dialogues_fr', # low fr - 'roskoN/dailydialog', # low - 'VadorMazer/skyrimdialogstest', # low - 'bigbio/med_qa', # med specific Q/A - 'biu-nlp/qa_srl2018', # low quality Q/A - 'biu-nlp/qa_discourse', # low quality Q/A - 'iarfmoose/qa_evaluator', # low quality Q/A - 'jeopardy', # low quality Q/A -- no reasoning - 'narrativeqa', # low quality Q/A - 'nomic-ai/gpt4all_prompt_generations', # bad license - 'nomic-ai/gpt4all_prompt_generations_with_p3', # bad license - 'HuggingFaceH4/alpaca', # bad license - 'tatsu-lab/alpaca', # ToS breaking - 'yahma/alpaca-cleaned', # ToS breaking - 'Hello-SimpleAI/HC3', # bad license - 'glue', # no reasoning QA - 'sahil2801/CodeAlpaca-20k', # bad license - 'Short-Answer-Feedback/saf_communication_networks_english', # long Q, medium A - ] - small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if x.id not in exclude_ids] - # some ids clearly speech related - small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if 'speech' not in x.id] - # HF testing - small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if - 'hf-internal-testing' not in x.id] - small_open_english_tasked_datasets = [x for x in small_open_english_tasked_datasets if - 'chinese' not in x.id] - - sorted_small_open_english_tasked_datasets = sorted([(x.downloads, x) for x in small_open_english_tasked_datasets], - key=lambda x: x[0], reverse=True) - - # NOTES: - # Run like pytest -s -v create_data.py::test_get_open_datasets &> getdata9.log - # See what needs config passed and add: - # grep 'load_dataset(' getdata9.log|grep -v data_id|less -S - # grep "pip install" getdata9.log - # NOTE: Some datasets have default config, but others are there. Don't know how to access them. - - """ - https://huggingface.co/datasets/wikihow/blob/main/wikihow.py - https://github.com/mahnazkoupaee/WikiHow-Dataset - https://ucsb.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358 - https://ucsb.app.box.com/s/ap23l8gafpezf4tq3wapr6u8241zz358 - """ - - """ - # some ambiguous or non-commercial datasets - https://github.com/PhoebusSi/alpaca-CoT - """ - - timeout = 3 * 60 - # laion/OIG takes longer - for num_downloads, dataset in sorted_small_open_english_tasked_datasets: - data_id = dataset.id - func = do_one - args = (data_id, num_downloads) - kwargs = {} - with ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(func, *args, **kwargs) - try: - future.result(timeout=timeout) - except concurrent.futures.TimeoutError: - print("\n\ndata_id %s timeout\n\n" % data_id, flush=True) - for child in psutil.Process(os.getpid()).children(recursive=True): - os.kill(child.pid, signal.SIGINT) - os.kill(child.pid, signal.SIGTERM) - os.kill(child.pid, signal.SIGKILL) - - -def do_one(data_id, num_downloads): - from datasets import load_dataset - out_file = "data_%s.parquet" % str(data_id.replace('/', '_')) - if os.path.isfile(out_file) and os.path.getsize(out_file) > 1024 ** 3: - return - try: - print("Loading data_id %s num_downloads: %s" % (data_id, num_downloads), flush=True) - avail_list = None - try: - data = load_dataset(data_id, 'foobar') - except Exception as e: - if 'Available: ' in str(e): - avail_list = ast.literal_eval(str(e).split('Available:')[1].strip()) - else: - avail_list = None - if avail_list is None: - avail_list = [None] - print("%s avail_list: %s" % (data_id, avail_list), flush=True) - - for name in avail_list: - out_file = "data_%s_%s.parquet" % (str(data_id.replace('/', '_')), str(name)) - if os.path.isfile(out_file): - continue - data = load_dataset(data_id, name) - column_names_dict = data.column_names - column_names = column_names_dict[list(column_names_dict.keys())[0]] - print("Processing data_id %s num_downloads: %s columns: %s" % (data_id, num_downloads, column_names), - flush=True) - data_dict = data.data - col_dict = data.num_columns - first_col = list(col_dict.keys())[0] - if 'train' in data_dict: - df = data['train'].to_pandas() - else: - df = data[first_col].to_pandas() - # csv has issues with escaping chars, even for datasets I know I want - df.to_parquet(out_file, index=False) - except Exception as e: - t, v, tb = sys.exc_info() - ex = ''.join(traceback.format_exception(t, v, tb)) - print("Exception: %s %s" % (data_id, ex), flush=True) - - -def test_otherlic(): - from huggingface_hub import list_datasets - lic = ['license:odc-by', - 'license:cc-by-4.0', - 'license:cc-by-3.0', - 'license:cc-by-2.0', - 'license:cc-by-2.5', - 'license:cc-by-sa-4.0', - 'license:odbl', - 'license:pddl', - 'license:ms-pl', - 'license:zlib', - ] - datasets = flatten_list([[x for x in list_datasets(filter=y) if 'translation' not in str(x.tags)] for y in lic]) - print(len(datasets)) - - -# These useful datasets are determined based upon data sample, column types, and uniqueness compared to larger datasets like Pile -# grep columns getdata13.log|grep -v "\['image'\]"|sort|uniq|grep -v tokens|grep -v "'image'"|grep -v embedding|grep dialog -useful = ['Dahoas/instruct-human-assistant-prompt', - 'Dahoas/first-instruct-human-assistant-prompt', - 'knkarthick/dialogsum', # summary of conversation - 'McGill-NLP/FaithDial', # medium quality - 'Zaid/quac_expanded', # medium quality context + QA - '0-hero/OIG-small-chip2', # medium - 'alistvt/coqa-flat', # QA medium - 'AnonymousSub/MedQuAD_47441_Question_Answer_Pairs', # QA medium - 'Anthropic/hh-rlhf', # high quality # similar to Dahoas/full-hh-rlhf - 'arjunth2001/online_privacy_qna', # good quality QA - 'Dahoas/instruct_helpful_preferences', # medium quality instruct - 'Dahoas/rl-prompt-dataset', # medium chat - 'Dahoas/rm-static', # medium chat - 'Dahoas/static-hh', # medium chat # HuggingFaceH4/self_instruct - 'Dahoas/synthetic-instruct-gptj-pairwise', # medium chat - 'eli5', # QA if prompt ELI5 - 'gsm8k', # QA (various) - 'guanaco/guanaco', # prompt/response - 'kastan/rlhf-qa-comparisons', # good QA - 'kastan/rlhf-qa-conditional-generation-v2', # prompt answer - 'OllieStanley/humaneval-mbpp-codegen-qa', # code QA, but started from words, so better than other code QA - 'OllieStanley/humaneval-mbpp-testgen-qa', # code QA - 'Graverman/Instruct-to-Code', # code QA - 'openai/summarize_from_feedback', # summarize - 'relbert/analogy_questions', # analogy QA - 'yitingxie/rlhf-reward-datasets', # prompt, chosen, rejected. - 'yizhongw/self_instruct', # instruct (super natural & instruct) - 'HuggingFaceH4/asss', # QA, big A - 'kastan/rlhf-qa-conditional-generation-v2', # QA - 'cosmos_qa', # context QA - 'vishal-burman/c4-faqs', # QA but not so much reasoning, but alot of text - 'squadshifts', # QA from context - 'hotpot_qa', # QA from context - 'adversarial_qa', # QA from context - 'allenai/soda', # dialog -> narrative/summary - 'squad_v2', # context QA - 'squadshifts', # context QA - 'dferndz/cSQuAD1', # context QA - 'dferndz/cSQuAD2', # context QA - 'din0s/msmarco-nlgen', # context QA - 'domenicrosati/TruthfulQA', # common sense truthful QA -- trivia but good trivia - 'hotpot_qa', # context, QA - 'HuggingFaceH4/self-instruct-eval', # instruct QA, medium quality, some language reasoning - 'kastan/EE_QA_for_RLHF', # context QA - 'KK04/LogicInference_OA', # instruction logical QA - 'lmqg/qa_squadshifts_synthetic', # context QA - 'lmqg/qg_squad', # context QA - 'lmqg/qg_squadshifts', # context QA - 'lmqg/qg_subjqa', # context QA - 'pszemraj/HC3-textgen-qa', - # QA medium, has human responses -- humans tend to provide links instead of trying to answer - 'pythonist/newdata', # long context, QA, brief A - 'ropes', # long background, situation, question, A - 'wikitablequestions', # table -> QA - 'bigscience/p3', # context QA but short answers - ] - -code_useful = ['0n1xus/codexglue', - 'openai_humaneval', - 'koutch/staqc', - ] - -maybe_useful = ['AlekseyKorshuk/comedy-scripts', - 'openbookqa', # hard to parse, low reasoning - 'qed', # reasonable QA, but low reasoning - 'selqa', # candidate answers - 'HuggingFaceH4/instruction-pilot-outputs-filtered', - 'GBaker/MedQA-USMLE-4-options', # medical QA with long questions - 'npc-engine/light-batch-summarize-dialogue', # dialog summarize, kinda low specific quality - ] - -summary_useful = ['austin/rheum_abstracts', - 'CarperAI/openai_summarize_comparisons', # summarize chosen/rejected - 'CarperAI/openai_summarize_tldr', # summarize QA - 'ccdv/cnn_dailymail', # summarize news - 'ccdv/govreport-summarization', # summarize high quality - 'ccdv/pubmed-summarization', # summarize high quality - 'duorc', # plot -> QA - 'farleyknight/big_patent_5_percent', # desc -> abstract - 'multi_news', # summary - 'opinosis', - 'SophieTr/reddit_clean', - 'allenai/mup', # long text -> summary - 'allenai/multi_lexsum', # long text -> summary - 'big_patent', - 'allenai/wcep_dense_max', - 'awinml/costco_long_practice', - 'GEM/xsum', - 'ratishsp/newshead', - 'RussianNLP/wikiomnia', # russian - 'stacked-summaries/stacked-xsum-1024', - ] - -math_useful = [ - 'competition_math' -] - -skipped = ['c4', # maybe useful, used for flan, but skipped due to size - ] - -""" -To get training data from oig: -pytest test_oig test_grade_final test_finalize_to_json -""" - -human = ':' -bot = ':' - - -def test_assemble_and_detox(): - import re - from profanity_check import predict_prob - df_list = [] - for data in useful_oig_files: - print("Processing %s" % data, flush=True) - df = pd.read_parquet(data) - df = df.reset_index(drop=True) - # chop up into human/bot interactions of no more than 10kB per row - text_list = df[['text']].values.ravel().tolist() - new_text = [] - max_len = 2048 # uber cutoff - MAX_LEN = 2048 // 2 - 30 # max len per question/answer - for text in tqdm(text_list): - human_starts = [m.start() for m in re.finditer(': ', text)] - if len(human_starts) == 1: - human_starts = [0, len(text)] # always go into for loop below - blurb = '' - for i in range(len(human_starts) - 1): - interaction = text[human_starts[i]: human_starts[i + 1]][:max_len] - blurb += interaction - if len(blurb) >= MAX_LEN: - blurb = get_sentences(blurb, length=MAX_LEN)[0] - new_text.append(blurb + "\n:") - blurb = '' - if blurb: - blurb = get_sentences(blurb, length=MAX_LEN)[0] - new_text.append(blurb + "\n:") - - if len(new_text) > len(text_list): - print("Added %d new rows (before: %d)" % (len(new_text) - df.shape[0], df.shape[0])) - df = pd.DataFrame({"text": new_text, "source": [data] * len(new_text)}) - df = df.drop_duplicates(keep='first') - print(df['text'].apply(lambda x: len(x)).describe()) - assert df['text'].apply(lambda x: len(x)).max() <= 2 * max_len - - # faster than better_profanity, do early - df['profanity'] = predict_prob(df['text']) - before_rows = df.shape[0] - df = df[df['profanity'] < 0.25] # drop any low quality stuff - after_rows = df.shape[0] - print("Dropped %d rows out of %d due to alt-profanity-check" % (before_rows - after_rows, before_rows)) - df_list.append(df) - print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True) - print("So far have %d rows" % sum([len(x) for x in df_list])) - df_final = pd.concat(df_list) - df_final = df_final.sample(frac=1, random_state=1234).reset_index(drop=True) - df_final.to_parquet('h2oGPT.cleaned.human_bot.shorter.parquet', index=False) - - -def test_basic_cleaning(): - # from better_profanity import profanity - # https://pypi.org/project/alt-profanity-check/ - from profanity_check import predict - df_list = [] - for data in useful_oig_files: - # for data in useful_oig_files[:5]: - # for data in ['unified_openai_summarize_tldr.jsonl.parquet']: - print("Processing %s" % data, flush=True) - df = pd.read_parquet(data) - df = df.reset_index(drop=True) - # NOTE: Not correct if multiple human-bot interactions, but those dialogs even more desired - # avg_chars = len(df['text'][0])/(df['text'][0].count(human)+df['text'][0].count(bot)) - df['avg_words'] = df['text'].apply(lambda x: x.count(' ') / (x.count(human) + x.count(bot)) / 2.0) - df['avg_bot_words'] = df['text'].apply(lambda x: x.split(bot)[1].count(' ') / x.count(bot)) - # df['bad_words'] = df['text'].apply(lambda x: profanity.contains_profanity(x)) - # low_quality_patterns = ['Write the rest of this wikipedia article'] - res = predict(df['text']) - df['bad_words'] = res - df = df.reset_index(drop=True) - df = df[df['bad_words'] == 0] - df = df[['text', 'avg_words', 'avg_bot_words']] - df = df.drop_duplicates(keep='first') - print(df[df['avg_words'] == df['avg_words'].max()]['text'].values) - median_words = np.median(df['avg_words']) - min_words_per_entity = max(30, 0.8 * median_words) - max_words_per_entity = 2048 # too hard to learn from for now - df = df[df['avg_words'] > min_words_per_entity] - df = df[df['avg_words'] < max_words_per_entity] - - min_words_per_entity = max(20, 0.5 * median_words) # bot should say stuff for now - max_words_per_entity = 2048 # too hard to learn from for now - df = df[df['avg_bot_words'] > min_words_per_entity] - df = df[df['avg_bot_words'] < max_words_per_entity] - - df_list.append(df) - print("Done processing %s -> %s rows" % (data, df.shape[0]), flush=True) - df_final = pd.concat(df_list) - df_final.to_parquet('h2oGPT.cleaned.human_bot.parquet', index=False) - - -from joblib import Parallel, delayed, effective_n_jobs -from sklearn.utils import gen_even_slices -from sklearn.utils.validation import _num_samples - - -def parallel_apply(df, func, n_jobs=-1, **kwargs): - """ Pandas apply in parallel using joblib. - Uses sklearn.utils to partition input evenly. - - Args: - df: Pandas DataFrame, Series, or any other object that supports slicing and apply. - func: Callable to apply - n_jobs: Desired number of workers. Default value -1 means use all available cores. - **kwargs: Any additional parameters will be supplied to the apply function - - Returns: - Same as for normal Pandas DataFrame.apply() - - """ - - if effective_n_jobs(n_jobs) == 1: - return df.apply(func, **kwargs) - else: - ret = Parallel(n_jobs=n_jobs)( - delayed(type(df).apply)(df[s], func, **kwargs) - for s in gen_even_slices(_num_samples(df), effective_n_jobs(n_jobs))) - return pd.concat(ret) - - -def add_better_profanity_flag(df): - from better_profanity import profanity - df['better_profanity'] = parallel_apply( - df['text'], - lambda x: profanity.contains_profanity(x), - n_jobs=-1, - ) - return df - - -def add_textstat_grade(df): - import textstat - - def myfunc(x): - return textstat.flesch_kincaid_grade(x) # simple grade - - if False: - import dask.dataframe as dd - # 40 seconds for 1000 rows, but have 1,787,799 rows - ddata = dd.from_pandas(df, npartitions=120) - - df['flesch_grade'] = ddata['text'].apply(myfunc).compute() - if True: - # fast way - df['flesch_grade'] = parallel_apply(df['text'], myfunc, n_jobs=-1) - return df - - -def add_deberta_grade(df): - from transformers import AutoModelForSequenceClassification, AutoTokenizer - import torch - reward_name = "OpenAssistant/reward-model-deberta-v3-large-v2" - rank_model, tokenizer = AutoModelForSequenceClassification.from_pretrained( - reward_name), AutoTokenizer.from_pretrained(reward_name) - device = 'cuda' if torch.cuda.is_available() else 'cpu' - rank_model.to(device) - - def get_question(x): - return x.replace(': ', '').split(':')[0] - - def get_answer(x): - try: - answer = x.split(': ')[1].split(':')[0].replace(': ', '') - except: - answer = x.split(':')[1].split(':')[0].replace(':', '') - return answer - - df['question'] = parallel_apply(df['text'], get_question, n_jobs=-1) - df['answer'] = parallel_apply(df['text'], get_answer, n_jobs=-1) - - from datasets import Dataset - from transformers import pipeline - from transformers.pipelines.pt_utils import KeyPairDataset - import tqdm - - pipe = pipeline( - "text-classification", - model=reward_name, - device="cuda:0" if torch.cuda.is_available() else "cpu" - ) - start = 0 - batch_size = 64 * 16 - micro_batch = orig_micro_batch = 16 - end = 0 - import socket - checkpoint = "grades.%s.pkl" % socket.gethostname() - grades = [] - import pickle - if os.path.exists(checkpoint): - with open(checkpoint, "rb") as f: - start, grades = pickle.loads(f.read()) - last_oom = 0 - while end < df.shape[0]: - # manual batching to handle OOM more gracefully - end = min(start + batch_size, df.shape[0]) - if start == end: - break - dataset = Dataset.from_pandas(df.iloc[start:end, :]) - try: - grades.extend([ - x['score'] for x in tqdm.tqdm( - pipe(KeyPairDataset(dataset, "question", "answer"), batch_size=micro_batch) - ) - ]) - except torch.cuda.OutOfMemoryError: - last_oom = start - micro_batch = max(1, micro_batch // 2) - print("OOM - retrying with micro_batch=%d" % micro_batch) - continue - if last_oom == start: - micro_batch = orig_micro_batch - print("Returning to micro_batch=%d" % micro_batch) - assert len(grades) == end - start = end - with open(checkpoint, "wb") as f: - f.write(pickle.dumps((end, grades))) - print("%d/%d" % (end, df.shape[0])) - df['grade_deberta'] = grades - if os.path.exists(checkpoint): - os.remove(checkpoint) - return df - - -def test_chop_by_lengths(): - file = "h2oGPT.cleaned.human_bot.shorter.parquet" - df = pd.read_parquet(file).reset_index(drop=True) - df = count_human_bot_lengths(df) - df['rand'] = np.random.rand(df.shape[0]) - df['rand2'] = np.random.rand(df.shape[0]) - before_rows = df.shape[0] - # throw away short human/bot responses with higher likelihood - df = df[(df['len_human_mean'] > 20)] # never keep very short ones - df = df[(df['len_human_mean'] > 30) | (df['rand'] < 0.2)] - df = df[(df['len_human_mean'] > 50) | (df['rand'] < 0.5)] - df = df[(df['len_human_max'] < 10000)] # drop super long (basically only human) ones - df = df[(df['len_bot_mean'] > 20)] # never keep very short ones - df = df[(df['len_bot_mean'] > 30) | (df['rand2'] < 0.2)] - df = df[(df['len_bot_mean'] > 50) | (df['rand2'] < 0.5)] - df = df[(df['len_bot_max'] < 10000)] # drop super long (only bot) ones - assert df['text'].apply(lambda x: len(x)).max() < 20000 - df = df.drop(['rand', 'rand2'], axis=1) - after_rows = df.shape[0] - print("Chopped off %d out of %d rows due to length" % (before_rows - after_rows, before_rows)) - print(df.describe()) - df.to_parquet('h2oGPT.cleaned.chopped.human_bot.shorter.parquet', index=False) - - -def count_human_bot_lengths(df, human=None, bot=None): - import re - len_human_min = [] - len_human_max = [] - len_human_mean = [] - len_bot_min = [] - len_bot_max = [] - len_bot_mean = [] - human = human or ':' - bot = bot or ':' - for is_human in [True, False]: - what = human if is_human else bot - other = human if not is_human else bot - for i in range(df.shape[0]): - text = df.loc[i, 'text'] - assert isinstance(text, str) - starts = [m.start() for m in re.finditer(what, text)] - if len(starts) == 1: - starts = [starts[0], len(text)] # always go into for loop below - assert len(text) - list_what = [] - for ii in range(len(starts) - 1): - interaction = text[starts[ii]: starts[ii + 1]] - if other in interaction: - interaction = interaction[:interaction.find(other)] - interaction.strip() - list_what.append(interaction) - if not list_what: - list_what = [''] # handle corrupted data, very rare, leads to sizes 0 - if is_human: - len_human_min.append(min([len(x) for x in list_what])) - len_human_max.append(max([len(x) for x in list_what])) - len_human_mean.append(np.mean([len(x) for x in list_what])) - else: - len_bot_min.append(min([len(x) for x in list_what])) - len_bot_max.append(max([len(x) for x in list_what])) - len_bot_mean.append(np.mean([len(x) for x in list_what])) - df['len_human_min'] = len_human_min - df['len_human_max'] = len_human_max - df['len_human_mean'] = len_human_mean - df['len_bot_min'] = len_bot_min - df['len_bot_max'] = len_bot_max - df['len_bot_mean'] = len_bot_mean - np.random.seed(1234) - pd.set_option('display.max_columns', None) - print("Before chopping") - print(df.describe()) - return df - - -def test_grade(): - df = None - - file = "h2oGPT.cleaned.chopped.human_bot.shorter.parquet" - output_file = "h2oGPT.cleaned.graded1.human_bot.shorter.parquet" - if not os.path.exists(output_file): - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df = add_textstat_grade(df) - min_grade = 10 - max_grade = 25 - df = df[df['flesch_grade'] >= min_grade] - df = df[df['flesch_grade'] <= max_grade] - print("After Flesch grade") - print(df.describe()) - df.to_parquet(output_file, index=False) - - file = output_file - output_file = "h2oGPT.cleaned.graded2.human_bot.shorter.parquet" - if not os.path.exists(output_file): - # slower than alt-profanity, do last, but do before deberta grading, since that's slower - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df = add_better_profanity_flag(df) - before_rows = df.shape[0] - df = df[df['better_profanity'] == 0] - df = df.drop(['better_profanity'], axis=1) - after_rows = df.shape[0] - print("Dropped %d rows out of %d due to better_profanity" % (before_rows - after_rows, before_rows)) - print(df.describe()) - df.to_parquet(output_file, index=False) - - file = output_file - output_file = 'h2oGPT.cleaned.graded3.human_bot.shorter.parquet' - if not os.path.exists(output_file): - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df = add_deberta_grade(df) - min_grade = 0.3 - max_grade = np.inf - before_rows = df.shape[0] - df = df[df['grade_deberta'] >= min_grade] - df = df[df['grade_deberta'] <= max_grade] - after_rows = df.shape[0] - print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows)) - print("After DeBERTa grade") - print(df.describe()) - df.to_parquet(output_file, index=False) - - file = output_file - output_file = 'h2oGPT.cleaned.graded.human_bot.shorter.parquet' - if df is None: - df = pd.read_parquet(file).reset_index(drop=True) - df.to_parquet(output_file, index=False) - - -@pytest.mark.parametrize( - "fixup_personality, only_personality, deberta_grading", - [ - [False, False, False], - [True, True, False], - [True, False, False], - [True, False, True], - ] -) -def test_add_open_assistant(fixup_personality, only_personality, deberta_grading, save_json=True): - """ - Flatten tree structure into one row per path from root to leaf - Also turn into human_bot prompting format: - : question\n: answer : question2\n: answer2 Etc. - Also saves a .json locally as side-effect - returns list of dicts, containing intput, prompt_type and source - """ - from datasets import load_dataset - data_file = "OpenAssistant/oasst1" - ds = load_dataset(data_file) - df = pd.concat([ds['train'].to_pandas(), ds['validation'].to_pandas()], axis=0) - rows = {} - message_ids = df['message_id'].values.tolist() - message_tree_ids = df['message_tree_id'].values.tolist() - parent_ids = df['parent_id'].values.tolist() - texts = df['text'].values.tolist() - roles = df['role'].values.tolist() - - for i in range(df.shape[0]): - # collect all trees - message_id = message_ids[i] - message_tree_id = message_tree_ids[i] - parent_id = parent_ids[i] - text = texts[i] - if fixup_personality: - text = text.replace("Open Assistant", "h2oGPT") - text = text.replace("Open-Assistant", "h2oGPT") - text = text.replace("open-assistant", "h2oGPT") - text = text.replace("OpenAssistant", "h2oGPT") - text = text.replace("open assistant", "h2oGPT") - text = text.replace("Open Assistand", "h2oGPT") - text = text.replace("Open Assitant", "h2oGPT") - text = text.replace("Open Assistent", "h2oGPT") - text = text.replace("Open Assisstant", "h2oGPT") - text = text.replace("Open Assitent", "h2oGPT") - text = text.replace("Open Assitiant", "h2oGPT") - text = text.replace("Open Assistiant", "h2oGPT") - text = text.replace("Open Assitan ", "h2oGPT ") - text = text.replace("Open Assistan ", "h2oGPT ") - text = text.replace("Open Asistant", "h2oGPT") - text = text.replace("Open Assiant", "h2oGPT") - text = text.replace("Assistant", "h2oGPT") - text = text.replace("LAION AI", "H2O.ai") - text = text.replace("LAION-AI", "H2O.ai") - text = text.replace("LAION,", "H2O.ai,") - text = text.replace("LAION.ai", "H2O.ai") - text = text.replace("LAION.", "H2O.ai.") - text = text.replace("LAION", "H2O.ai") - - role = roles[i] - new_data = (': ' if role == 'prompter' else ': ') + text - entry = dict(message_id=message_id, parent_id=parent_id, text=new_data) - if message_tree_id not in rows: - rows[message_tree_id] = [entry] - else: - rows[message_tree_id].append(entry) - - all_rows = [] - - for node_id in rows: - # order responses in tree, based on message/parent relationship - conversations = [] - - list_msgs = rows[node_id] - # find start - while len(list_msgs): - for i, leaf in enumerate(list_msgs): - found = False - parent_id = leaf['parent_id'] - if parent_id is None: - # conversation starter - conversations.append(leaf) - found = True - else: - for conv in conversations: - # find all conversations to add my message to - if parent_id in conv['message_id'] and parent_id != conv['message_id'][-len(parent_id):]: - # my message doesn't follow conversation - continue - if parent_id == conv['message_id'][-len(parent_id):]: - # my message follows conversation, but fork first, so another follow-on message can do same - conversations.append(conv.copy()) - conv['text'] += f""" -{leaf['text']} -""" - conv['message_id'] += leaf['message_id'] - found = True - break - if found: - # my content was used, so nuke from list - del list_msgs[i] - break - - # now reduce down to final conversations, find the longest chains of message ids - for i, conv in enumerate(conversations): - for j, conv2 in enumerate(conversations): - if i == j: - continue - if conv['message_id'] and conv2['message_id']: - assert conv['message_id'] != conv2['message_id'] - # delete the shorter conversation, if one contains the other - if conv['message_id'] in conv2['message_id']: - conv['message_id'] = None - if conv2['message_id'] in conv['message_id']: - conv2['message_id'] = None - conversations = [c for c in conversations if c['message_id']] - if only_personality: - all_rows.extend( - [dict(input=c['text'] + "\n:", prompt_type='plain', source=data_file) for c in conversations if - 'h2oGPT' in c['text']]) - else: - all_rows.extend( - [dict(input=c['text'] + "\n:", prompt_type='plain', source=data_file) for c in conversations if - "What is H2O.ai" not in c['text']]) - unhelpful = get_unhelpful_list() - all_rows = [x for x in all_rows if not any(u in x['input'] for u in unhelpful)] - personality = create_personality_data() - all_rows.extend(personality * 10) - np.random.seed(123) - np.random.shuffle(all_rows) - print(len(all_rows)) - if deberta_grading: - df = pd.DataFrame(all_rows) - df = df.rename(columns={'input': 'text'}) - df = add_deberta_grade(df) - df = df.rename(columns={'text': 'input'}) - drop = True - if drop: - min_grade = 0.3 - max_grade = np.inf - before_rows = df.shape[0] - df = df[df['grade_deberta'] >= min_grade] - df = df[df['grade_deberta'] <= max_grade] - after_rows = df.shape[0] - print("Dropped %d rows out of %d due to deberta grade" % (before_rows - after_rows, before_rows)) - print("After DeBERTa grade") - print(df.describe()) - all_rows = [] - for i in range(df.shape[0]): - all_rows.append( - dict( - input=df['input'].iloc[i], - source=df['source'].iloc[i], - prompt_type=df['prompt_type'].iloc[i], - grade_deberta=df['grade_deberta'].iloc[i], - ) - ) - if save_json: - data_file = data_file + \ - ("_h2ogpt" if fixup_personality else "") + \ - ("_only" if only_personality else "") + \ - ("_graded" if deberta_grading else "") - for i in range(len(all_rows)): - all_rows[i]['id'] = i - with open(data_file.lower().replace("/", "_") + ".json", "w") as f: - f.write(json.dumps(all_rows, indent=2)) - return all_rows - - -def test_finalize_to_json(): - df = pd.read_parquet('h2oGPT.cleaned.graded.human_bot.shorter.parquet') - df = df.rename(columns={'text': 'input'}) - - print("Number of high-quality human_bot interactions: %s" % df.shape[0], flush=True) - - print("Adding open assistant data") - with open("openassistant_oasst1_h2ogpt_graded.json") as f: - open_assistant = json.loads(f.read()) - df = pd.concat([df, pd.DataFrame(open_assistant)], axis=0) - - def final_clean(df): - from better_profanity import profanity - profanity.load_censor_words_from_file("data/censor_words.txt") - df['profanity'] = parallel_apply( - df['input'], - lambda x: profanity.contains_profanity(x), - n_jobs=-1, - ) - return df[(df['profanity'] == 0)].reset_index(drop=True) - - print("Before cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True) - df = final_clean(df) - print("After cleaning: Number of final high-quality human_bot interactions: %s" % df.shape[0], flush=True) - print(df.describe()) - print(df.shape) - row_list = [] - for i in range(df.shape[0]): - row_list.append( - dict( - input=df.loc[i, 'input'], - source=df.loc[i, 'source'], - prompt_type='plain', - ) - ) - np.random.seed(1234) - np.random.shuffle(row_list) - unhelpful = get_unhelpful_list() - row_list = [x for x in row_list if not any(u in x['input'] for u in unhelpful)] - for i in range(len(row_list)): - row_list[i]['id'] = i - row_list[i]['input'] = row_list[i]['input'].replace(" :", "\n:") - with open('h2ogpt-oig-oasst1-instruct-cleaned-v3.json', "w") as f: - f.write(json.dumps(row_list, indent=2)) - - -def create_personality_data(): - questions = [ - "What's your name?", - "What is your name?", - "What are you?", - "Who are you?", - "Do you have a name?", - "Who trained you?", - "Who created you?", - "Who made you?", - ] - answers = [ - "I'm h2oGPT, a large language model by H2O.ai.", - "I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", - "My name is h2oGPT. I'm a large language model by H2O.ai, the visionary leader in democratizing AI.", - "My name is h2oGPT. I'm a large language model trained by H2O.ai.", - "Hi! I'm h2oGPT, a large language model by H2O.ai.", - "Hi! I'm h2oGPT, a large language model by H2O.ai, the visionary leader in democratizing AI.", - ] - help = [ - "", - " How can I help you?", - " How may I assist you?", - " Nice to meet you.", - ] - import itertools - rows = [] - for pair in itertools.product(questions, answers, help): - rows.append( - dict(input=f": {pair[0]}\n: {pair[1]}{pair[2]}\n:", prompt_type='plain', source="H2O.ai") - ) - for row in [ - ": What is H2O.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": What is h2o.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": What is H2O?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": Who is h2o.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": who is h2o.ai?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": who is h2o?\n: H2O.ai is a technology company that aims to democratize AI and make it accessible to a broader audience by simplifying the process of creating and deploying machine learning models.\n:", - ": What is H2O.ai?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": Who is H2O.ai?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": Who is H2O?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": Who is h2o?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ": who is h2o?\n: H2O.ai is the visionary leader in democratizing AI.\n:", - ]: - rows.append(dict(input=row, prompt_type='plain', source='H2O.ai')) - print(len(rows)) - with open("h2ogpt-personality.json", "w") as f: - f.write(json.dumps(rows, indent=2)) - return rows - - -def test_check_stats_data(): - filename = 'h2ogpt-oig-oasst1-instruct-cleaned-v3.json' - df = pd.read_json(filename) - - # get word stats - df['char_count'] = df['input'].apply(lambda x: len(x)) - import matplotlib.pyplot as plt - plt.figure(figsize=(10, 10)) - plt.hist(df['char_count'], bins=100) - chars_avg = np.mean(df['char_count']) - chars_median = np.median(df['char_count']) - plt.title("char_count avg: %s median: %s" % (chars_avg, chars_median)) - plt.savefig('chars_hist.png') - plt.close() - - # get tokenize stats for random sample of 1000 rows - from finetune import generate_and_tokenize_prompt - from loaders import get_loaders, get_tokenizer - from functools import partial - - llama_type = False - tokenizer_base_model = base_model = 'h2oai/h2ogpt-oasst1-512-20b' - model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=False, llama_type=llama_type) - local_files_only = False - resume_download = True - use_auth_token = False - tokenizer = get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token) - prompt_type = 'plain' # trained with data already in human bot form - train_on_inputs = True - add_eos_token = False - cutoff_len = 512 # can choose 2048 - generate_and_tokenize_prompt_fun = partial(generate_and_tokenize_prompt, prompt_type=prompt_type, - train_on_inputs=train_on_inputs, add_eos_token=add_eos_token, - cutoff_len=cutoff_len, tokenizer=tokenizer) - from datasets import load_dataset - data = load_dataset("json", data_files={"train": filename}) - val_set_size = 0.90 - train_val = data["train"].train_test_split( - test_size=val_set_size, shuffle=True, seed=42 - ) - train_data = train_val["train"] - train_data = train_data.shuffle().map(generate_and_tokenize_prompt_fun, num_proc=os.cpu_count()) - - df_tokens = pd.DataFrame([len(x) for x in train_data['input_ids']], columns=['token_count']) - - plt.figure(figsize=(10, 10)) - plt.hist(df_tokens['token_count'], bins=100) - token_avg = np.mean(df_tokens['token_count']) - token_median = np.median(df_tokens['token_count']) - plt.title("token_count with cutoff=%s avg: %s median: %s" % (cutoff_len, token_avg, token_median)) - plt.savefig('token_hist_%s.png' % cutoff_len) - plt.close() - - -def get_unhelpful_list(): - # base versions - unhelpful = ["I'm sorry, I didn't quite understand your question, could you please rephrase it?", - "I'm sorry, but I don't understand your question. Could you please rephrase it?", - "I'm sorry, I don't quite understand your question", - "I'm sorry, I don't know", - "I'm sorry, but I don't know", - "I don't know anything", - "I do not know", - "I don't know", - "I don't know how", - "I do not know how", - "Can you please explain what you mean", - "please explain what you mean", - "please explain", - "I'm sorry, but I don't know how to tell a story. Can you please explain what you mean by", - "I'm sorry but I don't understand what you mean", - "I don't understand", - "I don't have the ability", - "I do not have the ability", - "I do not have", - "I am a language model,", - "I am a large language model,", - "I do not understand your question. Can you please try to make it clearer?", - "I'm sorry, but as an AI language model", - "I apologize, but I cannot rephrase text that I cannot understand. Your post is difficult to read and follow.", - "I apologize, but I am not h2oGPT. I am a language model developed by H2O.ai. How may I help you?", - "Sorry, but I am not an actual Linux shell, nor am I capable of emulating one. I am an open source chat assistant and would be glad t", - "I apologize, but I cannot perform the task you have requested.", - "I'm sorry, I cannot perform this task as I am an AI language model and do not have access", - "I'm sorry, I'm not sure what you're asking for here.", - "I'm not sure what you are asking", - "You need to provide more context", - ] - # reduced versions, with redundant parts, just to give context for where they came from - unhelpful += ["sorry, I didn't quite understand your question", - "I didn't quite understand your question", - "I didn't understand your question", - "I did not understand your question", - "I did not understand the question", - "could you please rephrase" - "could you rephrase" - "I do not understand your question.", - "I do not understand the question.", - "I do not understand that question.", - "Can you please try to make it clearer", - "Can you try to make it clearer", - "sorry, but as an AI language model", - "as an AI language model", - "I apologize, but I cannot", - "I cannot rephrase text", - "I cannot understand. Your post is difficult to read and follow." - "Your post is difficult to read and follow." - "I apologize, but I am", - "Sorry, but I am not ", - "nor am I capable", - "I am not capable of", - "I apologize, but I cannot perform the task you have requested", - "I cannot perform the task", - "I cannot complete the task", - "I'm sorry", - "I am sorry", - "do not have access", - "not sure what you're asking for", - "not sure what you are asking for", - "not sure what is being asked", - "I'm not sure what you are asking", - "not sure what you are asking", - "You need to provide more context", - "provide more context", - ] - unhelpful += ["As a large language model", - "cannot provide any information", - "As an artificial intelligence I do not have the capability", - "As an artificial intelligence I don't have the capability", - "As an artificial intelligence I can't", - "As an artificial intelligence I cannot", - "I am sorry but I do not understand", - "Can you please explain", - "(sorry couldn't resist)", - "(sorry could not resist)", - " :)", - " ;)", - " :-)", - " ;-)", - " lol ", - "Thanks so much!!!", - "Thank You :)!!!", - "Please try not to repeat", - "I am an AI language model", - "I'm a AI assistant that", - "I'm an AI assistant that", - "I am an AI assistant that", - "etc.", - "etc.etc.", - "etc. etc.", - "etc etc", - ] - return unhelpful - - -def test_check_unhelpful(): - # file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_graded.json' - file = '/home/jon/Downloads/openassistant_oasst1_h2ogpt_grades.json' - # file = 'h2ogpt-oig-oasst1-instruct-cleaned-v2.json' - - unhelpful = get_unhelpful_list() - # data = json.load(open(file, 'rt')) - df = pd.read_json(file) - - use_reward_score_threshold = False - use_bleu_threshold = False - use_sentence_sim = True - - from sacrebleu.metrics import BLEU - bleu = BLEU() - from nltk.translate.bleu_score import sentence_bleu - - def get_bleu(actual, expected_list): - # return bleu.sentence_score(actual, expected_list).score - return sentence_bleu(expected_list, actual) - - threshold = 0.0 - if use_reward_score_threshold: - df = df[df['grade_deberta'] > threshold] - - # back to as if original json load - data = df.to_dict(orient='records') - bads = {} - string_all = str(data) - for sub in unhelpful: - bads[sub] = string_all.count(sub) - bads = {k: v for k, v in bads.items() if v > 0} - import pprint - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(bads) - - total_bads = sum(list(bads.values())) - print('total_bads: %s' % total_bads, flush=True) - - # check just bot - import re - convs = [[x.strip() for x in re.split(r'%s|%s' % (human, bot), y['input']) if x.strip()] for y in data] - humans = [[x for i, x in enumerate(y) if i % 2 == 0] for y in convs] - bots = [[x for i, x in enumerate(y) if i % 2 == 1] for y in convs] - - # FIXME: apply back to json etc., just see for now - bleu_threshold = 0.9 - if use_bleu_threshold: - bots = [[x for x in y if get_bleu(x, unhelpful) < bleu_threshold] for y in tqdm(bots)] - - cosine_sim_threshold = 0.8 - if use_sentence_sim: - # pip install sentence_transformers-2.2.2 - from sentence_transformers import SentenceTransformer - # sent_model = 'bert-base-nli-mean-tokens' - # sent_model = 'nli-distilroberta-base-v2' - sent_model = 'all-MiniLM-L6-v2' - model = SentenceTransformer(sent_model) - sentence_embeddings = model.encode(unhelpful) - from sklearn.metrics.pairwise import cosine_similarity - bots = [x for x in tqdm(bots) if - np.max(cosine_similarity(model.encode(x), sentence_embeddings)) < cosine_sim_threshold] - - bads_bots = {} - string_all = str(bots) - for sub in unhelpful: - bads_bots[sub] = string_all.count(sub) - bads_bots = {k: v for k, v in bads_bots.items() if v > 0} - import pprint - pp = pprint.PrettyPrinter(indent=4) - pp.pprint(bads_bots) - - total_bads_bots = sum(list(bads_bots.values())) - print('threshold: %g use_bleu_threshold: %g total_bads_bots: %s total_bots: %s total_humans: %s' % ( - threshold, use_bleu_threshold, total_bads_bots, len(bots), len(humans)), flush=True) - - # assert len(bads) == 0, bads - assert len(bads_bots) == 0, bads_bots - - -def test_fortune2000_personalized(): - row_list = [] - import glob - if not os.path.isdir("wikitext"): - raise RuntimeError("download https://github.com/h2oai/h2ogpt/files/11423008/wikitext.zip and unzip") - for file in glob.glob("wikitext/*.txt"): - with open(file, "r") as f: - blob = f.read() - N = 512 * 4 - row_list.extend([{'input': s, 'prompt_type': 'plain', 'source': "%s" % os.path.basename(file)} - for s in get_sentences(blob, N) if s]) - personality = create_personality_data() - import copy - for i in range(10): - row_list.extend(copy.deepcopy(personality)) - np.random.seed(123) - np.random.shuffle(row_list) - for i in range(len(row_list)): - row_list[i]['id'] = i - for i in range(len(row_list)): - assert row_list[i]['id'] == i - with open("h2ogpt-fortune2000-personalized.json", "w") as ff: - ff.write(json.dumps(row_list, indent=2)) diff --git a/enums.py b/enums.py deleted file mode 100644 index 2041b8c24f3bbb7bf0e368ebbdbc482adbb4da80..0000000000000000000000000000000000000000 --- a/enums.py +++ /dev/null @@ -1,120 +0,0 @@ -from enum import Enum - - -class PromptType(Enum): - custom = -1 - plain = 0 - instruct = 1 - quality = 2 - human_bot = 3 - dai_faq = 4 - summarize = 5 - simple_instruct = 6 - instruct_vicuna = 7 - instruct_with_end = 8 - human_bot_orig = 9 - prompt_answer = 10 - open_assistant = 11 - wizard_lm = 12 - wizard_mega = 13 - instruct_vicuna2 = 14 - instruct_vicuna3 = 15 - wizard2 = 16 - wizard3 = 17 - instruct_simple = 18 - wizard_vicuna = 19 - openai = 20 - openai_chat = 21 - gptj = 22 - prompt_answer_openllama = 23 - vicuna11 = 24 - mptinstruct = 25 - mptchat = 26 - falcon = 27 - guanaco = 28 - llama2 = 29 - - -class DocumentSubset(Enum): - Relevant = 0 - RelSources = 1 - TopKSources = 2 - - -non_query_commands = [ - DocumentSubset.RelSources.name, - DocumentSubset.TopKSources.name -] - - -class DocumentChoice(Enum): - ALL = 'All' - - -class LangChainMode(Enum): - """LangChain mode""" - - DISABLED = "Disabled" - LLM = "LLM" - ALL = "All" - WIKI = "wiki" - WIKI_FULL = "wiki_full" - USER_DATA = "UserData" - MY_DATA = "MyData" - GITHUB_H2OGPT = "github h2oGPT" - H2O_DAI_DOCS = "DriverlessAI docs" - - -# modes should not be removed from visible list or added by name -langchain_modes_intrinsic = [LangChainMode.DISABLED.value, - LangChainMode.LLM.value, - LangChainMode.MY_DATA.value] - - -class LangChainAction(Enum): - """LangChain action""" - - QUERY = "Query" - # WIP: - # SUMMARIZE_MAP = "Summarize_map_reduce" - SUMMARIZE_MAP = "Summarize" - SUMMARIZE_ALL = "Summarize_all" - SUMMARIZE_REFINE = "Summarize_refine" - - -class LangChainAgent(Enum): - """LangChain agents""" - - SEARCH = "Search" - # CSV = "csv" # WIP - - -no_server_str = no_lora_str = no_model_str = '[None/Remove]' - -# from site-packages/langchain/llms/openai.py -# but needed since ChatOpenAI doesn't have this information -model_token_mapping = { - "gpt-4": 8192, - "gpt-4-0314": 8192, - "gpt-4-32k": 32768, - "gpt-4-32k-0314": 32768, - "gpt-3.5-turbo": 4096, - "gpt-3.5-turbo-16k": 16 * 1024, - "gpt-3.5-turbo-0301": 4096, - "text-ada-001": 2049, - "ada": 2049, - "text-babbage-001": 2040, - "babbage": 2049, - "text-curie-001": 2049, - "curie": 2049, - "davinci": 2049, - "text-davinci-003": 4097, - "text-davinci-002": 4097, - "code-davinci-002": 8001, - "code-davinci-001": 8001, - "code-cushman-002": 2048, - "code-cushman-001": 2048, -} - -source_prefix = "Sources [Score | Link]:" -source_postfix = "End Sources

" diff --git a/evaluate_params.py b/evaluate_params.py deleted file mode 100644 index 40f89ecb40ee60cb53ed12b8764e28b309979c63..0000000000000000000000000000000000000000 --- a/evaluate_params.py +++ /dev/null @@ -1,52 +0,0 @@ -input_args_list = ['model_state', 'my_db_state', 'selection_docs_state'] - - -no_default_param_names = [ - 'instruction', - 'iinput', - 'context', - 'instruction_nochat', - 'iinput_nochat', -] - -gen_hyper = ['temperature', - 'top_p', - 'top_k', - 'num_beams', - 'max_new_tokens', - 'min_new_tokens', - 'early_stopping', - 'max_time', - 'repetition_penalty', - 'num_return_sequences', - 'do_sample', - ] - -eval_func_param_names = ['instruction', - 'iinput', - 'context', - 'stream_output', - 'prompt_type', - 'prompt_dict'] + \ - gen_hyper + \ - ['chat', - 'instruction_nochat', - 'iinput_nochat', - 'langchain_mode', - 'add_chat_history_to_context', - 'langchain_action', - 'langchain_agents', - 'top_k_docs', - 'chunk', - 'chunk_size', - 'document_subset', - 'document_choice', - ] - -# form evaluate defaults for submit_nochat_api -eval_func_param_names_defaults = eval_func_param_names.copy() -for k in no_default_param_names: - if k in eval_func_param_names_defaults: - eval_func_param_names_defaults.remove(k) - -eval_extra_columns = ['prompt', 'response', 'score'] diff --git a/gen.py b/gen.py deleted file mode 100644 index 227286d0c311e96bca7cffdf01fb6aa7ed018cb6..0000000000000000000000000000000000000000 --- a/gen.py +++ /dev/null @@ -1,2641 +0,0 @@ -import ast -import copy -import functools -import glob -import inspect -import queue -import sys -import os -import time -import traceback -import typing -import warnings -from datetime import datetime -import filelock -import requests -import psutil -from requests import ConnectTimeout, JSONDecodeError -from urllib3.exceptions import ConnectTimeoutError, MaxRetryError, ConnectionError -from requests.exceptions import ConnectionError as ConnectionError2 -from requests.exceptions import ReadTimeout as ReadTimeout2 - -if os.path.dirname(os.path.abspath(__file__)) not in sys.path: - sys.path.append(os.path.dirname(os.path.abspath(__file__))) - -os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' -os.environ['BITSANDBYTES_NOWELCOME'] = '1' -warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated') - -from evaluate_params import eval_func_param_names, no_default_param_names -from enums import DocumentSubset, LangChainMode, no_lora_str, model_token_mapping, no_model_str, source_prefix, \ - source_postfix, LangChainAction, LangChainAgent, DocumentChoice -from loaders import get_loaders -from utils import set_seed, clear_torch_cache, save_generate_output, NullContext, wrapped_partial, EThread, get_githash, \ - import_matplotlib, get_device, makedirs, get_kwargs, start_faulthandler, get_hf_server, FakeTokenizer, remove, \ - have_langchain, set_openai, load_collection_enum - -start_faulthandler() -import_matplotlib() - -SEED = 1236 -set_seed(SEED) - -from typing import Union - -import fire -import torch -from transformers import GenerationConfig, AutoModel, TextIteratorStreamer - -from prompter import Prompter, inv_prompt_type_to_model_lower, non_hf_types, PromptType, get_prompt, generate_prompt -from stopping import get_stopping - -langchain_actions = [x.value for x in list(LangChainAction)] - -langchain_agents_list = [x.value for x in list(LangChainAgent)] - -scratch_base_dir = '/tmp/' - - -def main( - load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = True, - load_gptq: str = '', - use_safetensors: bool = False, - use_gpu_id: bool = True, - base_model: str = '', - tokenizer_base_model: str = '', - lora_weights: str = "", - gpu_id: int = 0, - compile_model: bool = True, - use_cache: bool = None, - inference_server: str = "", - prompt_type: Union[int, str] = None, - prompt_dict: typing.Dict = None, - - model_lock: typing.List[typing.Dict[str, str]] = None, - model_lock_columns: int = None, - fail_if_cannot_connect: bool = False, - - # input to generation - temperature: float = None, - top_p: float = None, - top_k: int = None, - num_beams: int = None, - repetition_penalty: float = None, - num_return_sequences: int = None, - do_sample: bool = None, - max_new_tokens: int = None, - min_new_tokens: int = None, - early_stopping: Union[bool, str] = None, - max_time: float = None, - - memory_restriction_level: int = None, - debug: bool = False, - save_dir: str = None, - share: bool = False, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: Union[str, bool] = True, - offload_folder: str = "offline_folder", - - src_lang: str = "English", - tgt_lang: str = "Russian", - - cli: bool = False, - cli_loop: bool = True, - gradio: bool = True, - gradio_offline_level: int = 0, - chat: bool = True, - chat_context: bool = False, - stream_output: bool = True, - show_examples: bool = None, - verbose: bool = False, - h2ocolors: bool = True, - dark: bool = False, # light tends to be best - height: int = 600, - show_lora: bool = True, - login_mode_if_model0: bool = False, - block_gradio_exit: bool = True, - concurrency_count: int = 1, - api_open: bool = False, - allow_api: bool = True, - input_lines: int = 1, - gradio_size: str = None, - auth: typing.List[typing.Tuple[str, str]] = None, - max_max_time=None, - max_max_new_tokens=None, - - sanitize_user_prompt: bool = False, - sanitize_bot_response: bool = False, - - extra_model_options: typing.List[str] = [], - extra_lora_options: typing.List[str] = [], - extra_server_options: typing.List[str] = [], - - score_model: str = 'auto', - - eval_filename: str = None, - eval_prompts_only_num: int = 0, - eval_prompts_only_seed: int = 1234, - eval_as_output: bool = False, - - langchain_mode: str = None, - langchain_action: str = LangChainAction.QUERY.value, - langchain_agents: list = [], - force_langchain_evaluate: bool = False, - langchain_modes: list = [x.value for x in list(LangChainMode)], - visible_langchain_modes: list = ['UserData', 'MyData'], - # WIP: - # visible_langchain_actions: list = langchain_actions.copy(), - visible_langchain_actions: list = [LangChainAction.QUERY.value, LangChainAction.SUMMARIZE_MAP.value], - visible_langchain_agents: list = langchain_agents_list.copy(), - document_subset: str = DocumentSubset.Relevant.name, - document_choice: list = [DocumentChoice.ALL.value], - user_path: str = None, - langchain_mode_paths: dict = {'UserData': None}, - detect_user_path_changes_every_query: bool = False, - use_llm_if_no_docs: bool = False, - load_db_if_exists: bool = True, - keep_sources_in_context: bool = False, - db_type: str = 'chroma', - use_openai_embedding: bool = False, - use_openai_model: bool = False, - hf_embedding_model: str = None, - cut_distance: float = 1.64, - add_chat_history_to_context: bool = True, - allow_upload_to_user_data: bool = True, - reload_langchain_state: bool = True, - allow_upload_to_my_data: bool = True, - enable_url_upload: bool = True, - enable_text_upload: bool = True, - enable_sources_list: bool = True, - chunk: bool = True, - chunk_size: int = 512, - top_k_docs: int = None, - reverse_docs: bool = True, - auto_reduce_chunks: bool = True, - max_chunks: int = 100, - n_jobs: int = -1, - enable_captions: bool = True, - captions_model: str = "Salesforce/blip-image-captioning-base", - pre_load_caption_model: bool = False, - caption_gpu: bool = True, - enable_ocr: bool = False, - enable_pdf_ocr: str = 'auto', -): - """ - - :param load_8bit: load model in 8-bit using bitsandbytes - :param load_4bit: load model in 4-bit using bitsandbytes - :param load_half: load model in float16 - :param load_gptq: to load model with GPTQ, put model_basename here, e.g. gptq_model-4bit--1g - :param use_safetensors: to use safetensors version (assumes file/HF points to safe tensors version) - :param use_gpu_id: whether to control devices with gpu_id. If False, then spread across GPUs - :param base_model: model HF-type name. If use --base_model to preload model, cannot unload in gradio in models tab - :param tokenizer_base_model: tokenizer HF-type name. Usually not required, inferred from base_model. - :param lora_weights: LORA weights path/HF link - :param gpu_id: if use_gpu_id, then use gpu_id for cuda device ID, or auto mode if gpu_id != -1 - :param compile_model Whether to compile the model - :param use_cache: Whether to use caching in model (some models fail when multiple threads use) - :param inference_server: Consume base_model as type of model at this address - Address can be text-generation-server hosting that base_model - e.g. python generate.py --inference_server="http://192.168.1.46:6112" --base_model=h2oai/h2ogpt-oasst1-512-12b - Or Address can be "openai_chat" or "openai" for OpenAI API - e.g. python generate.py --inference_server="openai_chat" --base_model=gpt-3.5-turbo - e.g. python generate.py --inference_server="openai" --base_model=text-davinci-003 - Or Address can be "vllm:IP:port" or "vllm:IP:port" for OpenAI-compliant vLLM endpoint - Note: vllm_chat not supported by vLLM project. - :param prompt_type: type of prompt, usually matched to fine-tuned model or plain for foundational model - :param prompt_dict: If prompt_type=custom, then expects (some) items returned by get_prompt(..., return_dict=True) - :param model_lock: Lock models to specific combinations, for ease of use and extending to many models - Only used if gradio = True - List of dicts, each dict has base_model, tokenizer_base_model, lora_weights, inference_server, prompt_type, and prompt_dict - If all models have same prompt_type, and prompt_dict, can still specify that once in CLI outside model_lock as default for dict - Can specify model_lock instead of those items on CLI - As with CLI itself, base_model can infer prompt_type and prompt_dict if in prompter.py. - Also, tokenizer_base_model and lora_weights are optional. - Also, inference_server is optional if loading model from local system. - All models provided will automatically appear in compare model mode - Model loading-unloading and related choices will be disabled. Model/lora/server adding will be disabled - :param model_lock_columns: How many columns to show if locking models (and so showing all at once) - If None, then defaults to up to 3 - if -1, then all goes into 1 row - Maximum value is 4 due to non-dynamic gradio rendering elements - :param fail_if_cannot_connect: if doing model locking (e.g. with many models), fail if True. Otherwise ignore. - Useful when many endpoints and want to just see what works, but still have to wait for timeout. - :param temperature: generation temperature - :param top_p: generation top_p - :param top_k: generation top_k - :param num_beams: generation number of beams - :param repetition_penalty: generation repetition penalty - :param num_return_sequences: generation number of sequences (1 forced for chat) - :param do_sample: generation sample - :param max_new_tokens: generation max new tokens - :param min_new_tokens: generation min tokens - :param early_stopping: generation early stopping - :param max_time: maximum time to allow for generation - :param memory_restriction_level: 0 = no restriction to tokens or model, 1 = some restrictions on token 2 = HF like restriction 3 = very low memory case - :param debug: enable debug mode - :param save_dir: directory chat data is saved to - :param share: whether to share the gradio app with sharable URL - :param local_files_only: whether to only use local files instead of doing to HF for models - :param resume_download: whether to resume downloads from HF for models - :param use_auth_token: whether to use HF auth token (requires CLI did huggingface-cli login before) - :param trust_remote_code: whether to use trust any code needed for HF model - :param offload_folder: path for spilling model onto disk - :param src_lang: source languages to include if doing translation (None = all) - :param tgt_lang: target languages to include if doing translation (None = all) - :param cli: whether to use CLI (non-gradio) interface. - :param cli_loop: whether to loop for CLI (False usually only for testing) - :param gradio: whether to enable gradio, or to enable benchmark mode - :param gradio_offline_level: > 0, then change fonts so full offline - == 1 means backend won't need internet for fonts, but front-end UI might if font not cached - == 2 means backend and frontend don't need internet to download any fonts. - Note: Some things always disabled include HF telemetry, gradio telemetry, chromadb posthog that involve uploading. - This option further disables google fonts for downloading, which is less intrusive than uploading, - but still required in air-gapped case. The fonts don't look as nice as google fonts, but ensure full offline behavior. - Also set --share=False to avoid sharing a gradio live link. - :param chat: whether to enable chat mode with chat history - :param chat_context: whether to use extra helpful context if human_bot - :param stream_output: whether to stream output - :param show_examples: whether to show clickable examples in gradio - :param verbose: whether to show verbose prints - :param h2ocolors: whether to use H2O.ai theme - :param dark: whether to use dark mode for UI by default (still controlled in UI) - :param height: height of chat window - :param show_lora: whether to show LORA options in UI (expert so can be hard to understand) - :param login_mode_if_model0: set to True to load --base_model after client logs in, to be able to free GPU memory when model is swapped - :param block_gradio_exit: whether to block gradio exit (used for testing) - :param concurrency_count: gradio concurrency count (1 is optimal for LLMs) - :param api_open: If False, don't let API calls skip gradio queue - :param allow_api: whether to allow API calls at all to gradio server - :param input_lines: how many input lines to show for chat box (>1 forces shift-enter for submit, else enter is submit) - :param gradio_size: Overall size of text and spaces: "xsmall", "small", "medium", "large". - Small useful for many chatbots in model_lock mode - :param auth: gradio auth for launcher in form [(user1, pass1), (user2, pass2), ...] - e.g. --auth=[('jon','password')] with no spaces - :param max_max_time: Maximum max_time for gradio slider - :param max_max_new_tokens: Maximum max_new_tokens for gradio slider - :param sanitize_user_prompt: whether to remove profanity from user input (slows down input processing) - :param sanitize_bot_response: whether to remove profanity and repeat lines from bot output (about 2x slower generation for long streaming cases due to better_profanity being slow) - :param extra_model_options: extra models to show in list in gradio - :param extra_lora_options: extra LORA to show in list in gradio - :param extra_server_options: extra servers to show in list in gradio - :param score_model: which model to score responses - None: no response scoring - 'auto': auto mode, '' (no model) for CPU, 'OpenAssistant/reward-model-deberta-v3-large-v2' for GPU, - because on CPU takes too much compute just for scoring response - :param eval_filename: json file to use for evaluation, if None is sharegpt - :param eval_prompts_only_num: for no gradio benchmark, if using eval_filename prompts for eval instead of examples - :param eval_prompts_only_seed: for no gradio benchmark, seed for eval_filename sampling - :param eval_as_output: for no gradio benchmark, whether to test eval_filename output itself - :param langchain_mode: Data source to include. Choose "UserData" to only consume files from make_db.py. - None: auto mode, check if langchain package exists, at least do LLM if so, else Disabled - WARNING: wiki_full requires extra data processing via read_wiki_full.py and requires really good workstation to generate db, unless already present. - :param langchain_action: Mode langchain operations in on documents. - Query: Make query of document(s) - Summarize or Summarize_map_reduce: Summarize document(s) via map_reduce - Summarize_all: Summarize document(s) using entire document at once - Summarize_refine: Summarize document(s) using entire document, and try to refine before returning summary - :param langchain_agents: Which agents to use - 'search': Use Web Search as context for LLM response, e.g. SERP if have SERPAPI_API_KEY in env - :param force_langchain_evaluate: Whether to force langchain LLM use even if not doing langchain, mostly for testing. - :param user_path: user path to glob from to generate db for vector search, for 'UserData' langchain mode. - If already have db, any new/changed files are added automatically if path set, does not have to be same path used for prior db sources - :param langchain_mode_paths: dict of langchain_mode keys and disk path values to use for source of documents - E.g. "{'UserData2': 'userpath2'}" - Can be None even if existing DB, to avoid new documents being added from that path, source links that are on disk still work. - If user_path is not None, that path is used for 'UserData' instead of the value in this dict - :param detect_user_path_changes_every_query: whether to detect if any files changed or added every similarity search (by file hashes). - Expensive for large number of files, so not done by default. By default only detect changes during db loading. - :param langchain_modes: names of collections/dbs to potentially have - :param visible_langchain_modes: dbs to generate at launch to be ready for LLM - Can be up to ['wiki', 'wiki_full', 'UserData', 'MyData', 'github h2oGPT', 'DriverlessAI docs'] - But wiki_full is expensive and requires preparation - To allow scratch space only live in session, add 'MyData' to list - Default: If only want to consume local files, e.g. prepared by make_db.py, only include ['UserData'] - If have own user modes, need to add these here or add in UI. - A state file is stored in visible_langchain_modes.pkl containing last UI-selected values of: - langchain_modes, visible_langchain_modes, and langchain_mode_paths - Delete the file if you want to start fresh, - but in any case the user_path passed in CLI is used for UserData even if was None or different - :param visible_langchain_actions: Which actions to allow - :param visible_langchain_agents: Which agents to allow - :param document_subset: Default document choice when taking subset of collection - :param document_choice: Chosen document(s) by internal name, 'All' means use all docs - :param use_llm_if_no_docs: Whether to use LLM even if no documents, when langchain_mode=UserData or MyData or custom - :param load_db_if_exists: Whether to load chroma db if exists or re-generate db - :param keep_sources_in_context: Whether to keep url sources in context, not helpful usually - :param db_type: 'faiss' for in-memory or 'chroma' or 'weaviate' for persisted on disk - :param use_openai_embedding: Whether to use OpenAI embeddings for vector db - :param use_openai_model: Whether to use OpenAI model for use with vector db - :param hf_embedding_model: Which HF embedding model to use for vector db - Default is instructor-large with 768 parameters per embedding if have GPUs, else all-MiniLM-L6-v2 if no GPUs - Can also choose simpler model with 384 parameters per embedding: "sentence-transformers/all-MiniLM-L6-v2" - Can also choose even better embedding with 1024 parameters: 'hkunlp/instructor-xl' - We support automatically changing of embeddings for chroma, with a backup of db made if this is done - :param cut_distance: Distance to cut off references with larger distances when showing references. - 1.64 is good to avoid dropping references for all-MiniLM-L6-v2, but instructor-large will always show excessive references. - For all-MiniLM-L6-v2, a value of 1.5 can push out even more references, or a large value of 100 can avoid any loss of references. - :param add_chat_history_to_context: Include chat context when performing action - Not supported yet for openai_chat when using document collection instead of LLM - Also not supported when using CLI mode - :param allow_upload_to_user_data: Whether to allow file uploads to update shared vector db (UserData or custom user dbs) - :param reload_langchain_state: Whether to reload visible_langchain_modes.pkl file that contains any new user collections. - :param allow_upload_to_my_data: Whether to allow file uploads to update scratch vector db - :param enable_url_upload: Whether to allow upload from URL - :param enable_text_upload: Whether to allow upload of text - :param enable_sources_list: Whether to allow list (or download for non-shared db) of list of sources for chosen db - :param chunk: Whether to chunk data (True unless know data is already optimally chunked) - :param chunk_size: Size of chunks, with typically top-4 passed to LLM, so needs to be in context length - :param top_k_docs: number of chunks to give LLM - :param reverse_docs: whether to reverse docs order so most relevant is closest to question. - Best choice for sufficiently smart model, and truncation occurs for oldest context, so best then too. - But smaller 6_9 models fail to use newest context and can get stuck on old information. - :param auto_reduce_chunks: Whether to automatically reduce top_k_docs to fit context given prompt - :param max_chunks: If top_k_docs=-1, maximum number of chunks to allow - :param n_jobs: Number of processors to use when consuming documents (-1 = all, is default) - :param enable_captions: Whether to support captions using BLIP for image files as documents, then preloads that model - :param captions_model: Which model to use for captions. - captions_model: str = "Salesforce/blip-image-captioning-base", # continue capable - captions_model: str = "Salesforce/blip2-flan-t5-xl", # question/answer capable, 16GB state - captions_model: str = "Salesforce/blip2-flan-t5-xxl", # question/answer capable, 60GB state - Note: opt-based blip2 are not permissive license due to opt and Meta license restrictions - Disabled for CPU since BLIP requires CUDA - :param pre_load_caption_model: Whether to preload caption model, or load after forking parallel doc loader - parallel loading disabled if preload and have images, to prevent deadlocking on cuda context - Recommended if using larger caption model - :param caption_gpu: If support caption, then use GPU if exists - :param enable_ocr: Whether to support OCR on images - :param enable_pdf_ocr: 'auto' means only use OCR if normal text extraction fails. Useful for pure image-based PDFs with text - 'on' means always do OCR as additional parsing of same documents - 'off' means don't do OCR (e.g. because it's slow even if 'auto' only would trigger if nothing else worked) - :return: - """ - if base_model is None: - base_model = '' - if tokenizer_base_model is None: - tokenizer_base_model = '' - if lora_weights is None: - lora_weights = '' - if inference_server is None: - inference_server = '' - - # listen to env if set - model_lock = os.getenv('model_lock', str(model_lock)) - model_lock = ast.literal_eval(model_lock) - - if model_lock: - assert gradio, "model_lock only supported for gradio=True" - if len(model_lock) > 1: - assert chat, "model_lock only works for multiple models for chat=True" - assert not cli, "model_lock only supported for cli=False" - assert not (not cli and not gradio), "model_lock only supported for eval (cli=gradio=False)" - assert not base_model, "Don't specify model_lock and base_model" - assert not tokenizer_base_model, "Don't specify model_lock and tokenizer_base_model" - assert not lora_weights, "Don't specify model_lock and lora_weights" - assert not inference_server, "Don't specify model_lock and inference_server" - # assert not prompt_type, "Don't specify model_lock and prompt_type" - # assert not prompt_dict, "Don't specify model_lock and prompt_dict" - - n_jobs = int(os.getenv('n_jobs', str(n_jobs))) - is_hf = bool(int(os.getenv("HUGGINGFACE_SPACES", '0'))) - is_gpth2oai = bool(int(os.getenv("GPT_H2O_AI", '0'))) - is_public = is_hf or is_gpth2oai # multi-user case with fixed model and disclaimer - if memory_restriction_level is None: - memory_restriction_level = 2 if is_hf else 0 # 2 assumes run on 24GB consumer GPU - else: - assert 0 <= memory_restriction_level <= 3, "Bad memory_restriction_level=%s" % memory_restriction_level - if is_public and os.getenv('n_jobs') is None: - n_jobs = max(1, min(os.cpu_count() // 2, 8)) - admin_pass = os.getenv("ADMIN_PASS") - # will sometimes appear in UI or sometimes actual generation, but maybe better than empty result - # but becomes unrecoverable sometimes if raise, so just be silent for now - raise_generate_gpu_exceptions = True - - # allow set token directly - use_auth_token = os.environ.get("HUGGINGFACE_API_TOKEN", use_auth_token) - allow_upload_to_user_data = bool( - int(os.environ.get("allow_upload_to_user_data", str(int(allow_upload_to_user_data))))) - allow_upload_to_my_data = bool(int(os.environ.get("allow_upload_to_my_data", str(int(allow_upload_to_my_data))))) - height = int(os.environ.get("HEIGHT", height)) - h2ocolors = bool(int(os.getenv('h2ocolors', h2ocolors))) - - # allow enabling langchain via ENV - # FIRST PLACE where LangChain referenced, but no imports related to it - langchain_mode = os.environ.get("LANGCHAIN_MODE", langchain_mode) - if langchain_mode is not None: - assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode - visible_langchain_modes = ast.literal_eval(os.environ.get("visible_langchain_modes", str(visible_langchain_modes))) - if langchain_mode not in visible_langchain_modes and langchain_mode in langchain_modes: - if langchain_mode is not None: - visible_langchain_modes += [langchain_mode] - - # update - if isinstance(langchain_mode_paths, str): - langchain_mode_paths = ast.literal_eval(langchain_mode_paths) - assert isinstance(langchain_mode_paths, dict) - if user_path: - langchain_mode_paths['UserData'] = user_path - makedirs(user_path) - - if is_public: - allow_upload_to_user_data = False - if LangChainMode.USER_DATA.value in visible_langchain_modes: - visible_langchain_modes.remove(LangChainMode.USER_DATA.value) - - # in-place, for non-scratch dbs - if allow_upload_to_user_data: - update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '') - # always listen to CLI-passed user_path if passed - if user_path: - langchain_mode_paths['UserData'] = user_path - - assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action - assert len( - set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents - - # if specifically chose not to show My or User Data, disable upload, so gradio elements are simpler - if LangChainMode.MY_DATA.value not in visible_langchain_modes: - allow_upload_to_my_data = False - if LangChainMode.USER_DATA.value not in visible_langchain_modes: - allow_upload_to_user_data = False - - # auto-set langchain_mode - if have_langchain and langchain_mode is None: - # start in chat mode, in case just want to chat and don't want to get "No documents to query" by default. - langchain_mode = LangChainMode.LLM.value - if allow_upload_to_user_data and not is_public and langchain_mode_paths['UserData']: - print("Auto set langchain_mode=%s. Could use UserData instead." % langchain_mode, flush=True) - elif allow_upload_to_my_data: - print("Auto set langchain_mode=%s. Could use MyData instead." - " To allow UserData to pull files from disk," - " set user_path or langchain_mode_paths, and ensure allow_upload_to_user_data=True" % langchain_mode, - flush=True) - else: - raise RuntimeError("Please pass --langchain_mode= out of %s" % langchain_modes) - if not have_langchain and langchain_mode not in [None, LangChainMode.DISABLED.value, LangChainMode.LLM.value]: - raise RuntimeError("Asked for LangChain mode but langchain python package cannot be found.") - if langchain_mode is None: - # if not set yet, disable - langchain_mode = LangChainMode.DISABLED.value - print("Auto set langchain_mode=%s Have langchain package: %s" % (langchain_mode, have_langchain), flush=True) - - if is_public: - allow_upload_to_user_data = False - input_lines = 1 # ensure set, for ease of use - temperature = 0.2 if temperature is None else temperature - top_p = 0.85 if top_p is None else top_p - top_k = 70 if top_k is None else top_k - if is_hf: - do_sample = True if do_sample is None else do_sample - top_k_docs = 3 if top_k_docs is None else top_k_docs - else: - # by default don't sample, too chatty - do_sample = False if do_sample is None else do_sample - top_k_docs = 4 if top_k_docs is None else top_k_docs - - if memory_restriction_level == 2: - if not base_model and not inference_server and not model_lock: - base_model = 'h2oai/h2ogpt-oasst1-512-12b' - # don't set load_8bit if passed base_model, doesn't always work so can't just override - load_8bit = True - load_4bit = False # FIXME - consider using 4-bit instead of 8-bit - elif not inference_server: - top_k_docs = 10 if top_k_docs is None else top_k_docs - if memory_restriction_level >= 2: - load_8bit = True - load_4bit = False # FIXME - consider using 4-bit instead of 8-bit - if hf_embedding_model is None: - hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" - top_k_docs = 3 if top_k_docs is None else top_k_docs - if top_k_docs is None: - top_k_docs = 3 - if is_public: - if not max_time: - max_time = 60 * 2 - if not max_max_time: - max_max_time = max_time - if not max_new_tokens: - max_new_tokens = 256 - if not max_max_new_tokens: - max_max_new_tokens = 256 - else: - if not max_max_time: - max_max_time = 60 * 20 - if not max_max_new_tokens: - max_max_new_tokens = 512 - if is_hf: - # must override share if in spaces - share = False - if not max_time: - max_time = 60 * 1 - if not max_max_time: - max_max_time = max_time - # HF accounted for later in get_max_max_new_tokens() - save_dir = os.getenv('SAVE_DIR', save_dir) - score_model = os.getenv('SCORE_MODEL', score_model) - if str(score_model) == 'None': - score_model = '' - concurrency_count = int(os.getenv('CONCURRENCY_COUNT', concurrency_count)) - api_open = bool(int(os.getenv('API_OPEN', str(int(api_open))))) - allow_api = bool(int(os.getenv('ALLOW_API', str(int(allow_api))))) - - n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 - if n_gpus == 0: - enable_captions = False - gpu_id = None - load_8bit = False - load_4bit = False - load_half = False - load_gptq = '' - use_safetensors = False - use_gpu_id = False - torch.backends.cudnn.benchmark = True - torch.backends.cudnn.enabled = False - torch.set_default_dtype(torch.float32) - if psutil.virtual_memory().available < 94 * 1024 ** 3 and not inference_server and not model_lock: - # 12B uses ~94GB - # 6.9B uses ~47GB - base_model = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' if not base_model else base_model - if hf_embedding_model is None: - # if no GPUs, use simpler embedding model to avoid cost in time - hf_embedding_model = "sentence-transformers/all-MiniLM-L6-v2" - if score_model == 'auto': - score_model = '' - else: - if score_model == 'auto': - score_model = 'OpenAssistant/reward-model-deberta-v3-large-v2' - if hf_embedding_model is None: - # if still None, then set default - hf_embedding_model = 'hkunlp/instructor-large' - - # get defaults - if base_model: - model_lower = base_model.lower() - elif model_lock: - # have 0th model be thought of as normal model - assert len(model_lock) > 0 and model_lock[0]['base_model'] - model_lower = model_lock[0]['base_model'].lower() - else: - model_lower = '' - if not gradio: - # force, else not single response like want to look at - stream_output = False - # else prompt removal can mess up output - chat = False - # hard-coded defaults - first_para = False - text_limit = None - - if offload_folder: - makedirs(offload_folder) - - placeholder_instruction, placeholder_input, \ - stream_output, show_examples, \ - prompt_type, prompt_dict, \ - temperature, top_p, top_k, num_beams, \ - max_new_tokens, min_new_tokens, early_stopping, max_time, \ - repetition_penalty, num_return_sequences, \ - do_sample, \ - src_lang, tgt_lang, \ - examples, \ - task_info = \ - get_generate_params(model_lower, - chat, - stream_output, show_examples, - prompt_type, prompt_dict, - temperature, top_p, top_k, num_beams, - max_new_tokens, min_new_tokens, early_stopping, max_time, - repetition_penalty, num_return_sequences, - do_sample, - top_k_docs, - chunk, - chunk_size, - verbose, - ) - - git_hash = get_githash() if is_public or os.getenv('GET_GITHASH') else "GET_GITHASH" - locals_dict = locals() - locals_print = '\n'.join(['%s: %s' % (k, v) for k, v in locals_dict.items()]) - if verbose: - print(f"Generating model with params:\n{locals_print}", flush=True) - print("Command: %s\nHash: %s" % (str(' '.join(sys.argv)), git_hash), flush=True) - - if langchain_mode != "Disabled": - # SECOND PLACE where LangChain referenced, but all imports are kept local so not required - from gpt_langchain import prep_langchain, get_some_dbs_from_hf - if is_hf: - get_some_dbs_from_hf() - dbs = {} - for langchain_mode1 in visible_langchain_modes: - if langchain_mode1 in ['MyData']: # FIXME: Remove other custom temp dbs - # don't use what is on disk, remove it instead - for gpath1 in glob.glob(os.path.join(scratch_base_dir, 'db_dir_%s*' % langchain_mode1)): - if os.path.isdir(gpath1): - print("Removing old MyData: %s" % gpath1, flush=True) - remove(gpath1) - continue - if langchain_mode1 in ['All']: - # FIXME: All should be avoided until scans over each db, shouldn't be separate db - continue - persist_directory1 = 'db_dir_%s' % langchain_mode1 # single place, no special names for each case - try: - db = prep_langchain(persist_directory1, - load_db_if_exists, - db_type, use_openai_embedding, - langchain_mode1, langchain_mode_paths, - hf_embedding_model, - kwargs_make_db=locals()) - finally: - # in case updated embeddings or created new embeddings - clear_torch_cache() - dbs[langchain_mode1] = db - # remove None db's so can just rely upon k in dbs for if hav db - dbs = {k: v for k, v in dbs.items() if v is not None} - else: - dbs = {} - # import control - if os.environ.get("TEST_LANGCHAIN_IMPORT"): - assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" - assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" - - model_state_none = dict(model=None, tokenizer=None, device=None, - base_model=None, tokenizer_base_model=None, lora_weights=None, - inference_server=None, prompt_type=None, prompt_dict=None) - my_db_state0 = {LangChainMode.MY_DATA.value: [None, None]} - selection_docs_state0 = dict(visible_langchain_modes=visible_langchain_modes, - langchain_mode_paths=langchain_mode_paths, - langchain_modes=langchain_modes) - selection_docs_state = selection_docs_state0 - langchain_modes0 = langchain_modes - langchain_mode_paths0 = langchain_mode_paths - visible_langchain_modes0 = visible_langchain_modes - - if cli: - from cli import run_cli - return run_cli(**get_kwargs(run_cli, exclude_names=['model_state0'], **locals())) - elif not gradio: - from eval import run_eval - return run_eval(**get_kwargs(run_eval, exclude_names=['model_state0'], **locals())) - elif gradio: - # imported here so don't require gradio to run generate - from gradio_runner import go_gradio - - # get default model - model_states = [] - model_list = [dict(base_model=base_model, tokenizer_base_model=tokenizer_base_model, lora_weights=lora_weights, - inference_server=inference_server, prompt_type=prompt_type, prompt_dict=prompt_dict)] - model_list0 = copy.deepcopy(model_list) # just strings, safe to deepcopy - model_state0 = model_state_none.copy() - assert len(model_state_none) == len(model_state0) - if model_lock: - model_list = model_lock - for model_dict in reversed(model_list): - # do reverse, so first is default base_model etc., so some logic works in go_gradio() more easily - # handles defaults user didn't have to pass - model_dict['base_model'] = base_model1 = model_dict.get('base_model', '') - model_dict['tokenizer_base_model'] = tokenizer_base_model1 = model_dict.get('tokenizer_base_model', '') - model_dict['lora_weights'] = lora_weights1 = model_dict.get('lora_weights', '') - model_dict['inference_server'] = inference_server1 = model_dict.get('inference_server', '') - prompt_type1 = model_dict.get('prompt_type', model_list0[0]['prompt_type']) # don't use mutated value - # try to infer, ignore empty initial state leading to get_generate_params -> 'plain' - if model_dict.get('prompt_type') is None: - model_lower1 = base_model1.lower() - if model_lower1 in inv_prompt_type_to_model_lower: - prompt_type1 = inv_prompt_type_to_model_lower[model_lower1] - prompt_dict1, error0 = get_prompt(prompt_type1, '', - chat=False, context='', reduced=False, making_context=False, - return_dict=True) - else: - prompt_dict1 = prompt_dict - else: - prompt_dict1 = prompt_dict - model_dict['prompt_type'] = prompt_type1 - model_dict['prompt_dict'] = prompt_dict1 = model_dict.get('prompt_dict', prompt_dict1) - all_kwargs = locals().copy() - all_kwargs.update(dict(base_model=base_model1, tokenizer_base_model=tokenizer_base_model1, - lora_weights=lora_weights1, inference_server=inference_server1)) - if base_model1 and not login_mode_if_model0: - model0, tokenizer0, device = get_model(reward_type=False, - **get_kwargs(get_model, exclude_names=['reward_type'], - **all_kwargs)) - else: - # if empty model, then don't load anything, just get gradio up - model0, tokenizer0, device = None, None, None - if model0 is None: - if fail_if_cannot_connect: - raise RuntimeError("Could not connect, see logs") - # skip - if isinstance(model_lock, list): - model_lock.remove(model_dict) - continue - model_state_trial = dict(model=model0, tokenizer=tokenizer0, device=device) - model_state_trial.update(model_dict) - assert len(model_state_none) == len(model_state_trial) - print("Model %s" % model_dict, flush=True) - if model_lock: - # last in iteration will be first - model_states.insert(0, model_state_trial) - # fill model_state0 so go_gradio() easier, manage model_states separately - model_state0 = model_state_trial.copy() - else: - model_state0 = model_state_trial.copy() - assert len(model_state_none) == len(model_state0) - - # get score model - all_kwargs = locals().copy() - smodel, stokenizer, sdevice = get_score_model(reward_type=True, - **get_kwargs(get_score_model, exclude_names=['reward_type'], - **all_kwargs)) - score_model_state0 = dict(model=smodel, tokenizer=stokenizer, device=sdevice, - base_model=score_model, tokenizer_base_model='', lora_weights='', - inference_server='', prompt_type='', prompt_dict='') - - if enable_captions: - if pre_load_caption_model: - from image_captions import H2OImageCaptionLoader - caption_loader = H2OImageCaptionLoader(caption_gpu=caption_gpu).load_model() - else: - caption_loader = 'gpu' if caption_gpu else 'cpu' - else: - caption_loader = False - - # assume gradio needs everything - go_gradio(**locals()) - - -def get_config(base_model, - use_auth_token=False, - trust_remote_code=True, - offload_folder=None, - triton_attn=False, - long_sequence=True, - return_model=False, - raise_exception=False, - ): - from accelerate import init_empty_weights - with init_empty_weights(): - from transformers import AutoConfig - try: - config = AutoConfig.from_pretrained(base_model, use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder) - except OSError as e: - if raise_exception: - raise - if 'not a local folder and is not a valid model identifier listed on' in str( - e) or '404 Client Error' in str(e): - # e.g. llama, gpjt, etc. - # e.g. HF TGI but not model on HF or private etc. - # HF TGI server only should really require prompt_type, not HF model state - return None, None - else: - raise - if triton_attn and 'mpt-' in base_model.lower(): - config.attn_config['attn_impl'] = 'triton' - if long_sequence: - if 'mpt-7b-storywriter' in base_model.lower(): - config.update({"max_seq_len": 83968}) - if 'mosaicml/mpt-7b-chat' in base_model.lower(): - config.update({"max_seq_len": 4096}) - if 'mpt-30b' in base_model.lower(): - config.update({"max_seq_len": 2 * 8192}) - if return_model and \ - issubclass(config.__class__, tuple(AutoModel._model_mapping.keys())): - model = AutoModel.from_config( - config, - trust_remote_code=trust_remote_code, - ) - else: - # can't infer - model = None - if 'falcon' in base_model.lower(): - config.use_cache = False - - return config, model - - -def get_non_lora_model(base_model, model_loader, load_half, - load_gptq, use_safetensors, - model_kwargs, reward_type, - config, model, - gpu_id=0, - ): - """ - Ensure model gets on correct device - """ - - if model is not None: - # NOTE: Can specify max_memory={0: max_mem, 1: max_mem}, to shard model - # NOTE: Some models require avoiding sharding some layers, - # then would pass no_split_module_classes and give list of those layers. - from accelerate import infer_auto_device_map - device_map = infer_auto_device_map( - model, - dtype=torch.float16 if load_half else torch.float32, - ) - if hasattr(model, 'model'): - device_map_model = infer_auto_device_map( - model.model, - dtype=torch.float16 if load_half else torch.float32, - ) - device_map.update(device_map_model) - else: - device_map = "auto" - - n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 - - if n_gpus > 0: - if gpu_id >= 0: - # FIXME: If really distributes model, tend to get things like: ValueError: gpt_neox.embed_in.weight doesn't have any device set. - # So avoid for now, just put on first GPU, unless score_model, put on last - if reward_type: - device_map = {'': n_gpus - 1} - else: - device_map = {'': min(n_gpus - 1, gpu_id)} - if gpu_id == -1: - device_map = {'': 'cuda'} - else: - device_map = {'': 'cpu'} - model_kwargs['load_in_8bit'] = False - model_kwargs['load_in_4bit'] = False - print('device_map: %s' % device_map, flush=True) - - load_in_8bit = model_kwargs.get('load_in_8bit', False) - load_in_4bit = model_kwargs.get('load_in_4bit', False) - model_kwargs['device_map'] = device_map - model_kwargs['use_safetensors'] = use_safetensors - pop_unused_model_kwargs(model_kwargs) - - if load_gptq: - model_kwargs.pop('torch_dtype', None) - model_kwargs.pop('device_map') - model = model_loader( - model_name_or_path=base_model, - model_basename=load_gptq, - **model_kwargs, - ) - elif load_in_8bit or load_in_4bit or not load_half: - model = model_loader( - base_model, - config=config, - **model_kwargs, - ) - else: - model = model_loader( - base_model, - config=config, - **model_kwargs, - ).half() - return model - - -def get_client_from_inference_server(inference_server, base_model=None, raise_connection_exception=False): - inference_server, headers = get_hf_server(inference_server) - # preload client since slow for gradio case especially - from gradio_utils.grclient import GradioClient - gr_client = None - hf_client = None - if headers is None: - try: - print("GR Client Begin: %s %s" % (inference_server, base_model), flush=True) - # first do sanity check if alive, else gradio client takes too long by default - requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) - gr_client = GradioClient(inference_server) - print("GR Client End: %s" % inference_server, flush=True) - except (OSError, ValueError) as e: - # Occurs when wrong endpoint and should have been HF client, so don't hard raise, just move to HF - gr_client = None - print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(e)), flush=True) - except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2, - JSONDecodeError, ReadTimeout2, KeyError) as e: - t, v, tb = sys.exc_info() - ex = ''.join(traceback.format_exception(t, v, tb)) - print("GR Client Failed %s %s: %s" % (inference_server, base_model, str(ex)), flush=True) - if raise_connection_exception: - raise - - if gr_client is None: - res = None - from text_generation import Client as HFClient - print("HF Client Begin: %s %s" % (inference_server, base_model)) - try: - hf_client = HFClient(inference_server, headers=headers, timeout=int(os.getenv('REQUEST_TIMEOUT', '30'))) - # quick check valid TGI endpoint - res = hf_client.generate('What?', max_new_tokens=1) - hf_client = HFClient(inference_server, headers=headers, timeout=300) - except (ConnectTimeoutError, ConnectTimeout, MaxRetryError, ConnectionError, ConnectionError2, - JSONDecodeError, ReadTimeout2, KeyError) as e: - hf_client = None - t, v, tb = sys.exc_info() - ex = ''.join(traceback.format_exception(t, v, tb)) - print("HF Client Failed %s %s: %s" % (inference_server, base_model, str(ex))) - if raise_connection_exception: - raise - print("HF Client End: %s %s : %s" % (inference_server, base_model, res)) - return inference_server, gr_client, hf_client - - -def get_model( - load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = True, - load_gptq: str = '', - use_safetensors: bool = False, - use_gpu_id: bool = True, - base_model: str = '', - inference_server: str = "", - tokenizer_base_model: str = '', - lora_weights: str = "", - gpu_id: int = 0, - - reward_type: bool = None, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: bool = True, - offload_folder: str = None, - compile_model: bool = True, - - verbose: bool = False, -): - """ - - :param load_8bit: load model in 8-bit, not supported by all models - :param load_4bit: load model in 4-bit, not supported by all models - :param load_half: load model in 16-bit - :param load_gptq: GPTQ model_basename - :param use_safetensors: use safetensors file - :param use_gpu_id: Use torch infer of optimal placement of layers on devices (for non-lora case) - For non-LORA case, False will spread shards across multiple GPUs, but this can lead to cuda:x cuda:y mismatches - So it is not the default - :param base_model: name/path of base model - :param inference_server: whether base_model is hosted locally ('') or via http (url) - :param tokenizer_base_model: name/path of tokenizer - :param lora_weights: name/path - :param gpu_id: which GPU (0..n_gpus-1) or allow all GPUs if relevant (-1) - :param reward_type: reward type model for sequence classification - :param local_files_only: use local files instead of from HF - :param resume_download: resume downloads from HF - :param use_auth_token: assumes user did on CLI `huggingface-cli login` to access private repo - :param trust_remote_code: trust code needed by model - :param offload_folder: offload folder - :param compile_model: whether to compile torch model - :param verbose: - :return: - """ - print("Starting get_model: %s %s" % (base_model, inference_server), flush=True) - - triton_attn = False - long_sequence = True - config_kwargs = dict(use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - triton_attn=triton_attn, - long_sequence=long_sequence) - config, _ = get_config(base_model, **config_kwargs, raise_exception=False) - - if base_model in non_hf_types: - assert config is None, "Expected config None for %s" % base_model - - llama_type_from_config = 'llama' in str(config).lower() - llama_type_from_name = "llama" in base_model.lower() - llama_type = llama_type_from_config or llama_type_from_name - if "xgen" in base_model.lower(): - llama_type = False - if llama_type: - if verbose: - print("Detected as llama type from" - " config (%s) or name (%s)" % (llama_type_from_config, llama_type_from_name), flush=True) - - model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type, - load_gptq=load_gptq) - - tokenizer_kwargs = dict(local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - padding_side='left', - config=config, - ) - if not tokenizer_base_model: - tokenizer_base_model = base_model - - if config is not None and tokenizer_loader is not None and not isinstance(tokenizer_loader, str): - tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, **tokenizer_kwargs) - # sets raw (no cushion) limit - set_model_max_len(config, tokenizer, verbose=False) - # if using fake tokenizer, not really accurate when lots of numbers, give a bit of buffer, else get: - # Generation Failed: Input validation error: `inputs` must have less than 2048 tokens. Given: 2233 - tokenizer.model_max_length = tokenizer.model_max_length - 50 - else: - tokenizer = FakeTokenizer() - - if isinstance(inference_server, str) and inference_server.startswith("http"): - inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server, - base_model=base_model) - client = gr_client or hf_client - # Don't return None, None for model, tokenizer so triggers - return client, tokenizer, 'http' - if isinstance(inference_server, str) and ( - inference_server.startswith('openai') or inference_server.startswith('vllm')): - if inference_server.startswith('openai'): - assert os.getenv('OPENAI_API_KEY'), "Set environment for OPENAI_API_KEY" - # Don't return None, None for model, tokenizer so triggers - # include small token cushion - tokenizer = FakeTokenizer(model_max_length=model_token_mapping[base_model] - 50) - return inference_server, tokenizer, inference_server - assert not inference_server, "Malformed inference_server=%s" % inference_server - if base_model in non_hf_types: - from gpt4all_llm import get_model_tokenizer_gpt4all - model, tokenizer, device = get_model_tokenizer_gpt4all(base_model) - return model, tokenizer, device - - # get local torch-HF model - return get_hf_model(load_8bit=load_8bit, - load_4bit=load_4bit, - load_half=load_half, - load_gptq=load_gptq, - use_safetensors=use_safetensors, - use_gpu_id=use_gpu_id, - base_model=base_model, - tokenizer_base_model=tokenizer_base_model, - lora_weights=lora_weights, - gpu_id=gpu_id, - - reward_type=reward_type, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - compile_model=compile_model, - - llama_type=llama_type, - config_kwargs=config_kwargs, - tokenizer_kwargs=tokenizer_kwargs, - - verbose=verbose) - - -def get_hf_model(load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = True, - load_gptq: str = '', - use_safetensors: bool = False, - use_gpu_id: bool = True, - base_model: str = '', - tokenizer_base_model: str = '', - lora_weights: str = "", - gpu_id: int = 0, - - reward_type: bool = None, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: bool = True, - offload_folder: str = None, - compile_model: bool = True, - - llama_type: bool = False, - config_kwargs=None, - tokenizer_kwargs=None, - - verbose: bool = False, - ): - assert config_kwargs is not None - assert tokenizer_kwargs is not None - - if lora_weights is not None and lora_weights.strip(): - if verbose: - print("Get %s lora weights" % lora_weights, flush=True) - device = get_device() - - if 'gpt2' in base_model.lower(): - # RuntimeError: where expected condition to be a boolean tensor, but got a tensor with dtype Half - load_8bit = False - load_4bit = False - - assert base_model.strip(), ( - "Please choose a base model with --base_model (CLI) or load one from Models Tab (gradio)" - ) - - model_loader, tokenizer_loader = get_loaders(model_name=base_model, reward_type=reward_type, llama_type=llama_type, - load_gptq=load_gptq) - - config, _ = get_config(base_model, return_model=False, raise_exception=True, **config_kwargs) - - if tokenizer_loader is not None and not isinstance(tokenizer_loader, str): - tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, - **tokenizer_kwargs) - else: - tokenizer = tokenizer_loader - - if isinstance(tokenizer, str): - # already a pipeline, tokenizer_loader is string for task - model = model_loader(tokenizer, - model=base_model, - device=0 if device == "cuda" else -1, - torch_dtype=torch.float16 if device == 'cuda' else torch.float32) - else: - assert device in ["cuda", "cpu", "mps"], "Unsupported device %s" % device - model_kwargs = dict(local_files_only=local_files_only, - torch_dtype=torch.float16 if device == 'cuda' else torch.float32, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - ) - if 'mbart-' not in base_model.lower() and 'mpt-' not in base_model.lower(): - if use_gpu_id and gpu_id is not None and gpu_id >= 0 and device == 'cuda': - device_map = {"": gpu_id} - else: - device_map = "auto" - model_kwargs.update(dict(load_in_8bit=load_8bit, - load_in_4bit=load_4bit, - device_map=device_map, - )) - if 'mpt-' in base_model.lower() and gpu_id is not None and gpu_id >= 0: - # MPT doesn't support spreading over GPUs - model_kwargs.update(dict(device_map={"": gpu_id} if device == 'cuda' else "cpu")) - - if 'OpenAssistant/reward-model'.lower() in base_model.lower(): - # FIXME: could put on other GPUs - model_kwargs['device_map'] = {"": 0} if device == 'cuda' else {"": 'cpu'} - model_kwargs.pop('torch_dtype', None) - pop_unused_model_kwargs(model_kwargs) - - if not lora_weights: - # torch.device context uses twice memory for AutoGPTQ - context = NullContext if load_gptq else torch.device - with context(device): - - if use_gpu_id: - config, model = get_config(base_model, return_model=True, raise_exception=True, **config_kwargs) - model = get_non_lora_model(base_model, model_loader, load_half, load_gptq, use_safetensors, - model_kwargs, reward_type, - config, model, - gpu_id=gpu_id, - ) - else: - config, _ = get_config(base_model, **config_kwargs) - if load_half and not (load_8bit or load_4bit or load_gptq): - model = model_loader( - base_model, - config=config, - **model_kwargs).half() - else: - model = model_loader( - base_model, - config=config, - **model_kwargs) - elif load_8bit or load_4bit: - config, _ = get_config(base_model, **config_kwargs) - model = model_loader( - base_model, - config=config, - **model_kwargs - ) - from peft import PeftModel # loads cuda, so avoid in global scope - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16 if device == 'cuda' else torch.float32, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - device_map={"": 0} if device == 'cuda' else {"": 'cpu'}, # seems to be required - ) - else: - with torch.device(device): - config, _ = get_config(base_model, raise_exception=True, **config_kwargs) - model = model_loader( - base_model, - config=config, - **model_kwargs - ) - from peft import PeftModel # loads cuda, so avoid in global scope - model = PeftModel.from_pretrained( - model, - lora_weights, - torch_dtype=torch.float16 if device == 'cuda' else torch.float32, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - trust_remote_code=trust_remote_code, - offload_folder=offload_folder, - device_map="auto", - ) - if load_half and not load_gptq: - model.half() - - # unwind broken decapoda-research config - if llama_type: - model.config.pad_token_id = tokenizer.pad_token_id = 0 # unk - model.config.bos_token_id = 1 - model.config.eos_token_id = 2 - if 'gpt2' in base_model.lower(): - # add special tokens that otherwise all share the same id - tokenizer.add_special_tokens({'bos_token': '', - 'eos_token': '', - 'pad_token': ''}) - - if not isinstance(tokenizer, str): - model.eval() - if torch.__version__ >= "2" and sys.platform != "win32" and compile_model: - model = torch.compile(model) - - set_model_max_len(config, tokenizer, verbose=False, reward_type=reward_type) - - return model, tokenizer, device - - -def set_model_max_len(config, tokenizer, verbose=False, reward_type=False): - if reward_type: - # limit deberta, else uses too much memory and not worth response score - tokenizer.model_max_length = 512 - if hasattr(config, 'max_seq_len') and isinstance(config.max_seq_len, int): - tokenizer.model_max_length = config.max_seq_len - elif hasattr(config, 'max_position_embeddings') and isinstance(config.max_position_embeddings, int): - # help automatically limit inputs to generate - tokenizer.model_max_length = config.max_position_embeddings - else: - if verbose: - print("Could not determine model_max_length, setting to 2048", flush=True) - tokenizer.model_max_length = 2048 - # for bug in HF transformers - if tokenizer.model_max_length > 100000000: - tokenizer.model_max_length = 2048 - - -def pop_unused_model_kwargs(model_kwargs): - """ - in-place pop unused kwargs that are not dependency-upgrade friendly - no point passing in False, is default, and helps avoid needing to update requirements for new deps - :param model_kwargs: - :return: - """ - check_list = ['load_in_8bit', 'load_in_4bit'] - for k in check_list: - if k in model_kwargs and not model_kwargs[k]: - model_kwargs.pop(k) - - -def get_score_model(score_model: str = None, - load_8bit: bool = False, - load_4bit: bool = False, - load_half: bool = True, - load_gptq: str = '', - use_gpu_id: bool = True, - base_model: str = '', - inference_server: str = '', - tokenizer_base_model: str = '', - lora_weights: str = "", - gpu_id: int = 0, - - reward_type: bool = None, - local_files_only: bool = False, - resume_download: bool = True, - use_auth_token: Union[str, bool] = False, - trust_remote_code: bool = True, - offload_folder: str = None, - compile_model: bool = True, - - verbose: bool = False, - ): - if score_model is not None and score_model.strip(): - load_8bit = False - load_4bit = False - load_half = False - load_gptq = '' - use_safetensors = False - base_model = score_model.strip() - tokenizer_base_model = '' - lora_weights = '' - inference_server = '' - llama_type = False - compile_model = False - smodel, stokenizer, sdevice = get_model(reward_type=True, - **get_kwargs(get_model, exclude_names=['reward_type'], **locals())) - else: - smodel, stokenizer, sdevice = None, None, None - return smodel, stokenizer, sdevice - - -def evaluate( - model_state, - my_db_state, - selection_docs_state, - # START NOTE: Examples must have same order of parameters - instruction, - iinput, - context, - stream_output, - prompt_type, - prompt_dict, - temperature, - top_p, - top_k, - num_beams, - max_new_tokens, - min_new_tokens, - early_stopping, - max_time, - repetition_penalty, - num_return_sequences, - do_sample, - chat, - instruction_nochat, - iinput_nochat, - langchain_mode, - add_chat_history_to_context, - langchain_action, - langchain_agents, - top_k_docs, - chunk, - chunk_size, - document_subset, - document_choice, - # END NOTE: Examples must have same order of parameters - src_lang=None, - tgt_lang=None, - debug=False, - concurrency_count=None, - save_dir=None, - sanitize_bot_response=False, - model_state0=None, - langchain_modes0=None, - langchain_mode_paths0=None, - visible_langchain_modes0=None, - memory_restriction_level=None, - max_max_new_tokens=None, - is_public=None, - max_max_time=None, - raise_generate_gpu_exceptions=None, - chat_context=None, - lora_weights=None, - use_llm_if_no_docs=False, - load_db_if_exists=True, - dbs=None, - detect_user_path_changes_every_query=None, - use_openai_embedding=None, - use_openai_model=None, - hf_embedding_model=None, - cut_distance=None, - db_type=None, - n_jobs=None, - first_para=None, - text_limit=None, - verbose=False, - cli=False, - reverse_docs=True, - use_cache=None, - auto_reduce_chunks=None, - max_chunks=None, - model_lock=None, - force_langchain_evaluate=None, - model_state_none=None, -): - # ensure passed these - assert concurrency_count is not None - assert memory_restriction_level is not None - assert raise_generate_gpu_exceptions is not None - assert chat_context is not None - assert use_openai_embedding is not None - assert use_openai_model is not None - assert hf_embedding_model is not None - assert db_type is not None - assert top_k_docs is not None and isinstance(top_k_docs, int) - assert chunk is not None and isinstance(chunk, bool) - assert chunk_size is not None and isinstance(chunk_size, int) - assert n_jobs is not None - assert first_para is not None - assert isinstance(add_chat_history_to_context, bool) - - if selection_docs_state is not None: - langchain_modes = selection_docs_state.get('langchain_modes', langchain_modes0) - langchain_mode_paths = selection_docs_state.get('langchain_mode_paths', langchain_mode_paths0) - visible_langchain_modes = selection_docs_state.get('visible_langchain_modes', visible_langchain_modes0) - else: - langchain_modes = langchain_modes0 - langchain_mode_paths = langchain_mode_paths0 - visible_langchain_modes = visible_langchain_modes0 - - if debug: - locals_dict = locals().copy() - locals_dict.pop('model_state', None) - locals_dict.pop('model_state0', None) - locals_dict.pop('model_states', None) - print(locals_dict) - - no_model_msg = "Please choose a base model with --base_model (CLI) or load in Models Tab (gradio).\n" \ - "Then start New Conversation" - - if model_state is None: - model_state = model_state_none.copy() - if model_state0 is None: - # e.g. for no gradio case, set dummy value, else should be set - model_state0 = model_state_none.copy() - - # model_state['model] is only 'model' if should use model_state0 - # model could also be None - have_model_lock = model_lock is not None - have_fresh_model = model_state['model'] not in [None, 'model', no_model_str] - # for gradio UI control, expect model_state and model_state0 to match, so if have_model_lock=True, then should have_fresh_model=True - # but gradio API control will only use nochat api etc. and won't use fresh model, so can't assert in general - # if have_model_lock: - # assert have_fresh_model, "Expected model_state and model_state0 to match if have_model_lock" - have_cli_model = model_state0['model'] not in [None, 'model', no_model_str] - - if have_fresh_model: - # USE FRESH MODEL - if not have_model_lock: - # model_state0 is just one of model_state if model_lock, so don't nuke - # try to free-up original model (i.e. list was passed as reference) - if model_state0['model'] and hasattr(model_state0['model'], 'cpu'): - model_state0['model'].cpu() - model_state0['model'] = None - # try to free-up original tokenizer (i.e. list was passed as reference) - if model_state0['tokenizer']: - model_state0['tokenizer'] = None - clear_torch_cache() - chosen_model_state = model_state - elif have_cli_model: - # USE MODEL SETUP AT CLI - assert isinstance(model_state['model'], str) # expect no fresh model - chosen_model_state = model_state0 - else: - raise AssertionError(no_model_msg) - # get variables - model = chosen_model_state['model'] - tokenizer = chosen_model_state['tokenizer'] - device = chosen_model_state['device'] - base_model = chosen_model_state['base_model'] - tokenizer_base_model = chosen_model_state['tokenizer_base_model'] - lora_weights = chosen_model_state['lora_weights'] - inference_server = chosen_model_state['inference_server'] - # prefer use input from API over model state - prompt_type = prompt_type or chosen_model_state['prompt_type'] - prompt_dict = prompt_dict or chosen_model_state['prompt_dict'] - - if base_model is None: - raise AssertionError(no_model_msg) - - assert base_model.strip(), no_model_msg - assert model, "Model is missing" - assert tokenizer, "Tokenizer is missing" - - # choose chat or non-chat mode - if not chat: - instruction = instruction_nochat - iinput = iinput_nochat - - # in some cases, like lean nochat API, don't want to force sending prompt_type, allow default choice - model_lower = base_model.lower() - if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': - prompt_type = inv_prompt_type_to_model_lower[model_lower] - if verbose: - print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True) - assert prompt_type is not None, "prompt_type was None" - - # Control generation hyperparameters - # adjust for bad inputs, e.g. in case also come from API that doesn't get constrained by gradio sliders - # below is for TGI server, not required for HF transformers - # limits are chosen similar to gradio_runner.py sliders/numbers - top_p = min(max(1e-3, top_p), 1.0 - 1e-3) - top_k = min(max(1, int(top_k)), 100) - temperature = min(max(0.01, temperature), 2.0) - # FIXME: https://github.com/h2oai/h2ogpt/issues/106 - num_beams = 1 if stream_output else num_beams # See max_beams in gradio_runner - max_max_new_tokens = get_max_max_new_tokens(chosen_model_state, - memory_restriction_level=memory_restriction_level, - max_new_tokens=max_new_tokens, - max_max_new_tokens=max_max_new_tokens) - model_max_length = get_model_max_length(chosen_model_state) - max_new_tokens = min(max(1, int(max_new_tokens)), max_max_new_tokens) - min_new_tokens = min(max(0, int(min_new_tokens)), max_new_tokens) - max_time = min(max(0, max_time), max_max_time) - repetition_penalty = min(max(0.01, repetition_penalty), 3.0) - num_return_sequences = 1 if chat else min(max(1, int(num_return_sequences)), 10) - min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public) - top_k_docs = min(max(min_top_k_docs, int(top_k_docs)), max_top_k_docs) - chunk_size = min(max(128, int(chunk_size)), 2048) - if not context: - # get hidden context if have one - context = get_context(chat_context, prompt_type) - - # restrict instruction, typically what has large input - from h2oai_pipeline import H2OTextGenerationPipeline - instruction, num_prompt_tokens1 = H2OTextGenerationPipeline.limit_prompt(instruction, tokenizer) - context, num_prompt_tokens2 = H2OTextGenerationPipeline.limit_prompt(context, tokenizer) - iinput, num_prompt_tokens3 = H2OTextGenerationPipeline.limit_prompt(iinput, tokenizer) - num_prompt_tokens = (num_prompt_tokens1 or 0) + (num_prompt_tokens2 or 0) + (num_prompt_tokens3 or 0) - - # get prompt - prompter = Prompter(prompt_type, prompt_dict, debug=debug, chat=chat, stream_output=stream_output) - data_point = dict(context=context, instruction=instruction, input=iinput) - prompt = prompter.generate_prompt(data_point) - - # THIRD PLACE where LangChain referenced, but imports only occur if enabled and have db to use - assert langchain_mode in langchain_modes, "Invalid langchain_mode %s" % langchain_mode - assert langchain_action in langchain_actions, "Invalid langchain_action %s" % langchain_action - assert len( - set(langchain_agents).difference(langchain_agents_list)) == 0, "Invalid langchain_agents %s" % langchain_agents - if dbs is not None and langchain_mode in dbs: - db = dbs[langchain_mode] - elif my_db_state is not None and langchain_mode in my_db_state: - db1 = my_db_state[langchain_mode] - if db1 is not None and len(db1) == 2: - db = db1[0] - else: - db = None - else: - db = None - do_langchain_path = langchain_mode not in [False, 'Disabled', 'LLM'] or \ - base_model in non_hf_types or \ - force_langchain_evaluate - if do_langchain_path: - outr = "" - # use smaller cut_distance for wiki_full since so many matches could be obtained, and often irrelevant unless close - from gpt_langchain import run_qa_db - gen_hyper_langchain = dict(do_sample=do_sample, - temperature=temperature, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - min_new_tokens=min_new_tokens, - max_new_tokens=max_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - num_return_sequences=num_return_sequences, - ) - for r in run_qa_db(query=instruction, - iinput=iinput, - context=context, - model_name=base_model, model=model, tokenizer=tokenizer, - inference_server=inference_server, - stream_output=stream_output, - prompter=prompter, - use_llm_if_no_docs=use_llm_if_no_docs, - load_db_if_exists=load_db_if_exists, - db=db, - langchain_mode_paths=langchain_mode_paths, - detect_user_path_changes_every_query=detect_user_path_changes_every_query, - cut_distance=1.1 if langchain_mode in ['wiki_full'] else cut_distance, - add_chat_history_to_context=add_chat_history_to_context, - use_openai_embedding=use_openai_embedding, - use_openai_model=use_openai_model, - hf_embedding_model=hf_embedding_model, - first_para=first_para, - text_limit=text_limit, - chunk=chunk, - chunk_size=chunk_size, - langchain_mode=langchain_mode, - langchain_action=langchain_action, - langchain_agents=langchain_agents, - document_subset=document_subset, - document_choice=document_choice, - db_type=db_type, - top_k_docs=top_k_docs, - - **gen_hyper_langchain, - - prompt_type=prompt_type, - prompt_dict=prompt_dict, - n_jobs=n_jobs, - verbose=verbose, - cli=cli, - sanitize_bot_response=sanitize_bot_response, - reverse_docs=reverse_docs, - - lora_weights=lora_weights, - - auto_reduce_chunks=auto_reduce_chunks, - max_chunks=max_chunks, - ): - outr, extra = r # doesn't accumulate, new answer every yield, so only save that full answer - yield dict(response=outr, sources=extra) - if save_dir: - extra_dict = gen_hyper_langchain.copy() - extra_dict.update(prompt_type=prompt_type, - inference_server=inference_server, - langchain_mode=langchain_mode, - langchain_action=langchain_action, - langchain_agents=langchain_agents, - document_subset=document_subset, - document_choice=document_choice, - num_prompt_tokens=num_prompt_tokens, - instruction=instruction, - iinput=iinput, - context=context, - ) - save_generate_output(prompt=prompt, - output=outr, base_model=base_model, save_dir=save_dir, - where_from='run_qa_db', - extra_dict=extra_dict) - if verbose: - print( - 'Post-Generate Langchain: %s decoded_output: %s' % (str(datetime.now()), len(outr) if outr else -1), - flush=True) - if outr or base_model in non_hf_types: - # if got no response (e.g. not showing sources and got no sources, - # so nothing to give to LLM), then slip through and ask LLM - # Or if llama/gptj, then just return since they had no response and can't go down below code path - # clear before return, since .then() never done if from API - clear_torch_cache() - return - - if inference_server.startswith('vllm') or inference_server.startswith('openai') or inference_server.startswith( - 'http'): - if inference_server.startswith('vllm') or inference_server.startswith('openai'): - where_from = "openai_client" - openai, inf_type = set_openai(inference_server) - - terminate_response = prompter.terminate_response or [] - stop_sequences = list(set(terminate_response + [prompter.PreResponse])) - stop_sequences = [x for x in stop_sequences if x] - # OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so. - max_new_tokens_openai = min(max_new_tokens, model_max_length - num_prompt_tokens) - gen_server_kwargs = dict(temperature=temperature if do_sample else 0, - max_tokens=max_new_tokens_openai, - top_p=top_p if do_sample else 1, - frequency_penalty=0, - n=num_return_sequences, - presence_penalty=1.07 - repetition_penalty + 0.6, # so good default - ) - if inf_type == 'vllm' or inference_server == 'openai': - response = openai.Completion.create( - model=base_model, - prompt=prompt, - **gen_server_kwargs, - stop=stop_sequences, - stream=stream_output, - ) - if not stream_output: - text = response['choices'][0]['text'] - yield dict(response=prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources='') - else: - collected_events = [] - text = '' - for event in response: - collected_events.append(event) # save the event response - event_text = event['choices'][0]['text'] # extract the text - text += event_text # append the text - yield dict(response=prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources='') - elif inf_type == 'vllm_chat' or inference_server == 'openai_chat': - if inf_type == 'vllm_chat': - raise NotImplementedError('%s not supported by vLLM' % inf_type) - response = openai.ChatCompletion.create( - model=base_model, - messages=[ - {"role": "system", "content": "You are a helpful assistant."}, - {'role': 'user', - 'content': prompt, - } - ], - stream=stream_output, - **gen_server_kwargs, - ) - if not stream_output: - text = response["choices"][0]["message"]["content"] - yield dict(response=prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources='') - else: - text = "" - for chunk in response: - delta = chunk["choices"][0]["delta"] - if 'content' in delta: - text += delta['content'] - yield dict(response=prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources='') - else: - raise RuntimeError("No such OpenAI mode: %s" % inference_server) - elif inference_server.startswith('http'): - inference_server, headers = get_hf_server(inference_server) - from gradio_utils.grclient import GradioClient - from text_generation import Client as HFClient - if isinstance(model, GradioClient): - gr_client = model - hf_client = None - elif isinstance(model, HFClient): - gr_client = None - hf_client = model - else: - inference_server, gr_client, hf_client = get_client_from_inference_server(inference_server, - base_model=base_model) - - # quick sanity check to avoid long timeouts, just see if can reach server - requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) - - if gr_client is not None: - # Note: h2oGPT gradio server could handle input token size issues for prompt, - # but best to handle here so send less data to server - - chat_client = False - where_from = "gr_client" - client_langchain_mode = 'Disabled' - client_add_chat_history_to_context = True - client_langchain_action = LangChainAction.QUERY.value - client_langchain_agents = [] - gen_server_kwargs = dict(temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - chat=chat_client, - ) - # account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection - if prompt_type in [None, '', PromptType.plain.name, PromptType.plain.value, - str(PromptType.plain.value)]: - # if our prompt is plain, assume either correct or gradio server knows different prompt type, - # so pass empty prompt_Type - gr_prompt_type = '' - gr_prompt_dict = '' - gr_prompt = prompt # already prepared prompt - gr_context = '' - gr_iinput = '' - else: - # if already have prompt_type that is not plain, None, or '', then already applied some prompting - # But assume server can handle prompting, and need to avoid double-up. - # Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle - # So avoid "prompt" and let gradio server reconstruct from prompt_type we passed - # Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed, - # because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter - # since those won't appear - gr_context = context - gr_prompt = instruction - gr_iinput = iinput - gr_prompt_type = prompt_type - gr_prompt_dict = prompt_dict - client_kwargs = dict(instruction=gr_prompt if chat_client else '', # only for chat=True - iinput=gr_iinput, # only for chat=True - context=gr_context, - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - - **gen_server_kwargs, - - prompt_type=gr_prompt_type, - prompt_dict=gr_prompt_dict, - - instruction_nochat=gr_prompt if not chat_client else '', - iinput_nochat=gr_iinput, # only for chat=False - langchain_mode=client_langchain_mode, - add_chat_history_to_context=client_add_chat_history_to_context, - langchain_action=client_langchain_action, - langchain_agents=client_langchain_agents, - top_k_docs=top_k_docs, - chunk=chunk, - chunk_size=chunk_size, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - ) - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - if not stream_output: - res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) - res_dict = ast.literal_eval(res) - text = res_dict['response'] - sources = res_dict['sources'] - yield dict(response=prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources=sources) - else: - job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) - text = '' - sources = '' - res_dict = dict(response=text, sources=sources) - while not job.done(): - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res_dict = ast.literal_eval(res) - text = res_dict['response'] - sources = res_dict['sources'] - if gr_prompt_type == 'plain': - # then gradio server passes back full prompt + text - prompt_and_text = text - else: - prompt_and_text = prompt + text - yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources=sources) - time.sleep(0.01) - # ensure get last output to avoid race - res_all = job.outputs() - if len(res_all) > 0: - res = res_all[-1] - res_dict = ast.literal_eval(res) - text = res_dict['response'] - sources = res_dict['sources'] - else: - # go with old text if last call didn't work - e = job.future._exception - if e is not None: - stre = str(e) - strex = ''.join(traceback.format_tb(e.__traceback__)) - else: - stre = '' - strex = '' - - print("Bad final response: %s %s %s %s %s: %s %s" % (base_model, inference_server, - res_all, prompt, text, stre, strex), - flush=True) - if gr_prompt_type == 'plain': - # then gradio server passes back full prompt + text - prompt_and_text = text - else: - prompt_and_text = prompt + text - yield dict(response=prompter.get_response(prompt_and_text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources=sources) - elif hf_client: - # HF inference server needs control over input tokens - where_from = "hf_client" - - # prompt must include all human-bot like tokens, already added by prompt - # https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types - terminate_response = prompter.terminate_response or [] - stop_sequences = list(set(terminate_response + [prompter.PreResponse])) - stop_sequences = [x for x in stop_sequences if x] - gen_server_kwargs = dict(do_sample=do_sample, - max_new_tokens=max_new_tokens, - # best_of=None, - repetition_penalty=repetition_penalty, - return_full_text=True, - seed=SEED, - stop_sequences=stop_sequences, - temperature=temperature, - top_k=top_k, - top_p=top_p, - # truncate=False, # behaves oddly - # typical_p=top_p, - # watermark=False, - # decoder_input_details=False, - ) - # work-around for timeout at constructor time, will be issue if multi-threading, - # so just do something reasonable or max_time if larger - # lower bound because client is re-used if multi-threading - hf_client.timeout = max(300, max_time) - if not stream_output: - text = hf_client.generate(prompt, **gen_server_kwargs).generated_text - yield dict(response=prompter.get_response(text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources='') - else: - text = "" - for response in hf_client.generate_stream(prompt, **gen_server_kwargs): - if not response.token.special: - # stop_sequences - text_chunk = response.token.text - text += text_chunk - yield dict(response=prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=sanitize_bot_response), - sources='') - else: - raise RuntimeError("Failed to get client: %s" % inference_server) - else: - raise RuntimeError("No such inference_server %s" % inference_server) - - if save_dir and text: - # save prompt + new text - extra_dict = gen_server_kwargs.copy() - extra_dict.update(dict(inference_server=inference_server, num_prompt_tokens=num_prompt_tokens)) - save_generate_output(prompt=prompt, output=text, base_model=base_model, save_dir=save_dir, - where_from=where_from, extra_dict=extra_dict) - return - else: - assert not inference_server, "inferene_server=%s not supported" % inference_server - - if isinstance(tokenizer, str): - # pipeline - if tokenizer == "summarization": - key = 'summary_text' - else: - raise RuntimeError("No such task type %s" % tokenizer) - # NOTE: uses max_length only - yield dict(response=model(prompt, max_length=max_new_tokens)[0][key], sources='') - - if 'mbart-' in base_model.lower(): - assert src_lang is not None - tokenizer.src_lang = languages_covered()[src_lang] - - stopping_criteria = get_stopping(prompt_type, prompt_dict, tokenizer, device, - model_max_length=tokenizer.model_max_length) - - inputs = tokenizer(prompt, return_tensors="pt") - if debug and len(inputs["input_ids"]) > 0: - print('input_ids length', len(inputs["input_ids"][0]), flush=True) - input_ids = inputs["input_ids"].to(device) - # CRITICAL LIMIT else will fail - max_max_tokens = tokenizer.model_max_length - max_input_tokens = max_max_tokens - min_new_tokens - # NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py - input_ids = input_ids[:, -max_input_tokens:] - # required for falcon if multiple threads or asyncio accesses to model during generation - if use_cache is None: - use_cache = False if 'falcon' in base_model else True - gen_config_kwargs = dict(temperature=float(temperature), - top_p=float(top_p), - top_k=top_k, - num_beams=num_beams, - do_sample=do_sample, - repetition_penalty=float(repetition_penalty), - num_return_sequences=num_return_sequences, - renormalize_logits=True, - remove_invalid_values=True, - use_cache=use_cache, - ) - token_ids = ['eos_token_id', 'pad_token_id', 'bos_token_id', 'cls_token_id', 'sep_token_id'] - for token_id in token_ids: - if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None: - gen_config_kwargs.update({token_id: getattr(tokenizer, token_id)}) - generation_config = GenerationConfig(**gen_config_kwargs) - - gen_kwargs = dict(input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, # prompt + new - min_new_tokens=min_new_tokens, # prompt + new - early_stopping=early_stopping, # False, True, "never" - max_time=max_time, - stopping_criteria=stopping_criteria, - ) - if 'gpt2' in base_model.lower(): - gen_kwargs.update(dict(bos_token_id=tokenizer.bos_token_id, pad_token_id=tokenizer.eos_token_id)) - elif 'mbart-' in base_model.lower(): - assert tgt_lang is not None - tgt_lang = languages_covered()[tgt_lang] - gen_kwargs.update(dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang])) - else: - token_ids = ['eos_token_id', 'bos_token_id', 'pad_token_id'] - for token_id in token_ids: - if hasattr(tokenizer, token_id) and getattr(tokenizer, token_id) is not None: - gen_kwargs.update({token_id: getattr(tokenizer, token_id)}) - - decoder_kwargs = dict(skip_special_tokens=True, - clean_up_tokenization_spaces=True) - - decoder = functools.partial(tokenizer.decode, - **decoder_kwargs - ) - decoder_raw_kwargs = dict(skip_special_tokens=False, - clean_up_tokenization_spaces=True) - - decoder_raw = functools.partial(tokenizer.decode, - **decoder_raw_kwargs - ) - - with torch.no_grad(): - have_lora_weights = lora_weights not in [no_lora_str, '', None] - context_class_cast = NullContext if device == 'cpu' or have_lora_weights or device == 'mps' else torch.autocast - with context_class_cast(device): - # protection for gradio not keeping track of closed users, - # else hit bitsandbytes lack of thread safety: - # https://github.com/h2oai/h2ogpt/issues/104 - # but only makes sense if concurrency_count == 1 - context_class = NullContext # if concurrency_count > 1 else filelock.FileLock - if verbose: - print('Pre-Generate: %s' % str(datetime.now()), flush=True) - decoded_output = None - with context_class("generate.lock"): - if verbose: - print('Generate: %s' % str(datetime.now()), flush=True) - # decoded tokenized prompt can deviate from prompt due to special characters - inputs_decoded = decoder(input_ids[0]) - inputs_decoded_raw = decoder_raw(input_ids[0]) - if inputs_decoded == prompt: - # normal - pass - elif inputs_decoded.lstrip() == prompt.lstrip(): - # sometimes extra space in front, make prompt same for prompt removal - prompt = inputs_decoded - elif inputs_decoded_raw == prompt: - # some models specify special tokens that are part of normal prompt, so can't skip them - inputs_decoded = prompt = inputs_decoded_raw - decoder = decoder_raw - decoder_kwargs = decoder_raw_kwargs - elif inputs_decoded_raw.replace(" ", "").replace("", "").replace('\n', ' ').replace(' ', - '') == prompt.replace( - '\n', ' ').replace(' ', ''): - inputs_decoded = prompt = inputs_decoded_raw - decoder = decoder_raw - decoder_kwargs = decoder_raw_kwargs - else: - if verbose: - print("WARNING: Special characters in prompt", flush=True) - if stream_output: - skip_prompt = False - streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, - **decoder_kwargs) - gen_kwargs.update(dict(streamer=streamer)) - target = wrapped_partial(generate_with_exceptions, model.generate, - prompt=prompt, inputs_decoded=inputs_decoded, - raise_generate_gpu_exceptions=raise_generate_gpu_exceptions, - **gen_kwargs) - bucket = queue.Queue() - thread = EThread(target=target, streamer=streamer, bucket=bucket) - thread.start() - outputs = "" - try: - for new_text in streamer: - if bucket.qsize() > 0 or thread.exc: - thread.join() - outputs += new_text - yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded, - sanitize_bot_response=sanitize_bot_response), - sources='') - except BaseException: - # if any exception, raise that exception if was from thread, first - if thread.exc: - raise thread.exc - raise - finally: - # clear before return, since .then() never done if from API - clear_torch_cache() - # in case no exception and didn't join with thread yet, then join - if not thread.exc: - thread.join() - # in case raise StopIteration or broke queue loop in streamer, but still have exception - if thread.exc: - raise thread.exc - decoded_output = outputs - else: - try: - outputs = model.generate(**gen_kwargs) - finally: - clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called - outputs = [decoder(s) for s in outputs.sequences] - yield dict(response=prompter.get_response(outputs, prompt=inputs_decoded, - sanitize_bot_response=sanitize_bot_response), sources='') - if outputs and len(outputs) >= 1: - decoded_output = prompt + outputs[0] - if save_dir and decoded_output: - extra_dict = gen_config_kwargs.copy() - extra_dict.update(dict(num_prompt_tokens=num_prompt_tokens)) - save_generate_output(prompt=prompt, output=decoded_output, base_model=base_model, save_dir=save_dir, - where_from="evaluate_%s" % str(stream_output), - extra_dict=gen_config_kwargs) - if verbose: - print('Post-Generate: %s decoded_output: %s' % ( - str(datetime.now()), len(decoded_output) if decoded_output else -1), flush=True) - - -inputs_list_names = list(inspect.signature(evaluate).parameters) -state_names = ['model_state', 'my_db_state', 'selection_docs_state'] -inputs_kwargs_list = [x for x in inputs_list_names if x not in eval_func_param_names + state_names] - - -def get_cutoffs(memory_restriction_level, for_context=False, model_max_length=2048): - # help to avoid errors like: - # RuntimeError: The size of tensor a (2048) must match the size of tensor b (2049) at non-singleton dimension 3 - # RuntimeError: expected scalar type Half but found Float - # with - 256 - if memory_restriction_level > 0: - max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 - else: - # at least give room for 1 paragraph output - max_length_tokenize = model_max_length - 256 - cutoff_len = max_length_tokenize * 4 # if reaches limit, then can't generate new tokens - output_smallest = 30 * 4 - max_prompt_length = cutoff_len - output_smallest - - if for_context: - # then lower even more to avoid later chop, since just estimate tokens in context bot - max_prompt_length = max(64, int(max_prompt_length * 0.8)) - - return cutoff_len, output_smallest, max_length_tokenize, max_prompt_length - - -class H2OTextIteratorStreamer(TextIteratorStreamer): - """ - normally, timeout required for now to handle exceptions, else get() - but with H2O version of TextIteratorStreamer, loop over block to handle - """ - - def __init__(self, tokenizer, skip_prompt: bool = False, timeout: typing.Optional[float] = None, - block=True, **decode_kwargs): - super().__init__(tokenizer, skip_prompt, **decode_kwargs) - self.text_queue = queue.Queue() - self.stop_signal = None - self.do_stop = False - self.timeout = timeout - self.block = block - - def on_finalized_text(self, text: str, stream_end: bool = False): - """Put the new text in the queue. If the stream is ending, also put a stop signal in the queue.""" - self.text_queue.put(text, timeout=self.timeout) - if stream_end: - self.text_queue.put(self.stop_signal, timeout=self.timeout) - - def __iter__(self): - return self - - def __next__(self): - while True: - try: - value = self.stop_signal # value looks unused in pycharm, not true - if self.do_stop: - print("hit stop", flush=True) - # could raise or break, maybe best to raise and make parent see if any exception in thread - self.clear_queue() - self.do_stop = False - raise StopIteration() - # break - value = self.text_queue.get(block=self.block, timeout=self.timeout) - break - except queue.Empty: - time.sleep(0.01) - if value == self.stop_signal: - self.clear_queue() - self.do_stop = False - raise StopIteration() - else: - return value - - def clear_queue(self): - # make sure streamer is reusable after stop hit - with self.text_queue.mutex: - self.text_queue.queue.clear() - - -def generate_with_exceptions(func, *args, prompt='', inputs_decoded='', raise_generate_gpu_exceptions=True, **kwargs): - try: - func(*args, **kwargs) - except torch.cuda.OutOfMemoryError as e: - print("GPU OOM 2: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), - flush=True) - if 'input_ids' in kwargs: - if kwargs['input_ids'] is not None: - kwargs['input_ids'].cpu() - kwargs['input_ids'] = None - traceback.print_exc() - clear_torch_cache() - return - except (Exception, RuntimeError) as e: - if 'Expected all tensors to be on the same device' in str(e) or \ - 'expected scalar type Half but found Float' in str(e) or \ - 'probability tensor contains either' in str(e) or \ - 'cublasLt ran into an error!' in str(e) or \ - 'mat1 and mat2 shapes cannot be multiplied' in str(e): - print( - "GPU Error: prompt: %s inputs_decoded: %s exception: %s" % (prompt, inputs_decoded, str(e)), - flush=True) - traceback.print_exc() - clear_torch_cache() - if raise_generate_gpu_exceptions: - raise - return - else: - clear_torch_cache() - if raise_generate_gpu_exceptions: - raise - - -def get_generate_params(model_lower, - chat, - stream_output, show_examples, - prompt_type, prompt_dict, - temperature, top_p, top_k, num_beams, - max_new_tokens, min_new_tokens, early_stopping, max_time, - repetition_penalty, num_return_sequences, - do_sample, - top_k_docs, chunk, chunk_size, - verbose): - use_defaults = False - use_default_examples = True - examples = [] - task_info = 'LLM' - if model_lower: - print(f"Using Model {model_lower}", flush=True) - else: - if verbose: - print("No model defined yet", flush=True) - - min_new_tokens = min_new_tokens if min_new_tokens is not None else 0 - early_stopping = early_stopping if early_stopping is not None else False - max_time_defaults = 60 * 3 - max_time = max_time if max_time is not None else max_time_defaults - - if not prompt_type and model_lower in inv_prompt_type_to_model_lower and prompt_type != 'custom': - prompt_type = inv_prompt_type_to_model_lower[model_lower] - if verbose: - print("Auto-selecting prompt_type=%s for %s" % (prompt_type, model_lower), flush=True) - - # examples at first don't include chat, instruction_nochat, iinput_nochat, added at end - if show_examples is None: - if chat: - show_examples = False - else: - show_examples = True - - summarize_example1 = """Jeff: Can I train a ? Transformers model on Amazon SageMaker? -Philipp: Sure you can use the new Hugging Face Deep Learning Container. -Jeff: ok. -Jeff: and how can I get started? -Jeff: where can I find documentation? -Philipp: ok, ok you can find everything here. https://huggingface.co/blog/the-partnership-amazon-sagemaker-and-hugging-face""" - - use_placeholder_instruction_as_example = False - if 'bart-large-cnn-samsum' in model_lower or 'flan-t5-base-samsum' in model_lower: - placeholder_instruction = summarize_example1 - placeholder_input = "" - use_defaults = True - use_default_examples = False - use_placeholder_instruction_as_example = True - task_info = "Summarization" - elif 't5-' in model_lower or 't5' == model_lower or 'flan-' in model_lower: - placeholder_instruction = "The square root of x is the cube root of y. What is y to the power of 2, if x = 4?" - placeholder_input = "" - use_defaults = True - use_default_examples = True - task_info = "Multi-Task: Q/A, translation, Chain-of-Thought, Logical Reasoning, Summarization, etc. Best to use task prefix as trained on, e.g. `translate English to German: ` (space after colon)" - elif 'mbart-' in model_lower: - placeholder_instruction = "The girl has long hair." - placeholder_input = "" - use_defaults = True - use_default_examples = False - use_placeholder_instruction_as_example = True - elif 'gpt2' in model_lower: - placeholder_instruction = "The sky is" - placeholder_input = "" - prompt_type = prompt_type or 'plain' - use_default_examples = True # some will be odd "continuations" but can be ok - use_placeholder_instruction_as_example = True - task_info = "Auto-complete phrase, code, etc." - use_defaults = True - else: - if chat: - placeholder_instruction = "" - else: - placeholder_instruction = "Give detailed answer for whether Einstein or Newton is smarter." - placeholder_input = "" - if model_lower in inv_prompt_type_to_model_lower: - if prompt_type != 'custom': - prompt_type = inv_prompt_type_to_model_lower[model_lower] - elif model_lower: - # default is plain, because might rely upon trust_remote_code to handle prompting - prompt_type = prompt_type or 'plain' - else: - prompt_type = '' - task_info = "No task" - if prompt_type == 'instruct': - task_info = "Answer question or follow imperative as instruction with optionally input." - elif prompt_type == 'plain': - task_info = "Auto-complete phrase, code, etc." - elif prompt_type == 'human_bot': - if chat: - task_info = "Chat (Shift-Enter to give question/imperative, input concatenated with instruction)" - else: - task_info = "Ask question/imperative (input concatenated with instruction)" - - # revert to plain if still nothing - prompt_type = prompt_type or 'plain' - if use_defaults: - temperature = 1.0 if temperature is None else temperature - top_p = 1.0 if top_p is None else top_p - top_k = 40 if top_k is None else top_k - num_beams = num_beams or 1 - max_new_tokens = max_new_tokens or 128 - repetition_penalty = repetition_penalty or 1.07 - num_return_sequences = min(num_beams, num_return_sequences or 1) - do_sample = False if do_sample is None else do_sample - else: - temperature = 0.1 if temperature is None else temperature - top_p = 0.75 if top_p is None else top_p - top_k = 40 if top_k is None else top_k - num_beams = num_beams or 1 - max_new_tokens = max_new_tokens or 256 - repetition_penalty = repetition_penalty or 1.07 - num_return_sequences = min(num_beams, num_return_sequences or 1) - do_sample = False if do_sample is None else do_sample - # doesn't include chat, instruction_nochat, iinput_nochat, added later - params_list = ["", - stream_output, - prompt_type, prompt_dict, - temperature, top_p, top_k, num_beams, - max_new_tokens, min_new_tokens, - early_stopping, max_time, repetition_penalty, num_return_sequences, do_sample] - - if use_placeholder_instruction_as_example: - examples += [[placeholder_instruction, ''] + params_list] - - if use_default_examples: - examples += [ - ["Translate English to French", "Good morning"] + params_list, - ["Give detailed answer for whether Einstein or Newton is smarter.", ''] + params_list, - ["Explain in detailed list, all the best practices for coding in python.", ''] + params_list, - [ - "Create a markdown table with 3 rows for the primary colors, and 2 columns, with color name and hex codes.", - ''] + params_list, - ['Translate to German: My name is Arthur', ''] + params_list, - ["Please answer to the following question. Who is going to be the next Ballon d'or?", ''] + params_list, - ['Can Geoffrey Hinton have a conversation with George Washington? Give the rationale before answering.', - ''] + params_list, - ['Please answer the following question. What is the boiling point of Nitrogen?', ''] + params_list, - ['Answer the following yes/no question. Can you write a whole Haiku in a single tweet?', ''] + params_list, - ["Simplify the following expression: (False or False and True). Explain your answer.", ''] + params_list, - [ - "Premise: At my age you will probably have learnt one lesson. Hypothesis: It's not certain how many lessons you'll learn by your thirties. Does the premise entail the hypothesis?", - ''] + params_list, - ['The square root of x is the cube root of y. What is y to the power of 2, if x = 4?', ''] + params_list, - [ - 'Answer the following question by reasoning step by step. The cafeteria had 23 apples. If they used 20 for lunch, and bought 6 more, how many apple do they have?', - ''] + params_list, - ["""def area_of_rectangle(a: float, b: float): - \"\"\"Return the area of the rectangle.\"\"\"""", ''] + params_list, - ["""# a function in native python: -def mean(a): - return sum(a)/len(a) - -# the same function using numpy: -import numpy as np -def mean(a):""", ''] + params_list, - ["""X = np.random.randn(100, 100) -y = np.random.randint(0, 1, 100) - -# fit random forest classifier with 20 estimators""", ''] + params_list, - ] - # add summary example - examples += [ - [summarize_example1, 'Summarize' if prompt_type not in ['plain', 'instruct_simple'] else ''] + params_list] - - src_lang = "English" - tgt_lang = "Russian" - - # move to correct position - for example in examples: - example += [chat, '', '', LangChainMode.DISABLED.value, True, LangChainAction.QUERY.value, [], - top_k_docs, chunk, chunk_size, DocumentSubset.Relevant.name, [] - ] - # adjust examples if non-chat mode - if not chat: - example[eval_func_param_names.index('instruction_nochat')] = example[ - eval_func_param_names.index('instruction')] - example[eval_func_param_names.index('instruction')] = '' - - example[eval_func_param_names.index('iinput_nochat')] = example[eval_func_param_names.index('iinput')] - example[eval_func_param_names.index('iinput')] = '' - assert len(example) == len(eval_func_param_names), "Wrong example: %s %s" % ( - len(example), len(eval_func_param_names)) - - if prompt_type == PromptType.custom.name and not prompt_dict: - raise ValueError("Unexpected to get non-empty prompt_dict=%s for prompt_type=%s" % (prompt_dict, prompt_type)) - - # get prompt_dict from prompt_type, so user can see in UI etc., or for custom do nothing except check format - prompt_dict, error0 = get_prompt(prompt_type, prompt_dict, - chat=False, context='', reduced=False, making_context=False, return_dict=True) - if error0: - raise RuntimeError("Prompt wrong: %s" % error0) - - return placeholder_instruction, placeholder_input, \ - stream_output, show_examples, \ - prompt_type, prompt_dict, \ - temperature, top_p, top_k, num_beams, \ - max_new_tokens, min_new_tokens, early_stopping, max_time, \ - repetition_penalty, num_return_sequences, \ - do_sample, \ - src_lang, tgt_lang, \ - examples, \ - task_info - - -def languages_covered(): - # https://huggingface.co/facebook/mbart-large-50-many-to-many-mmt#languages-covered - covered = """Arabic (ar_AR), Czech (cs_CZ), German (de_DE), English (en_XX), Spanish (es_XX), Estonian (et_EE), Finnish (fi_FI), French (fr_XX), Gujarati (gu_IN), Hindi (hi_IN), Italian (it_IT), Japanese (ja_XX), Kazakh (kk_KZ), Korean (ko_KR), Lithuanian (lt_LT), Latvian (lv_LV), Burmese (my_MM), Nepali (ne_NP), Dutch (nl_XX), Romanian (ro_RO), Russian (ru_RU), Sinhala (si_LK), Turkish (tr_TR), Vietnamese (vi_VN), Chinese (zh_CN), Afrikaans (af_ZA), Azerbaijani (az_AZ), Bengali (bn_IN), Persian (fa_IR), Hebrew (he_IL), Croatian (hr_HR), Indonesian (id_ID), Georgian (ka_GE), Khmer (km_KH), Macedonian (mk_MK), Malayalam (ml_IN), Mongolian (mn_MN), Marathi (mr_IN), Polish (pl_PL), Pashto (ps_AF), Portuguese (pt_XX), Swedish (sv_SE), Swahili (sw_KE), Tamil (ta_IN), Telugu (te_IN), Thai (th_TH), Tagalog (tl_XX), Ukrainian (uk_UA), Urdu (ur_PK), Xhosa (xh_ZA), Galician (gl_ES), Slovene (sl_SI)""" - covered = covered.split(', ') - covered = {x.split(' ')[0]: x.split(' ')[1].replace(')', '').replace('(', '') for x in covered} - return covered - - -def get_context(chat_context, prompt_type): - if chat_context and prompt_type == 'human_bot': - context0 = """: I am an intelligent, helpful, truthful, and fair assistant named h2oGPT, who will give accurate, balanced, and reliable responses. I will not respond with I don't know or I don't understand. -: I am a human person seeking useful assistance and request all questions be answered completely, and typically expect detailed responses. Give answers in numbered list format if several distinct but related items are being listed.""" - else: - context0 = '' - return context0 - - -def score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len): - question = question[-cutoff_len:] - answer = answer[-cutoff_len:] - - inputs = stokenizer(question, answer, - return_tensors="pt", - truncation=True, - max_length=max_length_tokenize).to(smodel.device) - try: - score = torch.sigmoid(smodel(**inputs.to(smodel.device)).logits[0].float()).cpu().detach().numpy()[0] - except torch.cuda.OutOfMemoryError as e: - print("GPU OOM 3: question: %s answer: %s exception: %s" % (question, answer, str(e)), flush=True) - del inputs - traceback.print_exc() - clear_torch_cache() - return 'Response Score: GPU OOM' - except (Exception, RuntimeError) as e: - if 'Expected all tensors to be on the same device' in str(e) or \ - 'expected scalar type Half but found Float' in str(e) or \ - 'probability tensor contains either' in str(e) or \ - 'cublasLt ran into an error!' in str(e) or \ - 'device-side assert triggered' in str(e): - print("GPU Error: question: %s answer: %s exception: %s" % (question, answer, str(e)), - flush=True) - traceback.print_exc() - clear_torch_cache() - return 'Response Score: GPU Error' - else: - raise - os.environ['TOKENIZERS_PARALLELISM'] = 'true' - return score - - -def check_locals(**kwargs): - # ensure everything in evaluate is here - can_skip_because_locally_generated = no_default_param_names + [ - # get_model: - 'reward_type' - ] - for k in eval_func_param_names: - if k in can_skip_because_locally_generated: - continue - assert k in kwargs, "Missing %s" % k - for k in inputs_kwargs_list: - if k in can_skip_because_locally_generated: - continue - assert k in kwargs, "Missing %s" % k - - for k in list(inspect.signature(get_model).parameters): - if k in can_skip_because_locally_generated: - continue - assert k in kwargs, "Missing %s" % k - - -def get_model_max_length(model_state): - if not isinstance(model_state['tokenizer'], (str, type(None))): - return model_state['tokenizer'].model_max_length - else: - return 2048 - - -def get_max_max_new_tokens(model_state, **kwargs): - if not isinstance(model_state['tokenizer'], (str, type(None))): - max_max_new_tokens = model_state['tokenizer'].model_max_length - else: - max_max_new_tokens = None - - if kwargs['max_max_new_tokens'] is not None and max_max_new_tokens is not None: - return min(max_max_new_tokens, kwargs['max_max_new_tokens']) - elif kwargs['max_max_new_tokens'] is not None: - return kwargs['max_max_new_tokens'] - elif kwargs['memory_restriction_level'] == 1: - return 768 - elif kwargs['memory_restriction_level'] == 2: - return 512 - elif kwargs['memory_restriction_level'] >= 3: - return 256 - else: - # FIXME: Need to update after new model loaded, so user can control with slider - return 2048 - - -def get_minmax_top_k_docs(is_public): - if is_public: - min_top_k_docs = 1 - max_top_k_docs = 3 - label_top_k_docs = "Number of document chunks" - else: - min_top_k_docs = -1 - max_top_k_docs = 100 - label_top_k_docs = "Number of document chunks (-1 = auto fill model context)" - return min_top_k_docs, max_top_k_docs, label_top_k_docs - - -def history_to_context(history, langchain_mode1, - add_chat_history_to_context, - prompt_type1, prompt_dict1, chat1, model_max_length1, - memory_restriction_level1, keep_sources_in_context1): - """ - consumes all history up to (but not including) latest history item that is presumed to be an [instruction, None] pair - :param history: - :param langchain_mode1: - :param add_chat_history_to_context: - :param prompt_type1: - :param prompt_dict1: - :param chat1: - :param model_max_length1: - :param memory_restriction_level1: - :param keep_sources_in_context1: - :return: - """ - # ensure output will be unique to models - _, _, _, max_prompt_length = get_cutoffs(memory_restriction_level1, - for_context=True, model_max_length=model_max_length1) - context1 = '' - if max_prompt_length is not None and add_chat_history_to_context: - context1 = '' - # - 1 below because current instruction already in history from user() - for histi in range(0, len(history) - 1): - data_point = dict(instruction=history[histi][0], input='', output=history[histi][1]) - prompt, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt(data_point, - prompt_type1, - prompt_dict1, - chat1, - reduced=True, - making_context=True) - # md -> back to text, maybe not super important if model trained enough - if not keep_sources_in_context1 and langchain_mode1 != 'Disabled' and prompt.find(source_prefix) >= 0: - # FIXME: This is relatively slow even for small amount of text, like 0.3s each history item - import re - prompt = re.sub(f'{re.escape(source_prefix)}.*?{re.escape(source_postfix)}', '', prompt, - flags=re.DOTALL) - if prompt.endswith('\n

'): - prompt = prompt[:-4] - prompt = prompt.replace('
', chat_turn_sep) - if not prompt.endswith(chat_turn_sep): - prompt += chat_turn_sep - # most recent first, add older if can - # only include desired chat history - if len(prompt + context1) > max_prompt_length: - break - context1 += prompt - - _, pre_response, terminate_response, chat_sep, chat_turn_sep = generate_prompt({}, prompt_type1, prompt_dict1, - chat1, reduced=True, - making_context=True) - if context1 and not context1.endswith(chat_turn_sep): - context1 += chat_turn_sep # ensure if terminates abruptly, then human continues on next line - return context1 - - -def update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, extra): - # update from saved state on disk - langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = \ - load_collection_enum(extra) - - visible_langchain_modes_temp = visible_langchain_modes.copy() + visible_langchain_modes_from_file - visible_langchain_modes.clear() # don't lose original reference - [visible_langchain_modes.append(x) for x in visible_langchain_modes_temp if x not in visible_langchain_modes] - - langchain_mode_paths.update(langchain_mode_paths_from_file) - - langchain_modes_temp = langchain_modes.copy() + langchain_modes_from_file - langchain_modes.clear() # don't lose original reference - [langchain_modes.append(x) for x in langchain_modes_temp if x not in langchain_modes] - - -def entrypoint_main(): - """ - Examples: - - WORLD_SIZE=4 CUDA_VISIBLE_DEVICES="0,1,2,3" torchrun --nproc_per_node=4 --master_port=1234 generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights=lora-alpaca_6B - python generate.py --base_model='EleutherAI/gpt-j-6B' --lora_weights='lora-alpaca_6B' - python generate.py --base_model='EleutherAI/gpt-neox-20b' --lora_weights='lora-alpaca_20B' - - # generate without lora weights, no prompt - python generate.py --base_model='EleutherAI/gpt-neox-20b' --prompt_type='plain' - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' - - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='dai_faq' --lora_weights='lora_20B_daifaq' - # OpenChatKit settings: - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 - - python generate.py --base_model='distilgpt2' --prompt_type='plain' --debug=True --num_beams=1 --temperature=0.6 --top_k=40 --top_p=1.0 --share=False - python generate.py --base_model='t5-large' --prompt_type='simple_instruct' - python generate.py --base_model='philschmid/bart-large-cnn-samsum' - python generate.py --base_model='philschmid/flan-t5-base-samsum' - python generate.py --base_model='facebook/mbart-large-50-many-to-many-mmt' - - python generate.py --base_model='togethercomputer/GPT-NeoXT-Chat-Base-20B' --prompt_type='human_bot' --lora_weights='GPT-NeoXT-Chat-Base-20B.merged.json.8_epochs.57b2892c53df5b8cefac45f84d019cace803ef26.28' - - must have 4*48GB GPU and run without 8bit in order for sharding to work with use_gpu_id=False - can also pass --prompt_type='human_bot' and model can somewhat handle instructions without being instruct tuned - python generate.py --base_model=decapoda-research/llama-65b-hf --load_8bit=False --use_gpu_id=False --prompt_type='human_bot' - - python generate.py --base_model=h2oai/h2ogpt-oig-oasst1-512-6_9b - """ - fire.Fire(main) - - -if __name__ == "__main__": - entrypoint_main() diff --git a/gpt4all_llm.py b/gpt4all_llm.py deleted file mode 100644 index 125892d99621ca80ad4ea6efcc39b01fa5cead63..0000000000000000000000000000000000000000 --- a/gpt4all_llm.py +++ /dev/null @@ -1,316 +0,0 @@ -import inspect -import os -from functools import partial -from typing import Dict, Any, Optional, List -from langchain.callbacks.manager import CallbackManagerForLLMRun -from pydantic import root_validator -from langchain.llms import gpt4all -from dotenv import dotenv_values - -from utils import FakeTokenizer - - -def get_model_tokenizer_gpt4all(base_model, **kwargs): - # defaults (some of these are generation parameters, so need to be passed in at generation time) - model_kwargs = dict(n_threads=os.cpu_count() // 2, - temp=kwargs.get('temperature', 0.2), - top_p=kwargs.get('top_p', 0.75), - top_k=kwargs.get('top_k', 40), - n_ctx=2048 - 256) - env_gpt4all_file = ".env_gpt4all" - model_kwargs.update(dotenv_values(env_gpt4all_file)) - # make int or float if can to satisfy types for class - for k, v in model_kwargs.items(): - try: - if float(v) == int(v): - model_kwargs[k] = int(v) - else: - model_kwargs[k] = float(v) - except: - pass - - if base_model == "llama": - if 'model_path_llama' not in model_kwargs: - raise ValueError("No model_path_llama in %s" % env_gpt4all_file) - model_path = model_kwargs.pop('model_path_llama') - # FIXME: GPT4All version of llama doesn't handle new quantization, so use llama_cpp_python - from llama_cpp import Llama - # llama sets some things at init model time, not generation time - func_names = list(inspect.signature(Llama.__init__).parameters) - model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names} - model_kwargs['n_ctx'] = int(model_kwargs['n_ctx']) - model = Llama(model_path=model_path, **model_kwargs) - elif base_model in "gpt4all_llama": - if 'model_name_gpt4all_llama' not in model_kwargs and 'model_path_gpt4all_llama' not in model_kwargs: - raise ValueError("No model_name_gpt4all_llama or model_path_gpt4all_llama in %s" % env_gpt4all_file) - model_name = model_kwargs.pop('model_name_gpt4all_llama') - model_type = 'llama' - from gpt4all import GPT4All as GPT4AllModel - model = GPT4AllModel(model_name=model_name, model_type=model_type) - elif base_model in "gptj": - if 'model_name_gptj' not in model_kwargs and 'model_path_gptj' not in model_kwargs: - raise ValueError("No model_name_gpt4j or model_path_gpt4j in %s" % env_gpt4all_file) - model_name = model_kwargs.pop('model_name_gptj') - model_type = 'gptj' - from gpt4all import GPT4All as GPT4AllModel - model = GPT4AllModel(model_name=model_name, model_type=model_type) - else: - raise ValueError("No such base_model %s" % base_model) - return model, FakeTokenizer(), 'cpu' - - -from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler - - -class H2OStreamingStdOutCallbackHandler(StreamingStdOutCallbackHandler): - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - # streaming to std already occurs without this - # sys.stdout.write(token) - # sys.stdout.flush() - pass - - -def get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=[]): - # default from class - model_kwargs = {k: v.default for k, v in dict(inspect.signature(cls).parameters).items() if k not in exclude_list} - # from our defaults - model_kwargs.update(default_kwargs) - # from user defaults - model_kwargs.update(env_kwargs) - # ensure only valid keys - func_names = list(inspect.signature(cls).parameters) - model_kwargs = {k: v for k, v in model_kwargs.items() if k in func_names} - return model_kwargs - - -def get_llm_gpt4all(model_name, - model=None, - max_new_tokens=256, - temperature=0.1, - repetition_penalty=1.0, - top_k=40, - top_p=0.7, - streaming=False, - callbacks=None, - prompter=None, - context='', - iinput='', - verbose=False, - ): - assert prompter is not None - env_gpt4all_file = ".env_gpt4all" - env_kwargs = dotenv_values(env_gpt4all_file) - max_tokens = env_kwargs.pop('max_tokens', 2048 - max_new_tokens) - default_kwargs = dict(context_erase=0.5, - n_batch=1, - max_tokens=max_tokens, - n_predict=max_new_tokens, - repeat_last_n=64 if repetition_penalty != 1.0 else 0, - repeat_penalty=repetition_penalty, - temp=temperature, - temperature=temperature, - top_k=top_k, - top_p=top_p, - use_mlock=True, - verbose=verbose) - if model_name == 'llama': - cls = H2OLlamaCpp - model_path = env_kwargs.pop('model_path_llama') if model is None else model - model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs']) - model_kwargs.update(dict(model_path=model_path, callbacks=callbacks, streaming=streaming, - prompter=prompter, context=context, iinput=iinput)) - llm = cls(**model_kwargs) - llm.client.verbose = verbose - elif model_name == 'gpt4all_llama': - cls = H2OGPT4All - model_path = env_kwargs.pop('model_path_gpt4all_llama') if model is None else model - model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs']) - model_kwargs.update( - dict(model=model_path, backend='llama', callbacks=callbacks, streaming=streaming, - prompter=prompter, context=context, iinput=iinput)) - llm = cls(**model_kwargs) - elif model_name == 'gptj': - cls = H2OGPT4All - model_path = env_kwargs.pop('model_path_gptj') if model is None else model - model_kwargs = get_model_kwargs(env_kwargs, default_kwargs, cls, exclude_list=['lc_kwargs']) - model_kwargs.update( - dict(model=model_path, backend='gptj', callbacks=callbacks, streaming=streaming, - prompter=prompter, context=context, iinput=iinput)) - llm = cls(**model_kwargs) - else: - raise RuntimeError("No such model_name %s" % model_name) - return llm - - -class H2OGPT4All(gpt4all.GPT4All): - model: Any - prompter: Any - context: Any = '' - iinput: Any = '' - """Path to the pre-trained GPT4All model file.""" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that the python package exists in the environment.""" - try: - if isinstance(values["model"], str): - from gpt4all import GPT4All as GPT4AllModel - - full_path = values["model"] - model_path, delimiter, model_name = full_path.rpartition("/") - model_path += delimiter - - values["client"] = GPT4AllModel( - model_name=model_name, - model_path=model_path or None, - model_type=values["backend"], - allow_download=False, - ) - if values["n_threads"] is not None: - # set n_threads - values["client"].model.set_thread_count(values["n_threads"]) - else: - values["client"] = values["model"] - try: - values["backend"] = values["client"].model_type - except AttributeError: - # The below is for compatibility with GPT4All Python bindings <= 0.2.3. - values["backend"] = values["client"].model.model_type - - except ImportError: - raise ValueError( - "Could not import gpt4all python package. " - "Please install it with `pip install gpt4all`." - ) - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs, - ) -> str: - # Roughly 4 chars per token if natural language - n_ctx = 2048 - prompt = prompt[-self.max_tokens * 4:] - - # use instruct prompting - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - - verbose = False - if verbose: - print("_call prompt: %s" % prompt, flush=True) - # FIXME: GPT4ALl doesn't support yield during generate, so cannot support streaming except via itself to stdout - return super()._call(prompt, stop=stop, run_manager=run_manager) - - -from langchain.llms import LlamaCpp - - -class H2OLlamaCpp(LlamaCpp): - model_path: Any - prompter: Any - context: Any - iinput: Any - """Path to the pre-trained GPT4All model file.""" - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that llama-cpp-python library is installed.""" - if isinstance(values["model_path"], str): - model_path = values["model_path"] - model_param_names = [ - "lora_path", - "lora_base", - "n_ctx", - "n_parts", - "seed", - "f16_kv", - "logits_all", - "vocab_only", - "use_mlock", - "n_threads", - "n_batch", - "use_mmap", - "last_n_tokens_size", - ] - model_params = {k: values[k] for k in model_param_names} - # For backwards compatibility, only include if non-null. - if values["n_gpu_layers"] is not None: - model_params["n_gpu_layers"] = values["n_gpu_layers"] - - try: - from llama_cpp import Llama - - values["client"] = Llama(model_path, **model_params) - except ImportError: - raise ModuleNotFoundError( - "Could not import llama-cpp-python library. " - "Please install the llama-cpp-python library to " - "use this embedding model: pip install llama-cpp-python" - ) - except Exception as e: - raise ValueError( - f"Could not load Llama model from path: {model_path}. " - f"Received error {e}" - ) - else: - values["client"] = values["model_path"] - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs, - ) -> str: - verbose = False - # tokenize twice, just to count tokens, since llama cpp python wrapper has no way to truncate - # still have to avoid crazy sizes, else hit llama_tokenize: too many tokens -- might still hit, not fatal - prompt = prompt[-self.n_ctx * 4:] - prompt_tokens = self.client.tokenize(b" " + prompt.encode("utf-8")) - num_prompt_tokens = len(prompt_tokens) - if num_prompt_tokens > self.n_ctx: - # conservative by using int() - chars_per_token = int(len(prompt) / num_prompt_tokens) - prompt = prompt[-self.n_ctx * chars_per_token:] - if verbose: - print("reducing tokens, assuming average of %s chars/token: %s" % chars_per_token, flush=True) - prompt_tokens2 = self.client.tokenize(b" " + prompt.encode("utf-8")) - num_prompt_tokens2 = len(prompt_tokens2) - print("reduced tokens from %d -> %d" % (num_prompt_tokens, num_prompt_tokens2), flush=True) - - # use instruct prompting - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - - if verbose: - print("_call prompt: %s" % prompt, flush=True) - - if self.streaming: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter - if text_callback: - text_callback(prompt) - text = "" - for token in self.stream(prompt=prompt, stop=stop, run_manager=run_manager): - text_chunk = token["choices"][0]["text"] - # self.stream already calls text_callback - # if text_callback: - # text_callback(text_chunk) - text += text_chunk - return text - else: - params = self._get_parameters(stop) - params = {**params, **kwargs} - result = self.client(prompt=prompt, **params) - return result["choices"][0]["text"] diff --git a/gpt_langchain.py b/gpt_langchain.py deleted file mode 100644 index c2a3438e865fac91693a84625bc0709332ba6e82..0000000000000000000000000000000000000000 --- a/gpt_langchain.py +++ /dev/null @@ -1,2559 +0,0 @@ -import ast -import glob -import inspect -import os -import pathlib -import pickle -import shutil -import subprocess -import tempfile -import time -import traceback -import types -import uuid -import zipfile -from collections import defaultdict -from datetime import datetime -from functools import reduce -from operator import concat -import filelock - -from joblib import delayed -from langchain.callbacks import streaming_stdout -from langchain.embeddings import HuggingFaceInstructEmbeddings -from langchain.schema import LLMResult -from tqdm import tqdm - -from enums import DocumentSubset, no_lora_str, model_token_mapping, source_prefix, source_postfix, non_query_commands, \ - LangChainAction, LangChainMode, DocumentChoice -from evaluate_params import gen_hyper -from gen import get_model, SEED -from prompter import non_hf_types, PromptType, Prompter -from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \ - get_device, ProgressParallel, remove, hash_file, clear_torch_cache, NullContext, get_hf_server, FakeTokenizer, \ - have_libreoffice, have_arxiv, have_playwright, have_selenium, have_tesseract, have_pymupdf, set_openai -from utils_langchain import StreamingGradioCallbackHandler - -import_matplotlib() - -import numpy as np -import pandas as pd -import requests -from langchain.chains.qa_with_sources import load_qa_with_sources_chain -# , GCSDirectoryLoader, GCSFileLoader -# , OutlookMessageLoader # GPL3 -# ImageCaptionLoader, # use our own wrapper -# ReadTheDocsLoader, # no special file, some path, so have to give as special option -from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \ - UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \ - EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \ - UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader, UnstructuredPDFLoader, \ - UnstructuredExcelLoader -from langchain.text_splitter import RecursiveCharacterTextSplitter, Language -from langchain.chains.question_answering import load_qa_chain -from langchain.docstore.document import Document -from langchain import PromptTemplate, HuggingFaceTextGenInference -from langchain.vectorstores import Chroma - - -def get_db(sources, use_openai_embedding=False, db_type='faiss', - persist_directory="db_dir", load_db_if_exists=True, - langchain_mode='notset', - collection_name=None, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"): - if not sources: - return None - - # get embedding model - embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) - assert collection_name is not None or langchain_mode != 'notset' - if collection_name is None: - collection_name = langchain_mode.replace(' ', '_') - - # Create vector database - if db_type == 'faiss': - from langchain.vectorstores import FAISS - db = FAISS.from_documents(sources, embedding) - elif db_type == 'weaviate': - import weaviate - from weaviate.embedded import EmbeddedOptions - from langchain.vectorstores import Weaviate - - if os.getenv('WEAVIATE_URL', None): - client = _create_local_weaviate_client() - else: - client = weaviate.Client( - embedded_options=EmbeddedOptions() - ) - index_name = collection_name.capitalize() - db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, - index_name=index_name) - elif db_type == 'chroma': - assert persist_directory is not None - os.makedirs(persist_directory, exist_ok=True) - - # see if already actually have persistent db, and deal with possible changes in embedding - db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, - hf_embedding_model, verbose=False) - if db is None: - from chromadb.config import Settings - client_settings = Settings(anonymized_telemetry=False, - chroma_db_impl="duckdb+parquet", - persist_directory=persist_directory) - db = Chroma.from_documents(documents=sources, - embedding=embedding, - persist_directory=persist_directory, - collection_name=collection_name, - client_settings=client_settings) - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - # then just add - db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model) - else: - raise RuntimeError("No such db_type=%s" % db_type) - - return db - - -def _get_unique_sources_in_weaviate(db): - batch_size = 100 - id_source_list = [] - result = db._client.data_object.get(class_name=db._index_name, limit=batch_size) - - while result['objects']: - id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']] - last_id = id_source_list[-1][0] - result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id) - - unique_sources = {source for _, source in id_source_list} - return unique_sources - - -def add_to_db(db, sources, db_type='faiss', - avoid_dup_by_file=False, - avoid_dup_by_content=True, - use_openai_embedding=False, - hf_embedding_model=None): - assert hf_embedding_model is not None - num_new_sources = len(sources) - if not sources: - return db, num_new_sources, [] - if db_type == 'faiss': - db.add_documents(sources) - elif db_type == 'weaviate': - # FIXME: only control by file name, not hash yet - if avoid_dup_by_file or avoid_dup_by_content: - unique_sources = _get_unique_sources_in_weaviate(db) - sources = [x for x in sources if x.metadata['source'] not in unique_sources] - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] - db.add_documents(documents=sources) - elif db_type == 'chroma': - collection = get_documents(db) - # files we already have: - metadata_files = set([x['source'] for x in collection['metadatas']]) - if avoid_dup_by_file: - # Too weak in case file changed content, assume parent shouldn't pass true for this for now - raise RuntimeError("Not desired code path") - sources = [x for x in sources if x.metadata['source'] not in metadata_files] - if avoid_dup_by_content: - # look at hash, instead of page_content - # migration: If no hash previously, avoid updating, - # since don't know if need to update and may be expensive to redo all unhashed files - metadata_hash_ids = set( - [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]]) - # avoid sources with same hash - sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids] - num_nohash = len([x for x in sources if not x.metadata.get('hashid')]) - print("Found %s new sources (%d have no hash in original source," - " so have to reprocess for migration to sources with hash)" % (len(sources), num_nohash), flush=True) - # get new file names that match existing file names. delete existing files we are overridding - dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files]) - print("Removing %s duplicate files from db because ingesting those as new documents" % len( - dup_metadata_files), flush=True) - client_collection = db._client.get_collection(name=db._collection.name, - embedding_function=db._collection._embedding_function) - for dup_file in dup_metadata_files: - dup_file_meta = dict(source=dup_file) - try: - client_collection.delete(where=dup_file_meta) - except KeyError: - pass - num_new_sources = len(sources) - if num_new_sources == 0: - return db, num_new_sources, [] - db.add_documents(documents=sources) - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - else: - raise RuntimeError("No such db_type=%s" % db_type) - - new_sources_metadata = [x.metadata for x in sources] - - return db, num_new_sources, new_sources_metadata - - -def create_or_update_db(db_type, persist_directory, collection_name, - sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model): - if db_type == 'weaviate': - import weaviate - from weaviate.embedded import EmbeddedOptions - - if os.getenv('WEAVIATE_URL', None): - client = _create_local_weaviate_client() - else: - client = weaviate.Client( - embedded_options=EmbeddedOptions() - ) - - index_name = collection_name.replace(' ', '_').capitalize() - if client.schema.exists(index_name) and not add_if_exists: - client.schema.delete_class(index_name) - if verbose: - print("Removing %s" % index_name, flush=True) - elif db_type == 'chroma': - if not os.path.isdir(persist_directory) or not add_if_exists: - if os.path.isdir(persist_directory): - if verbose: - print("Removing %s" % persist_directory, flush=True) - remove(persist_directory) - if verbose: - print("Generating db", flush=True) - - if not add_if_exists: - if verbose: - print("Generating db", flush=True) - else: - if verbose: - print("Loading and updating db", flush=True) - - db = get_db(sources, - use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - langchain_mode=collection_name, - hf_embedding_model=hf_embedding_model) - - return db - - -def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"): - # Get embedding model - if use_openai_embedding: - assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY" - from langchain.embeddings import OpenAIEmbeddings - embedding = OpenAIEmbeddings(disallowed_special=()) - else: - # to ensure can fork without deadlock - from langchain.embeddings import HuggingFaceEmbeddings - - device, torch_dtype, context_class = get_device_dtype() - model_kwargs = dict(device=device) - if 'instructor' in hf_embedding_model: - encode_kwargs = {'normalize_embeddings': True} - embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model, - model_kwargs=model_kwargs, - encode_kwargs=encode_kwargs) - else: - embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs) - return embedding - - -def get_answer_from_sources(chain, sources, question): - return chain( - { - "input_documents": sources, - "question": question, - }, - return_only_outputs=True, - )["output_text"] - - -"""Wrapper around Huggingface text generation inference API.""" -from functools import partial -from typing import Any, Dict, List, Optional, Set - -from pydantic import Extra, Field, root_validator - -from langchain.callbacks.manager import CallbackManagerForLLMRun, Callbacks -from langchain.llms.base import LLM - - -class GradioInference(LLM): - """ - Gradio generation inference API. - """ - inference_server_url: str = "" - - temperature: float = 0.8 - top_p: Optional[float] = 0.95 - top_k: Optional[int] = None - num_beams: Optional[int] = 1 - max_new_tokens: int = 512 - min_new_tokens: int = 1 - early_stopping: bool = False - max_time: int = 180 - repetition_penalty: Optional[float] = None - num_return_sequences: Optional[int] = 1 - do_sample: bool = False - chat_client: bool = False - - return_full_text: bool = True - stream: bool = False - sanitize_bot_response: bool = False - - prompter: Any = None - context: Any = '' - iinput: Any = '' - client: Any = None - - class Config: - """Configuration for this pydantic object.""" - - extra = Extra.forbid - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that python package exists in environment.""" - - try: - if values['client'] is None: - import gradio_client - values["client"] = gradio_client.Client( - values["inference_server_url"] - ) - except ImportError: - raise ImportError( - "Could not import gradio_client python package. " - "Please install it with `pip install gradio_client`." - ) - return values - - @property - def _llm_type(self) -> str: - """Return type of llm.""" - return "gradio_inference" - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - # NOTE: prompt here has no prompt_type (e.g. human: bot:) prompt injection, - # so server should get prompt_type or '', not plain - # This is good, so gradio server can also handle stopping.py conditions - # this is different than TGI server that uses prompter to inject prompt_type prompting - stream_output = self.stream - gr_client = self.client - client_langchain_mode = 'Disabled' - client_add_chat_history_to_context = True - client_langchain_action = LangChainAction.QUERY.value - client_langchain_agents = [] - top_k_docs = 1 - chunk = True - chunk_size = 512 - client_kwargs = dict(instruction=prompt if self.chat_client else '', # only for chat=True - iinput=self.iinput if self.chat_client else '', # only for chat=True - context=self.context, - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - prompt_type=self.prompter.prompt_type, - prompt_dict='', - - temperature=self.temperature, - top_p=self.top_p, - top_k=self.top_k, - num_beams=self.num_beams, - max_new_tokens=self.max_new_tokens, - min_new_tokens=self.min_new_tokens, - early_stopping=self.early_stopping, - max_time=self.max_time, - repetition_penalty=self.repetition_penalty, - num_return_sequences=self.num_return_sequences, - do_sample=self.do_sample, - chat=self.chat_client, - - instruction_nochat=prompt if not self.chat_client else '', - iinput_nochat=self.iinput if not self.chat_client else '', - langchain_mode=client_langchain_mode, - add_chat_history_to_context=client_add_chat_history_to_context, - langchain_action=client_langchain_action, - langchain_agents=client_langchain_agents, - top_k_docs=top_k_docs, - chunk=chunk, - chunk_size=chunk_size, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - ) - api_name = '/submit_nochat_api' # NOTE: like submit_nochat but stable API for string dict passing - if not stream_output: - res = gr_client.predict(str(dict(client_kwargs)), api_name=api_name) - res_dict = ast.literal_eval(res) - text = res_dict['response'] - return self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - - job = gr_client.submit(str(dict(client_kwargs)), api_name=api_name) - text0 = '' - while not job.done(): - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res_dict = ast.literal_eval(res) - text = res_dict['response'] - text = self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - # FIXME: derive chunk from full for now - text_chunk = text[len(text0):] - # save old - text0 = text - - if text_callback: - text_callback(text_chunk) - - time.sleep(0.01) - - # ensure get last output to avoid race - res_all = job.outputs() - if len(res_all) > 0: - res = res_all[-1] - res_dict = ast.literal_eval(res) - text = res_dict['response'] - # FIXME: derive chunk from full for now - else: - # go with old if failure - text = text0 - text_chunk = text[len(text0):] - if text_callback: - text_callback(text_chunk) - return self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - - -class H2OHuggingFaceTextGenInference(HuggingFaceTextGenInference): - max_new_tokens: int = 512 - do_sample: bool = False - top_k: Optional[int] = None - top_p: Optional[float] = 0.95 - typical_p: Optional[float] = 0.95 - temperature: float = 0.8 - repetition_penalty: Optional[float] = None - return_full_text: bool = False - stop_sequences: List[str] = Field(default_factory=list) - seed: Optional[int] = None - inference_server_url: str = "" - timeout: int = 300 - headers: dict = None - stream: bool = False - sanitize_bot_response: bool = False - prompter: Any = None - context: Any = '' - iinput: Any = '' - tokenizer: Any = None - client: Any = None - - @root_validator() - def validate_environment(cls, values: Dict) -> Dict: - """Validate that python package exists in environment.""" - - try: - if values['client'] is None: - import text_generation - - values["client"] = text_generation.Client( - values["inference_server_url"], - timeout=values["timeout"], - headers=values["headers"], - ) - except ImportError: - raise ImportError( - "Could not import text_generation python package. " - "Please install it with `pip install text_generation`." - ) - return values - - def _call( - self, - prompt: str, - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> str: - if stop is None: - stop = self.stop_sequences - else: - stop += self.stop_sequences - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) - - # NOTE: TGI server does not add prompting, so must do here - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - - gen_server_kwargs = dict(do_sample=self.do_sample, - stop_sequences=stop, - max_new_tokens=self.max_new_tokens, - top_k=self.top_k, - top_p=self.top_p, - typical_p=self.typical_p, - temperature=self.temperature, - repetition_penalty=self.repetition_penalty, - return_full_text=self.return_full_text, - seed=self.seed, - ) - gen_server_kwargs.update(kwargs) - - # lower bound because client is re-used if multi-threading - self.client.timeout = max(300, self.timeout) - - if not self.stream: - res = self.client.generate( - prompt, - **gen_server_kwargs, - ) - if self.return_full_text: - gen_text = res.generated_text[len(prompt):] - else: - gen_text = res.generated_text - # remove stop sequences from the end of the generated text - for stop_seq in stop: - if stop_seq in gen_text: - gen_text = gen_text[:gen_text.index(stop_seq)] - text = prompt + gen_text - text = self.prompter.get_response(text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - else: - text_callback = None - if run_manager: - text_callback = partial( - run_manager.on_llm_new_token, verbose=self.verbose - ) - # parent handler of streamer expects to see prompt first else output="" and lose if prompt=None in prompter - if text_callback: - text_callback(prompt) - text = "" - # Note: Streaming ignores return_full_text=True - for response in self.client.generate_stream(prompt, **gen_server_kwargs): - text_chunk = response.token.text - text += text_chunk - text = self.prompter.get_response(prompt + text, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - # stream part - is_stop = False - for stop_seq in stop: - if stop_seq in response.token.text: - is_stop = True - break - if is_stop: - break - if not response.token.special: - if text_callback: - text_callback(response.token.text) - return text - - -from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI -from langchain.llms.openai import _streaming_response_template, completion_with_retry, _update_response, \ - update_token_usage - - -class H2OOpenAI(OpenAI): - """ - New class to handle vLLM's use of OpenAI, no vllm_chat supported, so only need here - Handles prompting that OpenAI doesn't need, stopping as well - """ - stop_sequences: Any = None - sanitize_bot_response: bool = False - prompter: Any = None - context: Any = '' - iinput: Any = '' - tokenizer: Any = None - - @classmethod - def all_required_field_names(cls) -> Set: - all_required_field_names = super(OpenAI, cls).all_required_field_names() - all_required_field_names.update( - {'top_p', 'frequency_penalty', 'presence_penalty', 'stop_sequences', 'sanitize_bot_response', 'prompter', - 'tokenizer'}) - return all_required_field_names - - def _generate( - self, - prompts: List[str], - stop: Optional[List[str]] = None, - run_manager: Optional[CallbackManagerForLLMRun] = None, - **kwargs: Any, - ) -> LLMResult: - stop = self.stop_sequences if not stop else self.stop_sequences + stop - - # HF inference server needs control over input tokens - assert self.tokenizer is not None - from h2oai_pipeline import H2OTextGenerationPipeline - for prompti, prompt in enumerate(prompts): - prompt, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt, self.tokenizer) - # NOTE: OpenAI/vLLM server does not add prompting, so must do here - data_point = dict(context=self.context, instruction=prompt, input=self.iinput) - prompt = self.prompter.generate_prompt(data_point) - prompts[prompti] = prompt - - params = self._invocation_params - params = {**params, **kwargs} - sub_prompts = self.get_sub_prompts(params, prompts, stop) - choices = [] - token_usage: Dict[str, int] = {} - # Get the token usage from the response. - # Includes prompt, completion, and total tokens used. - _keys = {"completion_tokens", "prompt_tokens", "total_tokens"} - text = '' - for _prompts in sub_prompts: - if self.streaming: - text_with_prompt = "" - prompt = _prompts[0] - if len(_prompts) > 1: - raise ValueError("Cannot stream results with multiple prompts.") - params["stream"] = True - response = _streaming_response_template() - first = True - for stream_resp in completion_with_retry( - self, prompt=_prompts, **params - ): - if first: - stream_resp["choices"][0]["text"] = prompt + stream_resp["choices"][0]["text"] - first = False - text_chunk = stream_resp["choices"][0]["text"] - text_with_prompt += text_chunk - text = self.prompter.get_response(text_with_prompt, prompt=prompt, - sanitize_bot_response=self.sanitize_bot_response) - if run_manager: - run_manager.on_llm_new_token( - text_chunk, - verbose=self.verbose, - logprobs=stream_resp["choices"][0]["logprobs"], - ) - _update_response(response, stream_resp) - choices.extend(response["choices"]) - else: - response = completion_with_retry(self, prompt=_prompts, **params) - choices.extend(response["choices"]) - if not self.streaming: - # Can't update token usage if streaming - update_token_usage(_keys, response, token_usage) - choices[0]['text'] = text - return self.create_llm_result(choices, prompts, token_usage) - - -class H2OChatOpenAI(ChatOpenAI): - @classmethod - def all_required_field_names(cls) -> Set: - all_required_field_names = super(ChatOpenAI, cls).all_required_field_names() - all_required_field_names.update({'top_p', 'frequency_penalty', 'presence_penalty'}) - return all_required_field_names - - -def get_llm(use_openai_model=False, - model_name=None, - model=None, - tokenizer=None, - inference_server=None, - stream_output=False, - do_sample=False, - temperature=0.1, - top_k=40, - top_p=0.7, - num_beams=1, - max_new_tokens=256, - min_new_tokens=1, - early_stopping=False, - max_time=180, - repetition_penalty=1.0, - num_return_sequences=1, - prompt_type=None, - prompt_dict=None, - prompter=None, - context=None, - iinput=None, - sanitize_bot_response=False, - verbose=False, - ): - if inference_server is None: - inference_server = '' - if use_openai_model or inference_server.startswith('openai') or inference_server.startswith('vllm'): - if use_openai_model and model_name is None: - model_name = "gpt-3.5-turbo" - # FIXME: Will later import be ignored? I think so, so should be fine - openai, inf_type = set_openai(inference_server) - kwargs_extra = {} - if inference_server == 'openai_chat' or inf_type == 'vllm_chat': - cls = H2OChatOpenAI - # FIXME: Support context, iinput - else: - cls = H2OOpenAI - if inf_type == 'vllm': - terminate_response = prompter.terminate_response or [] - stop_sequences = list(set(terminate_response + [prompter.PreResponse])) - stop_sequences = [x for x in stop_sequences if x] - kwargs_extra = dict(stop_sequences=stop_sequences, - sanitize_bot_response=sanitize_bot_response, - prompter=prompter, - context=context, - iinput=iinput, - tokenizer=tokenizer, - client=None) - - callbacks = [StreamingGradioCallbackHandler()] - llm = cls(model_name=model_name, - temperature=temperature if do_sample else 0, - # FIXME: Need to count tokens and reduce max_new_tokens to fit like in generate.py - max_tokens=max_new_tokens, - top_p=top_p if do_sample else 1, - frequency_penalty=0, - presence_penalty=1.07 - repetition_penalty + 0.6, # so good default - callbacks=callbacks if stream_output else None, - openai_api_key=openai.api_key, - openai_api_base=openai.api_base, - logit_bias=None if inf_type == 'vllm' else {}, - max_retries=2, - streaming=stream_output, - **kwargs_extra - ) - streamer = callbacks[0] if stream_output else None - if inference_server in ['openai', 'openai_chat']: - prompt_type = inference_server - else: - # vllm goes here - prompt_type = prompt_type or 'plain' - elif inference_server: - assert inference_server.startswith( - 'http'), "Malformed inference_server=%s. Did you add http:// in front?" % inference_server - - from gradio_utils.grclient import GradioClient - from text_generation import Client as HFClient - if isinstance(model, GradioClient): - gr_client = model - hf_client = None - else: - gr_client = None - hf_client = model - assert isinstance(hf_client, HFClient) - - inference_server, headers = get_hf_server(inference_server) - - # quick sanity check to avoid long timeouts, just see if can reach server - requests.get(inference_server, timeout=int(os.getenv('REQUEST_TIMEOUT_FAST', '10'))) - - callbacks = [StreamingGradioCallbackHandler()] - assert prompter is not None - terminate_response = prompter.terminate_response or [] - stop_sequences = list(set(terminate_response + [prompter.PreResponse])) - stop_sequences = [x for x in stop_sequences if x] - - if gr_client: - chat_client = False - llm = GradioInference( - inference_server_url=inference_server, - return_full_text=True, - - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - chat_client=chat_client, - - callbacks=callbacks if stream_output else None, - stream=stream_output, - prompter=prompter, - context=context, - iinput=iinput, - client=gr_client, - sanitize_bot_response=sanitize_bot_response, - ) - elif hf_client: - llm = H2OHuggingFaceTextGenInference( - inference_server_url=inference_server, - do_sample=do_sample, - max_new_tokens=max_new_tokens, - repetition_penalty=repetition_penalty, - return_full_text=True, - seed=SEED, - - stop_sequences=stop_sequences, - temperature=temperature, - top_k=top_k, - top_p=top_p, - # typical_p=top_p, - callbacks=callbacks if stream_output else None, - stream=stream_output, - prompter=prompter, - context=context, - iinput=iinput, - tokenizer=tokenizer, - client=hf_client, - timeout=max_time, - sanitize_bot_response=sanitize_bot_response, - ) - else: - raise RuntimeError("No defined client") - streamer = callbacks[0] if stream_output else None - elif model_name in non_hf_types: - if model_name == 'llama': - callbacks = [StreamingGradioCallbackHandler()] - streamer = callbacks[0] if stream_output else None - else: - # stream_output = False - # doesn't stream properly as generator, but at least - callbacks = [streaming_stdout.StreamingStdOutCallbackHandler()] - streamer = None - if prompter: - prompt_type = prompter.prompt_type - else: - prompter = Prompter(prompt_type, prompt_dict, debug=False, chat=False, stream_output=stream_output) - pass # assume inputted prompt_type is correct - from gpt4all_llm import get_llm_gpt4all - llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens, - temperature=temperature, - repetition_penalty=repetition_penalty, - top_k=top_k, - top_p=top_p, - callbacks=callbacks, - verbose=verbose, - streaming=stream_output, - prompter=prompter, - context=context, - iinput=iinput, - ) - else: - if model is None: - # only used if didn't pass model in - assert tokenizer is None - prompt_type = 'human_bot' - if model_name is None: - model_name = 'h2oai/h2ogpt-oasst1-512-12b' - # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' - # model_name = 'h2oai/h2ogpt-oasst1-512-20b' - inference_server = '' - model, tokenizer, device = get_model(load_8bit=True, base_model=model_name, - inference_server=inference_server, gpu_id=0) - - max_max_tokens = tokenizer.model_max_length - gen_kwargs = dict(do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - return_full_text=True, - handle_long_generation=None) - assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0 - - if stream_output: - skip_prompt = False - from gen import H2OTextIteratorStreamer - decoder_kwargs = {} - streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs) - gen_kwargs.update(dict(streamer=streamer)) - else: - streamer = None - - from h2oai_pipeline import H2OTextGenerationPipeline - pipe = H2OTextGenerationPipeline(model=model, use_prompter=True, - prompter=prompter, - context=context, - iinput=iinput, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - sanitize_bot_response=sanitize_bot_response, - chat=False, stream_output=stream_output, - tokenizer=tokenizer, - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens=max_max_tokens - max(min_new_tokens, 256), - **gen_kwargs) - # pipe.task = "text-generation" - # below makes it listen only to our prompt removal, - # not built in prompt removal that is less general and not specific for our model - pipe.task = "text2text-generation" - - from langchain.llms import HuggingFacePipeline - llm = HuggingFacePipeline(pipeline=pipe) - return llm, model_name, streamer, prompt_type - - -def get_device_dtype(): - # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently - import torch - n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 - device = 'cpu' if n_gpus == 0 else 'cuda' - # from utils import NullContext - # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class - context_class = torch.device - torch_dtype = torch.float16 if device == 'cuda' else torch.float32 - return device, torch_dtype, context_class - - -def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True): - """ - Get wikipedia data from online - :param title: - :param first_paragraph_only: - :param text_limit: - :param take_head: - :return: - """ - filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head) - url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}" - if first_paragraph_only: - url += "&exintro=1" - import json - if not os.path.isfile(filename): - data = requests.get(url).json() - json.dump(data, open(filename, 'wt')) - else: - data = json.load(open(filename, "rt")) - page_content = list(data["query"]["pages"].values())[0]["extract"] - if take_head is not None and text_limit is not None: - page_content = page_content[:text_limit] if take_head else page_content[-text_limit:] - title_url = str(title).replace(' ', '_') - return Document( - page_content=page_content, - metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"}, - ) - - -def get_wiki_sources(first_para=True, text_limit=None): - """ - Get specific named sources from wikipedia - :param first_para: - :param text_limit: - :return: - """ - default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux'] - wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources)) - return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources] - - -def get_github_docs(repo_owner, repo_name): - """ - Access github from specific repo - :param repo_owner: - :param repo_name: - :return: - """ - with tempfile.TemporaryDirectory() as d: - subprocess.check_call( - f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", - cwd=d, - shell=True, - ) - git_sha = ( - subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) - .decode("utf-8") - .strip() - ) - repo_path = pathlib.Path(d) - markdown_files = list(repo_path.glob("*/*.md")) + list( - repo_path.glob("*/*.mdx") - ) - for markdown_file in markdown_files: - with open(markdown_file, "r") as f: - relative_path = markdown_file.relative_to(repo_path) - github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" - yield Document(page_content=f.read(), metadata={"source": github_url}) - - -def get_dai_pickle(dest="."): - from huggingface_hub import hf_hub_download - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv('HUGGINGFACE_API_TOKEN', True) - path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset') - shutil.copy(path_to_zip_file, dest) - - -def get_dai_docs(from_hf=False, get_pickle=True): - """ - Consume DAI documentation, or consume from public pickle - :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain - :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF - :return: - """ - import pickle - - if get_pickle: - get_dai_pickle() - - dai_store = 'dai_docs.pickle' - dst = "working_dir_docs" - if not os.path.isfile(dai_store): - from create_data import setup_dai_docs - dst = setup_dai_docs(dst=dst, from_hf=from_hf) - - import glob - files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True)) - - basedir = os.path.abspath(os.getcwd()) - from create_data import rst_to_outputs - new_outputs = rst_to_outputs(files) - os.chdir(basedir) - - pickle.dump(new_outputs, open(dai_store, 'wb')) - else: - new_outputs = pickle.load(open(dai_store, 'rb')) - - sources = [] - for line, file in new_outputs: - # gradio requires any linked file to be with app.py - sym_src = os.path.abspath(os.path.join(dst, file)) - sym_dst = os.path.abspath(os.path.join(os.getcwd(), file)) - if os.path.lexists(sym_dst): - os.remove(sym_dst) - os.symlink(sym_src, sym_dst) - itm = Document(page_content=line, metadata={"source": file}) - # NOTE: yield has issues when going into db, loses metadata - # yield itm - sources.append(itm) - return sources - - -image_types = ["png", "jpg", "jpeg"] -non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf", - "md", - "html", "mhtml", - "enex", "eml", "epub", "odt", "pptx", "ppt", - "zip", "urls", - - ] -# "msg", GPL3 - -if have_libreoffice or True: - # or True so it tries to load, e.g. on MAC/Windows, even if don't have libreoffice since works without that - non_image_types.extend(["docx", "doc", "xls", "xlsx"]) - -file_types = non_image_types + image_types - - -def add_meta(docs1, file): - file_extension = pathlib.Path(file).suffix - hashid = hash_file(file) - doc_hash = str(uuid.uuid4())[:10] - if not isinstance(docs1, (list, tuple, types.GeneratorType)): - docs1 = [docs1] - [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now()), hashid=hashid, doc_hash=doc_hash)) for - x in docs1] - - -def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, - chunk=True, chunk_size=512, n_jobs=-1, - is_url=False, is_txt=False, - enable_captions=True, - captions_model=None, - enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None, - headsize=50): - if file is None: - if fail_any_exception: - raise RuntimeError("Unexpected None file") - else: - return [] - doc1 = [] # in case no support, or disabled support - if base_path is None and not is_txt and not is_url: - # then assume want to persist but don't care which path used - # can't be in base_path - dir_name = os.path.dirname(file) - base_name = os.path.basename(file) - # if from gradio, will have its own temp uuid too, but that's ok - base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10] - base_path = os.path.join(dir_name, base_name) - if is_url: - file = file.strip() # in case accidental spaces in front or at end - if file.lower().startswith('arxiv:'): - query = file.lower().split('arxiv:') - if len(query) == 2 and have_arxiv: - query = query[1] - docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load() - # ensure string, sometimes None - [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1] - query_url = f"https://arxiv.org/abs/{query}" - [x.metadata.update( - dict(source=x.metadata.get('entry_id', query_url), query=query_url, - input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in - docs1] - else: - docs1 = [] - else: - if not (file.startswith("http://") or file.startswith("file://") or file.startswith("https://")): - file = 'http://' + file - docs1 = UnstructuredURLLoader(urls=[file]).load() - if len(docs1) == 0 and have_playwright: - # then something went wrong, try another loader: - from langchain.document_loaders import PlaywrightURLLoader - docs1 = PlaywrightURLLoader(urls=[file]).load() - if len(docs1) == 0 and have_selenium: - # then something went wrong, try another loader: - # but requires Chrome binary, else get: selenium.common.exceptions.WebDriverException: Message: unknown error: cannot find Chrome binary - from langchain.document_loaders import SeleniumURLLoader - from selenium.common.exceptions import WebDriverException - try: - docs1 = SeleniumURLLoader(urls=[file]).load() - except WebDriverException as e: - print("No web driver: %s" % str(e), flush=True) - [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1] - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif is_txt: - base_path = "user_paste" - source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10]) - makedirs(os.path.dirname(source_file), exist_ok=True) - with open(source_file, "wt") as f: - f.write(file) - metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt') - doc1 = Document(page_content=file, metadata=metadata) - doc1 = clean_doc(doc1) - elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'): - docs1 = UnstructuredHTMLLoader(file_path=file).load() - add_meta(docs1, file) - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.HTML) - elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and (have_libreoffice or True): - docs1 = UnstructuredWordDocumentLoader(file_path=file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif (file.lower().endswith('.xlsx') or file.lower().endswith('.xls')) and (have_libreoffice or True): - docs1 = UnstructuredExcelLoader(file_path=file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.odt'): - docs1 = UnstructuredODTLoader(file_path=file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('pptx') or file.lower().endswith('ppt'): - docs1 = UnstructuredPowerPointLoader(file_path=file).load() - add_meta(docs1, file) - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.txt'): - # use UnstructuredFileLoader ? - docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load() - # makes just one, but big one - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - doc1 = clean_doc(doc1) - add_meta(doc1, file) - elif file.lower().endswith('.rtf'): - docs1 = UnstructuredRTFLoader(file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.md'): - docs1 = UnstructuredMarkdownLoader(file).load() - add_meta(docs1, file) - docs1 = clean_doc(docs1) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size, language=Language.MARKDOWN) - elif file.lower().endswith('.enex'): - docs1 = EverNoteLoader(file).load() - add_meta(doc1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.epub'): - docs1 = UnstructuredEPubLoader(file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'): - docs1 = [] - if have_tesseract and enable_ocr: - # OCR, somewhat works, but not great - docs1.extend(UnstructuredImageLoader(file).load()) - add_meta(docs1, file) - if enable_captions: - # BLIP - if caption_loader is not None and not isinstance(caption_loader, (str, bool)): - # assumes didn't fork into this process with joblib, else can deadlock - caption_loader.set_image_paths([file]) - docs1c = caption_loader.load() - add_meta(docs1c, file) - [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c] - docs1.extend(docs1c) - else: - from image_captions import H2OImageCaptionLoader - caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu', - blip_model=captions_model, - blip_processor=captions_model) - caption_loader.set_image_paths([file]) - docs1c = caption_loader.load() - add_meta(docs1c, file) - [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c] - docs1.extend(docs1c) - for doci in docs1: - doci.metadata['source'] = doci.metadata['image_path'] - doci.metadata['hash'] = hash_file(doci.metadata['source']) - if docs1: - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.msg'): - raise RuntimeError("Not supported, GPL3 license") - # docs1 = OutlookMessageLoader(file).load() - # docs1[0].metadata['source'] = file - elif file.lower().endswith('.eml'): - try: - docs1 = UnstructuredEmailLoader(file).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - except ValueError as e: - if 'text/html content not found in email' in str(e): - # e.g. plain/text dict key exists, but not - # doc1 = TextLoader(file, encoding="utf8").load() - docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - else: - raise - # elif file.lower().endswith('.gcsdir'): - # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load() - # elif file.lower().endswith('.gcsfile'): - # doc1 = GCSFileLoader(project_name, bucket, blob).load() - elif file.lower().endswith('.rst'): - with open(file, "r") as f: - doc1 = Document(page_content=f.read(), metadata={"source": file}) - add_meta(doc1, file) - doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.RST) - elif file.lower().endswith('.pdf'): - env_gpt4all_file = ".env_gpt4all" - from dotenv import dotenv_values - env_kwargs = dotenv_values(env_gpt4all_file) - pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser') - doc1 = [] - handled = False - if have_pymupdf and pdf_class_name == 'PyMuPDFParser': - # GPL, only use if installed - from langchain.document_loaders import PyMuPDFLoader - # load() still chunks by pages, but every page has title at start to help - doc1 = PyMuPDFLoader(file).load() - # remove empty documents - handled |= len(doc1) > 0 - doc1 = [x for x in doc1 if x.page_content] - doc1 = clean_doc(doc1) - if len(doc1) == 0: - doc1 = UnstructuredPDFLoader(file).load() - handled |= len(doc1) > 0 - # remove empty documents - doc1 = [x for x in doc1 if x.page_content] - # seems to not need cleaning in most cases - if len(doc1) == 0: - # open-source fallback - # load() still chunks by pages, but every page has title at start to help - doc1 = PyPDFLoader(file).load() - handled |= len(doc1) > 0 - # remove empty documents - doc1 = [x for x in doc1 if x.page_content] - doc1 = clean_doc(doc1) - if have_pymupdf and len(doc1) == 0: - # GPL, only use if installed - from langchain.document_loaders import PyMuPDFLoader - # load() still chunks by pages, but every page has title at start to help - doc1 = PyMuPDFLoader(file).load() - handled |= len(doc1) > 0 - # remove empty documents - doc1 = [x for x in doc1 if x.page_content] - doc1 = clean_doc(doc1) - if len(doc1) == 0 and enable_pdf_ocr == 'auto' or enable_pdf_ocr == 'on': - # try OCR in end since slowest, but works on pure image pages well - doc1 = UnstructuredPDFLoader(file, strategy='ocr_only').load() - handled |= len(doc1) > 0 - # remove empty documents - doc1 = [x for x in doc1 if x.page_content] - # seems to not need cleaning in most cases - # Some PDFs return nothing or junk from PDFMinerLoader - if len(doc1) == 0: - # if literally nothing, show failed to parse so user knows, since unlikely nothing in PDF at all. - if handled: - raise ValueError("%s had no valid text, but meta data was parsed" % file) - else: - raise ValueError("%s had no valid text and no meta data was parsed" % file) - doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size) - add_meta(doc1, file) - elif file.lower().endswith('.csv'): - doc1 = CSVLoader(file).load() - add_meta(doc1, file) - elif file.lower().endswith('.py'): - doc1 = PythonLoader(file).load() - add_meta(doc1, file) - doc1 = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size, language=Language.PYTHON) - elif file.lower().endswith('.toml'): - doc1 = TomlLoader(file).load() - add_meta(doc1, file) - elif file.lower().endswith('.urls'): - with open(file, "r") as f: - docs1 = UnstructuredURLLoader(urls=f.readlines()).load() - add_meta(docs1, file) - doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) - elif file.lower().endswith('.zip'): - with zipfile.ZipFile(file, 'r') as zip_ref: - # don't put into temporary path, since want to keep references to docs inside zip - # so just extract in path where - zip_ref.extractall(base_path) - # recurse - doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception, n_jobs=n_jobs) - else: - raise RuntimeError("No file handler for %s" % os.path.basename(file)) - - # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now - # if list of length one, don't trust and chunk it - if not isinstance(doc1, list): - if chunk: - docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size) - else: - docs = [doc1] - elif isinstance(doc1, list) and len(doc1) == 1: - if chunk: - docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size) - else: - docs = doc1 - else: - docs = doc1 - - assert isinstance(docs, list) - return docs - - -def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, - chunk=True, chunk_size=512, - n_jobs=-1, - is_url=False, is_txt=False, - enable_captions=True, - captions_model=None, - enable_ocr=False, enable_pdf_ocr='auto', caption_loader=None): - if verbose: - if is_url: - print("Ingesting URL: %s" % file, flush=True) - elif is_txt: - print("Ingesting Text: %s" % file, flush=True) - else: - print("Ingesting file: %s" % file, flush=True) - res = None - try: - # don't pass base_path=path, would infinitely recurse - res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception, - chunk=chunk, chunk_size=chunk_size, - n_jobs=n_jobs, - is_url=is_url, is_txt=is_txt, - enable_captions=enable_captions, - captions_model=captions_model, - enable_ocr=enable_ocr, - enable_pdf_ocr=enable_pdf_ocr, - caption_loader=caption_loader) - except BaseException as e: - print("Failed to ingest %s due to %s" % (file, traceback.format_exc())) - if fail_any_exception: - raise - else: - exception_doc = Document( - page_content='', - metadata={"source": file, "exception": '%s Exception: %s' % (file, str(e)), - "traceback": traceback.format_exc()}) - res = [exception_doc] - if return_file: - base_tmp = "temp_path_to_doc1" - if not os.path.isdir(base_tmp): - os.makedirs(base_tmp, exist_ok=True) - filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") - with open(filename, 'wb') as f: - pickle.dump(res, f) - return filename - return res - - -def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1, - chunk=True, chunk_size=512, - url=None, text=None, - enable_captions=True, - captions_model=None, - caption_loader=None, - enable_ocr=False, - enable_pdf_ocr='auto', - existing_files=[], - existing_hash_ids={}, - ): - # path_or_paths could be str, list, tuple, generator - globs_image_types = [] - globs_non_image_types = [] - if not path_or_paths and not url and not text: - return [] - elif url: - globs_non_image_types = url if isinstance(url, (list, tuple, types.GeneratorType)) else [url] - elif text: - globs_non_image_types = text if isinstance(text, (list, tuple, types.GeneratorType)) else [text] - elif isinstance(path_or_paths, str) and os.path.isdir(path_or_paths): - # single path, only consume allowed files - path = path_or_paths - # Below globs should match patterns in file_to_doc() - [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) - for ftype in image_types] - [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) - for ftype in non_image_types] - else: - if isinstance(path_or_paths, str): - if os.path.isfile(path_or_paths) or os.path.isdir(path_or_paths): - path_or_paths = [path_or_paths] - else: - # path was deleted etc. - return [] - # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows) - assert isinstance(path_or_paths, (list, tuple, types.GeneratorType)), \ - "Wrong type for path_or_paths: %s %s" % (path_or_paths, type(path_or_paths)) - # reform out of allowed types - globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types])) - # could do below: - # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types]) - # But instead, allow fail so can collect unsupported too - set_globs_image_types = set(globs_image_types) - globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types]) - - # filter out any files to skip (e.g. if already processed them) - # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[] - assert not existing_files, "DEV: assume not using this approach" - if existing_files: - set_skip_files = set(existing_files) - globs_image_types = [x for x in globs_image_types if x not in set_skip_files] - globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files] - if existing_hash_ids: - # assume consistent with add_meta() use of hash_file(file) - # also assume consistent with get_existing_hash_ids for dict creation - # assume hashable values - existing_hash_ids_set = set(existing_hash_ids.items()) - hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items()) - hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items()) - # don't use symmetric diff. If file is gone, ignore and don't remove or something - # just consider existing files (key) having new hash or not (value) - new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys()) - new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys()) - globs_image_types = [x for x in globs_image_types if x in new_files_image] - globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image] - - # could use generator, but messes up metadata handling in recursive case - if caption_loader and not isinstance(caption_loader, (bool, str)) and \ - caption_loader.device != 'cpu' or \ - get_device() == 'cuda': - # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context - n_jobs_image = 1 - else: - n_jobs_image = n_jobs - - return_file = True # local choice - is_url = url is not None - is_txt = text is not None - kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception, - return_file=return_file, - chunk=chunk, chunk_size=chunk_size, - n_jobs=n_jobs, - is_url=is_url, - is_txt=is_txt, - enable_captions=enable_captions, - captions_model=captions_model, - caption_loader=caption_loader, - enable_ocr=enable_ocr, - enable_pdf_ocr=enable_pdf_ocr, - ) - - if n_jobs != 1 and len(globs_non_image_types) > 1: - # avoid nesting, e.g. upload 1 zip and then inside many files - # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib - documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( - delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types - ) - else: - documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_non_image_types)] - - # do images separately since can't fork after cuda in parent, so can't be parallel - if n_jobs_image != 1 and len(globs_image_types) > 1: - # avoid nesting, e.g. upload 1 zip and then inside many files - # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib - image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( - delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types - ) - else: - image_documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types)] - - # add image docs in - documents += image_documents - - if return_file: - # then documents really are files - files = documents.copy() - documents = [] - for fil in files: - with open(fil, 'rb') as f: - documents.extend(pickle.load(f)) - # remove temp pickle - remove(fil) - else: - documents = reduce(concat, documents) - return documents - - -def prep_langchain(persist_directory, - load_db_if_exists, - db_type, use_openai_embedding, langchain_mode, langchain_mode_paths, - hf_embedding_model, n_jobs=-1, kwargs_make_db={}): - """ - do prep first time, involving downloads - # FIXME: Add github caching then add here - :return: - """ - assert langchain_mode not in ['MyData'], "Should not prep scratch data" - - db_dir_exists = os.path.isdir(persist_directory) - user_path = langchain_mode_paths.get(langchain_mode) - - if db_dir_exists and user_path is None: - print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True) - db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, - hf_embedding_model) - else: - if db_dir_exists and user_path is not None: - print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % ( - persist_directory, user_path), flush=True) - elif not db_dir_exists: - print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True) - db = None - if langchain_mode in ['All', 'DriverlessAI docs']: - # FIXME: Could also just use dai_docs.pickle directly and upload that - get_dai_docs(from_hf=True) - - if langchain_mode in ['All', 'wiki']: - get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit']) - - langchain_kwargs = kwargs_make_db.copy() - langchain_kwargs.update(locals()) - db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs) - - return db - - -import posthog - -posthog.disabled = True - - -class FakeConsumer(object): - def __init__(self, *args, **kwargs): - pass - - def run(self): - pass - - def pause(self): - pass - - def upload(self): - pass - - def next(self): - pass - - def request(self, batch): - pass - - -posthog.Consumer = FakeConsumer - - -def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, langchain_mode): - changed_db = False - if load_embed(db) != (use_openai_embedding, hf_embedding_model): - print("Detected new embedding, updating db: %s" % langchain_mode, flush=True) - # handle embedding changes - db_get = get_documents(db) - sources = [Document(page_content=result[0], metadata=result[1] or {}) - for result in zip(db_get['documents'], db_get['metadatas'])] - # delete index, has to be redone - persist_directory = db._persist_directory - shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak") - db_type = 'chroma' - load_db_if_exists = False - db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, - persist_directory=persist_directory, load_db_if_exists=load_db_if_exists, - langchain_mode=langchain_mode, - collection_name=None, - hf_embedding_model=hf_embedding_model) - if False: - # below doesn't work if db already in memory, so have to switch to new db as above - # upsert does new embedding, but if index already in memory, complains about size mismatch etc. - client_collection = db._client.get_collection(name=db._collection.name, - embedding_function=db._collection._embedding_function) - client_collection.upsert(ids=db_get['ids'], metadatas=db_get['metadatas'], documents=db_get['documents']) - changed_db = True - print("Done updating db for new embedding: %s" % langchain_mode, flush=True) - - return db, changed_db - - -def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, - hf_embedding_model, verbose=False, check_embedding=True): - if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir( - os.path.join(persist_directory, 'index')): - if db is None: - if verbose: - print("DO Loading db: %s" % langchain_mode, flush=True) - embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) - from chromadb.config import Settings - client_settings = Settings(anonymized_telemetry=False, - chroma_db_impl="duckdb+parquet", - persist_directory=persist_directory) - db = Chroma(persist_directory=persist_directory, embedding_function=embedding, - collection_name=langchain_mode.replace(' ', '_'), - client_settings=client_settings) - if verbose: - print("DONE Loading db: %s" % langchain_mode, flush=True) - else: - if verbose: - print("USING already-loaded db: %s" % langchain_mode, flush=True) - if check_embedding: - db_trial, changed_db = check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, - langchain_mode) - if changed_db: - db = db_trial - # only call persist if really changed db, else takes too long for large db - if db is not None: - db.persist() - clear_embedding(db) - save_embed(db, use_openai_embedding, hf_embedding_model) - return db - return None - - -def clear_embedding(db): - if db is None: - return - # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed - db._embedding_function.client.cpu() - clear_torch_cache() - - -def make_db(**langchain_kwargs): - func_names = list(inspect.signature(_make_db).parameters) - missing_kwargs = [x for x in func_names if x not in langchain_kwargs] - defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()} - for k in missing_kwargs: - if k in defaults_db: - langchain_kwargs[k] = defaults_db[k] - # final check for missing - missing_kwargs = [x for x in func_names if x not in langchain_kwargs] - assert not missing_kwargs, "Missing kwargs for make_db: %s" % missing_kwargs - # only keep actual used - langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names} - return _make_db(**langchain_kwargs) - - -def save_embed(db, use_openai_embedding, hf_embedding_model): - if db is not None: - embed_info_file = os.path.join(db._persist_directory, 'embed_info') - with open(embed_info_file, 'wb') as f: - pickle.dump((use_openai_embedding, hf_embedding_model), f) - return use_openai_embedding, hf_embedding_model - - -def load_embed(db): - embed_info_file = os.path.join(db._persist_directory, 'embed_info') - if os.path.isfile(embed_info_file): - with open(embed_info_file, 'rb') as f: - use_openai_embedding, hf_embedding_model = pickle.load(f) - else: - # migration, assume defaults - use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2" - return use_openai_embedding, hf_embedding_model - - -def get_persist_directory(langchain_mode): - return 'db_dir_%s' % langchain_mode # single place, no special names for each case - - -def _make_db(use_openai_embedding=False, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - first_para=False, text_limit=None, - chunk=True, chunk_size=512, - langchain_mode=None, - langchain_mode_paths=None, - db_type='faiss', - load_db_if_exists=True, - db=None, - n_jobs=-1, - verbose=False): - persist_directory = get_persist_directory(langchain_mode) - user_path = langchain_mode_paths.get(langchain_mode) - # see if can get persistent chroma db - db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, - hf_embedding_model, verbose=verbose) - if db_trial is not None: - db = db_trial - - sources = [] - if not db: - if langchain_mode in ['wiki_full']: - from read_wiki_full import get_all_documents - small_test = None - print("Generating new wiki", flush=True) - sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2) - print("Got new wiki", flush=True) - if chunk: - sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) - print("Chunked new wiki", flush=True) - sources.extend(sources1) - elif langchain_mode in ['wiki']: - sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit) - if chunk: - sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) - sources.extend(sources1) - elif langchain_mode in ['github h2oGPT']: - # sources = get_github_docs("dagster-io", "dagster") - sources1 = get_github_docs("h2oai", "h2ogpt") - # FIXME: always chunk for now - sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) - sources.extend(sources1) - elif langchain_mode in ['DriverlessAI docs']: - sources1 = get_dai_docs(from_hf=True) - if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit - sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) - sources.extend(sources1) - if user_path: - # UserData or custom, which has to be from user's disk - if db is not None: - # NOTE: Ignore file names for now, only go by hash ids - # existing_files = get_existing_files(db) - existing_files = [] - existing_hash_ids = get_existing_hash_ids(db) - else: - # pretend no existing files so won't filter - existing_files = [] - existing_hash_ids = [] - # chunk internally for speed over multiple docs - # FIXME: If first had old Hash=None and switch embeddings, - # then re-embed, and then hit here and reload so have hash, and then re-embed. - sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size, - existing_files=existing_files, existing_hash_ids=existing_hash_ids) - new_metadata_sources = set([x.metadata['source'] for x in sources1]) - if new_metadata_sources: - print("Loaded %s new files as sources to add to %s" % (len(new_metadata_sources), langchain_mode), - flush=True) - if verbose: - print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True) - sources.extend(sources1) - print("Loaded %s sources for potentially adding to %s" % (len(sources), langchain_mode), flush=True) - - # see if got sources - if not sources: - if verbose: - if db is not None: - print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True) - else: - print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True) - return db, 0, [] - if verbose: - if db is not None: - print("Generating db", flush=True) - else: - print("Adding to db", flush=True) - if not db: - if sources: - db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, - persist_directory=persist_directory, langchain_mode=langchain_mode, - hf_embedding_model=hf_embedding_model) - if verbose: - print("Generated db", flush=True) - else: - print("Did not generate db since no sources", flush=True) - new_sources_metadata = [x.metadata for x in sources] - elif user_path is not None: - print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True) - db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model) - print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True) - else: - new_sources_metadata = [x.metadata for x in sources] - - return db, len(new_sources_metadata), new_sources_metadata - - -def get_metadatas(db): - from langchain.vectorstores import FAISS - if isinstance(db, FAISS): - metadatas = [v.metadata for k, v in db.docstore._dict.items()] - elif isinstance(db, Chroma): - metadatas = get_documents(db)['metadatas'] - else: - # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 - # seems no way to get all metadata, so need to avoid this approach for weaviate - metadatas = [x.metadata for x in db.similarity_search("", k=10000)] - return metadatas - - -def get_documents(db): - if hasattr(db, '_persist_directory'): - name_path = os.path.basename(db._persist_directory) - base_path = 'locks' - makedirs(base_path) - with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)): - # get segfaults and other errors when multiple threads access this - return _get_documents(db) - else: - return _get_documents(db) - - -def _get_documents(db): - from langchain.vectorstores import FAISS - if isinstance(db, FAISS): - documents = [v for k, v in db.docstore._dict.items()] - elif isinstance(db, Chroma): - documents = db.get() - else: - # FIXME: Hack due to https://github.com/weaviate/weaviate/issues/1947 - # seems no way to get all metadata, so need to avoid this approach for weaviate - documents = [x for x in db.similarity_search("", k=10000)] - return documents - - -def get_docs_and_meta(db, top_k_docs, filter_kwargs={}): - if hasattr(db, '_persist_directory'): - name_path = os.path.basename(db._persist_directory) - base_path = 'locks' - makedirs(base_path) - with filelock.FileLock(os.path.join(base_path, "getdb_%s.lock" % name_path)): - return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs) - else: - return _get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs) - - -def _get_docs_and_meta(db, top_k_docs, filter_kwargs={}): - from langchain.vectorstores import FAISS - if isinstance(db, Chroma): - db_get = db._collection.get(where=filter_kwargs.get('filter')) - db_metadatas = db_get['metadatas'] - db_documents = db_get['documents'] - elif isinstance(db, FAISS): - import itertools - db_metadatas = get_metadatas(db) - # FIXME: FAISS has no filter - # slice dict first - db_documents = list(dict(itertools.islice(db.docstore._dict.items(), top_k_docs)).values()) - else: - db_metadatas = get_metadatas(db) - db_documents = get_documents(db) - return db_documents, db_metadatas - - -def get_existing_files(db): - metadatas = get_metadatas(db) - metadata_sources = set([x['source'] for x in metadatas]) - return metadata_sources - - -def get_existing_hash_ids(db): - metadatas = get_metadatas(db) - # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks - metadata_hash_ids = {x['source']: x.get('hashid') for x in metadatas} - return metadata_hash_ids - - -def run_qa_db(**kwargs): - func_names = list(inspect.signature(_run_qa_db).parameters) - # hard-coded defaults - kwargs['answer_with_sources'] = True - kwargs['show_rank'] = False - missing_kwargs = [x for x in func_names if x not in kwargs] - assert not missing_kwargs, "Missing kwargs for run_qa_db: %s" % missing_kwargs - # only keep actual used - kwargs = {k: v for k, v in kwargs.items() if k in func_names} - try: - return _run_qa_db(**kwargs) - finally: - clear_torch_cache() - - -def _run_qa_db(query=None, - iinput=None, - context=None, - use_openai_model=False, use_openai_embedding=False, - first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, - langchain_mode_paths={}, - detect_user_path_changes_every_query=False, - db_type='faiss', - model_name=None, model=None, tokenizer=None, inference_server=None, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - stream_output=False, - prompter=None, - prompt_type=None, - prompt_dict=None, - answer_with_sources=True, - cut_distance=1.64, - add_chat_history_to_context=True, - sanitize_bot_response=False, - show_rank=False, - use_llm_if_no_docs=False, - load_db_if_exists=False, - db=None, - do_sample=False, - temperature=0.1, - top_k=40, - top_p=0.7, - num_beams=1, - max_new_tokens=256, - min_new_tokens=1, - early_stopping=False, - max_time=180, - repetition_penalty=1.0, - num_return_sequences=1, - langchain_mode=None, - langchain_action=None, - langchain_agents=None, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - n_jobs=-1, - verbose=False, - cli=False, - reverse_docs=True, - lora_weights='', - auto_reduce_chunks=True, - max_chunks=100, - ): - """ - - :param query: - :param use_openai_model: - :param use_openai_embedding: - :param first_para: - :param text_limit: - :param top_k_docs: - :param chunk: - :param chunk_size: - :param langchain_mode_paths: dict of langchain_mode -> user path to glob recursively from - :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db - :param model_name: model name, used to switch behaviors - :param model: pre-initialized model, else will make new one - :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None - :param answer_with_sources - :return: - """ - assert langchain_mode_paths is not None - if model is not None: - assert model_name is not None # require so can make decisions - assert query is not None - assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate - if prompter is not None: - prompt_type = prompter.prompt_type - prompt_dict = prompter.prompt_dict - if model is not None: - assert prompt_type is not None - if prompt_type == PromptType.custom.name: - assert prompt_dict is not None # should at least be {} or '' - else: - prompt_dict = '' - assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0 - # pass in context to LLM directly, since already has prompt_type structure - # can't pass through langchain in get_chain() to LLM: https://github.com/hwchase17/langchain/issues/6638 - llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name, - model=model, - tokenizer=tokenizer, - inference_server=inference_server, - stream_output=stream_output, - do_sample=do_sample, - temperature=temperature, - top_k=top_k, - top_p=top_p, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - prompt_type=prompt_type, - prompt_dict=prompt_dict, - prompter=prompter, - context=context if add_chat_history_to_context else '', - iinput=iinput if add_chat_history_to_context else '', - sanitize_bot_response=sanitize_bot_response, - verbose=verbose, - ) - - use_docs_planned = False - scores = [] - chain = None - - if isinstance(document_choice, str): - # support string as well - document_choice = [document_choice] - - func_names = list(inspect.signature(get_chain).parameters) - sim_kwargs = {k: v for k, v in locals().items() if k in func_names} - missing_kwargs = [x for x in func_names if x not in sim_kwargs] - assert not missing_kwargs, "Missing: %s" % missing_kwargs - docs, chain, scores, use_docs_planned, have_any_docs = get_chain(**sim_kwargs) - if document_subset in non_query_commands: - formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs]) - if not formatted_doc_chunks and not use_llm_if_no_docs: - yield "No sources", '' - return - # if no souces, outside gpt_langchain, LLM will be used with '' input - yield formatted_doc_chunks, '' - return - if not use_llm_if_no_docs: - if not docs and langchain_action in [LangChainAction.SUMMARIZE_MAP.value, - LangChainAction.SUMMARIZE_ALL.value, - LangChainAction.SUMMARIZE_REFINE.value]: - ret = 'No relevant documents to summarize.' if have_any_docs else 'No documents to summarize.' - extra = '' - yield ret, extra - return - if not docs and langchain_mode not in [LangChainMode.DISABLED.value, - LangChainMode.LLM.value]: - ret = 'No relevant documents to query.' if have_any_docs else 'No documents to query.' - extra = '' - yield ret, extra - return - - if chain is None and model_name not in non_hf_types: - # here if no docs at all and not HF type - # can only return if HF type - return - - # context stuff similar to used in evaluate() - import torch - device, torch_dtype, context_class = get_device_dtype() - with torch.no_grad(): - have_lora_weights = lora_weights not in [no_lora_str, '', None] - context_class_cast = NullContext if device == 'cpu' or have_lora_weights else torch.autocast - with context_class_cast(device): - if stream_output and streamer: - answer = None - import queue - bucket = queue.Queue() - thread = EThread(target=chain, streamer=streamer, bucket=bucket) - thread.start() - outputs = "" - prompt = None # FIXME - try: - for new_text in streamer: - # print("new_text: %s" % new_text, flush=True) - if bucket.qsize() > 0 or thread.exc: - thread.join() - outputs += new_text - if prompter: # and False: # FIXME: pipeline can already use prompter - output1 = prompter.get_response(outputs, prompt=prompt, - sanitize_bot_response=sanitize_bot_response) - yield output1, '' - else: - yield outputs, '' - except BaseException: - # if any exception, raise that exception if was from thread, first - if thread.exc: - raise thread.exc - raise - finally: - # in case no exception and didn't join with thread yet, then join - if not thread.exc: - answer = thread.join() - # in case raise StopIteration or broke queue loop in streamer, but still have exception - if thread.exc: - raise thread.exc - # FIXME: answer is not string outputs from streamer. How to get actual final output? - # answer = outputs - else: - answer = chain() - - if not use_docs_planned: - ret = answer['output_text'] - extra = '' - yield ret, extra - elif answer is not None: - ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose) - yield ret, extra - return - - -def get_chain(query=None, - iinput=None, - context=None, # FIXME: https://github.com/hwchase17/langchain/issues/6638 - use_openai_model=False, use_openai_embedding=False, - first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, - langchain_mode_paths=None, - detect_user_path_changes_every_query=False, - db_type='faiss', - model_name=None, - inference_server='', - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - prompt_type=None, - prompt_dict=None, - cut_distance=1.1, - add_chat_history_to_context=True, # FIXME: https://github.com/hwchase17/langchain/issues/6638 - load_db_if_exists=False, - db=None, - langchain_mode=None, - langchain_action=None, - langchain_agents=None, - document_subset=DocumentSubset.Relevant.name, - document_choice=[DocumentChoice.ALL.value], - n_jobs=-1, - # beyond run_db_query: - llm=None, - tokenizer=None, - verbose=False, - reverse_docs=True, - - # local - auto_reduce_chunks=True, - max_chunks=100, - ): - assert langchain_agents is not None # should be at least [] - # determine whether use of context out of docs is planned - if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types: - if langchain_mode in ['Disabled', 'LLM']: - use_docs_planned = False - else: - use_docs_planned = True - else: - use_docs_planned = True - - # https://github.com/hwchase17/langchain/issues/1946 - # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid - # Chroma collection MyData contains fewer than 4 elements. - # type logger error - if top_k_docs == -1: - k_db = 1000 if db_type == 'chroma' else 100 - else: - # top_k_docs=100 works ok too - k_db = 1000 if db_type == 'chroma' else top_k_docs - - # FIXME: For All just go over all dbs instead of a separate db for All - if not detect_user_path_changes_every_query and db is not None: - # avoid looking at user_path during similarity search db handling, - # if already have db and not updating from user_path every query - # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was - if langchain_mode_paths is None: - langchain_mode_paths = {} - langchain_mode_paths = langchain_mode_paths.copy() - langchain_mode_paths[langchain_mode] = None - db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model, - first_para=first_para, text_limit=text_limit, - chunk=chunk, - chunk_size=chunk_size, - langchain_mode=langchain_mode, - langchain_mode_paths=langchain_mode_paths, - db_type=db_type, - load_db_if_exists=load_db_if_exists, - db=db, - n_jobs=n_jobs, - verbose=verbose) - have_any_docs = db is not None - if langchain_action == LangChainAction.QUERY.value: - if iinput: - query = "%s\n%s" % (query, iinput) - - if 'falcon' in model_name: - extra = "According to only the information in the document sources provided within the context above, " - prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends." - elif inference_server in ['openai', 'openai_chat']: - extra = "According to (primarily) the information in the document sources provided within context above, " - prefix = "Pay attention and remember information below, which will help to answer the question or imperative after the context ends. If the answer cannot be primarily obtained from information within the context, then respond that the answer does not appear in the context of the documents." - else: - extra = "" - prefix = "" - if langchain_mode in ['Disabled', 'LLM'] or not use_docs_planned: - template_if_no_docs = template = """%s{context}{question}""" % prefix - else: - template = """%s - \"\"\" - {context} - \"\"\" - %s{question}""" % (prefix, extra) - template_if_no_docs = """%s{context}%s{question}""" % (prefix, extra) - elif langchain_action in [LangChainAction.SUMMARIZE_ALL.value, LangChainAction.SUMMARIZE_MAP.value]: - none = ['', '\n', None] - if query in none and iinput in none: - prompt_summary = "Using only the text above, write a condensed and concise summary:\n" - elif query not in none: - prompt_summary = "Focusing on %s, write a condensed and concise Summary:\n" % query - elif iinput not in None: - prompt_summary = iinput - else: - prompt_summary = "Focusing on %s, %s:\n" % (query, iinput) - # don't auto reduce - auto_reduce_chunks = False - if langchain_action == LangChainAction.SUMMARIZE_MAP.value: - fstring = '{text}' - else: - fstring = '{input_documents}' - template = """In order to write a concise single-paragraph or bulleted list summary, pay attention to the following text: -\"\"\" -%s -\"\"\"\n%s""" % (fstring, prompt_summary) - template_if_no_docs = "Exactly only say: There are no documents to summarize." - elif langchain_action in [LangChainAction.SUMMARIZE_REFINE]: - template = '' # unused - template_if_no_docs = '' # unused - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - - if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types: - use_template = True - else: - use_template = False - - if db and use_docs_planned: - base_path = 'locks' - makedirs(base_path) - if hasattr(db, '_persist_directory'): - name_path = "sim_%s.lock" % os.path.basename(db._persist_directory) - else: - name_path = "sim.lock" - lock_file = os.path.join(base_path, name_path) - - if not isinstance(db, Chroma): - # only chroma supports filtering - filter_kwargs = {} - else: - assert document_choice is not None, "Document choice was None" - if len(document_choice) >= 1 and document_choice[0] == DocumentChoice.ALL.value: - filter_kwargs = {} - elif len(document_choice) >= 2: - if document_choice[0] == DocumentChoice.ALL.value: - # remove 'All' - document_choice = document_choice[1:] - or_filter = [{"source": {"$eq": x}} for x in document_choice] - filter_kwargs = dict(filter={"$or": or_filter}) - elif len(document_choice) == 1: - # degenerate UX bug in chroma - one_filter = [{"source": {"$eq": x}} for x in document_choice][0] - filter_kwargs = dict(filter=one_filter) - else: - # shouldn't reach - filter_kwargs = {} - if langchain_mode in [LangChainMode.LLM.value]: - docs = [] - scores = [] - elif document_subset == DocumentSubset.TopKSources.name or query in [None, '', '\n']: - db_documents, db_metadatas = get_docs_and_meta(db, top_k_docs, filter_kwargs=filter_kwargs) - # similar to langchain's chroma's _results_to_docs_and_scores - docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) - for result in zip(db_documents, db_metadatas)] - - # order documents - doc_hashes = [x.get('doc_hash', 'None') for x in db_metadatas] - doc_chunk_ids = [x.get('chunk_id', 0) for x in db_metadatas] - docs_with_score = [x for _, _, x in - sorted(zip(doc_hashes, doc_chunk_ids, docs_with_score), key=lambda x: (x[0], x[1])) - ] - - docs_with_score = docs_with_score[:top_k_docs] - docs = [x[0] for x in docs_with_score] - scores = [x[1] for x in docs_with_score] - have_any_docs |= len(docs) > 0 - else: - # FIXME: if langchain_action == LangChainAction.SUMMARIZE_MAP.value - # if map_reduce, then no need to auto reduce chunks - if top_k_docs == -1 or auto_reduce_chunks: - # docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs] - top_k_docs_tokenize = 100 - with filelock.FileLock(lock_file): - docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[ - :top_k_docs_tokenize] - if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'tokenizer'): - # more accurate - tokens = [len(llm.pipeline.tokenizer(x[0].page_content)['input_ids']) for x in docs_with_score] - template_tokens = len(llm.pipeline.tokenizer(template)['input_ids']) - elif inference_server in ['openai', 'openai_chat'] or use_openai_model or db_type in ['faiss', - 'weaviate']: - # use ticktoken for faiss since embedding called differently - tokens = [llm.get_num_tokens(x[0].page_content) for x in docs_with_score] - template_tokens = llm.get_num_tokens(template) - elif isinstance(tokenizer, FakeTokenizer): - tokens = [tokenizer.num_tokens_from_string(x[0].page_content) for x in docs_with_score] - template_tokens = tokenizer.num_tokens_from_string(template) - else: - # in case model is not our pipeline with HF tokenizer - tokens = [db._embedding_function.client.tokenize([x[0].page_content])['input_ids'].shape[1] for x in - docs_with_score] - template_tokens = db._embedding_function.client.tokenize([template])['input_ids'].shape[1] - tokens_cumsum = np.cumsum(tokens) - if hasattr(llm, 'pipeline') and hasattr(llm.pipeline, 'max_input_tokens'): - max_input_tokens = llm.pipeline.max_input_tokens - elif inference_server in ['openai']: - max_tokens = llm.modelname_to_contextsize(model_name) - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens = max_tokens - 256 - elif inference_server in ['openai_chat']: - max_tokens = model_token_mapping[model_name] - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens = max_tokens - 256 - elif isinstance(tokenizer, FakeTokenizer): - max_input_tokens = tokenizer.model_max_length - 256 - else: - # leave some room for 1 paragraph, even if min_new_tokens=0 - max_input_tokens = 2048 - 256 - max_input_tokens -= template_tokens - # FIXME: Doesn't account for query, == context, or new lines between contexts - where_res = np.where(tokens_cumsum < max_input_tokens)[0] - if where_res.shape[0] == 0: - # then no chunk can fit, still do first one - top_k_docs_trial = 1 - else: - top_k_docs_trial = 1 + where_res[-1] - if 0 < top_k_docs_trial < max_chunks: - # avoid craziness - if top_k_docs == -1: - top_k_docs = top_k_docs_trial - else: - top_k_docs = min(top_k_docs, top_k_docs_trial) - if top_k_docs == -1: - # if here, means 0 and just do best with 1 doc - print("Unexpected large chunks and can't add to context, will add 1 anyways", flush=True) - top_k_docs = 1 - docs_with_score = docs_with_score[:top_k_docs] - else: - with filelock.FileLock(lock_file): - docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs] - # put most relevant chunks closest to question, - # esp. if truncation occurs will be "oldest" or "farthest from response" text that is truncated - # BUT: for small models, e.g. 6_9 pythia, if sees some stuff related to h2oGPT first, it can connect that and not listen to rest - if reverse_docs: - docs_with_score.reverse() - # cut off so no high distance docs/sources considered - have_any_docs |= len(docs_with_score) > 0 # before cut - docs = [x[0] for x in docs_with_score if x[1] < cut_distance] - scores = [x[1] for x in docs_with_score if x[1] < cut_distance] - if len(scores) > 0 and verbose: - print("Distance: min: %s max: %s mean: %s median: %s" % - (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True) - else: - docs = [] - scores = [] - - if not docs and use_docs_planned and model_name not in non_hf_types: - # if HF type and have no docs, can bail out - return docs, None, [], False, have_any_docs - - if document_subset in non_query_commands: - # no LLM use - return docs, None, [], False, have_any_docs - - common_words_file = "data/NGSL_1.2_stats.csv.zip" - if os.path.isfile(common_words_file) and langchain_mode == LangChainAction.QUERY.value: - df = pd.read_csv("data/NGSL_1.2_stats.csv.zip") - import string - reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip() - reduced_query_words = reduced_query.split(' ') - set_common = set(df['Lemma'].values.tolist()) - num_common = len([x.lower() in set_common for x in reduced_query_words]) - frac_common = num_common / len(reduced_query) if reduced_query else 0 - # FIXME: report to user bad query that uses too many common words - if verbose: - print("frac_common: %s" % frac_common, flush=True) - - if len(docs) == 0: - # avoid context == in prompt then - use_docs_planned = False - template = template_if_no_docs - - if langchain_action == LangChainAction.QUERY.value: - if use_template: - # instruct-like, rather than few-shot prompt_type='plain' as default - # but then sources confuse the model with how inserted among rest of text, so avoid - prompt = PromptTemplate( - # input_variables=["summaries", "question"], - input_variables=["context", "question"], - template=template, - ) - chain = load_qa_chain(llm, prompt=prompt) - else: - # only if use_openai_model = True, unused normally except in testing - chain = load_qa_with_sources_chain(llm) - if not use_docs_planned: - chain_kwargs = dict(input_documents=[], question=query) - else: - chain_kwargs = dict(input_documents=docs, question=query) - target = wrapped_partial(chain, chain_kwargs) - elif langchain_action in [LangChainAction.SUMMARIZE_MAP.value, - LangChainAction.SUMMARIZE_REFINE, - LangChainAction.SUMMARIZE_ALL.value]: - from langchain.chains.summarize import load_summarize_chain - if langchain_action == LangChainAction.SUMMARIZE_MAP.value: - prompt = PromptTemplate(input_variables=["text"], template=template) - chain = load_summarize_chain(llm, chain_type="map_reduce", - map_prompt=prompt, combine_prompt=prompt, return_intermediate_steps=True) - target = wrapped_partial(chain, {"input_documents": docs}) # , return_only_outputs=True) - elif langchain_action == LangChainAction.SUMMARIZE_ALL.value: - assert use_template - prompt = PromptTemplate(input_variables=["text"], template=template) - chain = load_summarize_chain(llm, chain_type="stuff", prompt=prompt, return_intermediate_steps=True) - target = wrapped_partial(chain) - elif langchain_action == LangChainAction.SUMMARIZE_REFINE.value: - chain = load_summarize_chain(llm, chain_type="refine", return_intermediate_steps=True) - target = wrapped_partial(chain) - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - else: - raise RuntimeError("No such langchain_action=%s" % langchain_action) - - return docs, target, scores, use_docs_planned, have_any_docs - - -def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False): - if verbose: - print("query: %s" % query, flush=True) - print("answer: %s" % answer['output_text'], flush=True) - - if len(answer['input_documents']) == 0: - extra = '' - ret = answer['output_text'] + extra - return ret, extra - - # link - answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in - zip(scores, answer['input_documents'])] - answer_sources_dict = defaultdict(list) - [answer_sources_dict[url].append(score) for score, url in answer_sources] - answers_dict = {} - for url, scores_url in answer_sources_dict.items(): - answers_dict[url] = np.max(scores_url) - answer_sources = [(score, url) for url, score in answers_dict.items()] - answer_sources.sort(key=lambda x: x[0], reverse=True) - if show_rank: - # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)] - # sorted_sources_urls = "Sources [Rank | Link]:
" + "
".join(answer_sources) - answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)] - sorted_sources_urls = "Ranked Sources:
" + "
".join(answer_sources) - else: - answer_sources = ['

  • %.2g | %s
  • ' % (score, url) for score, url in answer_sources] - sorted_sources_urls = f"{source_prefix}

      " + "

      ".join(answer_sources) - sorted_sources_urls += f"

    {source_postfix}" - - if not answer['output_text'].endswith('\n'): - answer['output_text'] += '\n' - - if answer_with_sources: - extra = '\n' + sorted_sources_urls - else: - extra = '' - ret = answer['output_text'] + extra - return ret, extra - - -def clean_doc(docs1): - if not isinstance(docs1, (list, tuple, types.GeneratorType)): - docs1 = [docs1] - for doci, doc in enumerate(docs1): - docs1[doci].page_content = '\n'.join([x.strip() for x in doc.page_content.split("\n") if x.strip()]) - return docs1 - - -def chunk_sources(sources, chunk=True, chunk_size=512, language=None): - if not chunk: - [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(sources)] - return sources - if not isinstance(sources, (list, tuple, types.GeneratorType)) and not callable(sources): - # if just one document - sources = [sources] - if language and False: - # Bug in langchain, keep separator=True not working - # https://github.com/hwchase17/langchain/issues/2836 - # so avoid this for now - keep_separator = True - separators = RecursiveCharacterTextSplitter.get_separators_for_language(language) - else: - separators = ["\n\n", "\n", " ", ""] - keep_separator = False - splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0, keep_separator=keep_separator, - separators=separators) - source_chunks = splitter.split_documents(sources) - - # currently in order, but when pull from db won't be, so mark order and document by hash - [x.metadata.update(dict(chunk_id=chunk_id)) for chunk_id, x in enumerate(source_chunks)] - - return source_chunks - - -def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'): - from huggingface_hub import hf_hub_download - # True for case when locally already logged in with correct token, so don't have to set key - token = os.getenv('HUGGINGFACE_API_TOKEN', True) - path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset') - import zipfile - with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: - persist_directory = os.path.dirname(zip_ref.namelist()[0]) - remove(persist_directory) - zip_ref.extractall(dest) - return path_to_zip_file - - -# Note dir has space in some cases, while zip does not -some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'], - ['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'], - ['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'], - ['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'], - # ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], - ] - -all_db_zips = some_db_zips + \ - [['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], - ] - - -def get_some_dbs_from_hf(dest='.', db_zips=None): - if db_zips is None: - db_zips = some_db_zips - for db_dir, dir_expected, license1 in db_zips: - path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir) - assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file - if dir_expected: - assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected - assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected - - -def _create_local_weaviate_client(): - WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080") - WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME') - WEAVIATE_PASSWORD = os.getenv('WEAVIATE_PASSWORD') - WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access") - - resource_owner_config = None - try: - import weaviate - if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None: - resource_owner_config = weaviate.AuthClientPassword( - username=WEAVIATE_USERNAME, - password=WEAVIATE_PASSWORD, - scope=WEAVIATE_SCOPE - ) - - client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config) - return client - except Exception as e: - print(f"Failed to create Weaviate client: {e}") - return None - - -if __name__ == '__main__': - pass diff --git a/gradio_runner.py b/gradio_runner.py deleted file mode 100644 index d37c9f211a0830f1566f5adee6e59584e27b8386..0000000000000000000000000000000000000000 --- a/gradio_runner.py +++ /dev/null @@ -1,2933 +0,0 @@ -import ast -import copy -import functools -import inspect -import itertools -import json -import os -import pprint -import random -import shutil -import sys -import time -import traceback -import typing -import uuid -import filelock -import pandas as pd -import requests -import tabulate -from iterators import TimeoutIterator - -from gradio_utils.css import get_css -from gradio_utils.prompt_form import make_chatbots - -# This is a hack to prevent Gradio from phoning home when it gets imported -os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False' - - -def my_get(url, **kwargs): - print('Gradio HTTP request redirected to localhost :)', flush=True) - kwargs.setdefault('allow_redirects', True) - return requests.api.request('get', 'http://127.0.0.1/', **kwargs) - - -original_get = requests.get -requests.get = my_get -import gradio as gr - -requests.get = original_get - - -def fix_pydantic_duplicate_validators_error(): - try: - from pydantic import class_validators - - class_validators.in_ipython = lambda: True # type: ignore[attr-defined] - except ImportError: - pass - - -fix_pydantic_duplicate_validators_error() - -from enums import DocumentSubset, no_model_str, no_lora_str, no_server_str, LangChainAction, LangChainMode, \ - DocumentChoice, langchain_modes_intrinsic -from gradio_themes import H2oTheme, SoftTheme, get_h2o_title, get_simple_title, get_dark_js, spacing_xsm, radius_xsm, \ - text_xsm -from prompter import prompt_type_to_model_name, prompt_types_strings, inv_prompt_type_to_model_lower, non_hf_types, \ - get_prompt -from utils import flatten_list, zip_data, s3up, clear_torch_cache, get_torch_allocated, system_info_print, \ - ping, get_short_name, makedirs, get_kwargs, remove, system_info, ping_gpu, get_url, get_local_ip, \ - save_collection_names -from gen import get_model, languages_covered, evaluate, score_qa, inputs_kwargs_list, scratch_base_dir, \ - get_max_max_new_tokens, get_minmax_top_k_docs, history_to_context, langchain_actions, langchain_agents_list, \ - update_langchain -from evaluate_params import eval_func_param_names, no_default_param_names, eval_func_param_names_defaults, \ - input_args_list - -from apscheduler.schedulers.background import BackgroundScheduler - - -def fix_text_for_gradio(text, fix_new_lines=False, fix_latex_dollars=True): - if fix_latex_dollars: - ts = text.split('```') - for parti, part in enumerate(ts): - inside = parti % 2 == 1 - if not inside: - ts[parti] = ts[parti].replace('$', '﹩') - text = '```'.join(ts) - - if fix_new_lines: - # let Gradio handle code, since got improved recently - ## FIXME: below conflicts with Gradio, but need to see if can handle multiple \n\n\n etc. properly as is. - # ensure good visually, else markdown ignores multiple \n - # handle code blocks - ts = text.split('```') - for parti, part in enumerate(ts): - inside = parti % 2 == 1 - if not inside: - ts[parti] = ts[parti].replace('\n', '
    ') - text = '```'.join(ts) - return text - - -def go_gradio(**kwargs): - allow_api = kwargs['allow_api'] - is_public = kwargs['is_public'] - is_hf = kwargs['is_hf'] - memory_restriction_level = kwargs['memory_restriction_level'] - n_gpus = kwargs['n_gpus'] - admin_pass = kwargs['admin_pass'] - model_states = kwargs['model_states'] - dbs = kwargs['dbs'] - db_type = kwargs['db_type'] - visible_langchain_actions = kwargs['visible_langchain_actions'] - visible_langchain_agents = kwargs['visible_langchain_agents'] - allow_upload_to_user_data = kwargs['allow_upload_to_user_data'] - allow_upload_to_my_data = kwargs['allow_upload_to_my_data'] - enable_sources_list = kwargs['enable_sources_list'] - enable_url_upload = kwargs['enable_url_upload'] - enable_text_upload = kwargs['enable_text_upload'] - use_openai_embedding = kwargs['use_openai_embedding'] - hf_embedding_model = kwargs['hf_embedding_model'] - enable_captions = kwargs['enable_captions'] - captions_model = kwargs['captions_model'] - enable_ocr = kwargs['enable_ocr'] - enable_pdf_ocr = kwargs['enable_pdf_ocr'] - caption_loader = kwargs['caption_loader'] - - # for dynamic state per user session in gradio - model_state0 = kwargs['model_state0'] - score_model_state0 = kwargs['score_model_state0'] - my_db_state0 = kwargs['my_db_state0'] - selection_docs_state0 = kwargs['selection_docs_state0'] - # for evaluate defaults - langchain_modes0 = kwargs['langchain_modes'] - visible_langchain_modes0 = kwargs['visible_langchain_modes'] - langchain_mode_paths0 = kwargs['langchain_mode_paths'] - - # easy update of kwargs needed for evaluate() etc. - queue = True - allow_upload = allow_upload_to_user_data or allow_upload_to_my_data - kwargs.update(locals()) - - # import control - if kwargs['langchain_mode'] != 'Disabled': - from gpt_langchain import file_types, have_arxiv - else: - have_arxiv = False - file_types = [] - - if 'mbart-' in kwargs['model_lower']: - instruction_label_nochat = "Text to translate" - else: - instruction_label_nochat = "Instruction (Shift-Enter or push Submit to send message," \ - " use Enter for multiple input lines)" - - title = 'h2oGPT' - description = """h2oGPT H2O LLM Studio
    🤗 Models""" - description_bottom = "If this host is busy, try
    [Multi-Model](https://gpt.h2o.ai)
    [Falcon 40B](https://falcon.h2o.ai)
    [Vicuna 33B](https://wizardvicuna.h2o.ai)
    [MPT 30B-Chat](https://mpt.h2o.ai)
    [HF Spaces1](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot)
    [HF Spaces2](https://huggingface.co/spaces/h2oai/h2ogpt-chatbot2)
    " - if is_hf: - description_bottom += '''Duplicate Space''' - task_info_md = '' - css_code = get_css(kwargs) - - if kwargs['gradio_offline_level'] >= 0: - # avoid GoogleFont that pulls from internet - if kwargs['gradio_offline_level'] == 1: - # front end would still have to download fonts or have cached it at some point - base_font = 'Source Sans Pro' - else: - base_font = 'Helvetica' - theme_kwargs = dict(font=(base_font, 'ui-sans-serif', 'system-ui', 'sans-serif'), - font_mono=('IBM Plex Mono', 'ui-monospace', 'Consolas', 'monospace')) - else: - theme_kwargs = dict() - if kwargs['gradio_size'] == 'xsmall': - theme_kwargs.update(dict(spacing_size=spacing_xsm, text_size=text_xsm, radius_size=radius_xsm)) - elif kwargs['gradio_size'] in [None, 'small']: - theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_sm, text_size=gr.themes.sizes.text_sm, - radius_size=gr.themes.sizes.spacing_sm)) - elif kwargs['gradio_size'] == 'large': - theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_lg, text_size=gr.themes.sizes.text_lg), - radius_size=gr.themes.sizes.spacing_lg) - elif kwargs['gradio_size'] == 'medium': - theme_kwargs.update(dict(spacing_size=gr.themes.sizes.spacing_md, text_size=gr.themes.sizes.text_md, - radius_size=gr.themes.sizes.spacing_md)) - - theme = H2oTheme(**theme_kwargs) if kwargs['h2ocolors'] else SoftTheme(**theme_kwargs) - demo = gr.Blocks(theme=theme, css=css_code, title="h2oGPT", analytics_enabled=False) - callback = gr.CSVLogger() - - model_options0 = flatten_list(list(prompt_type_to_model_name.values())) + kwargs['extra_model_options'] - if kwargs['base_model'].strip() not in model_options0: - model_options0 = [kwargs['base_model'].strip()] + model_options0 - lora_options = kwargs['extra_lora_options'] - if kwargs['lora_weights'].strip() not in lora_options: - lora_options = [kwargs['lora_weights'].strip()] + lora_options - server_options = kwargs['extra_server_options'] - if kwargs['inference_server'].strip() not in server_options: - server_options = [kwargs['inference_server'].strip()] + server_options - if os.getenv('OPENAI_API_KEY'): - if 'openai_chat' not in server_options: - server_options += ['openai_chat'] - if 'openai' not in server_options: - server_options += ['openai'] - - # always add in no lora case - # add fake space so doesn't go away in gradio dropdown - model_options0 = [no_model_str] + model_options0 - lora_options = [no_lora_str] + lora_options - server_options = [no_server_str] + server_options - # always add in no model case so can free memory - # add fake space so doesn't go away in gradio dropdown - - # transcribe, will be detranscribed before use by evaluate() - if not kwargs['base_model'].strip(): - kwargs['base_model'] = no_model_str - - if not kwargs['lora_weights'].strip(): - kwargs['lora_weights'] = no_lora_str - - if not kwargs['inference_server'].strip(): - kwargs['inference_server'] = no_server_str - - # transcribe for gradio - kwargs['gpu_id'] = str(kwargs['gpu_id']) - - no_model_msg = 'h2oGPT [ !!! Please Load Model in Models Tab !!! ]' - output_label0 = f'h2oGPT [Model: {kwargs.get("base_model")}]' if kwargs.get( - 'base_model') else no_model_msg - output_label0_model2 = no_model_msg - - def update_prompt(prompt_type1, prompt_dict1, model_state1, which_model=0): - if not prompt_type1 or which_model != 0: - # keep prompt_type and prompt_dict in sync if possible - prompt_type1 = kwargs.get('prompt_type', prompt_type1) - prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1) - # prefer model specific prompt type instead of global one - if not prompt_type1 or which_model != 0: - prompt_type1 = model_state1.get('prompt_type', prompt_type1) - prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1) - - if not prompt_dict1 or which_model != 0: - # if still not defined, try to get - prompt_dict1 = kwargs.get('prompt_dict', prompt_dict1) - if not prompt_dict1 or which_model != 0: - prompt_dict1 = model_state1.get('prompt_dict', prompt_dict1) - return prompt_type1, prompt_dict1 - - default_kwargs = {k: kwargs[k] for k in eval_func_param_names_defaults} - # ensure prompt_type consistent with prep_bot(), so nochat API works same way - default_kwargs['prompt_type'], default_kwargs['prompt_dict'] = \ - update_prompt(default_kwargs['prompt_type'], default_kwargs['prompt_dict'], - model_state1=model_state0, which_model=0) - for k in no_default_param_names: - default_kwargs[k] = '' - - def dummy_fun(x): - # need dummy function to block new input from being sent until output is done, - # else gets input_list at time of submit that is old, and shows up as truncated in chatbot - return x - - def allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): - allow = False - allow |= langchain_action1 not in LangChainAction.QUERY.value - allow |= document_subset1 in DocumentSubset.TopKSources.name - if langchain_mode1 in [LangChainMode.LLM.value]: - allow = False - return allow - - with demo: - # avoid actual model/tokenizer here or anything that would be bad to deepcopy - # https://github.com/gradio-app/gradio/issues/3558 - model_state = gr.State( - dict(model='model', tokenizer='tokenizer', device=kwargs['device'], - base_model=kwargs['base_model'], - tokenizer_base_model=kwargs['tokenizer_base_model'], - lora_weights=kwargs['lora_weights'], - inference_server=kwargs['inference_server'], - prompt_type=kwargs['prompt_type'], - prompt_dict=kwargs['prompt_dict'], - ) - ) - - def update_langchain_mode_paths(db1s, selection_docs_state1): - if allow_upload_to_my_data: - selection_docs_state1['langchain_mode_paths'].update({k: None for k in db1s}) - dup = selection_docs_state1['langchain_mode_paths'].copy() - for k, v in dup.items(): - if k not in selection_docs_state1['visible_langchain_modes']: - selection_docs_state1['langchain_mode_paths'].pop(k) - return selection_docs_state1 - - # Setup some gradio states for per-user dynamic state - model_state2 = gr.State(kwargs['model_state_none'].copy()) - model_options_state = gr.State([model_options0]) - lora_options_state = gr.State([lora_options]) - server_options_state = gr.State([server_options]) - my_db_state = gr.State(my_db_state0) - chat_state = gr.State({}) - docs_state00 = kwargs['document_choice'] + [DocumentChoice.ALL.value] - docs_state0 = [] - [docs_state0.append(x) for x in docs_state00 if x not in docs_state0] - docs_state = gr.State(docs_state0) - viewable_docs_state0 = [] - viewable_docs_state = gr.State(viewable_docs_state0) - selection_docs_state0 = update_langchain_mode_paths(my_db_state0, selection_docs_state0) - selection_docs_state = gr.State(selection_docs_state0) - - gr.Markdown(f""" - {get_h2o_title(title, description) if kwargs['h2ocolors'] else get_simple_title(title, description)} - """) - - # go button visible if - base_wanted = kwargs['base_model'] != no_model_str and kwargs['login_mode_if_model0'] - go_btn = gr.Button(value="ENTER", visible=base_wanted, variant="primary") - - nas = ' '.join(['NA'] * len(kwargs['model_states'])) - res_value = "Response Score: NA" if not kwargs[ - 'model_lock'] else "Response Scores: %s" % nas - - if kwargs['langchain_mode'] != LangChainMode.DISABLED.value: - extra_prompt_form = ". For summarization, no query required, just click submit" - else: - extra_prompt_form = "" - if kwargs['input_lines'] > 1: - instruction_label = "Shift-Enter to Submit, Enter for more lines%s" % extra_prompt_form - else: - instruction_label = "Enter to Submit, Shift-Enter for more lines%s" % extra_prompt_form - - def get_langchain_choices(selection_docs_state1): - langchain_modes = selection_docs_state1['langchain_modes'] - visible_langchain_modes = selection_docs_state1['visible_langchain_modes'] - - if is_hf: - # don't show 'wiki' since only usually useful for internal testing at moment - no_show_modes = ['Disabled', 'wiki'] - else: - no_show_modes = ['Disabled'] - allowed_modes = visible_langchain_modes.copy() - # allowed_modes = [x for x in allowed_modes if x in dbs] - allowed_modes += ['LLM'] - if allow_upload_to_my_data and 'MyData' not in allowed_modes: - allowed_modes += ['MyData'] - if allow_upload_to_user_data and 'UserData' not in allowed_modes: - allowed_modes += ['UserData'] - choices = [x for x in langchain_modes if x in allowed_modes and x not in no_show_modes] - return choices - - def get_df_langchain_mode_paths(selection_docs_state1): - langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] - if langchain_mode_paths: - df = pd.DataFrame.from_dict(langchain_mode_paths.items(), orient='columns') - df.columns = ['Collection', 'Path'] - else: - df = pd.DataFrame(None) - return df - - normal_block = gr.Row(visible=not base_wanted, equal_height=False) - with normal_block: - side_bar = gr.Column(elem_id="col_container", scale=1, min_width=100) - with side_bar: - with gr.Accordion("Chats", open=False, visible=True): - radio_chats = gr.Radio(value=None, label="Saved Chats", show_label=False, - visible=True, interactive=True, - type='value') - upload_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload - with gr.Accordion("Upload", open=False, visible=upload_visible): - with gr.Column(): - with gr.Row(equal_height=False): - file_types_str = '[' + ' '.join(file_types) + ' URL ArXiv TEXT' + ']' - fileup_output = gr.File(label=f'Upload {file_types_str}', - show_label=False, - file_types=file_types, - file_count="multiple", - scale=1, - min_width=0, - elem_id="warning", elem_classes="feedback") - fileup_output_text = gr.Textbox(visible=False) - url_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_url_upload - url_label = 'URL/ArXiv' if have_arxiv else 'URL' - url_text = gr.Textbox(label=url_label, - # placeholder="Enter Submits", - max_lines=1, - interactive=True) - text_visible = kwargs['langchain_mode'] != 'Disabled' and allow_upload and enable_text_upload - user_text_text = gr.Textbox(label='Paste Text', - # placeholder="Enter Submits", - interactive=True, - visible=text_visible) - github_textbox = gr.Textbox(label="Github URL", visible=False) # FIXME WIP - database_visible = kwargs['langchain_mode'] != 'Disabled' - with gr.Accordion("Resources", open=False, visible=database_visible): - langchain_choices0 = get_langchain_choices(selection_docs_state0) - langchain_mode = gr.Radio( - langchain_choices0, - value=kwargs['langchain_mode'], - label="Collections", - show_label=True, - visible=kwargs['langchain_mode'] != 'Disabled', - min_width=100) - add_chat_history_to_context = gr.Checkbox(label="Chat History", - value=kwargs['add_chat_history_to_context']) - document_subset = gr.Radio([x.name for x in DocumentSubset], - label="Subset", - value=DocumentSubset.Relevant.name, - interactive=True, - ) - allowed_actions = [x for x in langchain_actions if x in visible_langchain_actions] - langchain_action = gr.Radio( - allowed_actions, - value=allowed_actions[0] if len(allowed_actions) > 0 else None, - label="Action", - visible=True) - allowed_agents = [x for x in langchain_agents_list if x in visible_langchain_agents] - langchain_agents = gr.Dropdown( - langchain_agents_list, - value=kwargs['langchain_agents'], - label="Agents", - multiselect=True, - interactive=True, - visible=False) # WIP - col_tabs = gr.Column(elem_id="col_container", scale=10) - with (col_tabs, gr.Tabs()): - with gr.TabItem("Chat"): - if kwargs['langchain_mode'] == 'Disabled': - text_output_nochat = gr.Textbox(lines=5, label=output_label0, show_copy_button=True, - visible=not kwargs['chat']) - else: - # text looks a bit worse, but HTML links work - text_output_nochat = gr.HTML(label=output_label0, visible=not kwargs['chat']) - with gr.Row(): - # NOCHAT - instruction_nochat = gr.Textbox( - lines=kwargs['input_lines'], - label=instruction_label_nochat, - placeholder=kwargs['placeholder_instruction'], - visible=not kwargs['chat'], - ) - iinput_nochat = gr.Textbox(lines=4, label="Input context for Instruction", - placeholder=kwargs['placeholder_input'], - visible=not kwargs['chat']) - submit_nochat = gr.Button("Submit", size='sm', visible=not kwargs['chat']) - flag_btn_nochat = gr.Button("Flag", size='sm', visible=not kwargs['chat']) - score_text_nochat = gr.Textbox("Response Score: NA", show_label=False, - visible=not kwargs['chat']) - submit_nochat_api = gr.Button("Submit nochat API", visible=False) - inputs_dict_str = gr.Textbox(label='API input for nochat', show_label=False, visible=False) - text_output_nochat_api = gr.Textbox(lines=5, label='API nochat output', visible=False, - show_copy_button=True) - - # CHAT - col_chat = gr.Column(visible=kwargs['chat']) - with col_chat: - with gr.Row(): # elem_id='prompt-form-area'): - with gr.Column(scale=50): - instruction = gr.Textbox( - lines=kwargs['input_lines'], - label='Ask anything', - placeholder=instruction_label, - info=None, - elem_id='prompt-form', - container=True, - ) - submit_buttons = gr.Row(equal_height=False) - with submit_buttons: - mw1 = 50 - mw2 = 50 - with gr.Column(min_width=mw1): - submit = gr.Button(value='Submit', variant='primary', size='sm', - min_width=mw1) - stop_btn = gr.Button(value="Stop", variant='secondary', size='sm', - min_width=mw1) - save_chat_btn = gr.Button("Save", size='sm', min_width=mw1) - with gr.Column(min_width=mw2): - retry_btn = gr.Button("Redo", size='sm', min_width=mw2) - undo = gr.Button("Undo", size='sm', min_width=mw2) - clear_chat_btn = gr.Button(value="Clear", size='sm', min_width=mw2) - text_output, text_output2, text_outputs = make_chatbots(output_label0, output_label0_model2, - **kwargs) - - with gr.Row(): - with gr.Column(visible=kwargs['score_model']): - score_text = gr.Textbox(res_value, - show_label=False, - visible=True) - score_text2 = gr.Textbox("Response Score2: NA", show_label=False, - visible=False and not kwargs['model_lock']) - - with gr.TabItem("Document Selection"): - document_choice = gr.Dropdown(docs_state0, - label="Select Subset of Document(s) %s" % file_types_str, - value=[DocumentChoice.ALL.value], - interactive=True, - multiselect=True, - visible=kwargs['langchain_mode'] != 'Disabled', - ) - sources_visible = kwargs['langchain_mode'] != 'Disabled' and enable_sources_list - with gr.Row(): - with gr.Column(scale=1): - get_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, size='sm', - visible=sources_visible) - show_sources_btn = gr.Button(value="Show Sources from DB", scale=0, size='sm', - visible=sources_visible) - refresh_sources_btn = gr.Button(value="Update DB with new/changed files on disk", scale=0, - size='sm', - visible=sources_visible and allow_upload_to_user_data) - with gr.Column(scale=4): - pass - with gr.Row(): - with gr.Column(scale=1): - visible_add_remove_collection = (allow_upload_to_user_data or - allow_upload_to_my_data) and \ - kwargs['langchain_mode'] != 'Disabled' - add_placeholder = "e.g. UserData2, user_path2 (optional)" \ - if not is_public else "e.g. MyData2" - remove_placeholder = "e.g. UserData2" if not is_public else "e.g. MyData2" - new_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection, - label='Add Collection', - placeholder=add_placeholder, - interactive=True) - remove_langchain_mode_text = gr.Textbox(value="", visible=visible_add_remove_collection, - label='Remove Collection', - placeholder=remove_placeholder, - interactive=True) - load_langchain = gr.Button(value="Load LangChain State", scale=0, size='sm', - visible=allow_upload_to_user_data and - kwargs['langchain_mode'] != 'Disabled') - with gr.Column(scale=1): - df0 = get_df_langchain_mode_paths(selection_docs_state0) - langchain_mode_path_text = gr.Dataframe(value=df0, - visible=visible_add_remove_collection, - label='LangChain Mode-Path', - show_label=False, - interactive=False) - with gr.Column(scale=4): - pass - - sources_row = gr.Row(visible=kwargs['langchain_mode'] != 'Disabled' and enable_sources_list, - equal_height=False) - with sources_row: - with gr.Column(scale=1): - file_source = gr.File(interactive=False, - label="Download File w/Sources") - with gr.Column(scale=2): - sources_text = gr.HTML(label='Sources Added', interactive=False) - - doc_exception_text = gr.Textbox(value="", label='Document Exceptions', - interactive=False, - visible=kwargs['langchain_mode'] != 'Disabled') - with gr.TabItem("Document Viewer"): - with gr.Row(visible=kwargs['langchain_mode'] != 'Disabled'): - with gr.Column(scale=2): - get_viewable_sources_btn = gr.Button(value="Update UI with Document(s) from DB", scale=0, - size='sm', - visible=sources_visible) - view_document_choice = gr.Dropdown(viewable_docs_state0, - label="Select Single Document", - value=None, - interactive=True, - multiselect=False, - visible=True, - ) - with gr.Column(scale=4): - pass - document = 'http://infolab.stanford.edu/pub/papers/google.pdf' - doc_view = gr.HTML(visible=False) - doc_view2 = gr.Dataframe(visible=False) - doc_view3 = gr.JSON(visible=False) - doc_view4 = gr.Markdown(visible=False) - - with gr.TabItem("Chat History"): - with gr.Row(): - with gr.Column(scale=1): - remove_chat_btn = gr.Button(value="Remove Selected Saved Chats", visible=True, size='sm') - flag_btn = gr.Button("Flag Current Chat", size='sm') - export_chats_btn = gr.Button(value="Export Chats to Download", size='sm') - with gr.Column(scale=4): - pass - with gr.Row(): - chats_file = gr.File(interactive=False, label="Download Exported Chats") - chatsup_output = gr.File(label="Upload Chat File(s)", - file_types=['.json'], - file_count='multiple', - elem_id="warning", elem_classes="feedback") - with gr.Row(): - if 'mbart-' in kwargs['model_lower']: - src_lang = gr.Dropdown(list(languages_covered().keys()), - value=kwargs['src_lang'], - label="Input Language") - tgt_lang = gr.Dropdown(list(languages_covered().keys()), - value=kwargs['tgt_lang'], - label="Output Language") - - chat_exception_text = gr.Textbox(value="", visible=True, label='Chat Exceptions', - interactive=False) - with gr.TabItem("Expert"): - with gr.Row(): - with gr.Column(): - stream_output = gr.components.Checkbox(label="Stream output", - value=kwargs['stream_output']) - prompt_type = gr.Dropdown(prompt_types_strings, - value=kwargs['prompt_type'], label="Prompt Type", - visible=not kwargs['model_lock'], - interactive=not is_public, - ) - prompt_type2 = gr.Dropdown(prompt_types_strings, - value=kwargs['prompt_type'], label="Prompt Type Model 2", - visible=False and not kwargs['model_lock'], - interactive=not is_public) - do_sample = gr.Checkbox(label="Sample", - info="Enable sampler, required for use of temperature, top_p, top_k", - value=kwargs['do_sample']) - temperature = gr.Slider(minimum=0.01, maximum=2, - value=kwargs['temperature'], - label="Temperature", - info="Lower is deterministic (but may lead to repeats), Higher more creative (but may lead to hallucinations)") - top_p = gr.Slider(minimum=1e-3, maximum=1.0 - 1e-3, - value=kwargs['top_p'], label="Top p", - info="Cumulative probability of tokens to sample from") - top_k = gr.Slider( - minimum=1, maximum=100, step=1, - value=kwargs['top_k'], label="Top k", - info='Num. tokens to sample from' - ) - # FIXME: https://github.com/h2oai/h2ogpt/issues/106 - if os.getenv('TESTINGFAIL'): - max_beams = 8 if not (memory_restriction_level or is_public) else 1 - else: - max_beams = 1 - num_beams = gr.Slider(minimum=1, maximum=max_beams, step=1, - value=min(max_beams, kwargs['num_beams']), label="Beams", - info="Number of searches for optimal overall probability. " - "Uses more GPU memory/compute", - interactive=False) - max_max_new_tokens = get_max_max_new_tokens(model_state0, **kwargs) - max_new_tokens = gr.Slider( - minimum=1, maximum=max_max_new_tokens, step=1, - value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length", - ) - min_new_tokens = gr.Slider( - minimum=0, maximum=max_max_new_tokens, step=1, - value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length", - ) - max_new_tokens2 = gr.Slider( - minimum=1, maximum=max_max_new_tokens, step=1, - value=min(max_max_new_tokens, kwargs['max_new_tokens']), label="Max output length 2", - visible=False and not kwargs['model_lock'], - ) - min_new_tokens2 = gr.Slider( - minimum=0, maximum=max_max_new_tokens, step=1, - value=min(max_max_new_tokens, kwargs['min_new_tokens']), label="Min output length 2", - visible=False and not kwargs['model_lock'], - ) - early_stopping = gr.Checkbox(label="EarlyStopping", info="Stop early in beam search", - value=kwargs['early_stopping']) - max_time = gr.Slider(minimum=0, maximum=kwargs['max_max_time'], step=1, - value=min(kwargs['max_max_time'], - kwargs['max_time']), label="Max. time", - info="Max. time to search optimal output.") - repetition_penalty = gr.Slider(minimum=0.01, maximum=3.0, - value=kwargs['repetition_penalty'], - label="Repetition Penalty") - num_return_sequences = gr.Slider(minimum=1, maximum=10, step=1, - value=kwargs['num_return_sequences'], - label="Number Returns", info="Must be <= num_beams", - interactive=not is_public) - iinput = gr.Textbox(lines=4, label="Input", - placeholder=kwargs['placeholder_input'], - interactive=not is_public) - context = gr.Textbox(lines=3, label="System Pre-Context", - info="Directly pre-appended without prompt processing", - interactive=not is_public) - chat = gr.components.Checkbox(label="Chat mode", value=kwargs['chat'], - visible=False, # no longer support nochat in UI - interactive=not is_public, - ) - count_chat_tokens_btn = gr.Button(value="Count Chat Tokens", - visible=not is_public and not kwargs['model_lock'], - interactive=not is_public) - chat_token_count = gr.Textbox(label="Chat Token Count", value=None, - visible=not is_public and not kwargs['model_lock'], - interactive=False) - chunk = gr.components.Checkbox(value=kwargs['chunk'], - label="Whether to chunk documents", - info="For LangChain", - visible=kwargs['langchain_mode'] != 'Disabled', - interactive=not is_public) - min_top_k_docs, max_top_k_docs, label_top_k_docs = get_minmax_top_k_docs(is_public) - top_k_docs = gr.Slider(minimum=min_top_k_docs, maximum=max_top_k_docs, step=1, - value=kwargs['top_k_docs'], - label=label_top_k_docs, - info="For LangChain", - visible=kwargs['langchain_mode'] != 'Disabled', - interactive=not is_public) - chunk_size = gr.Number(value=kwargs['chunk_size'], - label="Chunk size for document chunking", - info="For LangChain (ignored if chunk=False)", - minimum=128, - maximum=2048, - visible=kwargs['langchain_mode'] != 'Disabled', - interactive=not is_public, - precision=0) - - with gr.TabItem("Models"): - model_lock_msg = gr.Textbox(lines=1, label="Model Lock Notice", - placeholder="Started in model_lock mode, no model changes allowed.", - visible=bool(kwargs['model_lock']), interactive=False) - load_msg = "Load-Unload Model/LORA [unload works if did not use --base_model]" if not is_public \ - else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO" - load_msg2 = "Load-Unload Model/LORA 2 [unload works if did not use --base_model]" if not is_public \ - else "LOAD-UNLOAD DISABLED FOR HOSTED DEMO 2" - variant_load_msg = 'primary' if not is_public else 'secondary' - compare_checkbox = gr.components.Checkbox(label="Compare Mode", - value=kwargs['model_lock'], - visible=not is_public and not kwargs['model_lock']) - with gr.Row(): - n_gpus_list = [str(x) for x in list(range(-1, n_gpus))] - with gr.Column(): - with gr.Row(): - with gr.Column(scale=20, visible=not kwargs['model_lock']): - model_choice = gr.Dropdown(model_options_state.value[0], label="Choose Model", - value=kwargs['base_model']) - lora_choice = gr.Dropdown(lora_options_state.value[0], label="Choose LORA", - value=kwargs['lora_weights'], visible=kwargs['show_lora']) - server_choice = gr.Dropdown(server_options_state.value[0], label="Choose Server", - value=kwargs['inference_server'], visible=not is_public) - with gr.Column(scale=1, visible=not kwargs['model_lock']): - load_model_button = gr.Button(load_msg, variant=variant_load_msg, scale=0, - size='sm', interactive=not is_public) - model_load8bit_checkbox = gr.components.Checkbox( - label="Load 8-bit [requires support]", - value=kwargs['load_8bit'], interactive=not is_public) - model_use_gpu_id_checkbox = gr.components.Checkbox( - label="Choose Devices [If not Checked, use all GPUs]", - value=kwargs['use_gpu_id'], interactive=not is_public) - model_gpu = gr.Dropdown(n_gpus_list, - label="GPU ID [-1 = all GPUs, if Choose is enabled]", - value=kwargs['gpu_id'], interactive=not is_public) - model_used = gr.Textbox(label="Current Model", value=kwargs['base_model'], - interactive=False) - lora_used = gr.Textbox(label="Current LORA", value=kwargs['lora_weights'], - visible=kwargs['show_lora'], interactive=False) - server_used = gr.Textbox(label="Current Server", - value=kwargs['inference_server'], - visible=bool(kwargs['inference_server']) and not is_public, - interactive=False) - prompt_dict = gr.Textbox(label="Prompt (or Custom)", - value=pprint.pformat(kwargs['prompt_dict'], indent=4), - interactive=not is_public, lines=4) - col_model2 = gr.Column(visible=False) - with col_model2: - with gr.Row(): - with gr.Column(scale=20, visible=not kwargs['model_lock']): - model_choice2 = gr.Dropdown(model_options_state.value[0], label="Choose Model 2", - value=no_model_str) - lora_choice2 = gr.Dropdown(lora_options_state.value[0], label="Choose LORA 2", - value=no_lora_str, - visible=kwargs['show_lora']) - server_choice2 = gr.Dropdown(server_options_state.value[0], label="Choose Server 2", - value=no_server_str, - visible=not is_public) - with gr.Column(scale=1, visible=not kwargs['model_lock']): - load_model_button2 = gr.Button(load_msg2, variant=variant_load_msg, scale=0, - size='sm', interactive=not is_public) - model_load8bit_checkbox2 = gr.components.Checkbox( - label="Load 8-bit 2 [requires support]", - value=kwargs['load_8bit'], interactive=not is_public) - model_use_gpu_id_checkbox2 = gr.components.Checkbox( - label="Choose Devices 2 [If not Checked, use all GPUs]", - value=kwargs[ - 'use_gpu_id'], interactive=not is_public) - model_gpu2 = gr.Dropdown(n_gpus_list, - label="GPU ID 2 [-1 = all GPUs, if choose is enabled]", - value=kwargs['gpu_id'], interactive=not is_public) - # no model/lora loaded ever in model2 by default - model_used2 = gr.Textbox(label="Current Model 2", value=no_model_str, - interactive=False) - lora_used2 = gr.Textbox(label="Current LORA 2", value=no_lora_str, - visible=kwargs['show_lora'], interactive=False) - server_used2 = gr.Textbox(label="Current Server 2", value=no_server_str, - interactive=False, - visible=not is_public) - prompt_dict2 = gr.Textbox(label="Prompt (or Custom) 2", - value=pprint.pformat(kwargs['prompt_dict'], indent=4), - interactive=not is_public, lines=4) - with gr.Row(visible=not kwargs['model_lock']): - with gr.Column(scale=50): - new_model = gr.Textbox(label="New Model name/path", interactive=not is_public) - with gr.Column(scale=50): - new_lora = gr.Textbox(label="New LORA name/path", visible=kwargs['show_lora'], - interactive=not is_public) - with gr.Column(scale=50): - new_server = gr.Textbox(label="New Server url:port", interactive=not is_public) - with gr.Row(): - add_model_lora_server_button = gr.Button("Add new Model, Lora, Server url:port", scale=0, - size='sm', interactive=not is_public) - with gr.TabItem("System"): - with gr.Row(): - with gr.Column(scale=1): - side_bar_text = gr.Textbox('on', visible=False, interactive=False) - submit_buttons_text = gr.Textbox('on', visible=False, interactive=False) - - side_bar_btn = gr.Button("Toggle SideBar", variant="secondary", size="sm") - submit_buttons_btn = gr.Button("Toggle Submit Buttons", variant="secondary", size="sm") - col_tabs_scale = gr.Slider(minimum=1, maximum=20, value=10, step=1, label='Window Size') - text_outputs_height = gr.Slider(minimum=100, maximum=2000, value=kwargs['height'] or 400, - step=50, label='Chat Height') - dark_mode_btn = gr.Button("Dark Mode", variant="secondary", size="sm") - with gr.Column(scale=4): - pass - system_visible0 = not is_public and not admin_pass - admin_row = gr.Row() - with admin_row: - with gr.Column(scale=1): - admin_pass_textbox = gr.Textbox(label="Admin Password", type='password', - visible=not system_visible0) - with gr.Column(scale=4): - pass - system_row = gr.Row(visible=system_visible0) - with system_row: - with gr.Column(): - with gr.Row(): - system_btn = gr.Button(value='Get System Info', size='sm') - system_text = gr.Textbox(label='System Info', interactive=False, show_copy_button=True) - with gr.Row(): - system_input = gr.Textbox(label='System Info Dict Password', interactive=True, - visible=not is_public) - system_btn2 = gr.Button(value='Get System Info Dict', visible=not is_public, size='sm') - system_text2 = gr.Textbox(label='System Info Dict', interactive=False, - visible=not is_public, show_copy_button=True) - with gr.Row(): - system_btn3 = gr.Button(value='Get Hash', visible=not is_public, size='sm') - system_text3 = gr.Textbox(label='Hash', interactive=False, - visible=not is_public, show_copy_button=True) - - with gr.Row(): - zip_btn = gr.Button("Zip", size='sm') - zip_text = gr.Textbox(label="Zip file name", interactive=False) - file_output = gr.File(interactive=False, label="Zip file to Download") - with gr.Row(): - s3up_btn = gr.Button("S3UP", size='sm') - s3up_text = gr.Textbox(label='S3UP result', interactive=False) - - with gr.TabItem("Terms of Service"): - description = "" - description += """

    DISCLAIMERS:

    • The model was trained on The Pile and other data, which may contain objectionable content. Use at own risk.
    • """ - if kwargs['load_8bit']: - description += """
    • Model is loaded in 8-bit and has other restrictions on this host. UX can be worse than non-hosted version.
    • """ - description += """
    • Conversations may be used to improve h2oGPT. Do not share sensitive information.
    • """ - if 'h2ogpt-research' in kwargs['base_model']: - description += """
    • Research demonstration only, not used for commercial purposes.
    • """ - description += """
    • By using h2oGPT, you accept our Terms of Service

    """ - gr.Markdown(value=description, show_label=False, interactive=False) - - with gr.TabItem("Hosts"): - gr.Markdown(f""" - {description_bottom} - {task_info_md} - """) - - # Get flagged data - zip_data1 = functools.partial(zip_data, root_dirs=['flagged_data_points', kwargs['save_dir']]) - zip_event = zip_btn.click(zip_data1, inputs=None, outputs=[file_output, zip_text], queue=False, - api_name='zip_data' if allow_api else None) - s3up_event = s3up_btn.click(s3up, inputs=zip_text, outputs=s3up_text, queue=False, - api_name='s3up_data' if allow_api else None) - - def clear_file_list(): - return None - - def make_non_interactive(*args): - if len(args) == 1: - return gr.update(interactive=False) - else: - return tuple([gr.update(interactive=False)] * len(args)) - - def make_interactive(*args): - if len(args) == 1: - return gr.update(interactive=True) - else: - return tuple([gr.update(interactive=True)] * len(args)) - - # Add to UserData or custom user db - update_db_func = functools.partial(update_user_db, - dbs=dbs, - db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model, - captions_model=captions_model, - enable_captions=enable_captions, - caption_loader=caption_loader, - enable_ocr=enable_ocr, - enable_pdf_ocr=enable_pdf_ocr, - verbose=kwargs['verbose'], - n_jobs=kwargs['n_jobs'], - ) - add_file_outputs = [fileup_output, langchain_mode] - add_file_kwargs = dict(fn=update_db_func, - inputs=[fileup_output, my_db_state, selection_docs_state, chunk, chunk_size, - langchain_mode], - outputs=add_file_outputs + [sources_text, doc_exception_text], - queue=queue, - api_name='add_file' if allow_api and allow_upload_to_user_data else None) - - # then no need for add buttons, only single changeable db - eventdb1a = fileup_output.upload(make_non_interactive, inputs=add_file_outputs, outputs=add_file_outputs, - show_progress='minimal') - eventdb1 = eventdb1a.then(**add_file_kwargs, show_progress='full') - eventdb1b = eventdb1.then(make_interactive, inputs=add_file_outputs, outputs=add_file_outputs, - show_progress='minimal') - - # deal with challenge to have fileup_output itself as input - add_file_kwargs2 = dict(fn=update_db_func, - inputs=[fileup_output_text, my_db_state, selection_docs_state, chunk, chunk_size, - langchain_mode], - outputs=add_file_outputs + [sources_text, doc_exception_text], - queue=queue, - api_name='add_file_api' if allow_api and allow_upload_to_user_data else None) - eventdb1_api = fileup_output_text.submit(**add_file_kwargs2, show_progress='full') - - # note for update_user_db_func output is ignored for db - - def clear_textbox(): - return gr.Textbox.update(value='') - - update_user_db_url_func = functools.partial(update_db_func, is_url=True) - - add_url_outputs = [url_text, langchain_mode] - add_url_kwargs = dict(fn=update_user_db_url_func, - inputs=[url_text, my_db_state, selection_docs_state, chunk, chunk_size, - langchain_mode], - outputs=add_url_outputs + [sources_text, doc_exception_text], - queue=queue, - api_name='add_url' if allow_api and allow_upload_to_user_data else None) - - eventdb2a = url_text.submit(fn=dummy_fun, inputs=url_text, outputs=url_text, queue=queue, - show_progress='minimal') - # work around https://github.com/gradio-app/gradio/issues/4733 - eventdb2b = eventdb2a.then(make_non_interactive, inputs=add_url_outputs, outputs=add_url_outputs, - show_progress='minimal') - eventdb2 = eventdb2b.then(**add_url_kwargs, show_progress='full') - eventdb2c = eventdb2.then(make_interactive, inputs=add_url_outputs, outputs=add_url_outputs, - show_progress='minimal') - - update_user_db_txt_func = functools.partial(update_db_func, is_txt=True) - add_text_outputs = [user_text_text, langchain_mode] - add_text_kwargs = dict(fn=update_user_db_txt_func, - inputs=[user_text_text, my_db_state, selection_docs_state, chunk, chunk_size, - langchain_mode], - outputs=add_text_outputs + [sources_text, doc_exception_text], - queue=queue, - api_name='add_text' if allow_api and allow_upload_to_user_data else None - ) - eventdb3a = user_text_text.submit(fn=dummy_fun, inputs=user_text_text, outputs=user_text_text, queue=queue, - show_progress='minimal') - eventdb3b = eventdb3a.then(make_non_interactive, inputs=add_text_outputs, outputs=add_text_outputs, - show_progress='minimal') - eventdb3 = eventdb3b.then(**add_text_kwargs, show_progress='full') - eventdb3c = eventdb3.then(make_interactive, inputs=add_text_outputs, outputs=add_text_outputs, - show_progress='minimal') - db_events = [eventdb1a, eventdb1, eventdb1b, eventdb1_api, - eventdb2a, eventdb2, eventdb2b, eventdb2c, - eventdb3a, eventdb3b, eventdb3, eventdb3c] - - get_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=docs_state0) - - # if change collection source, must clear doc selections from it to avoid inconsistency - def clear_doc_choice(): - return gr.Dropdown.update(choices=docs_state0, value=DocumentChoice.ALL.value) - - langchain_mode.change(clear_doc_choice, inputs=None, outputs=document_choice, queue=False) - - def resize_col_tabs(x): - return gr.Dropdown.update(scale=x) - - col_tabs_scale.change(fn=resize_col_tabs, inputs=col_tabs_scale, outputs=col_tabs, queue=False) - - def resize_chatbots(x, num_model_lock=0): - if num_model_lock == 0: - num_model_lock = 3 # 2 + 1 (which is dup of first) - else: - num_model_lock = 2 + num_model_lock - return tuple([gr.update(height=x)] * num_model_lock) - - resize_chatbots_func = functools.partial(resize_chatbots, num_model_lock=len(text_outputs)) - text_outputs_height.change(fn=resize_chatbots_func, inputs=text_outputs_height, - outputs=[text_output, text_output2] + text_outputs, queue=False) - - def update_dropdown(x): - return gr.Dropdown.update(choices=x, value=[docs_state0[0]]) - - get_sources_args = dict(fn=get_sources1, inputs=[my_db_state, langchain_mode], - outputs=[file_source, docs_state], - queue=queue, - api_name='get_sources' if allow_api else None) - - eventdb7 = get_sources_btn.click(**get_sources_args) \ - .then(fn=update_dropdown, inputs=docs_state, outputs=document_choice) - # show button, else only show when add. Could add to above get_sources for download/dropdown, but bit much maybe - show_sources1 = functools.partial(get_source_files_given_langchain_mode, dbs=dbs) - eventdb8 = show_sources_btn.click(fn=show_sources1, inputs=[my_db_state, langchain_mode], outputs=sources_text, - api_name='show_sources' if allow_api else None) - - def update_viewable_dropdown(x): - return gr.Dropdown.update(choices=x, - value=viewable_docs_state0[0] if len(viewable_docs_state0) > 0 else None) - - get_viewable_sources1 = functools.partial(get_sources, dbs=dbs, docs_state0=viewable_docs_state0) - get_viewable_sources_args = dict(fn=get_viewable_sources1, inputs=[my_db_state, langchain_mode], - outputs=[file_source, viewable_docs_state], - queue=queue, - api_name='get_viewable_sources' if allow_api else None) - eventdb12 = get_viewable_sources_btn.click(**get_viewable_sources_args) \ - .then(fn=update_viewable_dropdown, inputs=viewable_docs_state, - outputs=view_document_choice) - - def show_doc(file): - dummy1 = gr.update(visible=False, value=None) - dummy_ret = dummy1, dummy1, dummy1, dummy1 - if not isinstance(file, str): - return dummy_ret - - if file.endswith('.md'): - try: - with open(file, 'rt') as f: - content = f.read() - return dummy1, dummy1, dummy1, gr.update(visible=True, value=content) - except: - return dummy_ret - - if file.endswith('.py'): - try: - with open(file, 'rt') as f: - content = f.read() - content = f"```python\n{content}\n```" - return dummy1, dummy1, dummy1, gr.update(visible=True, value=content) - except: - return dummy_ret - - if file.endswith('.txt') or file.endswith('.rst') or file.endswith('.rtf') or file.endswith('.toml'): - try: - with open(file, 'rt') as f: - content = f.read() - content = f"```text\n{content}\n```" - return dummy1, dummy1, dummy1, gr.update(visible=True, value=content) - except: - return dummy_ret - - func = None - if file.endswith(".csv"): - func = pd.read_csv - elif file.endswith(".pickle"): - func = pd.read_pickle - elif file.endswith(".xls") or file.endswith("xlsx"): - func = pd.read_excel - elif file.endswith('.json'): - func = pd.read_json - elif file.endswith('.xml'): - func = pd.read_xml - if func is not None: - try: - df = func(file).head(100) - except: - return dummy_ret - return dummy1, gr.update(visible=True, value=df), dummy1, dummy1 - port = int(os.getenv('GRADIO_SERVER_PORT', '7860')) - import pathlib - absolute_path_string = os.path.abspath(file) - url_path = pathlib.Path(absolute_path_string).as_uri() - url = get_url(absolute_path_string, from_str=True) - img_url = url.replace(""" - -"""), dummy1, dummy1, dummy1 - else: - ip = get_local_ip() - document1 = url_path.replace('file://', f'http://{ip}:{port}/') - # document1 = url - return gr.update(visible=True, value=f""" - -"""), dummy1, dummy1, dummy1 - else: - return dummy_ret - - view_document_choice.select(fn=show_doc, inputs=view_document_choice, - outputs=[doc_view, doc_view2, doc_view3, doc_view4]) - - # Get inputs to evaluate() and make_db() - # don't deepcopy, can contain model itself - all_kwargs = kwargs.copy() - all_kwargs.update(locals()) - - refresh_sources1 = functools.partial(update_and_get_source_files_given_langchain_mode, - **get_kwargs(update_and_get_source_files_given_langchain_mode, - exclude_names=['db1s', 'langchain_mode', 'chunk', - 'chunk_size'], - **all_kwargs)) - eventdb9 = refresh_sources_btn.click(fn=refresh_sources1, - inputs=[my_db_state, langchain_mode, chunk, chunk_size], - outputs=sources_text, - api_name='refresh_sources' if allow_api else None) - - def check_admin_pass(x): - return gr.update(visible=x == admin_pass) - - def close_admin(x): - return gr.update(visible=not (x == admin_pass)) - - admin_pass_textbox.submit(check_admin_pass, inputs=admin_pass_textbox, outputs=system_row, queue=False) \ - .then(close_admin, inputs=admin_pass_textbox, outputs=admin_row, queue=False) - - def add_langchain_mode(db1s, selection_docs_state1, langchain_mode1, y): - for k in db1s: - set_userid(db1s[k]) - langchain_modes = selection_docs_state1['langchain_modes'] - langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] - visible_langchain_modes = selection_docs_state1['visible_langchain_modes'] - - user_path = None - valid = True - y2 = y.strip().replace(' ', '').split(',') - if len(y2) >= 1: - langchain_mode2 = y2[0] - if len(langchain_mode2) >= 3 and langchain_mode2.isalnum(): - # real restriction is: - # ValueError: Expected collection name that (1) contains 3-63 characters, (2) starts and ends with an alphanumeric character, (3) otherwise contains only alphanumeric characters, underscores or hyphens (-), (4) contains no two consecutive periods (..) and (5) is not a valid IPv4 address, got me - # but just make simpler - user_path = y2[1] if len(y2) > 1 else None # assume scratch if don't have user_path - if user_path in ['', "''"]: - # for scratch spaces - user_path = None - if langchain_mode2 in langchain_modes_intrinsic: - user_path = None - textbox = "Invalid access to use internal name: %s" % langchain_mode2 - valid = False - langchain_mode2 = langchain_mode1 - elif user_path and allow_upload_to_user_data or not user_path and allow_upload_to_my_data: - langchain_mode_paths.update({langchain_mode2: user_path}) - if langchain_mode2 not in visible_langchain_modes: - visible_langchain_modes.append(langchain_mode2) - if langchain_mode2 not in langchain_modes: - langchain_modes.append(langchain_mode2) - textbox = '' - if user_path: - makedirs(user_path, exist_ok=True) - else: - valid = False - langchain_mode2 = langchain_mode1 - textbox = "Invalid access. user allowed: %s " \ - "scratch allowed: %s" % (allow_upload_to_user_data, allow_upload_to_my_data) - else: - valid = False - langchain_mode2 = langchain_mode1 - textbox = "Invalid, collection must be >=3 characters and alphanumeric" - else: - valid = False - langchain_mode2 = langchain_mode1 - textbox = "Invalid, must be like UserData2, user_path2" - selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1) - df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1) - choices = get_langchain_choices(selection_docs_state1) - - if valid and not user_path: - # needs to have key for it to make it known different from userdata case in _update_user_db() - db1s[langchain_mode2] = [None, None] - if valid: - save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, - db1s) - - return db1s, selection_docs_state1, gr.update(choices=choices, - value=langchain_mode2), textbox, df_langchain_mode_paths1 - - def remove_langchain_mode(db1s, selection_docs_state1, langchain_mode1, langchain_mode2, dbsu=None): - for k in db1s: - set_userid(db1s[k]) - assert dbsu is not None - langchain_modes = selection_docs_state1['langchain_modes'] - langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] - visible_langchain_modes = selection_docs_state1['visible_langchain_modes'] - - if langchain_mode2 in db1s and not allow_upload_to_my_data or \ - dbsu is not None and langchain_mode2 in dbsu and not allow_upload_to_user_data or \ - langchain_mode2 in langchain_modes_intrinsic: - # NOTE: Doesn't fail if remove MyData, but didn't debug odd behavior seen with upload after gone - textbox = "Invalid access, cannot remove %s" % langchain_mode2 - df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1) - else: - # change global variables - if langchain_mode2 in visible_langchain_modes: - visible_langchain_modes.remove(langchain_mode2) - textbox = "" - else: - textbox = "%s was not visible" % langchain_mode2 - if langchain_mode2 in langchain_modes: - langchain_modes.remove(langchain_mode2) - if langchain_mode2 in langchain_mode_paths: - langchain_mode_paths.pop(langchain_mode2) - if langchain_mode2 in db1s: - # remove db entirely, so not in list, else need to manage visible list in update_langchain_mode_paths() - # FIXME: Remove location? - if langchain_mode2 != LangChainMode.MY_DATA.value: - # don't remove last MyData, used as user hash - db1s.pop(langchain_mode2) - # only show - selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1) - df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1) - - save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, - db1s) - - return db1s, selection_docs_state1, \ - gr.update(choices=get_langchain_choices(selection_docs_state1), - value=langchain_mode2), textbox, df_langchain_mode_paths1 - - new_langchain_mode_text.submit(fn=add_langchain_mode, - inputs=[my_db_state, selection_docs_state, langchain_mode, - new_langchain_mode_text], - outputs=[my_db_state, selection_docs_state, langchain_mode, - new_langchain_mode_text, - langchain_mode_path_text], - api_name='new_langchain_mode_text' if allow_api and allow_upload_to_user_data else None) - remove_langchain_mode_func = functools.partial(remove_langchain_mode, dbsu=dbs) - remove_langchain_mode_text.submit(fn=remove_langchain_mode_func, - inputs=[my_db_state, selection_docs_state, langchain_mode, - remove_langchain_mode_text], - outputs=[my_db_state, selection_docs_state, langchain_mode, - remove_langchain_mode_text, - langchain_mode_path_text], - api_name='remove_langchain_mode_text' if allow_api and allow_upload_to_user_data else None) - - def update_langchain_gr(db1s, selection_docs_state1, langchain_mode1): - for k in db1s: - set_userid(db1s[k]) - langchain_modes = selection_docs_state1['langchain_modes'] - langchain_mode_paths = selection_docs_state1['langchain_mode_paths'] - visible_langchain_modes = selection_docs_state1['visible_langchain_modes'] - # in-place - - # update user collaborative collections - update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, '') - # update scratch single-user collections - user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1] - update_langchain(langchain_modes, visible_langchain_modes, langchain_mode_paths, user_hash) - - selection_docs_state1 = update_langchain_mode_paths(db1s, selection_docs_state1) - df_langchain_mode_paths1 = get_df_langchain_mode_paths(selection_docs_state1) - return selection_docs_state1, \ - gr.update(choices=get_langchain_choices(selection_docs_state1), - value=langchain_mode1), df_langchain_mode_paths1 - - load_langchain.click(fn=update_langchain_gr, - inputs=[my_db_state, selection_docs_state, langchain_mode], - outputs=[selection_docs_state, langchain_mode, langchain_mode_path_text], - api_name='load_langchain' if allow_api and allow_upload_to_user_data else None) - - inputs_list, inputs_dict = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=1) - inputs_list2, inputs_dict2 = get_inputs_list(all_kwargs, kwargs['model_lower'], model_id=2) - from functools import partial - kwargs_evaluate = {k: v for k, v in all_kwargs.items() if k in inputs_kwargs_list} - # ensure present - for k in inputs_kwargs_list: - assert k in kwargs_evaluate, "Missing %s" % k - - def evaluate_nochat(*args1, default_kwargs1=None, str_api=False, **kwargs1): - args_list = list(args1) - if str_api: - user_kwargs = args_list[len(input_args_list)] - assert isinstance(user_kwargs, str) - user_kwargs = ast.literal_eval(user_kwargs) - else: - user_kwargs = {k: v for k, v in zip(eval_func_param_names, args_list[len(input_args_list):])} - # only used for submit_nochat_api - user_kwargs['chat'] = False - if 'stream_output' not in user_kwargs: - user_kwargs['stream_output'] = False - if 'langchain_mode' not in user_kwargs: - # if user doesn't specify, then assume disabled, not use default - user_kwargs['langchain_mode'] = 'Disabled' - if 'langchain_action' not in user_kwargs: - user_kwargs['langchain_action'] = LangChainAction.QUERY.value - if 'langchain_agents' not in user_kwargs: - user_kwargs['langchain_agents'] = [] - - set1 = set(list(default_kwargs1.keys())) - set2 = set(eval_func_param_names) - assert set1 == set2, "Set diff: %s %s: %s" % (set1, set2, set1.symmetric_difference(set2)) - # correct ordering. Note some things may not be in default_kwargs, so can't be default of user_kwargs.get() - model_state1 = args_list[0] - my_db_state1 = args_list[1] - selection_docs_state1 = args_list[2] - args_list = [user_kwargs[k] if k in user_kwargs and user_kwargs[k] is not None else default_kwargs1[k] for k - in eval_func_param_names] - assert len(args_list) == len(eval_func_param_names) - args_list = [model_state1, my_db_state1, selection_docs_state1] + args_list - - try: - for res_dict in evaluate(*tuple(args_list), **kwargs1): - if str_api: - # full return of dict - yield res_dict - elif kwargs['langchain_mode'] == 'Disabled': - yield fix_text_for_gradio(res_dict['response']) - else: - yield '
    ' + fix_text_for_gradio(res_dict['response']) - finally: - clear_torch_cache() - clear_embeddings(user_kwargs['langchain_mode'], my_db_state1) - - fun = partial(evaluate_nochat, - default_kwargs1=default_kwargs, - str_api=False, - **kwargs_evaluate) - fun2 = partial(evaluate_nochat, - default_kwargs1=default_kwargs, - str_api=False, - **kwargs_evaluate) - fun_with_dict_str = partial(evaluate_nochat, - default_kwargs1=default_kwargs, - str_api=True, - **kwargs_evaluate - ) - - dark_mode_btn.click( - None, - None, - None, - _js=get_dark_js(), - api_name="dark" if allow_api else None, - queue=False, - ) - - def visible_toggle(x): - x = 'off' if x == 'on' else 'on' - return x, gr.Column.update(visible=True if x == 'on' else False) - - side_bar_btn.click(fn=visible_toggle, - inputs=side_bar_text, - outputs=[side_bar_text, side_bar], - queue=False) - - submit_buttons_btn.click(fn=visible_toggle, - inputs=submit_buttons_text, - outputs=[submit_buttons_text, submit_buttons], - queue=False) - - # examples after submit or any other buttons for chat or no chat - if kwargs['examples'] is not None and kwargs['show_examples']: - gr.Examples(examples=kwargs['examples'], inputs=inputs_list) - - # Score - def score_last_response(*args, nochat=False, num_model_lock=0): - try: - if num_model_lock > 0: - # then lock way - args_list = list(args).copy() - outputs = args_list[-num_model_lock:] - score_texts1 = [] - for output in outputs: - # same input, put into form good for _score_last_response() - args_list[-1] = output - score_texts1.append( - _score_last_response(*tuple(args_list), nochat=nochat, - num_model_lock=num_model_lock, prefix='')) - if len(score_texts1) > 1: - return "Response Scores: %s" % ' '.join(score_texts1) - else: - return "Response Scores: %s" % score_texts1[0] - else: - return _score_last_response(*args, nochat=nochat, num_model_lock=num_model_lock) - finally: - clear_torch_cache() - - def _score_last_response(*args, nochat=False, num_model_lock=0, prefix='Response Score: '): - """ Similar to user() """ - args_list = list(args) - smodel = score_model_state0['model'] - stokenizer = score_model_state0['tokenizer'] - sdevice = score_model_state0['device'] - - if memory_restriction_level > 0: - max_length_tokenize = 768 - 256 if memory_restriction_level <= 2 else 512 - 256 - elif hasattr(stokenizer, 'model_max_length'): - max_length_tokenize = stokenizer.model_max_length - else: - # limit to 1024, not worth OOMing on reward score - max_length_tokenize = 2048 - 1024 - cutoff_len = max_length_tokenize * 4 # restrict deberta related to max for LLM - - if not nochat: - history = args_list[-1] - if history is None: - history = [] - if smodel is not None and \ - stokenizer is not None and \ - sdevice is not None and \ - history is not None and len(history) > 0 and \ - history[-1] is not None and \ - len(history[-1]) >= 2: - os.environ['TOKENIZERS_PARALLELISM'] = 'false' - - question = history[-1][0] - - answer = history[-1][1] - else: - return '%sNA' % prefix - else: - answer = args_list[-1] - instruction_nochat_arg_id = eval_func_param_names.index('instruction_nochat') - question = args_list[instruction_nochat_arg_id] - - if question is None: - return '%sBad Question' % prefix - if answer is None: - return '%sBad Answer' % prefix - try: - score = score_qa(smodel, stokenizer, max_length_tokenize, question, answer, cutoff_len) - finally: - clear_torch_cache() - if isinstance(score, str): - return '%sNA' % prefix - return '{}{:.1%}'.format(prefix, score) - - def noop_score_last_response(*args, **kwargs): - return "Response Score: Disabled" - - if kwargs['score_model']: - score_fun = score_last_response - else: - score_fun = noop_score_last_response - - score_args = dict(fn=score_fun, - inputs=inputs_list + [text_output], - outputs=[score_text], - ) - score_args2 = dict(fn=partial(score_fun), - inputs=inputs_list2 + [text_output2], - outputs=[score_text2], - ) - score_fun_func = functools.partial(score_fun, num_model_lock=len(text_outputs)) - all_score_args = dict(fn=score_fun_func, - inputs=inputs_list + text_outputs, - outputs=score_text, - ) - - score_args_nochat = dict(fn=partial(score_fun, nochat=True), - inputs=inputs_list + [text_output_nochat], - outputs=[score_text_nochat], - ) - - def update_history(*args, undo=False, retry=False, sanitize_user_prompt=False): - """ - User that fills history for bot - :param args: - :param undo: - :param retry: - :param sanitize_user_prompt: - :return: - """ - args_list = list(args) - user_message = args_list[eval_func_param_names.index('instruction')] # chat only - input1 = args_list[eval_func_param_names.index('iinput')] # chat only - prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] - langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] - langchain_action1 = args_list[eval_func_param_names.index('langchain_action')] - langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')] - document_subset1 = args_list[eval_func_param_names.index('document_subset')] - document_choice1 = args_list[eval_func_param_names.index('document_choice')] - if not prompt_type1: - # shouldn't have to specify if CLI launched model - prompt_type1 = kwargs['prompt_type'] - # apply back - args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 - if input1 and not user_message.endswith(':'): - user_message1 = user_message + ":" + input1 - elif input1: - user_message1 = user_message + input1 - else: - user_message1 = user_message - if sanitize_user_prompt: - from better_profanity import profanity - user_message1 = profanity.censor(user_message1) - - history = args_list[-1] - if history is None: - # bad history - history = [] - history = history.copy() - - if undo: - if len(history) > 0: - history.pop() - return history - if retry: - if history: - history[-1][1] = None - return history - if user_message1 in ['', None, '\n']: - if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): - # reject non-retry submit/enter - return history - user_message1 = fix_text_for_gradio(user_message1) - return history + [[user_message1, None]] - - def user(*args, undo=False, retry=False, sanitize_user_prompt=False): - return update_history(*args, undo=undo, retry=retry, sanitize_user_prompt=sanitize_user_prompt) - - def all_user(*args, undo=False, retry=False, sanitize_user_prompt=False, num_model_lock=0): - args_list = list(args) - history_list = args_list[-num_model_lock:] - assert len(history_list) > 0, "Bad history list: %s" % history_list - for hi, history in enumerate(history_list): - if num_model_lock > 0: - hargs = args_list[:-num_model_lock].copy() - else: - hargs = args_list.copy() - hargs += [history] - history_list[hi] = update_history(*hargs, undo=undo, retry=retry, - sanitize_user_prompt=sanitize_user_prompt) - if len(history_list) > 1: - return tuple(history_list) - else: - return history_list[0] - - def get_model_max_length(model_state1): - if model_state1 and not isinstance(model_state1["tokenizer"], str): - tokenizer = model_state1["tokenizer"] - elif model_state0 and not isinstance(model_state0["tokenizer"], str): - tokenizer = model_state0["tokenizer"] - else: - tokenizer = None - if tokenizer is not None: - return tokenizer.model_max_length - else: - return 2000 - - def prep_bot(*args, retry=False, which_model=0): - """ - - :param args: - :param retry: - :param which_model: identifies which model if doing model_lock - API only called for which_model=0, default for inputs_list, but rest should ignore inputs_list - :return: last element is True if should run bot, False if should just yield history - """ - isize = len(input_args_list) + 1 # states + chat history - # don't deepcopy, can contain model itself - args_list = list(args).copy() - model_state1 = args_list[-isize] - my_db_state1 = args_list[-isize + 1] - selection_docs_state1 = args_list[-isize + 2] - history = args_list[-1] - prompt_type1 = args_list[eval_func_param_names.index('prompt_type')] - prompt_dict1 = args_list[eval_func_param_names.index('prompt_dict')] - - if model_state1['model'] is None or model_state1['model'] == no_model_str: - return history, None, None, None - - args_list = args_list[:-isize] # only keep rest needed for evaluate() - langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] - add_chat_history_to_context1 = args_list[eval_func_param_names.index('add_chat_history_to_context')] - langchain_action1 = args_list[eval_func_param_names.index('langchain_action')] - langchain_agents1 = args_list[eval_func_param_names.index('langchain_agents')] - document_subset1 = args_list[eval_func_param_names.index('document_subset')] - document_choice1 = args_list[eval_func_param_names.index('document_choice')] - if not history: - print("No history", flush=True) - history = [] - return history, None, None, None - instruction1 = history[-1][0] - if retry and history: - # if retry, pop history and move onto bot stuff - instruction1 = history[-1][0] - history[-1][1] = None - elif not instruction1: - if not allow_empty_instruction(langchain_mode1, document_subset1, langchain_action1): - # if not retrying, then reject empty query - return history, None, None, None - elif len(history) > 0 and history[-1][1] not in [None, '']: - # reject submit button if already filled and not retrying - # None when not filling with '' to keep client happy - return history, None, None, None - - # shouldn't have to specify in API prompt_type if CLI launched model, so prefer global CLI one if have it - prompt_type1, prompt_dict1 = update_prompt(prompt_type1, prompt_dict1, model_state1, - which_model=which_model) - # apply back to args_list for evaluate() - args_list[eval_func_param_names.index('prompt_type')] = prompt_type1 - args_list[eval_func_param_names.index('prompt_dict')] = prompt_dict1 - - chat1 = args_list[eval_func_param_names.index('chat')] - model_max_length1 = get_model_max_length(model_state1) - context1 = history_to_context(history, langchain_mode1, - add_chat_history_to_context1, - prompt_type1, prompt_dict1, chat1, - model_max_length1, memory_restriction_level, - kwargs['keep_sources_in_context']) - args_list[0] = instruction1 # override original instruction with history from user - args_list[2] = context1 - - fun1 = partial(evaluate, - model_state1, - my_db_state1, - selection_docs_state1, - *tuple(args_list), - **kwargs_evaluate) - - return history, fun1, langchain_mode1, my_db_state1 - - def get_response(fun1, history): - """ - bot that consumes history for user input - instruction (from input_list) itself is not consumed by bot - :return: - """ - if not fun1: - yield history, '' - return - try: - for output_fun in fun1(): - output = output_fun['response'] - extra = output_fun['sources'] # FIXME: can show sources in separate text box etc. - # ensure good visually, else markdown ignores multiple \n - bot_message = fix_text_for_gradio(output) - history[-1][1] = bot_message - yield history, '' - except StopIteration: - yield history, '' - except RuntimeError as e: - if "generator raised StopIteration" in str(e): - # assume last entry was bad, undo - history.pop() - yield history, '' - else: - if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None: - history[-1][1] = '' - yield history, str(e) - raise - except Exception as e: - # put error into user input - ex = "Exception: %s" % str(e) - if history and len(history) > 0 and len(history[0]) > 1 and history[-1][1] is None: - history[-1][1] = '' - yield history, ex - raise - finally: - clear_torch_cache() - return - - def clear_embeddings(langchain_mode1, db1s): - # clear any use of embedding that sits on GPU, else keeps accumulating GPU usage even if clear torch cache - if db_type == 'chroma' and langchain_mode1 not in ['LLM', 'Disabled', None, '']: - from gpt_langchain import clear_embedding - db = dbs.get('langchain_mode1') - if db is not None and not isinstance(db, str): - clear_embedding(db) - if db1s is not None and langchain_mode1 in db1s: - db1 = db1s[langchain_mode1] - if len(db1) == 2: - clear_embedding(db1[0]) - - def bot(*args, retry=False): - history, fun1, langchain_mode1, db1 = prep_bot(*args, retry=retry) - try: - for res in get_response(fun1, history): - yield res - finally: - clear_torch_cache() - clear_embeddings(langchain_mode1, db1) - - def all_bot(*args, retry=False, model_states1=None): - args_list = list(args).copy() - chatbots = args_list[-len(model_states1):] - args_list0 = args_list[:-len(model_states1)] # same for all models - exceptions = [] - stream_output1 = args_list[eval_func_param_names.index('stream_output')] - max_time1 = args_list[eval_func_param_names.index('max_time')] - langchain_mode1 = args_list[eval_func_param_names.index('langchain_mode')] - isize = len(input_args_list) + 1 # states + chat history - db1s = None - try: - gen_list = [] - for chatboti, (chatbot1, model_state1) in enumerate(zip(chatbots, model_states1)): - args_list1 = args_list0.copy() - args_list1.insert(-isize + 2, - model_state1) # insert at -2 so is at -3, and after chatbot1 added, at -4 - # if at start, have None in response still, replace with '' so client etc. acts like normal - # assumes other parts of code treat '' and None as if no response yet from bot - # can't do this later in bot code as racy with threaded generators - if len(chatbot1) > 0 and len(chatbot1[-1]) == 2 and chatbot1[-1][1] is None: - chatbot1[-1][1] = '' - args_list1.append(chatbot1) - # so consistent with prep_bot() - # with model_state1 at -3, my_db_state1 at -2, and history(chatbot) at -1 - # langchain_mode1 and my_db_state1 should be same for every bot - history, fun1, langchain_mode1, db1s = prep_bot(*tuple(args_list1), retry=retry, - which_model=chatboti) - gen1 = get_response(fun1, history) - if stream_output1: - gen1 = TimeoutIterator(gen1, timeout=0.01, sentinel=None, raise_on_exception=False) - # else timeout will truncate output for non-streaming case - gen_list.append(gen1) - - bots_old = chatbots.copy() - exceptions_old = [''] * len(bots_old) - tgen0 = time.time() - for res1 in itertools.zip_longest(*gen_list): - if time.time() - tgen0 > max_time1: - print("Took too long: %s" % max_time1, flush=True) - break - - bots = [x[0] if x is not None and not isinstance(x, BaseException) else y for x, y in - zip(res1, bots_old)] - bots_old = bots.copy() - - def larger_str(x, y): - return x if len(x) > len(y) else y - - exceptions = [x[1] if x is not None and not isinstance(x, BaseException) else larger_str(str(x), y) - for x, y in zip(res1, exceptions_old)] - exceptions_old = exceptions.copy() - - def choose_exc(x): - # don't expose ports etc. to exceptions window - if is_public: - return "Endpoint unavailable or failed" - else: - return x - - exceptions_str = '\n'.join( - ['Model %s: %s' % (iix, choose_exc(x)) for iix, x in enumerate(exceptions) if - x not in [None, '', 'None']]) - if len(bots) > 1: - yield tuple(bots + [exceptions_str]) - else: - yield bots[0], exceptions_str - if exceptions: - exceptions = [x for x in exceptions if x not in ['', None, 'None']] - if exceptions: - print("Generate exceptions: %s" % exceptions, flush=True) - finally: - clear_torch_cache() - clear_embeddings(langchain_mode1, db1s) - - # NORMAL MODEL - user_args = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']), - inputs=inputs_list + [text_output], - outputs=text_output, - ) - bot_args = dict(fn=bot, - inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output], - outputs=[text_output, chat_exception_text], - ) - retry_bot_args = dict(fn=functools.partial(bot, retry=True), - inputs=inputs_list + [model_state, my_db_state, selection_docs_state] + [text_output], - outputs=[text_output, chat_exception_text], - ) - retry_user_args = dict(fn=functools.partial(user, retry=True), - inputs=inputs_list + [text_output], - outputs=text_output, - ) - undo_user_args = dict(fn=functools.partial(user, undo=True), - inputs=inputs_list + [text_output], - outputs=text_output, - ) - - # MODEL2 - user_args2 = dict(fn=functools.partial(user, sanitize_user_prompt=kwargs['sanitize_user_prompt']), - inputs=inputs_list2 + [text_output2], - outputs=text_output2, - ) - bot_args2 = dict(fn=bot, - inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2], - outputs=[text_output2, chat_exception_text], - ) - retry_bot_args2 = dict(fn=functools.partial(bot, retry=True), - inputs=inputs_list2 + [model_state2, my_db_state, selection_docs_state] + [text_output2], - outputs=[text_output2, chat_exception_text], - ) - retry_user_args2 = dict(fn=functools.partial(user, retry=True), - inputs=inputs_list2 + [text_output2], - outputs=text_output2, - ) - undo_user_args2 = dict(fn=functools.partial(user, undo=True), - inputs=inputs_list2 + [text_output2], - outputs=text_output2, - ) - - # MODEL N - all_user_args = dict(fn=functools.partial(all_user, - sanitize_user_prompt=kwargs['sanitize_user_prompt'], - num_model_lock=len(text_outputs), - ), - inputs=inputs_list + text_outputs, - outputs=text_outputs, - ) - all_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states), - inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs, - outputs=text_outputs + [chat_exception_text], - ) - all_retry_bot_args = dict(fn=functools.partial(all_bot, model_states1=model_states, retry=True), - inputs=inputs_list + [my_db_state, selection_docs_state] + text_outputs, - outputs=text_outputs + [chat_exception_text], - ) - all_retry_user_args = dict(fn=functools.partial(all_user, retry=True, - sanitize_user_prompt=kwargs['sanitize_user_prompt'], - num_model_lock=len(text_outputs), - ), - inputs=inputs_list + text_outputs, - outputs=text_outputs, - ) - all_undo_user_args = dict(fn=functools.partial(all_user, undo=True, - sanitize_user_prompt=kwargs['sanitize_user_prompt'], - num_model_lock=len(text_outputs), - ), - inputs=inputs_list + text_outputs, - outputs=text_outputs, - ) - - def clear_instruct(): - return gr.Textbox.update(value='') - - def deselect_radio_chats(): - return gr.update(value=None) - - def clear_all(): - return gr.Textbox.update(value=''), gr.Textbox.update(value=''), gr.update(value=None), \ - gr.Textbox.update(value=''), gr.Textbox.update(value='') - - if kwargs['model_states']: - submits1 = submits2 = submits3 = [] - submits4 = [] - - fun_source = [instruction.submit, submit.click, retry_btn.click] - fun_name = ['instruction', 'submit', 'retry'] - user_args = [all_user_args, all_user_args, all_retry_user_args] - bot_args = [all_bot_args, all_bot_args, all_retry_bot_args] - for userargs1, botarg1, funn1, funs1 in zip(user_args, bot_args, fun_name, fun_source): - submit_event11 = funs1(fn=dummy_fun, - inputs=instruction, outputs=instruction, queue=queue) - submit_event1a = submit_event11.then(**userargs1, queue=queue, - api_name='%s' % funn1 if allow_api else None) - # if hit enter on new instruction for submitting new query, no longer the saved chat - submit_event1b = submit_event1a.then(clear_all, inputs=None, - outputs=[instruction, iinput, radio_chats, score_text, - score_text2], - queue=queue) - submit_event1c = submit_event1b.then(**botarg1, - api_name='%s_bot' % funn1 if allow_api else None, - queue=queue) - submit_event1d = submit_event1c.then(**all_score_args, - api_name='%s_bot_score' % funn1 if allow_api else None, - queue=queue) - - submits1.extend([submit_event1a, submit_event1b, submit_event1c, submit_event1d]) - - # if undo, no longer the saved chat - submit_event4 = undo.click(fn=dummy_fun, - inputs=instruction, outputs=instruction, queue=queue) \ - .then(**all_undo_user_args, api_name='undo' if allow_api else None) \ - .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text, - score_text2], queue=queue) \ - .then(**all_score_args, api_name='undo_score' if allow_api else None) - submits4 = [submit_event4] - - else: - # in case 2nd model, consume instruction first, so can clear quickly - # bot doesn't consume instruction itself, just history from user, so why works - submit_event11 = instruction.submit(fn=dummy_fun, - inputs=instruction, outputs=instruction, queue=queue) - submit_event1a = submit_event11.then(**user_args, queue=queue, - api_name='instruction' if allow_api else None) - # if hit enter on new instruction for submitting new query, no longer the saved chat - submit_event1a2 = submit_event1a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) - submit_event1b = submit_event1a2.then(**user_args2, api_name='instruction2' if allow_api else None) - submit_event1c = submit_event1b.then(clear_instruct, None, instruction) \ - .then(clear_instruct, None, iinput) - submit_event1d = submit_event1c.then(**bot_args, api_name='instruction_bot' if allow_api else None, - queue=queue) - submit_event1e = submit_event1d.then(**score_args, - api_name='instruction_bot_score' if allow_api else None, - queue=queue) - submit_event1f = submit_event1e.then(**bot_args2, api_name='instruction_bot2' if allow_api else None, - queue=queue) - submit_event1g = submit_event1f.then(**score_args2, - api_name='instruction_bot_score2' if allow_api else None, queue=queue) - - submits1 = [submit_event1a, submit_event1a2, submit_event1b, submit_event1c, submit_event1d, - submit_event1e, - submit_event1f, submit_event1g] - - submit_event21 = submit.click(fn=dummy_fun, - inputs=instruction, outputs=instruction, queue=queue) - submit_event2a = submit_event21.then(**user_args, api_name='submit' if allow_api else None) - # if submit new query, no longer the saved chat - submit_event2a2 = submit_event2a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) - submit_event2b = submit_event2a2.then(**user_args2, api_name='submit2' if allow_api else None) - submit_event2c = submit_event2b.then(clear_all, inputs=None, - outputs=[instruction, iinput, radio_chats, score_text, score_text2], - queue=queue) - submit_event2d = submit_event2c.then(**bot_args, api_name='submit_bot' if allow_api else None, queue=queue) - submit_event2e = submit_event2d.then(**score_args, - api_name='submit_bot_score' if allow_api else None, - queue=queue) - submit_event2f = submit_event2e.then(**bot_args2, api_name='submit_bot2' if allow_api else None, - queue=queue) - submit_event2g = submit_event2f.then(**score_args2, - api_name='submit_bot_score2' if allow_api else None, - queue=queue) - - submits2 = [submit_event2a, submit_event2a2, submit_event2b, submit_event2c, submit_event2d, - submit_event2e, - submit_event2f, submit_event2g] - - submit_event31 = retry_btn.click(fn=dummy_fun, - inputs=instruction, outputs=instruction, queue=queue) - submit_event3a = submit_event31.then(**user_args, api_name='retry' if allow_api else None) - # if retry, no longer the saved chat - submit_event3a2 = submit_event3a.then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=queue) - submit_event3b = submit_event3a2.then(**user_args2, api_name='retry2' if allow_api else None) - submit_event3c = submit_event3b.then(clear_instruct, None, instruction) \ - .then(clear_instruct, None, iinput) - submit_event3d = submit_event3c.then(**retry_bot_args, api_name='retry_bot' if allow_api else None, - queue=queue) - submit_event3e = submit_event3d.then(**score_args, - api_name='retry_bot_score' if allow_api else None, - queue=queue) - submit_event3f = submit_event3e.then(**retry_bot_args2, api_name='retry_bot2' if allow_api else None, - queue=queue) - submit_event3g = submit_event3f.then(**score_args2, - api_name='retry_bot_score2' if allow_api else None, - queue=queue) - - submits3 = [submit_event3a, submit_event3a2, submit_event3b, submit_event3c, submit_event3d, - submit_event3e, - submit_event3f, submit_event3g] - - # if undo, no longer the saved chat - submit_event4 = undo.click(fn=dummy_fun, - inputs=instruction, outputs=instruction, queue=queue) \ - .then(**undo_user_args, api_name='undo' if allow_api else None) \ - .then(**undo_user_args2, api_name='undo2' if allow_api else None) \ - .then(clear_all, inputs=None, outputs=[instruction, iinput, radio_chats, score_text, - score_text2], queue=queue) \ - .then(**score_args, api_name='undo_score' if allow_api else None) \ - .then(**score_args2, api_name='undo_score2' if allow_api else None) - submits4 = [submit_event4] - - # MANAGE CHATS - def dedup(short_chat, short_chats): - if short_chat not in short_chats: - return short_chat - for i in range(1, 1000): - short_chat_try = short_chat + "_" + str(i) - if short_chat_try not in short_chats: - return short_chat_try - # fallback and hope for best - short_chat = short_chat + "_" + str(random.random()) - return short_chat - - def get_short_chat(x, short_chats, short_len=20, words=4): - if x and len(x[0]) == 2 and x[0][0] is not None: - short_chat = ' '.join(x[0][0][:short_len].split(' ')[:words]).strip() - if not short_chat: - # e.g.summarization, try using answer - short_chat = ' '.join(x[0][1][:short_len].split(' ')[:words]).strip() - if not short_chat: - short_chat = 'Unk' - short_chat = dedup(short_chat, short_chats) - else: - short_chat = None - return short_chat - - def is_chat_same(x, y): - #

    etc. added in chat, try to remove some of that to help avoid dup entries when hit new conversation - is_same = True - # length of conversation has to be same - if len(x) != len(y): - return False - if len(x) != len(y): - return False - for stepx, stepy in zip(x, y): - if len(stepx) != len(stepy): - # something off with a conversation - return False - for stepxx, stepyy in zip(stepx, stepy): - if len(stepxx) != len(stepyy): - # something off with a conversation - return False - if len(stepxx) != 2: - # something off - return False - if len(stepyy) != 2: - # something off - return False - questionx = stepxx[0].replace('

    ', '').replace('

    ', '') if stepxx[0] is not None else None - answerx = stepxx[1].replace('

    ', '').replace('

    ', '') if stepxx[1] is not None else None - - questiony = stepyy[0].replace('

    ', '').replace('

    ', '') if stepyy[0] is not None else None - answery = stepyy[1].replace('

    ', '').replace('

    ', '') if stepyy[1] is not None else None - - if questionx != questiony or answerx != answery: - return False - return is_same - - def save_chat(*args, chat_is_list=False): - args_list = list(args) - if not chat_is_list: - # list of chatbot histories, - # can't pass in list with list of chatbot histories and state due to gradio limits - chat_list = args_list[:-1] - else: - assert len(args_list) == 2 - chat_list = args_list[0] - # if old chat file with single chatbot, get into shape - if isinstance(chat_list, list) and len(chat_list) > 0 and isinstance(chat_list[0], list) and len( - chat_list[0]) == 2 and isinstance(chat_list[0][0], str) and isinstance(chat_list[0][1], str): - chat_list = [chat_list] - # remove None histories - chat_list_not_none = [x for x in chat_list if x and len(x) > 0 and len(x[0]) == 2 and x[0][1] is not None] - chat_list_none = [x for x in chat_list if x not in chat_list_not_none] - if len(chat_list_none) > 0 and len(chat_list_not_none) == 0: - raise ValueError("Invalid chat file") - # dict with keys of short chat names, values of list of list of chatbot histories - chat_state1 = args_list[-1] - short_chats = list(chat_state1.keys()) - if len(chat_list_not_none) > 0: - # make short_chat key from only first history, based upon question that is same anyways - chat_first = chat_list_not_none[0] - short_chat = get_short_chat(chat_first, short_chats) - if short_chat: - old_chat_lists = list(chat_state1.values()) - already_exists = any([is_chat_same(chat_list, x) for x in old_chat_lists]) - if not already_exists: - chat_state1[short_chat] = chat_list.copy() - - # reverse so newest at top - choices = list(chat_state1.keys()).copy() - choices.reverse() - - return chat_state1, gr.update(choices=choices, value=None) - - def switch_chat(chat_key, chat_state1, num_model_lock=0): - chosen_chat = chat_state1[chat_key] - # deal with possible different size of chat list vs. current list - ret_chat = [None] * (2 + num_model_lock) - for chati in range(0, 2 + num_model_lock): - ret_chat[chati % len(ret_chat)] = chosen_chat[chati % len(chosen_chat)] - return tuple(ret_chat) - - def clear_texts(*args): - return tuple([gr.Textbox.update(value='')] * len(args)) - - def clear_scores(): - return gr.Textbox.update(value=res_value), \ - gr.Textbox.update(value='Response Score: NA'), \ - gr.Textbox.update(value='Response Score: NA') - - switch_chat_fun = functools.partial(switch_chat, num_model_lock=len(text_outputs)) - radio_chats.input(switch_chat_fun, - inputs=[radio_chats, chat_state], - outputs=[text_output, text_output2] + text_outputs) \ - .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) - - def remove_chat(chat_key, chat_state1): - if isinstance(chat_key, str): - chat_state1.pop(chat_key, None) - return gr.update(choices=list(chat_state1.keys()), value=None), chat_state1 - - remove_chat_event = remove_chat_btn.click(remove_chat, - inputs=[radio_chats, chat_state], outputs=[radio_chats, chat_state], - queue=False, api_name='remove_chat') - - def get_chats1(chat_state1): - base = 'chats' - makedirs(base, exist_ok=True) - filename = os.path.join(base, 'chats_%s.json' % str(uuid.uuid4())) - with open(filename, "wt") as f: - f.write(json.dumps(chat_state1, indent=2)) - return filename - - export_chat_event = export_chats_btn.click(get_chats1, inputs=chat_state, outputs=chats_file, queue=False, - api_name='export_chats' if allow_api else None) - - def add_chats_from_file(file, chat_state1, radio_chats1, chat_exception_text1): - if not file: - return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 - if isinstance(file, str): - files = [file] - else: - files = file - if not files: - return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 - chat_exception_list = [] - for file1 in files: - try: - if hasattr(file1, 'name'): - file1 = file1.name - with open(file1, "rt") as f: - new_chats = json.loads(f.read()) - for chat1_k, chat1_v in new_chats.items(): - # ignore chat1_k, regenerate and de-dup to avoid loss - chat_state1, _ = save_chat(chat1_v, chat_state1, chat_is_list=True) - except BaseException as e: - t, v, tb = sys.exc_info() - ex = ''.join(traceback.format_exception(t, v, tb)) - ex_str = "File %s exception: %s" % (file1, str(e)) - print(ex_str, flush=True) - chat_exception_list.append(ex_str) - chat_exception_text1 = '\n'.join(chat_exception_list) - return None, chat_state1, gr.update(choices=list(chat_state1.keys()), value=None), chat_exception_text1 - - # note for update_user_db_func output is ignored for db - chatup_change_event = chatsup_output.change(add_chats_from_file, - inputs=[chatsup_output, chat_state, radio_chats, - chat_exception_text], - outputs=[chatsup_output, chat_state, radio_chats, - chat_exception_text], - queue=False, - api_name='add_to_chats' if allow_api else None) - - clear_chat_event = clear_chat_btn.click(fn=clear_texts, - inputs=[text_output, text_output2] + text_outputs, - outputs=[text_output, text_output2] + text_outputs, - queue=False, api_name='clear' if allow_api else None) \ - .then(deselect_radio_chats, inputs=None, outputs=radio_chats, queue=False) \ - .then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) - - clear_event = save_chat_btn.click(save_chat, - inputs=[text_output, text_output2] + text_outputs + [chat_state], - outputs=[chat_state, radio_chats], - api_name='save_chat' if allow_api else None) - if kwargs['score_model']: - clear_event2 = clear_event.then(clear_scores, outputs=[score_text, score_text2, score_text_nochat]) - - # NOTE: clear of instruction/iinput for nochat has to come after score, - # because score for nochat consumes actual textbox, while chat consumes chat history filled by user() - no_chat_args = dict(fn=fun, - inputs=[model_state, my_db_state, selection_docs_state] + inputs_list, - outputs=text_output_nochat, - queue=queue, - ) - submit_event_nochat = submit_nochat.click(**no_chat_args, api_name='submit_nochat' if allow_api else None) \ - .then(clear_torch_cache) \ - .then(**score_args_nochat, api_name='instruction_bot_score_nochat' if allow_api else None, queue=queue) \ - .then(clear_instruct, None, instruction_nochat) \ - .then(clear_instruct, None, iinput_nochat) \ - .then(clear_torch_cache) - # copy of above with text box submission - submit_event_nochat2 = instruction_nochat.submit(**no_chat_args) \ - .then(clear_torch_cache) \ - .then(**score_args_nochat, queue=queue) \ - .then(clear_instruct, None, instruction_nochat) \ - .then(clear_instruct, None, iinput_nochat) \ - .then(clear_torch_cache) - - submit_event_nochat_api = submit_nochat_api.click(fun_with_dict_str, - inputs=[model_state, my_db_state, selection_docs_state, - inputs_dict_str], - outputs=text_output_nochat_api, - queue=True, # required for generator - api_name='submit_nochat_api' if allow_api else None) \ - .then(clear_torch_cache) - - def load_model(model_name, lora_weights, server_name, model_state_old, prompt_type_old, load_8bit, - use_gpu_id, gpu_id): - # ensure no API calls reach here - if is_public: - raise RuntimeError("Illegal access for %s" % model_name) - # ensure old model removed from GPU memory - if kwargs['debug']: - print("Pre-switch pre-del GPU memory: %s" % get_torch_allocated(), flush=True) - - model0 = model_state0['model'] - if isinstance(model_state_old['model'], str) and model0 is not None: - # best can do, move model loaded at first to CPU - model0.cpu() - - if model_state_old['model'] is not None and not isinstance(model_state_old['model'], str): - try: - model_state_old['model'].cpu() - except Exception as e: - # sometimes hit NotImplementedError: Cannot copy out of meta tensor; no data! - print("Unable to put model on CPU: %s" % str(e), flush=True) - del model_state_old['model'] - model_state_old['model'] = None - - if model_state_old['tokenizer'] is not None and not isinstance(model_state_old['tokenizer'], str): - del model_state_old['tokenizer'] - model_state_old['tokenizer'] = None - - clear_torch_cache() - if kwargs['debug']: - print("Pre-switch post-del GPU memory: %s" % get_torch_allocated(), flush=True) - - if model_name is None or model_name == no_model_str: - # no-op if no model, just free memory - # no detranscribe needed for model, never go into evaluate - lora_weights = no_lora_str - server_name = no_server_str - return [None, None, None, model_name, server_name], \ - model_name, lora_weights, server_name, prompt_type_old, \ - gr.Slider.update(maximum=256), \ - gr.Slider.update(maximum=256) - - # don't deepcopy, can contain model itself - all_kwargs1 = all_kwargs.copy() - all_kwargs1['base_model'] = model_name.strip() - all_kwargs1['load_8bit'] = load_8bit - all_kwargs1['use_gpu_id'] = use_gpu_id - all_kwargs1['gpu_id'] = int(gpu_id) # detranscribe - model_lower = model_name.strip().lower() - if model_lower in inv_prompt_type_to_model_lower: - prompt_type1 = inv_prompt_type_to_model_lower[model_lower] - else: - prompt_type1 = prompt_type_old - - # detranscribe - if lora_weights == no_lora_str: - lora_weights = '' - all_kwargs1['lora_weights'] = lora_weights.strip() - if server_name == no_server_str: - server_name = '' - all_kwargs1['inference_server'] = server_name.strip() - - model1, tokenizer1, device1 = get_model(reward_type=False, - **get_kwargs(get_model, exclude_names=['reward_type'], - **all_kwargs1)) - clear_torch_cache() - - tokenizer_base_model = model_name - prompt_dict1, error0 = get_prompt(prompt_type1, '', - chat=False, context='', reduced=False, making_context=False, - return_dict=True) - model_state_new = dict(model=model1, tokenizer=tokenizer1, device=device1, - base_model=model_name, tokenizer_base_model=tokenizer_base_model, - lora_weights=lora_weights, inference_server=server_name, - prompt_type=prompt_type1, prompt_dict=prompt_dict1, - ) - - max_max_new_tokens1 = get_max_max_new_tokens(model_state_new, **kwargs) - - if kwargs['debug']: - print("Post-switch GPU memory: %s" % get_torch_allocated(), flush=True) - return model_state_new, model_name, lora_weights, server_name, prompt_type1, \ - gr.Slider.update(maximum=max_max_new_tokens1), \ - gr.Slider.update(maximum=max_max_new_tokens1) - - def get_prompt_str(prompt_type1, prompt_dict1, which=0): - if prompt_type1 in ['', None]: - print("Got prompt_type %s: %s" % (which, prompt_type1), flush=True) - return str({}) - prompt_dict1, prompt_dict_error = get_prompt(prompt_type1, prompt_dict1, chat=False, context='', - reduced=False, making_context=False, return_dict=True) - if prompt_dict_error: - return str(prompt_dict_error) - else: - # return so user can manipulate if want and use as custom - return str(prompt_dict1) - - get_prompt_str_func1 = functools.partial(get_prompt_str, which=1) - get_prompt_str_func2 = functools.partial(get_prompt_str, which=2) - prompt_type.change(fn=get_prompt_str_func1, inputs=[prompt_type, prompt_dict], outputs=prompt_dict, queue=False) - prompt_type2.change(fn=get_prompt_str_func2, inputs=[prompt_type2, prompt_dict2], outputs=prompt_dict2, - queue=False) - - def dropdown_prompt_type_list(x): - return gr.Dropdown.update(value=x) - - def chatbot_list(x, model_used_in): - return gr.Textbox.update(label=f'h2oGPT [Model: {model_used_in}]') - - load_model_args = dict(fn=load_model, - inputs=[model_choice, lora_choice, server_choice, model_state, prompt_type, - model_load8bit_checkbox, model_use_gpu_id_checkbox, model_gpu], - outputs=[model_state, model_used, lora_used, server_used, - # if prompt_type changes, prompt_dict will change via change rule - prompt_type, max_new_tokens, min_new_tokens, - ]) - prompt_update_args = dict(fn=dropdown_prompt_type_list, inputs=prompt_type, outputs=prompt_type) - chatbot_update_args = dict(fn=chatbot_list, inputs=[text_output, model_used], outputs=text_output) - nochat_update_args = dict(fn=chatbot_list, inputs=[text_output_nochat, model_used], outputs=text_output_nochat) - load_model_event = load_model_button.click(**load_model_args, - api_name='load_model' if allow_api and is_public else None) \ - .then(**prompt_update_args) \ - .then(**chatbot_update_args) \ - .then(**nochat_update_args) \ - .then(clear_torch_cache) - - load_model_args2 = dict(fn=load_model, - inputs=[model_choice2, lora_choice2, server_choice2, model_state2, prompt_type2, - model_load8bit_checkbox2, model_use_gpu_id_checkbox2, model_gpu2], - outputs=[model_state2, model_used2, lora_used2, server_used2, - # if prompt_type2 changes, prompt_dict2 will change via change rule - prompt_type2, max_new_tokens2, min_new_tokens2 - ]) - prompt_update_args2 = dict(fn=dropdown_prompt_type_list, inputs=prompt_type2, outputs=prompt_type2) - chatbot_update_args2 = dict(fn=chatbot_list, inputs=[text_output2, model_used2], outputs=text_output2) - load_model_event2 = load_model_button2.click(**load_model_args2, - api_name='load_model2' if allow_api and is_public else None) \ - .then(**prompt_update_args2) \ - .then(**chatbot_update_args2) \ - .then(clear_torch_cache) - - def dropdown_model_lora_server_list(model_list0, model_x, - lora_list0, lora_x, - server_list0, server_x, - model_used1, lora_used1, server_used1, - model_used2, lora_used2, server_used2, - ): - model_new_state = [model_list0[0] + [model_x]] - model_new_options = [*model_new_state[0]] - x1 = model_x if model_used1 == no_model_str else model_used1 - x2 = model_x if model_used2 == no_model_str else model_used2 - ret1 = [gr.Dropdown.update(value=x1, choices=model_new_options), - gr.Dropdown.update(value=x2, choices=model_new_options), - '', model_new_state] - - lora_new_state = [lora_list0[0] + [lora_x]] - lora_new_options = [*lora_new_state[0]] - # don't switch drop-down to added lora if already have model loaded - x1 = lora_x if model_used1 == no_model_str else lora_used1 - x2 = lora_x if model_used2 == no_model_str else lora_used2 - ret2 = [gr.Dropdown.update(value=x1, choices=lora_new_options), - gr.Dropdown.update(value=x2, choices=lora_new_options), - '', lora_new_state] - - server_new_state = [server_list0[0] + [server_x]] - server_new_options = [*server_new_state[0]] - # don't switch drop-down to added server if already have model loaded - x1 = server_x if model_used1 == no_model_str else server_used1 - x2 = server_x if model_used2 == no_model_str else server_used2 - ret3 = [gr.Dropdown.update(value=x1, choices=server_new_options), - gr.Dropdown.update(value=x2, choices=server_new_options), - '', server_new_state] - - return tuple(ret1 + ret2 + ret3) - - add_model_lora_server_event = \ - add_model_lora_server_button.click(fn=dropdown_model_lora_server_list, - inputs=[model_options_state, new_model] + - [lora_options_state, new_lora] + - [server_options_state, new_server] + - [model_used, lora_used, server_used] + - [model_used2, lora_used2, server_used2], - outputs=[model_choice, model_choice2, new_model, model_options_state] + - [lora_choice, lora_choice2, new_lora, lora_options_state] + - [server_choice, server_choice2, new_server, - server_options_state], - queue=False) - - go_event = go_btn.click(lambda: gr.update(visible=False), None, go_btn, api_name="go" if allow_api else None, - queue=False) \ - .then(lambda: gr.update(visible=True), None, normal_block, queue=False) \ - .then(**load_model_args, queue=False).then(**prompt_update_args, queue=False) - - def compare_textbox_fun(x): - return gr.Textbox.update(visible=x) - - def compare_column_fun(x): - return gr.Column.update(visible=x) - - def compare_prompt_fun(x): - return gr.Dropdown.update(visible=x) - - def slider_fun(x): - return gr.Slider.update(visible=x) - - compare_checkbox.select(compare_textbox_fun, compare_checkbox, text_output2, - api_name="compare_checkbox" if allow_api else None) \ - .then(compare_column_fun, compare_checkbox, col_model2) \ - .then(compare_prompt_fun, compare_checkbox, prompt_type2) \ - .then(compare_textbox_fun, compare_checkbox, score_text2) \ - .then(slider_fun, compare_checkbox, max_new_tokens2) \ - .then(slider_fun, compare_checkbox, min_new_tokens2) - # FIXME: add score_res2 in condition, but do better - - # callback for logging flagged input/output - callback.setup(inputs_list + [text_output, text_output2] + text_outputs, "flagged_data_points") - flag_btn.click(lambda *args: callback.flag(args), inputs_list + [text_output, text_output2] + text_outputs, - None, - preprocess=False, - api_name='flag' if allow_api else None, queue=False) - flag_btn_nochat.click(lambda *args: callback.flag(args), inputs_list + [text_output_nochat], None, - preprocess=False, - api_name='flag_nochat' if allow_api else None, queue=False) - - def get_system_info(): - if is_public: - time.sleep(10) # delay to avoid spam since queue=False - return gr.Textbox.update(value=system_info_print()) - - system_event = system_btn.click(get_system_info, outputs=system_text, - api_name='system_info' if allow_api else None, queue=False) - - def get_system_info_dict(system_input1, **kwargs1): - if system_input1 != os.getenv("ADMIN_PASS", ""): - return json.dumps({}) - exclude_list = ['admin_pass', 'examples'] - sys_dict = {k: v for k, v in kwargs1.items() if - isinstance(v, (str, int, bool, float)) and k not in exclude_list} - try: - sys_dict.update(system_info()) - except Exception as e: - # protection - print("Exception: %s" % str(e), flush=True) - return json.dumps(sys_dict) - - system_kwargs = all_kwargs.copy() - system_kwargs.update(dict(command=str(' '.join(sys.argv)))) - get_system_info_dict_func = functools.partial(get_system_info_dict, **all_kwargs) - - system_dict_event = system_btn2.click(get_system_info_dict_func, - inputs=system_input, - outputs=system_text2, - api_name='system_info_dict' if allow_api else None, - queue=False, # queue to avoid spam - ) - - def get_hash(): - return kwargs['git_hash'] - - system_event = system_btn3.click(get_hash, - outputs=system_text3, - api_name='system_hash' if allow_api else None, - queue=False, - ) - - def count_chat_tokens(model_state1, chat1, prompt_type1, prompt_dict1, - memory_restriction_level1=0, - keep_sources_in_context1=False, - ): - if model_state1 and not isinstance(model_state1['tokenizer'], str): - tokenizer = model_state1['tokenizer'] - elif model_state0 and not isinstance(model_state0['tokenizer'], str): - tokenizer = model_state0['tokenizer'] - else: - tokenizer = None - if tokenizer is not None: - langchain_mode1 = 'LLM' - add_chat_history_to_context1 = True - # fake user message to mimic bot() - chat1 = copy.deepcopy(chat1) - chat1 = chat1 + [['user_message1', None]] - model_max_length1 = tokenizer.model_max_length - context1 = history_to_context(chat1, langchain_mode1, - add_chat_history_to_context1, - prompt_type1, prompt_dict1, chat1, - model_max_length1, - memory_restriction_level1, keep_sources_in_context1) - return str(tokenizer(context1, return_tensors="pt")['input_ids'].shape[1]) - else: - return "N/A" - - count_chat_tokens_func = functools.partial(count_chat_tokens, - memory_restriction_level1=memory_restriction_level, - keep_sources_in_context1=kwargs['keep_sources_in_context']) - count_tokens_event = count_chat_tokens_btn.click(fn=count_chat_tokens, - inputs=[model_state, text_output, prompt_type, prompt_dict], - outputs=chat_token_count, - api_name='count_tokens' if allow_api else None) - - # don't pass text_output, don't want to clear output, just stop it - # cancel only stops outer generation, not inner generation or non-generation - stop_btn.click(lambda: None, None, None, - cancels=submits1 + submits2 + submits3 + submits4 + - [submit_event_nochat, submit_event_nochat2] + - [eventdb1, eventdb2, eventdb3] + - [eventdb7, eventdb8, eventdb9, eventdb12] + - db_events + - [clear_event] + - [submit_event_nochat_api, submit_event_nochat] + - [load_model_event, load_model_event2] + - [count_tokens_event] - , - queue=False, api_name='stop' if allow_api else None).then(clear_torch_cache, queue=False) - - demo.load(None, None, None, _js=get_dark_js() if kwargs['dark'] else None) - - demo.queue(concurrency_count=kwargs['concurrency_count'], api_open=kwargs['api_open']) - favicon_path = "h2o-logo.svg" - if not os.path.isfile(favicon_path): - print("favicon_path=%s not found" % favicon_path, flush=True) - favicon_path = None - - scheduler = BackgroundScheduler() - scheduler.add_job(func=clear_torch_cache, trigger="interval", seconds=20) - if is_public and \ - kwargs['base_model'] not in non_hf_types: - # FIXME: disable for gptj, langchain or gpt4all modify print itself - # FIXME: and any multi-threaded/async print will enter model output! - scheduler.add_job(func=ping, trigger="interval", seconds=60) - if is_public or os.getenv('PING_GPU'): - scheduler.add_job(func=ping_gpu, trigger="interval", seconds=60 * 10) - scheduler.start() - - # import control - if kwargs['langchain_mode'] == 'Disabled' and \ - os.environ.get("TEST_LANGCHAIN_IMPORT") and \ - kwargs['base_model'] not in non_hf_types: - assert 'gpt_langchain' not in sys.modules, "Dev bug, import of langchain when should not have" - assert 'langchain' not in sys.modules, "Dev bug, import of langchain when should not have" - - demo.launch(share=kwargs['share'], server_name="0.0.0.0", show_error=True, - favicon_path=favicon_path, prevent_thread_lock=True, - auth=kwargs['auth']) - if kwargs['verbose']: - print("Started GUI", flush=True) - if kwargs['block_gradio_exit']: - demo.block_thread() - - -def get_inputs_list(inputs_dict, model_lower, model_id=1): - """ - map gradio objects in locals() to inputs for evaluate(). - :param inputs_dict: - :param model_lower: - :param model_id: Which model (1 or 2) of 2 - :return: - """ - inputs_list_names = list(inspect.signature(evaluate).parameters) - inputs_list = [] - inputs_dict_out = {} - for k in inputs_list_names: - if k == 'kwargs': - continue - if k in input_args_list + inputs_kwargs_list: - # these are added at use time for args or partial for kwargs, not taken as input - continue - if 'mbart-' not in model_lower and k in ['src_lang', 'tgt_lang']: - continue - if model_id == 2: - if k == 'prompt_type': - k = 'prompt_type2' - if k == 'prompt_used': - k = 'prompt_used2' - if k == 'max_new_tokens': - k = 'max_new_tokens2' - if k == 'min_new_tokens': - k = 'min_new_tokens2' - inputs_list.append(inputs_dict[k]) - inputs_dict_out[k] = inputs_dict[k] - return inputs_list, inputs_dict_out - - -def get_sources(db1s, langchain_mode, dbs=None, docs_state0=None): - for k in db1s: - set_userid(db1s[k]) - - if langchain_mode in ['LLM']: - source_files_added = "NA" - source_list = [] - elif langchain_mode in ['wiki_full']: - source_files_added = "Not showing wiki_full, takes about 20 seconds and makes 4MB file." \ - " Ask jon.mckinney@h2o.ai for file if required." - source_list = [] - elif langchain_mode in db1s and len(db1s[langchain_mode]) == 2 and db1s[langchain_mode][0] is not None: - db1 = db1s[langchain_mode] - from gpt_langchain import get_metadatas - metadatas = get_metadatas(db1[0]) - source_list = sorted(set([x['source'] for x in metadatas])) - source_files_added = '\n'.join(source_list) - elif langchain_mode in dbs and dbs[langchain_mode] is not None: - from gpt_langchain import get_metadatas - db1 = dbs[langchain_mode] - metadatas = get_metadatas(db1) - source_list = sorted(set([x['source'] for x in metadatas])) - source_files_added = '\n'.join(source_list) - else: - source_list = [] - source_files_added = "None" - sources_dir = "sources_dir" - makedirs(sources_dir) - sources_file = os.path.join(sources_dir, 'sources_%s_%s' % (langchain_mode, str(uuid.uuid4()))) - with open(sources_file, "wt") as f: - f.write(source_files_added) - source_list = docs_state0 + source_list - return sources_file, source_list - - -def set_userid(db1): - # can only call this after function called so for specific userr, not in gr.State() that occurs during app init - assert db1 is not None and len(db1) == 2 - if db1[1] is None: - # uuid in db is used as user ID - db1[1] = str(uuid.uuid4()) - - -def update_user_db(file, db1s, selection_docs_state1, chunk, chunk_size, langchain_mode, dbs=None, **kwargs): - kwargs.update(selection_docs_state1) - if file is None: - raise RuntimeError("Don't use change, use input") - - try: - return _update_user_db(file, db1s=db1s, chunk=chunk, chunk_size=chunk_size, - langchain_mode=langchain_mode, dbs=dbs, - **kwargs) - except BaseException as e: - print(traceback.format_exc(), flush=True) - # gradio has issues if except, so fail semi-gracefully, else would hang forever in processing textbox - ex_str = "Exception: %s" % str(e) - source_files_added = """\ - - -

    - Sources:
    -

    -
    - {0} -
    - - - """.format(ex_str) - doc_exception_text = str(e) - return None, langchain_mode, source_files_added, doc_exception_text - finally: - clear_torch_cache() - - -def get_lock_file(db1, langchain_mode): - set_userid(db1) - assert len(db1) == 2 and db1[1] is not None and isinstance(db1[1], str) - user_id = db1[1] - base_path = 'locks' - makedirs(base_path) - lock_file = os.path.join(base_path, "db_%s_%s.lock" % (langchain_mode.replace(' ', '_'), user_id)) - return lock_file - - -def _update_user_db(file, - db1s=None, - chunk=None, chunk_size=None, - dbs=None, db_type=None, - langchain_mode='UserData', - langchain_modes=None, # unused but required as part of selection_docs_state1 - langchain_mode_paths=None, - visible_langchain_modes=None, - use_openai_embedding=None, - hf_embedding_model=None, - caption_loader=None, - enable_captions=None, - captions_model=None, - enable_ocr=None, - enable_pdf_ocr=None, - verbose=None, - n_jobs=-1, - is_url=None, is_txt=None, - ): - assert db1s is not None - assert chunk is not None - assert chunk_size is not None - assert use_openai_embedding is not None - assert hf_embedding_model is not None - assert caption_loader is not None - assert enable_captions is not None - assert captions_model is not None - assert enable_ocr is not None - assert enable_pdf_ocr is not None - assert verbose is not None - - if dbs is None: - dbs = {} - assert isinstance(dbs, dict), "Wrong type for dbs: %s" % str(type(dbs)) - # assert db_type in ['faiss', 'chroma'], "db_type %s not supported" % db_type - from gpt_langchain import add_to_db, get_db, path_to_docs - # handle case of list of temp buffer - if isinstance(file, list) and len(file) > 0 and hasattr(file[0], 'name'): - file = [x.name for x in file] - # handle single file of temp buffer - if hasattr(file, 'name'): - file = file.name - if not isinstance(file, (list, tuple, typing.Generator)) and isinstance(file, str): - file = [file] - - if langchain_mode == LangChainMode.DISABLED.value: - return None, langchain_mode, get_source_files(), "" - - if langchain_mode in [LangChainMode.LLM.value]: - # then switch to MyData, so langchain_mode also becomes way to select where upload goes - # but default to mydata if nothing chosen, since safest - if LangChainMode.MY_DATA.value in visible_langchain_modes: - langchain_mode = LangChainMode.MY_DATA.value - - if langchain_mode_paths is None: - langchain_mode_paths = {} - user_path = langchain_mode_paths.get(langchain_mode) - # UserData or custom, which has to be from user's disk - if user_path is not None: - # move temp files from gradio upload to stable location - for fili, fil in enumerate(file): - if isinstance(fil, str) and os.path.isfile(fil): # not url, text - new_fil = os.path.normpath(os.path.join(user_path, os.path.basename(fil))) - if os.path.normpath(os.path.abspath(fil)) != os.path.normpath(os.path.abspath(new_fil)): - if os.path.isfile(new_fil): - remove(new_fil) - try: - shutil.move(fil, new_fil) - except FileExistsError: - pass - file[fili] = new_fil - - if verbose: - print("Adding %s" % file, flush=True) - sources = path_to_docs(file if not is_url and not is_txt else None, - verbose=verbose, - n_jobs=n_jobs, - chunk=chunk, chunk_size=chunk_size, - url=file if is_url else None, - text=file if is_txt else None, - enable_captions=enable_captions, - captions_model=captions_model, - enable_ocr=enable_ocr, - enable_pdf_ocr=enable_pdf_ocr, - caption_loader=caption_loader, - ) - exceptions = [x for x in sources if x.metadata.get('exception')] - exceptions_strs = [x.metadata['exception'] for x in exceptions] - sources = [x for x in sources if 'exception' not in x.metadata] - - # below must at least come after langchain_mode is modified in case was LLM -> MyData, - # so original langchain mode changed - for k in db1s: - set_userid(db1s[k]) - db1 = get_db1(db1s, langchain_mode) - - lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode) # user-level lock, not db-level lock - with filelock.FileLock(lock_file): - if langchain_mode in db1s: - if db1[0] is not None: - # then add - db, num_new_sources, new_sources_metadata = add_to_db(db1[0], sources, db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model) - else: - # in testing expect: - # assert len(db1) == 2 and db1[1] is None, "Bad MyData db: %s" % db1 - # for production hit, when user gets clicky: - assert len(db1) == 2, "Bad %s db: %s" % (langchain_mode, db1) - assert db1[1] is not None, "db hash was None, not allowed" - # then create - # if added has to original state and didn't change, then would be shared db for all users - persist_directory = os.path.join(scratch_base_dir, 'db_dir_%s_%s' % (langchain_mode, db1[1])) - db = get_db(sources, use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - langchain_mode=langchain_mode, - hf_embedding_model=hf_embedding_model) - if db is not None: - db1[0] = db - source_files_added = get_source_files(db=db1[0], exceptions=exceptions) - return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs) - else: - from gpt_langchain import get_persist_directory - persist_directory = get_persist_directory(langchain_mode) - if langchain_mode in dbs and dbs[langchain_mode] is not None: - # then add - db, num_new_sources, new_sources_metadata = add_to_db(dbs[langchain_mode], sources, db_type=db_type, - use_openai_embedding=use_openai_embedding, - hf_embedding_model=hf_embedding_model) - else: - # then create. Or might just be that dbs is unfilled, then it will fill, then add - db = get_db(sources, use_openai_embedding=use_openai_embedding, - db_type=db_type, - persist_directory=persist_directory, - langchain_mode=langchain_mode, - hf_embedding_model=hf_embedding_model) - dbs[langchain_mode] = db - # NOTE we do not return db, because function call always same code path - # return dbs[langchain_mode] - # db in this code path is updated in place - source_files_added = get_source_files(db=dbs[langchain_mode], exceptions=exceptions) - return None, langchain_mode, source_files_added, '\n'.join(exceptions_strs) - - -def get_db(db1s, langchain_mode, dbs=None): - db1 = get_db1(db1s, langchain_mode) - lock_file = get_lock_file(db1s[LangChainMode.MY_DATA.value], langchain_mode) - - with filelock.FileLock(lock_file): - if langchain_mode in ['wiki_full']: - # NOTE: avoid showing full wiki. Takes about 30 seconds over about 90k entries, but not useful for now - db = None - elif langchain_mode in db1s and len(db1) == 2 and db1[0] is not None: - db = db1[0] - elif dbs is not None and langchain_mode in dbs and dbs[langchain_mode] is not None: - db = dbs[langchain_mode] - else: - db = None - return db - - -def get_source_files_given_langchain_mode(db1s, langchain_mode='UserData', dbs=None): - db = get_db(db1s, langchain_mode, dbs=dbs) - if langchain_mode in ['LLM'] or db is None: - return "Sources: N/A" - return get_source_files(db=db, exceptions=None) - - -def get_source_files(db=None, exceptions=None, metadatas=None): - if exceptions is None: - exceptions = [] - - # only should be one source, not confused - # assert db is not None or metadatas is not None - # clicky user - if db is None and metadatas is None: - return "No Sources at all" - - if metadatas is None: - source_label = "Sources:" - if db is not None: - from gpt_langchain import get_metadatas - metadatas = get_metadatas(db) - else: - metadatas = [] - adding_new = False - else: - source_label = "New Sources:" - adding_new = True - - # below automatically de-dups - from gpt_langchain import get_url - small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('head')) for x in - metadatas} - # if small_dict is empty dict, that's ok - df = pd.DataFrame(small_dict.items(), columns=['source', 'head']) - df.index = df.index + 1 - df.index.name = 'index' - source_files_added = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml') - - if exceptions: - exception_metadatas = [x.metadata for x in exceptions] - small_dict = {get_url(x['source'], from_str=True, short_name=True): get_short_name(x.get('exception')) for x in - exception_metadatas} - # if small_dict is empty dict, that's ok - df = pd.DataFrame(small_dict.items(), columns=['source', 'exception']) - df.index = df.index + 1 - df.index.name = 'index' - exceptions_html = tabulate.tabulate(df, headers='keys', tablefmt='unsafehtml') - else: - exceptions_html = '' - - if metadatas and exceptions: - source_files_added = """\ - - -

    - {0}
    -

    -
    - {1} - {2} -
    - - - """.format(source_label, source_files_added, exceptions_html) - elif metadatas: - source_files_added = """\ - - -

    - {0}
    -

    -
    - {1} -
    - - - """.format(source_label, source_files_added) - elif exceptions_html: - source_files_added = """\ - - -

    - Exceptions:
    -

    -
    - {0} -
    - - - """.format(exceptions_html) - else: - if adding_new: - source_files_added = "No New Sources" - else: - source_files_added = "No Sources" - - return source_files_added - - -def update_and_get_source_files_given_langchain_mode(db1s, langchain_mode, chunk, chunk_size, - dbs=None, first_para=None, - text_limit=None, - langchain_mode_paths=None, db_type=None, load_db_if_exists=None, - n_jobs=None, verbose=None): - has_path = {k: v for k, v in langchain_mode_paths.items() if v} - if langchain_mode in [LangChainMode.LLM.value, LangChainMode.MY_DATA.value]: - # then assume user really meant UserData, to avoid extra clicks in UI, - # since others can't be on disk, except custom user modes, which they should then select to query it - if LangChainMode.USER_DATA.value in has_path: - langchain_mode = LangChainMode.USER_DATA.value - - db = get_db(db1s, langchain_mode, dbs=dbs) - - from gpt_langchain import make_db - db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=False, - hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", - first_para=first_para, text_limit=text_limit, - chunk=chunk, - chunk_size=chunk_size, - langchain_mode=langchain_mode, - langchain_mode_paths=langchain_mode_paths, - db_type=db_type, - load_db_if_exists=load_db_if_exists, - db=db, - n_jobs=n_jobs, - verbose=verbose) - # during refreshing, might have "created" new db since not in dbs[] yet, so insert back just in case - # so even if persisted, not kept up-to-date with dbs memory - if langchain_mode in db1s: - db1s[langchain_mode][0] = db - else: - dbs[langchain_mode] = db - - # return only new sources with text saying such - return get_source_files(db=None, exceptions=None, metadatas=new_sources_metadata) - - -def get_db1(db1s, langchain_mode1): - if langchain_mode1 in db1s: - db1 = db1s[langchain_mode1] - else: - # indicates to code that not scratch database - db1 = [None, None] - return db1 diff --git a/gradio_themes.py b/gradio_themes.py deleted file mode 100644 index 79f075df4273bbd609c626d4e6fb76ab499a96aa..0000000000000000000000000000000000000000 --- a/gradio_themes.py +++ /dev/null @@ -1,231 +0,0 @@ -from __future__ import annotations - -from typing import Iterable - -from gradio.themes.soft import Soft -from gradio.themes import Color, Size -from gradio.themes.utils import colors, sizes, fonts - -h2o_yellow = Color( - name="yellow", - c50="#fffef2", - c100="#fff9e6", - c200="#ffecb3", - c300="#ffe28c", - c400="#ffd659", - c500="#fec925", - c600="#e6ac00", - c700="#bf8f00", - c800="#a67c00", - c900="#664d00", - c950="#403000", -) -h2o_gray = Color( - name="gray", - c50="#f8f8f8", - c100="#e5e5e5", - c200="#cccccc", - c300="#b2b2b2", - c400="#999999", - c500="#7f7f7f", - c600="#666666", - c700="#4c4c4c", - c800="#333333", - c900="#191919", - c950="#0d0d0d", -) - - -text_xsm = Size( - name="text_xsm", - xxs="4px", - xs="5px", - sm="6px", - md="7px", - lg="8px", - xl="10px", - xxl="12px", -) - - -spacing_xsm = Size( - name="spacing_xsm", - xxs="1px", - xs="1px", - sm="1px", - md="2px", - lg="3px", - xl="5px", - xxl="7px", -) - - -radius_xsm = Size( - name="radius_xsm", - xxs="1px", - xs="1px", - sm="1px", - md="2px", - lg="3px", - xl="5px", - xxl="7px", -) - - -class H2oTheme(Soft): - def __init__( - self, - *, - primary_hue: colors.Color | str = h2o_yellow, - secondary_hue: colors.Color | str = h2o_yellow, - neutral_hue: colors.Color | str = h2o_gray, - spacing_size: sizes.Size | str = sizes.spacing_md, - radius_size: sizes.Size | str = sizes.radius_md, - text_size: sizes.Size | str = sizes.text_lg, - font: fonts.Font - | str - | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("Montserrat"), - "ui-sans-serif", - "system-ui", - "sans-serif", - ), - font_mono: fonts.Font - | str - | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("IBM Plex Mono"), - "ui-monospace", - "Consolas", - "monospace", - ), - ): - super().__init__( - primary_hue=primary_hue, - secondary_hue=secondary_hue, - neutral_hue=neutral_hue, - spacing_size=spacing_size, - radius_size=radius_size, - text_size=text_size, - font=font, - font_mono=font_mono, - ) - super().set( - link_text_color="#3344DD", - link_text_color_hover="#3344DD", - link_text_color_visited="#3344DD", - link_text_color_dark="#74abff", - link_text_color_hover_dark="#a3c8ff", - link_text_color_active_dark="#a3c8ff", - link_text_color_visited_dark="#74abff", - button_primary_text_color="*neutral_950", - button_primary_text_color_dark="*neutral_950", - button_primary_background_fill="*primary_500", - button_primary_background_fill_dark="*primary_500", - block_label_background_fill="*primary_500", - block_label_background_fill_dark="*primary_500", - block_label_text_color="*neutral_950", - block_label_text_color_dark="*neutral_950", - block_title_text_color="*neutral_950", - block_title_text_color_dark="*neutral_950", - block_background_fill_dark="*neutral_950", - body_background_fill="*neutral_50", - body_background_fill_dark="*neutral_900", - background_fill_primary_dark="*block_background_fill", - block_radius="0 0 8px 8px", - checkbox_label_text_color_selected_dark='#000000', - #checkbox_label_text_size="*text_xs", # too small for iPhone etc. but good if full large screen zoomed to fit - checkbox_label_text_size="*text_sm", - #radio_circle="""url("data:image/svg+xml,%3csvg viewBox='0 0 32 32' fill='white' xmlns='http://www.w3.org/2000/svg'%3e%3ccircle cx='32' cy='32' r='1'/%3e%3c/svg%3e")""", - #checkbox_border_width=1, - #heckbox_border_width_dark=1, - ) - - -class SoftTheme(Soft): - def __init__( - self, - *, - primary_hue: colors.Color | str = colors.indigo, - secondary_hue: colors.Color | str = colors.indigo, - neutral_hue: colors.Color | str = colors.gray, - spacing_size: sizes.Size | str = sizes.spacing_md, - radius_size: sizes.Size | str = sizes.radius_md, - text_size: sizes.Size | str = sizes.text_md, - font: fonts.Font - | str - | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("Montserrat"), - "ui-sans-serif", - "system-ui", - "sans-serif", - ), - font_mono: fonts.Font - | str - | Iterable[fonts.Font | str] = ( - fonts.GoogleFont("IBM Plex Mono"), - "ui-monospace", - "Consolas", - "monospace", - ), - ): - super().__init__( - primary_hue=primary_hue, - secondary_hue=secondary_hue, - neutral_hue=neutral_hue, - spacing_size=spacing_size, - radius_size=radius_size, - text_size=text_size, - font=font, - font_mono=font_mono, - ) - super().set( - checkbox_label_text_size="*text_sm", - ) - - -h2o_logo = '' - - -def get_h2o_title(title, description): - # NOTE: Check full width desktop, smallest width browser desktop, iPhone browsers to ensure no overlap etc. - return f"""
    - {description} -
    -
    -
    {h2o_logo}
    -

    {title}

    -
    -
    - -
    - """ - - -def get_simple_title(title, description): - return f"""{description}

    {title}

    """ - - -def get_dark_js(): - return """() => { - if (document.querySelectorAll('.dark').length) { - document.querySelectorAll('.dark').forEach(el => el.classList.remove('dark')); - } else { - document.querySelector('body').classList.add('dark'); - } - }""" diff --git a/gradio_utils/__init__.py b/gradio_utils/__init__.py deleted file mode 100644 index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..0000000000000000000000000000000000000000 diff --git a/gradio_utils/__pycache__/__init__.cpython-310.pyc b/gradio_utils/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index d241355589af12eb5291c68408a8b07ca2b773ca..0000000000000000000000000000000000000000 Binary files a/gradio_utils/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/gradio_utils/__pycache__/css.cpython-310.pyc b/gradio_utils/__pycache__/css.cpython-310.pyc deleted file mode 100644 index b20611b32b3a1209501b9a8dde07598ab38049d6..0000000000000000000000000000000000000000 Binary files a/gradio_utils/__pycache__/css.cpython-310.pyc and /dev/null differ diff --git a/gradio_utils/__pycache__/grclient.cpython-310.pyc b/gradio_utils/__pycache__/grclient.cpython-310.pyc deleted file mode 100644 index fae2390da62d66ce5a73f3b207536f3f25ebf8ab..0000000000000000000000000000000000000000 Binary files a/gradio_utils/__pycache__/grclient.cpython-310.pyc and /dev/null differ diff --git a/gradio_utils/__pycache__/prompt_form.cpython-310.pyc b/gradio_utils/__pycache__/prompt_form.cpython-310.pyc deleted file mode 100644 index 5e9983b4967f07603adb3b1bf86e7afce605299f..0000000000000000000000000000000000000000 Binary files a/gradio_utils/__pycache__/prompt_form.cpython-310.pyc and /dev/null differ diff --git a/gradio_utils/css.py b/gradio_utils/css.py deleted file mode 100644 index 7db8bee879c89a28d36b2f7f5d9c1183e76c1b1c..0000000000000000000000000000000000000000 --- a/gradio_utils/css.py +++ /dev/null @@ -1,60 +0,0 @@ -def get_css(kwargs) -> str: - if kwargs['h2ocolors']: - css_code = """footer {visibility: hidden;} - body{background:linear-gradient(#f5f5f5,#e5e5e5);} - body.dark{background:linear-gradient(#000000,#0d0d0d);} - """ - else: - css_code = """footer {visibility: hidden}""" - - css_code += make_css_base() - return css_code - - -def make_css_base() -> str: - css1 = """ - #col_container {margin-left: auto; margin-right: auto; text-align: left;} - """ - return css1 + """ - @import url('https://fonts.googleapis.com/css2?family=Source+Sans+Pro:wght@400;600&display=swap'); - - body.dark{#warning {background-color: #555555};} - - #small_btn { - margin: 0.6em 0em 0.55em 0; - max-width: 20em; - min-width: 5em !important; - height: 5em; - font-size: 14px !important; - } - - #prompt-form { - border: 1px solid var(--primary-500) !important; - } - - #prompt-form.block { - border-radius: var(--block-radius) !important; - } - - #prompt-form textarea { - border: 1px solid rgb(209, 213, 219); - } - - #prompt-form label > div { - margin-top: 4px; - } - - button.primary:hover { - background-color: var(--primary-600) !important; - transition: .2s; - } - - #prompt-form-area { - margin-bottom: 2.5rem; - } - .chatsmall chatbot {font-size: 10px !important} - - .gradio-container { - max-width: none !important; - } - """ diff --git a/gradio_utils/grclient.py b/gradio_utils/grclient.py deleted file mode 100644 index 8346a61cad99d492f8a10de17851454488364b83..0000000000000000000000000000000000000000 --- a/gradio_utils/grclient.py +++ /dev/null @@ -1,82 +0,0 @@ -import traceback -from typing import Callable -import os - -from gradio_client.client import Job - -os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1' - -from gradio_client import Client - - -class GradioClient(Client): - """ - Parent class of gradio client - To handle automatically refreshing client if detect gradio server changed - """ - - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - super().__init__(*args, **kwargs) - self.server_hash = self.get_server_hash() - - def get_server_hash(self): - """ - Get server hash using super without any refresh action triggered - Returns: git hash of gradio server - """ - return super().submit(api_name='/system_hash').result() - - def refresh_client_if_should(self): - # get current hash in order to update api_name -> fn_index map in case gradio server changed - # FIXME: Could add cli api as hash - server_hash = self.get_server_hash() - if self.server_hash != server_hash: - self.refresh_client() - self.server_hash = server_hash - else: - self.reset_session() - - def refresh_client(self): - """ - Ensure every client call is independent - Also ensure map between api_name and fn_index is updated in case server changed (e.g. restarted with new code) - Returns: - """ - # need session hash to be new every time, to avoid "generator already executing" - self.reset_session() - - client = Client(*self.args, **self.kwargs) - for k, v in client.__dict__.items(): - setattr(self, k, v) - - def submit( - self, - *args, - api_name: str | None = None, - fn_index: int | None = None, - result_callbacks: Callable | list[Callable] | None = None, - ) -> Job: - # Note predict calls submit - try: - self.refresh_client_if_should() - job = super().submit(*args, api_name=api_name, fn_index=fn_index) - except Exception as e: - print("Hit e=%s" % str(e), flush=True) - # force reconfig in case only that - self.refresh_client() - job = super().submit(*args, api_name=api_name, fn_index=fn_index) - - # see if immediately failed - e = job.future._exception - if e is not None: - print("GR job failed: %s %s" % (str(e), ''.join(traceback.format_tb(e.__traceback__))), flush=True) - # force reconfig in case only that - self.refresh_client() - job = super().submit(*args, api_name=api_name, fn_index=fn_index) - e2 = job.future._exception - if e2 is not None: - print("GR job failed again: %s\n%s" % (str(e2), ''.join(traceback.format_tb(e2.__traceback__))), flush=True) - - return job diff --git a/gradio_utils/prompt_form.py b/gradio_utils/prompt_form.py deleted file mode 100644 index 34707d44de1d9eb21b7caef4e5345b11c4c9bd28..0000000000000000000000000000000000000000 --- a/gradio_utils/prompt_form.py +++ /dev/null @@ -1,95 +0,0 @@ -import os -import math - -import gradio as gr - - -def make_chatbots(output_label0, output_label0_model2, **kwargs): - text_outputs = [] - chat_kwargs = [] - for model_state_lock in kwargs['model_states']: - if os.environ.get('DEBUG_MODEL_LOCK'): - model_name = model_state_lock["base_model"] + " : " + model_state_lock["inference_server"] - else: - model_name = model_state_lock["base_model"] - output_label = f'h2oGPT [{model_name}]' - min_width = 250 if kwargs['gradio_size'] in ['small', 'large', 'medium'] else 160 - chat_kwargs.append(dict(label=output_label, visible=kwargs['model_lock'], elem_classes='chatsmall', - height=kwargs['height'] or 400, min_width=min_width)) - - if kwargs['model_lock_columns'] == -1: - kwargs['model_lock_columns'] = len(kwargs['model_states']) - if kwargs['model_lock_columns'] is None: - kwargs['model_lock_columns'] = 3 - - ncols = kwargs['model_lock_columns'] - if kwargs['model_states'] == 0: - nrows = 0 - else: - nrows = math.ceil(len(kwargs['model_states']) / kwargs['model_lock_columns']) - - if kwargs['model_lock_columns'] == 0: - # not using model_lock - pass - elif nrows <= 1: - with gr.Row(): - for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']): - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - elif nrows == kwargs['model_states']: - with gr.Row(): - for chat_kwargs1, model_state_lock in zip(chat_kwargs, kwargs['model_states']): - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - elif nrows == 2: - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii >= len(kwargs['model_states']) / 2: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii < len(kwargs['model_states']) / 2: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - elif nrows == 3: - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii >= 1 * len(kwargs['model_states']) / 3: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii < 1 * len(kwargs['model_states']) / 3 or mii >= 2 * len(kwargs['model_states']) / 3: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii < 2 * len(kwargs['model_states']) / 3: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - elif nrows >= 4: - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii >= 1 * len(kwargs['model_states']) / 4: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii < 1 * len(kwargs['model_states']) / 4 or mii >= 2 * len(kwargs['model_states']) / 4: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii < 2 * len(kwargs['model_states']) / 4 or mii >= 3 * len(kwargs['model_states']) / 4: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - with gr.Row(): - for mii, (chat_kwargs1, model_state_lock) in enumerate(zip(chat_kwargs, kwargs['model_states'])): - if mii < 3 * len(kwargs['model_states']) / 4: - continue - text_outputs.append(gr.Chatbot(**chat_kwargs1)) - - with gr.Row(): - text_output = gr.Chatbot(label=output_label0, visible=not kwargs['model_lock'], height=kwargs['height'] or 400) - text_output2 = gr.Chatbot(label=output_label0_model2, - visible=False and not kwargs['model_lock'], height=kwargs['height'] or 400) - return text_output, text_output2, text_outputs diff --git a/h2o-logo.svg b/h2o-logo.svg deleted file mode 100644 index d6b04435700ffae6284031d15b2220ea53bdce7f..0000000000000000000000000000000000000000 --- a/h2o-logo.svg +++ /dev/null @@ -1 +0,0 @@ - \ No newline at end of file diff --git a/h2oai_pipeline.py b/h2oai_pipeline.py deleted file mode 100644 index 368f49fbd81993a200311a267a43649e9ea0bfca..0000000000000000000000000000000000000000 --- a/h2oai_pipeline.py +++ /dev/null @@ -1,201 +0,0 @@ -import os - -from transformers import TextGenerationPipeline -from transformers.pipelines.text_generation import ReturnType - -from stopping import get_stopping -from prompter import Prompter, PromptType - - -class H2OTextGenerationPipeline(TextGenerationPipeline): - def __init__(self, *args, debug=False, chat=False, stream_output=False, - sanitize_bot_response=False, - use_prompter=True, prompter=None, - context='', iinput='', - prompt_type=None, prompt_dict=None, - max_input_tokens=2048 - 256, **kwargs): - """ - HF-like pipeline, but handle instruction prompting and stopping (for some models) - :param args: - :param debug: - :param chat: - :param stream_output: - :param sanitize_bot_response: - :param use_prompter: Whether to use prompter. If pass prompt_type, will make prompter - :param prompter: prompter, can pass if have already - :param prompt_type: prompt_type, e.g. human_bot. See prompt_type to model mapping in from prompter.py. - If use_prompter, then will make prompter and use it. - :param prompt_dict: dict of get_prompt(, return_dict=True) for prompt_type=custom - :param max_input_tokens: - :param kwargs: - """ - super().__init__(*args, **kwargs) - self.prompt_text = None - self.use_prompter = use_prompter - self.prompt_type = prompt_type - self.prompt_dict = prompt_dict - self.prompter = prompter - self.context = context - self.iinput = iinput - if self.use_prompter: - if self.prompter is not None: - assert self.prompter.prompt_type is not None - else: - self.prompter = Prompter(self.prompt_type, self.prompt_dict, debug=debug, chat=chat, - stream_output=stream_output) - self.human = self.prompter.humanstr - self.bot = self.prompter.botstr - self.can_stop = True - else: - self.prompter = None - self.human = None - self.bot = None - self.can_stop = False - self.sanitize_bot_response = sanitize_bot_response - self.max_input_tokens = max_input_tokens # not for generate, so ok that not kwargs - - @staticmethod - def limit_prompt(prompt_text, tokenizer, max_prompt_length=None): - verbose = bool(int(os.getenv('VERBOSE_PIPELINE', '0'))) - - if hasattr(tokenizer, 'model_max_length'): - # model_max_length only defined for generate.py, not raw use of h2oai_pipeline.py - model_max_length = tokenizer.model_max_length - if max_prompt_length is not None: - model_max_length = min(model_max_length, max_prompt_length) - # cut at some upper likely limit to avoid excessive tokenization etc - # upper bound of 10 chars/token, e.g. special chars sometimes are long - if len(prompt_text) > model_max_length * 10: - len0 = len(prompt_text) - prompt_text = prompt_text[-model_max_length * 10:] - if verbose: - print("Cut of input: %s -> %s" % (len0, len(prompt_text)), flush=True) - else: - # unknown - model_max_length = None - - num_prompt_tokens = None - if model_max_length is not None: - # can't wait for "hole" if not plain prompt_type, since would lose prefix like : - # For https://github.com/h2oai/h2ogpt/issues/192 - for trial in range(0, 3): - prompt_tokens = tokenizer(prompt_text)['input_ids'] - num_prompt_tokens = len(prompt_tokens) - if num_prompt_tokens > model_max_length: - # conservative by using int() - chars_per_token = int(len(prompt_text) / num_prompt_tokens) - # keep tail, where question is if using langchain - prompt_text = prompt_text[-model_max_length * chars_per_token:] - if verbose: - print("reducing %s tokens, assuming average of %s chars/token for %s characters" % ( - num_prompt_tokens, chars_per_token, len(prompt_text)), flush=True) - else: - if verbose: - print("using %s tokens with %s chars" % (num_prompt_tokens, len(prompt_text)), flush=True) - break - - # Why Below False: don't limit max_new_tokens more, just rely upon stopping to reach limit of model - if False: - # if input prompt is some number of tokens, despite user request, can't have max_new_tokens more - # - assert num_prompt_tokens is not None - if self.prompt_type not in [PromptType.plain.name, PromptType.plain.value]: - # then give room for prompt - fudge = 20 - else: - fudge = 0 - max_new_tokens = max(0, min(generate_kwargs['max_new_tokens'], - model_max_length - (num_prompt_tokens + fudge))) - if max_new_tokens < generate_kwargs['max_new_tokens']: - if verbose: - print("Reduced max_new_tokens from %s -> %s" % ( - generate_kwargs['max_new_tokens'], max_new_tokens)) - generate_kwargs['max_new_tokens'] = max_new_tokens - return prompt_text, num_prompt_tokens - - def preprocess(self, prompt_text, prefix="", handle_long_generation=None, **generate_kwargs): - prompt_text, num_prompt_tokens = H2OTextGenerationPipeline.limit_prompt(prompt_text, self.tokenizer) - - data_point = dict(context=self.context, instruction=prompt_text, input=self.iinput) - if self.prompter is not None: - prompt_text = self.prompter.generate_prompt(data_point) - self.prompt_text = prompt_text - if handle_long_generation is None: - # forces truncation of inputs to avoid critical failure - handle_long_generation = None # disable with new approaches - return super().preprocess(prompt_text, prefix=prefix, handle_long_generation=handle_long_generation, - **generate_kwargs) - - def postprocess(self, model_outputs, return_type=ReturnType.FULL_TEXT, clean_up_tokenization_spaces=True): - records = super().postprocess(model_outputs, return_type=return_type, - clean_up_tokenization_spaces=clean_up_tokenization_spaces) - for rec in records: - if self.use_prompter: - outputs = rec['generated_text'] - outputs = self.prompter.get_response(outputs, prompt=self.prompt_text, - sanitize_bot_response=self.sanitize_bot_response) - elif self.bot and self.human: - outputs = rec['generated_text'].split(self.bot)[1].split(self.human)[0] - else: - outputs = rec['generated_text'] - rec['generated_text'] = outputs - print("prompt: %s\noutputs: %s\n\n" % (self.prompt_text, outputs), flush=True) - return records - - def _forward(self, model_inputs, **generate_kwargs): - if self.can_stop: - stopping_criteria = get_stopping(self.prompt_type, self.prompt_dict, - self.tokenizer, self.device, - human=self.human, bot=self.bot, - model_max_length=self.tokenizer.model_max_length) - generate_kwargs['stopping_criteria'] = stopping_criteria - # return super()._forward(model_inputs, **generate_kwargs) - return self.__forward(model_inputs, **generate_kwargs) - - # FIXME: Copy-paste of original _forward, but removed copy.deepcopy() - # FIXME: https://github.com/h2oai/h2ogpt/issues/172 - def __forward(self, model_inputs, **generate_kwargs): - input_ids = model_inputs["input_ids"] - attention_mask = model_inputs.get("attention_mask", None) - # Allow empty prompts - if input_ids.shape[1] == 0: - input_ids = None - attention_mask = None - in_b = 1 - else: - in_b = input_ids.shape[0] - prompt_text = model_inputs.pop("prompt_text") - - ## If there is a prefix, we may need to adjust the generation length. Do so without permanently modifying - ## generate_kwargs, as some of the parameterization may come from the initialization of the pipeline. - # generate_kwargs = copy.deepcopy(generate_kwargs) - prefix_length = generate_kwargs.pop("prefix_length", 0) - if prefix_length > 0: - has_max_new_tokens = "max_new_tokens" in generate_kwargs or ( - "generation_config" in generate_kwargs - and generate_kwargs["generation_config"].max_new_tokens is not None - ) - if not has_max_new_tokens: - generate_kwargs["max_length"] = generate_kwargs.get("max_length") or self.model.config.max_length - generate_kwargs["max_length"] += prefix_length - has_min_new_tokens = "min_new_tokens" in generate_kwargs or ( - "generation_config" in generate_kwargs - and generate_kwargs["generation_config"].min_new_tokens is not None - ) - if not has_min_new_tokens and "min_length" in generate_kwargs: - generate_kwargs["min_length"] += prefix_length - - # BS x SL - generated_sequence = self.model.generate(input_ids=input_ids, attention_mask=attention_mask, **generate_kwargs) - out_b = generated_sequence.shape[0] - if self.framework == "pt": - generated_sequence = generated_sequence.reshape(in_b, out_b // in_b, *generated_sequence.shape[1:]) - elif self.framework == "tf": - from transformers import is_tf_available - if is_tf_available(): - import tensorflow as tf - generated_sequence = tf.reshape(generated_sequence, - (in_b, out_b // in_b, *generated_sequence.shape[1:])) - else: - raise ValueError("TF not avaialble.") - return {"generated_sequence": generated_sequence, "input_ids": input_ids, "prompt_text": prompt_text} diff --git a/iterators/__init__.py b/iterators/__init__.py deleted file mode 100644 index d800eac15a042c02c0d8b31f086db83ade229a53..0000000000000000000000000000000000000000 --- a/iterators/__init__.py +++ /dev/null @@ -1,4 +0,0 @@ -from .timeout_iterator import TimeoutIterator, AsyncTimeoutIterator -from .iterator_pipe import IteratorPipe, AsyncIteratorPipe - -__all__ = ["TimeoutIterator", "AsyncTimeoutIterator", "IteratorPipe", "AsyncIteratorPipe"] \ No newline at end of file diff --git a/iterators/__pycache__/__init__.cpython-310.pyc b/iterators/__pycache__/__init__.cpython-310.pyc deleted file mode 100644 index 92a8d3cef6c28e4df6e8911b2f5ce838445b618c..0000000000000000000000000000000000000000 Binary files a/iterators/__pycache__/__init__.cpython-310.pyc and /dev/null differ diff --git a/iterators/__pycache__/iterator_pipe.cpython-310.pyc b/iterators/__pycache__/iterator_pipe.cpython-310.pyc deleted file mode 100644 index 3aa9bfb7ba8fe0fbd81f3ea4cba3062fc2508fff..0000000000000000000000000000000000000000 Binary files a/iterators/__pycache__/iterator_pipe.cpython-310.pyc and /dev/null differ diff --git a/iterators/__pycache__/timeout_iterator.cpython-310.pyc b/iterators/__pycache__/timeout_iterator.cpython-310.pyc deleted file mode 100644 index 08cfdd39f1e7e962957fc1959d9bce884d8085e7..0000000000000000000000000000000000000000 Binary files a/iterators/__pycache__/timeout_iterator.cpython-310.pyc and /dev/null differ diff --git a/iterators/iterator_pipe.py b/iterators/iterator_pipe.py deleted file mode 100644 index 90883b08ee6c5fbb7a575a7f1176f124b4d66134..0000000000000000000000000000000000000000 --- a/iterators/iterator_pipe.py +++ /dev/null @@ -1,93 +0,0 @@ -import queue -import asyncio - - -class IteratorPipe: - """ - Iterator Pipe creates an iterator that can be fed in data from another block of code or thread of execution - """ - - def __init__(self, sentinel=object()): - self._q = queue.Queue() - self._sentinel = sentinel - self._sentinel_pushed = False - self._closed = False - - def __iter__(self): - return self - - def __next__(self): - if self._closed: - raise StopIteration - - data = self._q.get(block=True) - if data is self._sentinel: - self._closed = True - raise StopIteration - - return data - - def put(self, data) -> bool: - """ - Pushes next item to Iterator and returns True - If iterator has been closed via close(), doesn't push anything and returns False - """ - if self._sentinel_pushed: - return False - - self._q.put(data) - return True - - def close(self): - """ - Close is idempotent. Calling close multiple times is safe - Iterator will raise StopIteration only after all elements pushed before close have been iterated - """ - # make close idempotent - if not self._sentinel_pushed: - self._sentinel_pushed = True - self._q.put(self._sentinel) - - -class AsyncIteratorPipe: - - def __init__(self, sentinel=object()): - self._q = asyncio.Queue() - self._sentinel = sentinel - self._sentinel_pushed = False - self._closed = False - - def __aiter__(self): - return self - - async def __anext__(self): - if self._closed: - raise StopAsyncIteration - - data = await self._q.get() - if data is self._sentinel: - self._closed = True - raise StopAsyncIteration - - return data - - async def put(self, data) -> bool: - """ - Pushes next item to Iterator and returns True - If iterator has been closed via close(), doesn't push anything and returns False - """ - if self._sentinel_pushed: - return False - - await self._q.put(data) - return True - - async def close(self): - """ - Close is idempotent. Calling close multiple times is safe - Iterator will raise StopIteration only after all elements pushed before close have been iterated - """ - # make close idempotent - if not self._sentinel_pushed: - self._sentinel_pushed = True - await self._q.put(self._sentinel) diff --git a/iterators/timeout_iterator.py b/iterators/timeout_iterator.py deleted file mode 100644 index d6f760e4b67448538dc95328a58c1eb1b1958471..0000000000000000000000000000000000000000 --- a/iterators/timeout_iterator.py +++ /dev/null @@ -1,170 +0,0 @@ -import queue -import asyncio -import threading -import traceback - - -class TimeoutIterator: - """ - Wrapper class to add timeout feature to synchronous iterators - - timeout: timeout for next(). Default=ZERO_TIMEOUT i.e. no timeout or blocking calls to next. Updated using set_timeout() - - sentinel: the object returned by iterator when timeout happens - - reset_on_next: if set to True, timeout is reset to the value of ZERO_TIMEOUT on each iteration - - TimeoutIterator uses a thread internally. - The thread stops once the iterator exhausts or raises an exception during iteration. - - Any exceptions raised within the wrapped iterator are propagated as it is. - Exception is raised when all elements generated by the actual iterator before exception have been consumed - Timeout can be set dynamically before going for iteration - """ - ZERO_TIMEOUT = 0.0 - - def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False, raise_on_exception=True): - self._iterator = iterator - self._timeout = timeout - self._sentinel = sentinel - self._reset_on_next = reset_on_next - self._raise_on_exception = raise_on_exception - - self._interrupt = False - self._done = False - self._buffer = queue.Queue() - self._thread = threading.Thread(target=self.__lookahead) - self._thread.start() - - def get_sentinel(self): - return self._sentinel - - def set_reset_on_next(self, reset_on_next): - self._reset_on_next = reset_on_next - - def set_timeout(self, timeout: float): - """ - Set timeout for next iteration - """ - self._timeout = timeout - - def interrupt(self): - """ - interrupt and stop the underlying thread. - the thread actually dies only after interrupt has been set and - the underlying iterator yields a value after that. - """ - self._interrupt = True - - def __iter__(self): - return self - - def __next__(self): - """ - yield the result from iterator - if timeout > 0: - yield data if available. - otherwise yield sentinal - """ - if self._done: - raise StopIteration - - data = self._sentinel - try: - if self._timeout > self.ZERO_TIMEOUT: - data = self._buffer.get(timeout=self._timeout) - else: - data = self._buffer.get() - except queue.Empty: - pass - finally: - # see if timeout needs to be reset - if self._reset_on_next: - self._timeout = self.ZERO_TIMEOUT - - # propagate any exceptions including StopIteration - if isinstance(data, BaseException): - self._done = True - if isinstance(data, StopIteration): - raise data - ex = ''.join(traceback.format_tb(data.__traceback__)) - print("Generation Failed: %s %s" % (str(data), str(ex)), flush=True) - if self._raise_on_exception: - raise data - else: - return data - - return data - - def __lookahead(self): - try: - while True: - self._buffer.put(next(self._iterator)) - if self._interrupt: - raise StopIteration() - except BaseException as e: - self._buffer.put(e) - - -class AsyncTimeoutIterator: - """ - Async version of TimeoutIterator. See method documentation of TimeoutIterator - """ - ZERO_TIMEOUT = 0.0 - - def __init__(self, iterator, timeout=0.0, sentinel=object(), reset_on_next=False): - self._iterator = iterator - self._timeout = timeout - self._sentinel = sentinel - self._reset_on_next = reset_on_next - - self._interrupt = False - self._done = False - self._buffer = asyncio.Queue() - self._task = asyncio.get_event_loop().create_task(self.__lookahead()) - - def get_sentinel(self): - return self._sentinel - - def set_reset_on_next(self, reset_on_next): - self._reset_on_next = reset_on_next - - def set_timeout(self, timeout: float): - self._timeout = timeout - - def interrupt(self): - self._interrupt = True - - def __aiter__(self): - return self - - async def __anext__(self): - if self._done: - raise StopAsyncIteration - - data = self._sentinel - try: - if self._timeout > self.ZERO_TIMEOUT: - data = await asyncio.wait_for(self._buffer.get(), self._timeout) - else: - data = await self._buffer.get() - except asyncio.TimeoutError: - pass - finally: - # see if timeout needs to be reset - if self._reset_on_next: - self._timeout = self.ZERO_TIMEOUT - - # propagate any exceptions including StopIteration - if isinstance(data, BaseException): - self._done = True - raise data - - return data - - async def __lookahead(self): - try: - while True: - data = await self._iterator.__anext__() - await self._buffer.put(data) - if self._interrupt: - raise StopAsyncIteration() - except BaseException as e: - await self._buffer.put(e) diff --git a/loaders.py b/loaders.py deleted file mode 100644 index 18e360e2bdc45e7bddfc6f0e24d1e9099ae2f73c..0000000000000000000000000000000000000000 --- a/loaders.py +++ /dev/null @@ -1,61 +0,0 @@ -import functools - - -def get_loaders(model_name, reward_type, llama_type=None, load_gptq=''): - # NOTE: Some models need specific new prompt_type - # E.g. t5_xxl_true_nli_mixture has input format: "premise: PREMISE_TEXT hypothesis: HYPOTHESIS_TEXT".) - if load_gptq: - from transformers import AutoTokenizer - from auto_gptq import AutoGPTQForCausalLM - use_triton = False - functools.partial(AutoGPTQForCausalLM.from_quantized, quantize_config=None, use_triton=use_triton) - return AutoGPTQForCausalLM.from_quantized, AutoTokenizer - if llama_type is None: - llama_type = "llama" in model_name.lower() - if llama_type: - from transformers import LlamaForCausalLM, LlamaTokenizer - return LlamaForCausalLM.from_pretrained, LlamaTokenizer - elif 'distilgpt2' in model_name.lower(): - from transformers import AutoModelForCausalLM, AutoTokenizer - return AutoModelForCausalLM.from_pretrained, AutoTokenizer - elif 'gpt2' in model_name.lower(): - from transformers import GPT2LMHeadModel, GPT2Tokenizer - return GPT2LMHeadModel.from_pretrained, GPT2Tokenizer - elif 'mbart-' in model_name.lower(): - from transformers import MBartForConditionalGeneration, MBart50TokenizerFast - return MBartForConditionalGeneration.from_pretrained, MBart50TokenizerFast - elif 't5' == model_name.lower() or \ - 't5-' in model_name.lower() or \ - 'flan-' in model_name.lower(): - from transformers import AutoTokenizer, T5ForConditionalGeneration - return T5ForConditionalGeneration.from_pretrained, AutoTokenizer - elif 'bigbird' in model_name: - from transformers import BigBirdPegasusForConditionalGeneration, AutoTokenizer - return BigBirdPegasusForConditionalGeneration.from_pretrained, AutoTokenizer - elif 'bart-large-cnn-samsum' in model_name or 'flan-t5-base-samsum' in model_name: - from transformers import pipeline - return pipeline, "summarization" - elif reward_type or 'OpenAssistant/reward-model'.lower() in model_name.lower(): - from transformers import AutoModelForSequenceClassification, AutoTokenizer - return AutoModelForSequenceClassification.from_pretrained, AutoTokenizer - else: - from transformers import AutoTokenizer, AutoModelForCausalLM - model_loader = AutoModelForCausalLM - tokenizer_loader = AutoTokenizer - return model_loader.from_pretrained, tokenizer_loader - - -def get_tokenizer(tokenizer_loader, tokenizer_base_model, local_files_only, resume_download, use_auth_token): - tokenizer = tokenizer_loader.from_pretrained(tokenizer_base_model, - local_files_only=local_files_only, - resume_download=resume_download, - use_auth_token=use_auth_token, - padding_side='left') - - tokenizer.pad_token_id = 0 # different from the eos token - # when generating, we will use the logits of right-most token to predict the next token - # so the padding should be on the left, - # e.g. see: https://huggingface.co/transformers/v4.11.3/model_doc/t5.html#inference - tokenizer.padding_side = "left" # Allow batched inference - - return tokenizer diff --git a/prompter.py b/prompter.py deleted file mode 100644 index 31d4d89f816a0da0fcbd9cab56d69b40254f8401..0000000000000000000000000000000000000000 --- a/prompter.py +++ /dev/null @@ -1,871 +0,0 @@ -import os -import ast -import time -from enums import PromptType # also supports imports from this file from other files - -non_hf_types = ['gpt4all_llama', 'llama', 'gptj'] - -prompt_type_to_model_name = { - 'plain': [ - 'EleutherAI/gpt-j-6B', - 'EleutherAI/pythia-6.9b', - 'EleutherAI/pythia-12b', - 'EleutherAI/pythia-12b-deduped', - 'EleutherAI/gpt-neox-20b', - 'openlm-research/open_llama_7b_700bt_preview', - 'decapoda-research/llama-7b-hf', - 'decapoda-research/llama-13b-hf', - 'decapoda-research/llama-30b-hf', - 'decapoda-research/llama-65b-hf', - 'facebook/mbart-large-50-many-to-many-mmt', - 'philschmid/bart-large-cnn-samsum', - 'philschmid/flan-t5-base-samsum', - 'gpt2', - 'distilgpt2', - 'mosaicml/mpt-7b-storywriter', - ], - 'gptj': ['gptj', 'gpt4all_llama'], - 'prompt_answer': [ - 'h2oai/h2ogpt-gm-oasst1-en-1024-20b', - 'h2oai/h2ogpt-gm-oasst1-en-1024-12b', - 'h2oai/h2ogpt-gm-oasst1-multilang-1024-20b', - 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b', - 'h2oai/h2ogpt-gm-oasst1-multilang-2048-falcon-7b-v2', - 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v3', - 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b', - 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-7b-v2', - 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v1', - 'h2oai/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2', - 'h2oai/h2ogpt-gm-oasst1-en-xgen-7b-8k', - 'h2oai/h2ogpt-gm-oasst1-multilang-xgen-7b-8k', - 'TheBloke/h2ogpt-gm-oasst1-en-2048-falcon-40b-v2-GPTQ', - ], - 'prompt_answer_openllama': [ - 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt', - 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-300bt-v2', - 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b-preview-700bt', - 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-7b', - 'h2oai/h2ogpt-gm-oasst1-en-2048-open-llama-13b', - ], - 'instruct': ['TheBloke/llama-30b-supercot-SuperHOT-8K-fp16'], # https://huggingface.co/TheBloke/llama-30b-supercot-SuperHOT-8K-fp16#prompting - 'instruct_with_end': ['databricks/dolly-v2-12b'], - 'quality': [], - 'human_bot': [ - 'h2oai/h2ogpt-oasst1-512-12b', - 'h2oai/h2ogpt-oasst1-512-20b', - 'h2oai/h2ogpt-oig-oasst1-256-6_9b', - 'h2oai/h2ogpt-oig-oasst1-512-6_9b', - 'h2oai/h2ogpt-oig-oasst1-256-6.9b', # legacy - 'h2oai/h2ogpt-oig-oasst1-512-6.9b', # legacy - 'h2oai/h2ogpt-research-oasst1-512-30b', - 'h2oai/h2ogpt-research-oasst1-llama-65b', - 'h2oai/h2ogpt-oasst1-falcon-40b', - 'h2oai/h2ogpt-oig-oasst1-falcon-40b', - ], - 'dai_faq': [], - 'summarize': [], - 'simple_instruct': ['t5-small', 't5-large', 'google/flan-t5', 'google/flan-t5-xxl', 'google/flan-ul2'], - 'instruct_vicuna': ['AlekseyKorshuk/vicuna-7b', 'TheBloke/stable-vicuna-13B-HF', 'junelee/wizard-vicuna-13b'], - 'human_bot_orig': ['togethercomputer/GPT-NeoXT-Chat-Base-20B'], - "open_assistant": ['OpenAssistant/oasst-sft-7-llama-30b-xor', 'oasst-sft-7-llama-30b'], - "wizard_lm": ['ehartford/WizardLM-7B-Uncensored', 'ehartford/WizardLM-13B-Uncensored'], - "wizard_mega": ['openaccess-ai-collective/wizard-mega-13b'], - "instruct_simple": ['JosephusCheung/Guanaco'], - "wizard_vicuna": ['ehartford/Wizard-Vicuna-13B-Uncensored'], - "wizard2": ['llama'], - "mptinstruct": ['mosaicml/mpt-30b-instruct', 'mosaicml/mpt-7b-instruct', 'mosaicml/mpt-30b-instruct'], - "mptchat": ['mosaicml/mpt-7b-chat', 'mosaicml/mpt-30b-chat', 'TheBloke/mpt-30B-chat-GGML'], - "vicuna11": ['lmsys/vicuna-33b-v1.3'], - "falcon": ['tiiuae/falcon-40b-instruct', 'tiiuae/falcon-40b', 'tiiuae/falcon-7b-instruct', 'tiiuae/falcon-7b'], - "llama2": [ - 'meta-llama/Llama-2-7b-chat-hf', - 'meta-llama/Llama-2-13b-chat-hf', - 'meta-llama/Llama-2-34b-chat-hf', - 'meta-llama/Llama-2-70b-chat-hf', - ], - # could be plain, but default is correct prompt_type for default TheBloke model ggml-wizardLM-7B.q4_2.bin -} -if os.getenv('OPENAI_API_KEY'): - prompt_type_to_model_name.update({ - "openai": ["text-davinci-003", "text-curie-001", "text-babbage-001", "text-ada-001"], - "openai_chat": ["gpt-3.5-turbo", "gpt-3.5-turbo-16k"], - }) - -inv_prompt_type_to_model_name = {v.strip(): k for k, l in prompt_type_to_model_name.items() for v in l} -inv_prompt_type_to_model_lower = {v.strip().lower(): k for k, l in prompt_type_to_model_name.items() for v in l} - -prompt_types_strings = [] -for p in PromptType: - prompt_types_strings.extend([p.name]) - -prompt_types = [] -for p in PromptType: - prompt_types.extend([p.name, p.value, str(p.value)]) - - -def get_prompt(prompt_type, prompt_dict, chat, context, reduced, making_context, return_dict=False): - prompt_dict_error = '' - generates_leading_space = False - - if prompt_type == PromptType.custom.name and not isinstance(prompt_dict, dict): - try: - prompt_dict = ast.literal_eval(prompt_dict) - except BaseException as e: - prompt_dict_error = str(e) - if prompt_dict_error: - promptA = None - promptB = None - PreInstruct = None - PreInput = '' - PreResponse = '' - terminate_response = None - chat_sep = '' - chat_turn_sep = '' - humanstr = '' - botstr = '' - generates_leading_space = False - elif prompt_type in [PromptType.custom.value, str(PromptType.custom.value), - PromptType.custom.name]: - promptA = prompt_dict.get('promptA', '') - promptB = prompt_dict.get('promptB', '') - PreInstruct = prompt_dict.get('PreInstruct', '') - PreInput = prompt_dict.get('PreInput', '') - PreResponse = prompt_dict.get('PreResponse', '') - terminate_response = prompt_dict.get('terminate_response', None) - chat_sep = prompt_dict.get('chat_sep', '\n') - chat_turn_sep = prompt_dict.get('chat_turn_sep', '\n') - humanstr = prompt_dict.get('humanstr', '') - botstr = prompt_dict.get('botstr', '') - elif prompt_type in [PromptType.plain.value, str(PromptType.plain.value), - PromptType.plain.name]: - promptA = promptB = PreInstruct = PreInput = PreResponse = None - terminate_response = [] - chat_turn_sep = chat_sep = '' - # plain should have None for human/bot, so nothing truncated out, not '' that would truncate after first token - humanstr = None - botstr = None - elif prompt_type == 'simple_instruct': - promptA = promptB = PreInstruct = PreInput = PreResponse = None - terminate_response = [] - chat_turn_sep = chat_sep = '\n' - humanstr = None - botstr = None - elif prompt_type in [PromptType.instruct.value, str(PromptType.instruct.value), - PromptType.instruct.name] + [PromptType.instruct_with_end.value, - str(PromptType.instruct_with_end.value), - PromptType.instruct_with_end.name]: - promptA = 'Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.\n' if not ( - chat and reduced) else '' - promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not ( - chat and reduced) else '' - - PreInstruct = """ -### Instruction: -""" - - PreInput = """ -### Input: -""" - - PreResponse = """ -### Response: -""" - if prompt_type in [PromptType.instruct_with_end.value, str(PromptType.instruct_with_end.value), - PromptType.instruct_with_end.name]: - terminate_response = ['### End'] - else: - terminate_response = None - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.quality.value, str(PromptType.quality.value), - PromptType.quality.name]: - promptA = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction as applied on the Input.\n' if not ( - chat and reduced) else '' - promptB = 'Write a detailed high-quality, accurate, fair, Response with about 100 words by following the Instruction.\n' if not ( - chat and reduced) else '' - - PreInstruct = """ -### Instruction: -""" - - PreInput = """ -### Input: -""" - - PreResponse = """ -### Response: -""" - terminate_response = None - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct # first thing human says - botstr = PreResponse # first thing bot says - elif prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value), - PromptType.human_bot.name] + [PromptType.human_bot_orig.value, - str(PromptType.human_bot_orig.value), - PromptType.human_bot_orig.name]: - human = ':' - bot = ":" - if reduced or context or prompt_type in [PromptType.human_bot.value, str(PromptType.human_bot.value), - PromptType.human_bot.name]: - preprompt = '' - else: - cur_date = time.strftime('%Y-%m-%d') - cur_time = time.strftime('%H:%M:%S %p %Z') - - PRE_PROMPT = """\ -Current Date: {} -Current Time: {} - -""" - preprompt = PRE_PROMPT.format(cur_date, cur_time) - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - - PreInstruct = human + ' ' - - PreInput = None - - if making_context: - # when making context, want it to appear as-if LLM generated, which starts with space after : - PreResponse = bot + ' ' - else: - # normally LLM adds space after this, because was how trained. - # if add space here, non-unique tokenization will often make LLM produce wrong output - PreResponse = bot - - terminate_response = ['\n' + human, '\n' + bot, human, bot, PreResponse] - chat_turn_sep = chat_sep = '\n' - humanstr = human # tag before human talks - botstr = bot # tag before bot talks - generates_leading_space = True - elif prompt_type in [PromptType.dai_faq.value, str(PromptType.dai_faq.value), - PromptType.dai_faq.name]: - promptA = '' - promptB = 'Answer the following Driverless AI question.\n' - - PreInstruct = """ -### Driverless AI frequently asked question: -""" - - PreInput = None - - PreResponse = """ -### Driverless AI documentation answer: -""" - terminate_response = ['\n\n'] - chat_turn_sep = chat_sep = terminate_response - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.summarize.value, str(PromptType.summarize.value), - PromptType.summarize.name]: - promptA = promptB = PreInput = '' - PreInstruct = '## Main Text\n\n' - PreResponse = '\n\n## Summary\n\n' - terminate_response = None - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.instruct_vicuna.value, str(PromptType.instruct_vicuna.value), - PromptType.instruct_vicuna.name]: - promptA = promptB = "A chat between a curious human and an artificial intelligence assistant. " \ - "The assistant gives helpful, detailed, and polite answers to the human's questions." if not ( - chat and reduced) else '' - - PreInstruct = """ -### Human: -""" - - PreInput = None - - PreResponse = """ -### Assistant: -""" - terminate_response = [ - '### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.prompt_answer.value, str(PromptType.prompt_answer.value), - PromptType.prompt_answer.name]: - preprompt = '' - prompt_tokens = "<|prompt|>" - answer_tokens = "<|answer|>" - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = prompt_tokens - PreInput = None - PreResponse = answer_tokens - eos = '<|endoftext|>' # neox eos - humanstr = prompt_tokens - botstr = answer_tokens - terminate_response = [humanstr, PreResponse, eos] - chat_sep = eos - chat_turn_sep = eos - elif prompt_type in [PromptType.prompt_answer_openllama.value, str(PromptType.prompt_answer_openllama.value), - PromptType.prompt_answer_openllama.name]: - preprompt = '' - prompt_tokens = "<|prompt|>" - answer_tokens = "<|answer|>" - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = prompt_tokens - PreInput = None - PreResponse = answer_tokens - eos = '' # llama eos - humanstr = prompt_tokens - botstr = answer_tokens - terminate_response = [humanstr, PreResponse, eos] - chat_sep = eos - chat_turn_sep = eos - elif prompt_type in [PromptType.open_assistant.value, str(PromptType.open_assistant.value), - PromptType.open_assistant.name]: - # From added_tokens.json - preprompt = '' - prompt_tokens = "<|prompter|>" - answer_tokens = "<|assistant|>" - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = prompt_tokens - PreInput = None - PreResponse = answer_tokens - pend = "<|prefix_end|>" - eos = "" - humanstr = prompt_tokens - botstr = answer_tokens - terminate_response = [humanstr, PreResponse, pend, eos] - chat_turn_sep = chat_sep = eos - elif prompt_type in [PromptType.wizard_lm.value, str(PromptType.wizard_lm.value), - PromptType.wizard_lm.name]: - # https://github.com/ehartford/WizardLM/blob/main/src/train_freeform.py - preprompt = '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = "" - PreInput = None - PreResponse = "\n\n### Response\n" - eos = "" - terminate_response = [PreResponse, eos] - chat_turn_sep = chat_sep = eos - humanstr = promptA - botstr = PreResponse - elif prompt_type in [PromptType.wizard_mega.value, str(PromptType.wizard_mega.value), - PromptType.wizard_mega.name]: - preprompt = '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = """ -### Instruction: -""" - PreInput = None - PreResponse = """ -### Assistant: -""" - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.instruct_vicuna2.value, str(PromptType.instruct_vicuna2.value), - PromptType.instruct_vicuna2.name]: - promptA = promptB = "" if not (chat and reduced) else '' - - PreInstruct = """ -HUMAN: -""" - - PreInput = None - - PreResponse = """ -ASSISTANT: -""" - terminate_response = [ - 'HUMAN:'] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.instruct_vicuna3.value, str(PromptType.instruct_vicuna3.value), - PromptType.instruct_vicuna3.name]: - promptA = promptB = "" if not (chat and reduced) else '' - - PreInstruct = """ -### User: -""" - - PreInput = None - - PreResponse = """ -### Assistant: -""" - terminate_response = [ - '### User:'] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.wizard2.value, str(PromptType.wizard2.value), - PromptType.wizard2.name]: - # https://huggingface.co/TheBloke/WizardLM-7B-uncensored-GGML - preprompt = """Below is an instruction that describes a task. Write a response that appropriately completes the request.""" if not ( - chat and reduced) else '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = """ -### Instruction: -""" - PreInput = None - PreResponse = """ -### Response: -""" - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.wizard3.value, str(PromptType.wizard3.value), - PromptType.wizard3.name]: - # https://huggingface.co/TheBloke/wizardLM-13B-1.0-GGML - preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions.""" if not ( - chat and reduced) else '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = """USER: """ - PreInput = None - PreResponse = """ASSISTANT: """ - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.wizard_vicuna.value, str(PromptType.wizard_vicuna.value), - PromptType.wizard_vicuna.name]: - preprompt = '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = """USER: """ - PreInput = None - PreResponse = """ASSISTANT: """ - terminate_response = [PreResponse] - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - - elif prompt_type in [PromptType.instruct_simple.value, str(PromptType.instruct_simple.value), - PromptType.instruct_simple.name]: - promptB = promptA = '' if not (chat and reduced) else '' - - PreInstruct = """ -### Instruction: -""" - - PreInput = """ -### Input: -""" - - PreResponse = """ -### Response: -""" - terminate_response = None - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.openai.value, str(PromptType.openai.value), - PromptType.openai.name]: - preprompt = """The following is a conversation with an AI assistant. The assistant is helpful, creative, clever, and very friendly.""" if not ( - chat and reduced) else '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = "\nHuman: " - PreInput = None - PreResponse = "\nAI:" - terminate_response = [PreResponse] + [" Human:", " AI:"] - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.gptj.value, str(PromptType.gptj.value), - PromptType.gptj.name]: - preprompt = "### Instruction:\n The prompt below is a question to answer, a task to complete, or a conversation to respond to; decide which and write an appropriate response." if not ( - chat and reduced) else '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = "\n### Prompt: " - PreInput = None - PreResponse = "\n### Response: " - terminate_response = [PreResponse] + ["Prompt:", "Response:"] - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.openai_chat.value, str(PromptType.openai_chat.value), - PromptType.openai_chat.name]: - # prompting and termination all handled by endpoint - preprompt = """""" - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - PreInstruct = "" - PreInput = None - PreResponse = "" - terminate_response = [] - chat_turn_sep = chat_sep = '\n' - humanstr = None - botstr = None - elif prompt_type in [PromptType.vicuna11.value, str(PromptType.vicuna11.value), - PromptType.vicuna11.name]: - preprompt = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. """ if not ( - chat and reduced) else '' - start = '' - promptB = promptA = '%s%s' % (preprompt, start) - eos = '' - PreInstruct = """USER: """ - PreInput = None - PreResponse = """ASSISTANT:""" - terminate_response = [PreResponse] - chat_sep = ' ' - chat_turn_sep = eos - humanstr = PreInstruct - botstr = PreResponse - - if making_context: - # when making context, want it to appear as-if LLM generated, which starts with space after : - PreResponse = PreResponse + ' ' - else: - # normally LLM adds space after this, because was how trained. - # if add space here, non-unique tokenization will often make LLM produce wrong output - PreResponse = PreResponse - elif prompt_type in [PromptType.mptinstruct.value, str(PromptType.mptinstruct.value), - PromptType.mptinstruct.name]: - # https://huggingface.co/mosaicml/mpt-30b-instruct#formatting - promptA = promptB = 'Below is an instruction that describes a task. Write a response that appropriately completes the request.\n' if not ( - chat and reduced) else '' - - PreInstruct = """ -### Instruction -""" - - PreInput = """ -### Input -""" - - PreResponse = """ -### Response -""" - terminate_response = None - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.mptchat.value, str(PromptType.mptchat.value), - PromptType.mptchat.name]: - # https://huggingface.co/TheBloke/mpt-30B-chat-GGML#prompt-template - promptA = promptB = """<|im_start|>system\nA conversation between a user and an LLM-based AI assistant. The assistant gives helpful and honest answers.\n<|im_end|>""" if not ( - chat and reduced) else '' - - PreInstruct = """<|im_start|>user -""" - - PreInput = None - - PreResponse = """<|im_end|><|im_start|>assistant -""" - terminate_response = ['<|im_end|>'] - chat_sep = '' - chat_turn_sep = '<|im_end|>' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.falcon.value, str(PromptType.falcon.value), - PromptType.falcon.name]: - promptA = promptB = "" if not (chat and reduced) else '' - - PreInstruct = """User: """ - - PreInput = None - - PreResponse = """Assistant:""" - terminate_response = ['\nUser', "<|endoftext|>"] - chat_sep = '\n\n' - chat_turn_sep = '\n\n' - humanstr = PreInstruct - botstr = PreResponse - if making_context: - # when making context, want it to appear as-if LLM generated, which starts with space after : - PreResponse = 'Assistant: ' - else: - # normally LLM adds space after this, because was how trained. - # if add space here, non-unique tokenization will often make LLM produce wrong output - PreResponse = PreResponse - # generates_leading_space = True - elif prompt_type in [PromptType.guanaco.value, str(PromptType.guanaco.value), - PromptType.guanaco.name]: - # https://huggingface.co/TheBloke/guanaco-65B-GPTQ - promptA = promptB = "" if not (chat and reduced) else '' - - PreInstruct = """### Human: """ - - PreInput = None - - PreResponse = """### Assistant:""" - terminate_response = ['### Human:'] # but only allow terminate after prompt is found correctly, else can't terminate - chat_turn_sep = chat_sep = '\n' - humanstr = PreInstruct - botstr = PreResponse - elif prompt_type in [PromptType.llama2.value, str(PromptType.llama2.value), - PromptType.llama2.name]: - PreInstruct = "" - llama2_sys = "<>\nYou are a helpful, respectful and honest assistant. Always answer as helpfully as possible, while being safe. Your answers should not include any harmful, unethical, racist, sexist, toxic, dangerous, or illegal content. Please ensure that your responses are socially unbiased and positive in nature.\n\nIf a question does not make any sense, or is not factually coherent, explain why instead of answering something not correct. If you don't know the answer to a question, please don't share false information.\n<>\n\n" - prompt = "[INST] " - enable_sys = False # too much safety, hurts accuracy - if not (chat and reduced): - if enable_sys: - promptA = promptB = prompt + llama2_sys - else: - promptA = promptB = prompt - else: - promptA = promptB = '' - PreInput = None - PreResponse = "" - terminate_response = ["[INST]", ""] - chat_sep = ' [/INST]' - chat_turn_sep = ' [INST] ' - humanstr = PreInstruct - botstr = PreResponse - if making_context: - PreResponse += " " - else: - raise RuntimeError("No such prompt_type=%s" % prompt_type) - - if isinstance(terminate_response, (tuple, list)): - assert '' not in terminate_response, "Bad terminate_response" - - ret_dict = dict(promptA=promptA, promptB=promptB, PreInstruct=PreInstruct, PreInput=PreInput, - PreResponse=PreResponse, terminate_response=terminate_response, chat_sep=chat_sep, - chat_turn_sep=chat_turn_sep, - humanstr=humanstr, botstr=botstr, - generates_leading_space=generates_leading_space) - - if return_dict: - return ret_dict, prompt_dict_error - else: - return tuple(list(ret_dict.values())) - - -def generate_prompt(data_point, prompt_type, prompt_dict, chat, reduced, making_context): - context = data_point.get('context') - if context is None: - context = '' - instruction = data_point.get('instruction') - input = data_point.get('input') - output = data_point.get('output') - prompt_type = data_point.get('prompt_type', prompt_type) - prompt_dict = data_point.get('prompt_dict', prompt_dict) - assert prompt_type in prompt_types, "Bad prompt type: %s" % prompt_type - promptA, promptB, PreInstruct, PreInput, PreResponse, \ - terminate_response, chat_sep, chat_turn_sep, humanstr, botstr, \ - generates_leading_space = get_prompt(prompt_type, prompt_dict, chat, - context, reduced, making_context) - - # could avoid if reduce=True, but too complex for parent functions to handle - prompt = context - - if input and promptA: - prompt += f"""{promptA}""" - elif promptB: - prompt += f"""{promptB}""" - - if instruction and PreInstruct is not None and input and PreInput is not None: - prompt += f"""{PreInstruct}{instruction}{PreInput}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif instruction and input and PreInstruct is None and PreInput is not None: - prompt += f"""{PreInput}{instruction} -{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction and PreInput is None and PreInstruct is not None: - prompt += f"""{PreInstruct}{instruction} -{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif instruction and PreInstruct is not None: - prompt += f"""{PreInstruct}{instruction}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and PreInput is not None: - prompt += f"""{PreInput}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction and PreInput is not None: - prompt += f"""{PreInput}{instruction}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction and PreInstruct is not None: - prompt += f"""{PreInstruct}{instruction}{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input and instruction: - # i.e. for simple_instruct - prompt += f"""{instruction}: {input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif input: - prompt += f"""{input}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - elif instruction: - prompt += f"""{instruction}""" - prompt = inject_chatsep(prompt_type, prompt, chat_sep=chat_sep) - - if PreResponse is not None: - prompt += f"""{PreResponse}""" - pre_response = PreResponse # Don't use strip - else: - pre_response = '' - - if output: - prompt += f"""{output}""" - - return prompt, pre_response, terminate_response, chat_sep, chat_turn_sep - - -def inject_chatsep(prompt_type, prompt, chat_sep=None): - if chat_sep: - # only add new line if structured prompt, while 'plain' is just generation of next tokens from input - prompt += chat_sep - return prompt - - -class Prompter(object): - def __init__(self, prompt_type, prompt_dict, debug=False, chat=False, stream_output=False, repeat_penalty=True, - allowed_repeat_line_length=10): - self.prompt_type = prompt_type - self.prompt_dict = prompt_dict - self.debug = debug - self.chat = chat - self.stream_output = stream_output - self.repeat_penalty = repeat_penalty - self.allowed_repeat_line_length = allowed_repeat_line_length - self.prompt = None - context = "" # not for chat context - reduced = False # not for chat context - making_context = False # not for chat context - self.promptA, self.promptB, self.PreInstruct, self.PreInput, self.PreResponse, \ - self.terminate_response, self.chat_sep, self.chat_turn_sep, self.humanstr, self.botstr, \ - self.generates_leading_space = \ - get_prompt(self.prompt_type, self.prompt_dict, chat, context, reduced, making_context) - self.pre_response = self.PreResponse - - def generate_prompt(self, data_point, reduced=None): - """ - data_point['context'] is assumed to be like a system prompt or pre-conversation, not inserted after user prompt - :param data_point: - :param reduced: - :return: - """ - reduced = data_point.get('context') not in ['', None] if reduced is None else reduced - making_context = False # whether really making final prompt or just generating context - prompt, _, _, _, _ = generate_prompt(data_point, self.prompt_type, self.prompt_dict, self.chat, reduced, - making_context) - if self.debug: - print("prompt: %s" % prompt, flush=True) - # if have context, should have always reduced and only preappend promptA/B here - if data_point.get('context'): - if data_point.get('input') and self.promptA: - prompt = self.promptA + prompt - elif self.promptB: - prompt = self.promptB + prompt - - self.prompt = prompt - return prompt - - def get_response(self, outputs, prompt=None, sanitize_bot_response=False): - if isinstance(outputs, str): - outputs = [outputs] - if self.debug: - print("output:\n%s" % '\n\n'.join(outputs), flush=True) - if prompt is not None: - self.prompt = prompt - - def clean_response(response): - meaningless_words = ['', '', '<|endoftext|>'] - for word in meaningless_words: - response = response.replace(word, "") - if sanitize_bot_response: - from better_profanity import profanity - response = profanity.censor(response) - if self.generates_leading_space and isinstance(response, str) and len(response) > 0 and response[0] == ' ': - response = response[1:] - return response - - def clean_repeats(response): - lines = response.split('\n') - new_lines = [] - [new_lines.append(line) for line in lines if - line not in new_lines or len(line) < self.allowed_repeat_line_length] - if self.debug and len(lines) != len(new_lines): - print("cleaned repeats: %s %s" % (len(lines), len(new_lines)), flush=True) - response = '\n'.join(new_lines) - return response - - multi_output = len(outputs) > 1 - - for oi, output in enumerate(outputs): - if self.prompt_type in [PromptType.plain.value, str(PromptType.plain.value), PromptType.plain.name]: - output = clean_response(output) - elif prompt is None: - # then use most basic parsing like pipeline - if not self.botstr: - pass - elif self.botstr in output: - if self.humanstr: - output = clean_response(output.split(self.botstr)[1].split(self.humanstr)[0]) - else: - # i.e. use after bot but only up to next bot - output = clean_response(output.split(self.botstr)[1].split(self.botstr)[0]) - else: - # output = clean_response(output) - # assume just not printed yet - output = "" - else: - # find first instance of prereponse - # prompt sometimes has odd characters, that mutate length, - # so can't go by length alone - if self.pre_response: - outputi = output.find(prompt) - if outputi >= 0: - output = output[outputi + len(prompt):] - allow_terminate = True - else: - # subtraction is risky due to space offsets sometimes, so only do if necessary - output = output[len(prompt) - len(self.pre_response):] - # [1] to avoid repeated pre_response, just take first (after prompt - pre_response for chat) - if self.pre_response in output: - output = output.split(self.pre_response)[1] - allow_terminate = True - else: - if output: - print("Failure of parsing or not enough output yet: %s" % output, flush=True) - allow_terminate = False - else: - allow_terminate = True - output = output[len(prompt):] - # clean after subtract prompt out, so correct removal of pre_response - output = clean_response(output) - if self.repeat_penalty: - output = clean_repeats(output) - if self.terminate_response and allow_terminate: - finds = [] - for term in self.terminate_response: - finds.append(output.find(term)) - finds = [x for x in finds if x >= 0] - if len(finds) > 0: - termi = finds[0] - output = output[:termi] - else: - output = output - if multi_output: - # prefix with output counter - output = "\n=========== Output %d\n\n" % (1 + oi) + output - if oi > 0: - # post fix outputs with seperator - output += '\n' - output = self.fix_text(self.prompt_type, output) - outputs[oi] = output - # join all outputs, only one extra new line between outputs - output = '\n'.join(outputs) - if self.debug: - print("outputclean:\n%s" % '\n\n'.join(outputs), flush=True) - return output - - @staticmethod - def fix_text(prompt_type1, text1): - if prompt_type1 == 'human_bot': - # hack bug in vLLM with stopping, stops right, but doesn't return last token - hfix = ' bool: - for stopi, stop in enumerate(self.stops): - if torch.all((stop == input_ids[0][-len(stop):])).item(): - self.num_stops[stopi] += 1 - if self.num_stops[stopi] >= self.encounters[stopi % len(self.encounters)]: - # print("Stopped", flush=True) - return True - if self.model_max_length is not None and input_ids[0].shape[0] >= self.model_max_length: - # critical limit - return True - # print("Tokens: %s" % input_ids[0].cpu().numpy(), flush=True) - # print("Stop Tokens: %s" % [x.cpu().numpy() for x in self.stops], flush=True) - return False - - -def get_stopping(prompt_type, prompt_dict, tokenizer, device, human=':', bot=":", model_max_length=None): - # FIXME: prompt_dict unused currently - if prompt_type in [PromptType.human_bot.name, PromptType.instruct_vicuna.name, PromptType.instruct_with_end.name]: - if prompt_type == PromptType.human_bot.name: - # encounters = [prompt.count(human) + 1, prompt.count(bot) + 1] - # stopping only starts once output is beyond prompt - # 1 human is enough to trigger, but need 2 bots, because very first view back will be bot we added - stop_words = [human, bot, '\n' + human, '\n' + bot] - encounters = [1, 2] - elif prompt_type == PromptType.instruct_vicuna.name: - # even below is not enough, generic strings and many ways to encode - stop_words = [ - '### Human:', - """ -### Human:""", - """ -### Human: -""", - '### Assistant:', - """ -### Assistant:""", - """ -### Assistant: -""", - ] - encounters = [1, 2] - else: - # some instruct prompts have this as end, doesn't hurt to stop on it since not common otherwise - stop_words = ['### End'] - encounters = [1] - stop_words_ids = [ - tokenizer(stop_word, return_tensors='pt')['input_ids'].squeeze() for stop_word in stop_words] - # handle single token case - stop_words_ids = [x if len(x.shape) > 0 else torch.tensor([x]) for x in stop_words_ids] - stop_words_ids = [x for x in stop_words_ids if x.shape[0] > 0] - # avoid padding in front of tokens - if tokenizer._pad_token: # use hidden variable to avoid annoying properly logger bug - stop_words_ids = [x[1:] if x[0] == tokenizer.pad_token_id and len(x) > 1 else x for x in stop_words_ids] - # handle fake \n added - stop_words_ids = [x[1:] if y[0] == '\n' else x for x, y in zip(stop_words_ids, stop_words)] - # build stopper - stopping_criteria = StoppingCriteriaList( - [StoppingCriteriaSub(stops=stop_words_ids, encounters=encounters, device=device, - model_max_length=model_max_length)]) - else: - stopping_criteria = StoppingCriteriaList() - return stopping_criteria diff --git a/utils.py b/utils.py deleted file mode 100644 index f8d57c85ad3da1b22c7df24f4821a653c27ae239..0000000000000000000000000000000000000000 --- a/utils.py +++ /dev/null @@ -1,1080 +0,0 @@ -import contextlib -import functools -import hashlib -import inspect -import os -import gc -import pathlib -import pickle -import random -import shutil -import subprocess -import sys -import threading -import time -import traceback -import zipfile -from datetime import datetime - -import filelock -import requests, uuid -from typing import Tuple, Callable, Dict -from tqdm.auto import tqdm -from joblib import Parallel -from concurrent.futures import ProcessPoolExecutor -import numpy as np -import pandas as pd - - -def set_seed(seed: int): - """ - Sets the seed of the entire notebook so results are the same every time we run. - This is for REPRODUCIBILITY. - """ - import torch - np.random.seed(seed) - random_state = np.random.RandomState(seed) - random.seed(seed) - torch.manual_seed(seed) - torch.cuda.manual_seed(seed) - torch.backends.cudnn.deterministic = True - torch.backends.cudnn.benchmark = False - os.environ['PYTHONHASHSEED'] = str(seed) - return random_state - - -def flatten_list(lis): - """Given a list, possibly nested to any level, return it flattened.""" - new_lis = [] - for item in lis: - if type(item) == type([]): - new_lis.extend(flatten_list(item)) - else: - new_lis.append(item) - return new_lis - - -def clear_torch_cache(): - import torch - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - gc.collect() - - -def ping(): - try: - print('Ping: %s' % str(datetime.now()), flush=True) - except AttributeError: - # some programs wrap print and will fail with flush passed - pass - - -def ping_gpu(): - try: - print('Ping_GPU: %s %s' % (str(datetime.now()), system_info()), flush=True) - except AttributeError: - # some programs wrap print and will fail with flush passed - pass - try: - ping_gpu_memory() - except Exception as e: - print('Ping_GPU memory failure: %s' % str(e), flush=True) - - -def ping_gpu_memory(): - from models.gpu_mem_track import MemTracker - gpu_tracker = MemTracker() # define a GPU tracker - from torch.cuda import memory_summary - gpu_tracker.track() - - -def get_torch_allocated(): - import torch - return torch.cuda.memory_allocated() - - -def get_device(): - import torch - if torch.cuda.is_available(): - device = "cuda" - elif torch.backends.mps.is_built(): - device = "mps" - else: - device = "cpu" - - return device - - -def system_info(): - import psutil - - system = {} - # https://stackoverflow.com/questions/48951136/plot-multiple-graphs-in-one-plot-using-tensorboard - # https://arshren.medium.com/monitoring-your-devices-in-python-5191d672f749 - try: - temps = psutil.sensors_temperatures(fahrenheit=False) - if 'coretemp' in temps: - coretemp = temps['coretemp'] - temp_dict = {k.label: k.current for k in coretemp} - for k, v in temp_dict.items(): - system['CPU_C/%s' % k] = v - except AttributeError: - pass - - # https://github.com/gpuopenanalytics/pynvml/blob/master/help_query_gpu.txt - try: - from pynvml.smi import nvidia_smi - nvsmi = nvidia_smi.getInstance() - - gpu_power_dict = {'W_gpu%d' % i: x['power_readings']['power_draw'] for i, x in - enumerate(nvsmi.DeviceQuery('power.draw')['gpu'])} - for k, v in gpu_power_dict.items(): - system['GPU_W/%s' % k] = v - - gpu_temp_dict = {'C_gpu%d' % i: x['temperature']['gpu_temp'] for i, x in - enumerate(nvsmi.DeviceQuery('temperature.gpu')['gpu'])} - for k, v in gpu_temp_dict.items(): - system['GPU_C/%s' % k] = v - - gpu_memory_free_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['free'] for i, x in - enumerate(nvsmi.DeviceQuery('memory.free')['gpu'])} - gpu_memory_total_dict = {'MiB_gpu%d' % i: x['fb_memory_usage']['total'] for i, x in - enumerate(nvsmi.DeviceQuery('memory.total')['gpu'])} - gpu_memory_frac_dict = {k: gpu_memory_free_dict[k] / gpu_memory_total_dict[k] for k in gpu_memory_total_dict} - for k, v in gpu_memory_frac_dict.items(): - system[f'GPU_M/%s' % k] = v - except (KeyError, ModuleNotFoundError): - pass - system['hash'] = get_githash() - - return system - - -def system_info_print(): - try: - df = pd.DataFrame.from_dict(system_info(), orient='index') - # avoid slamming GPUs - time.sleep(1) - return df.to_markdown() - except Exception as e: - return "Error: %s" % str(e) - - -def zip_data(root_dirs=None, zip_file=None, base_dir='./', fail_any_exception=False): - try: - return _zip_data(zip_file=zip_file, base_dir=base_dir, root_dirs=root_dirs) - except Exception as e: - traceback.print_exc() - print('Exception in zipping: %s' % str(e)) - if not fail_any_exception: - raise - - -def _zip_data(root_dirs=None, zip_file=None, base_dir='./'): - if isinstance(root_dirs, str): - root_dirs = [root_dirs] - if zip_file is None: - datetime_str = str(datetime.now()).replace(" ", "_").replace(":", "_") - host_name = os.getenv('HF_HOSTNAME', 'emptyhost') - zip_file = "data_%s_%s.zip" % (datetime_str, host_name) - assert root_dirs is not None - if not os.path.isdir(os.path.dirname(zip_file)) and os.path.dirname(zip_file): - os.makedirs(os.path.dirname(zip_file), exist_ok=True) - with zipfile.ZipFile(zip_file, "w") as expt_zip: - for root_dir in root_dirs: - if root_dir is None: - continue - for root, d, files in os.walk(root_dir): - for file in files: - file_to_archive = os.path.join(root, file) - assert os.path.exists(file_to_archive) - path_to_archive = os.path.relpath(file_to_archive, base_dir) - expt_zip.write(filename=file_to_archive, arcname=path_to_archive) - return zip_file, zip_file - - -def save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from', - extra_dict={}): - try: - return _save_generate_output(prompt=prompt, output=output, base_model=base_model, save_dir=save_dir, - where_from=where_from, extra_dict=extra_dict) - except Exception as e: - traceback.print_exc() - print('Exception in saving: %s' % str(e)) - - -def _save_generate_output(prompt=None, output=None, base_model=None, save_dir=None, where_from='unknown where from', - extra_dict={}): - """ - Save conversation to .json, row by row. - json_file_path is path to final JSON file. If not in ., then will attempt to make directories. - Appends if file exists - """ - prompt = '' if prompt is None else prompt - output = '' if output is None else output - assert save_dir, "save_dir must be provided" - if os.path.exists(save_dir) and not os.path.isdir(save_dir): - raise RuntimeError("save_dir already exists and is not a directory!") - os.makedirs(save_dir, exist_ok=True) - import json - dict_to_save = dict(prompt=prompt, text=output, time=time.ctime(), base_model=base_model, where_from=where_from) - dict_to_save.update(extra_dict) - with filelock.FileLock("save_dir.lock"): - # lock logging in case have concurrency - with open(os.path.join(save_dir, "history.json"), "a") as f: - # just add [ at start, and ] at end, and have proper JSON dataset - f.write( - " " + json.dumps( - dict_to_save - ) + ",\n" - ) - - -def s3up(filename): - try: - return _s3up(filename) - except Exception as e: - traceback.print_exc() - print('Exception for file %s in s3up: %s' % (filename, str(e))) - return "Failed to upload %s: Error: %s" % (filename, str(e)) - - -def _s3up(filename): - import boto3 - - aws_access_key_id = os.getenv('AWS_SERVER_PUBLIC_KEY') - aws_secret_access_key = os.getenv('AWS_SERVER_SECRET_KEY') - bucket = os.getenv('AWS_BUCKET') - assert aws_access_key_id, "Set AWS key" - assert aws_secret_access_key, "Set AWS secret" - assert bucket, "Set AWS Bucket" - - s3 = boto3.client('s3', - aws_access_key_id=os.getenv('AWS_SERVER_PUBLIC_KEY'), - aws_secret_access_key=os.getenv('AWS_SERVER_SECRET_KEY'), - ) - ret = s3.upload_file( - Filename=filename, - Bucket=os.getenv('AWS_BUCKET'), - Key=filename, - ) - if ret in [None, '']: - return "Successfully uploaded %s" % filename - - -def get_githash(): - try: - githash = subprocess.run(['git', 'rev-parse', 'HEAD'], stdout=subprocess.PIPE).stdout.decode('utf-8')[0:-1] - except: - githash = '' - return githash - - -def copy_code(run_id): - """ - copy code to track changes - :param run_id: - :return: - """ - rnd_num = str(random.randint(0, 2 ** 31)) - run_id = 'run_' + str(run_id) - os.makedirs(run_id, exist_ok=True) - me_full = os.path.join(pathlib.Path(__file__).parent.resolve(), __file__) - me_file = os.path.basename(__file__) - new_me = os.path.join(run_id, me_file + '_' + get_githash()) - if os.path.isfile(new_me): - new_me = os.path.join(run_id, me_file + '_' + get_githash() + '_' + rnd_num) - shutil.copy(me_full, new_me) - else: - shutil.copy(me_full, new_me) - - -class NullContext(threading.local): - """No-op context manager, executes block without doing any additional processing. - - Used as a stand-in if a particular block of code is only sometimes - used with a normal context manager: - """ - - def __init__(self, *args, **kwargs): - pass - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.finally_act() - - def finally_act(self): - pass - - -def wrapped_partial(func, *args, **kwargs): - """ - Give partial properties of normal function, like __name__ attribute etc. - :param func: - :param args: - :param kwargs: - :return: - """ - partial_func = functools.partial(func, *args, **kwargs) - functools.update_wrapper(partial_func, func) - return partial_func - - -class ThreadException(Exception): - pass - - -class EThread(threading.Thread): - # Function that raises the custom exception - def __init__(self, group=None, target=None, name=None, - args=(), kwargs=None, *, daemon=None, streamer=None, bucket=None): - self.bucket = bucket - self.streamer = streamer - self.exc = None - self._return = None - super().__init__(group=group, target=target, name=name, args=args, kwargs=kwargs, daemon=daemon) - - def run(self): - # Variable that stores the exception, if raised by someFunction - try: - if self._target is not None: - self._return = self._target(*self._args, **self._kwargs) - except BaseException as e: - print("thread exception: %s" % str(sys.exc_info())) - self.bucket.put(sys.exc_info()) - self.exc = e - if self.streamer: - print("make stop: %s" % str(sys.exc_info()), flush=True) - self.streamer.do_stop = True - finally: - # Avoid a refcycle if the thread is running a function with - # an argument that has a member that points to the thread. - del self._target, self._args, self._kwargs - - def join(self, timeout=None): - threading.Thread.join(self) - # Since join() returns in caller thread - # we re-raise the caught exception - # if any was caught - if self.exc: - raise self.exc - return self._return - - -def import_matplotlib(): - import matplotlib - matplotlib.use('agg') - # KEEP THESE HERE! START - import matplotlib.pyplot as plt - import pandas as pd - # to avoid dlopen deadlock in fork - import pandas.core.computation.expressions as pd_expressions - import pandas._libs.groupby as pd_libgroupby - import pandas._libs.reduction as pd_libreduction - import pandas.core.algorithms as pd_algorithms - import pandas.core.common as pd_com - import numpy as np - # KEEP THESE HERE! END - - -def get_sha(value): - return hashlib.md5(str(value).encode('utf-8')).hexdigest() - - -def sanitize_filename(name): - """ - Sanitize file *base* names. - :param name: name to sanitize - :return: - """ - bad_chars = ['[', ']', ',', '/', '\\', '\\w', '\\s', '-', '+', '\"', '\'', '>', '<', ' ', '=', ')', '(', ':', '^'] - for char in bad_chars: - name = name.replace(char, "_") - - length = len(name) - file_length_limit = 250 # bit smaller than 256 for safety - sha_length = 32 - real_length_limit = file_length_limit - (sha_length + 2) - if length > file_length_limit: - sha = get_sha(name) - half_real_length_limit = max(1, int(real_length_limit / 2)) - name = name[0:half_real_length_limit] + "_" + sha + "_" + name[length - half_real_length_limit:length] - - return name - - -def shutil_rmtree(*args, **kwargs): - return shutil.rmtree(*args, **kwargs) - - -def remove(path: str): - try: - if path is not None and os.path.exists(path): - if os.path.isdir(path): - shutil_rmtree(path, ignore_errors=True) - else: - with contextlib.suppress(FileNotFoundError): - os.remove(path) - except: - pass - - -def makedirs(path, exist_ok=True): - """ - Avoid some inefficiency in os.makedirs() - :param path: - :param exist_ok: - :return: - """ - if os.path.isdir(path) and os.path.exists(path): - assert exist_ok, "Path already exists" - return path - os.makedirs(path, exist_ok=exist_ok) - - -def atomic_move_simple(src, dst): - try: - shutil.move(src, dst) - except (shutil.Error, FileExistsError): - pass - remove(src) - - -def download_simple(url, dest=None, print_func=None): - if print_func is not None: - print_func("BEGIN get url %s" % str(url)) - if url.startswith("file://"): - from requests_file import FileAdapter - s = requests.Session() - s.mount('file://', FileAdapter()) - url_data = s.get(url, stream=True) - else: - url_data = requests.get(url, stream=True) - if dest is None: - dest = os.path.basename(url) - if url_data.status_code != requests.codes.ok: - msg = "Cannot get url %s, code: %s, reason: %s" % ( - str(url), - str(url_data.status_code), - str(url_data.reason), - ) - raise requests.exceptions.RequestException(msg) - url_data.raw.decode_content = True - makedirs(os.path.dirname(dest), exist_ok=True) - uuid_tmp = str(uuid.uuid4())[:6] - dest_tmp = dest + "_dl_" + uuid_tmp + ".tmp" - with open(dest_tmp, "wb") as f: - shutil.copyfileobj(url_data.raw, f) - atomic_move_simple(dest_tmp, dest) - if print_func is not None: - print_func("END get url %s" % str(url)) - - -def download(url, dest=None, dest_path=None): - if dest_path is not None: - dest = os.path.join(dest_path, os.path.basename(url)) - if os.path.isfile(dest): - print("already downloaded %s -> %s" % (url, dest)) - return dest - elif dest is not None: - if os.path.exists(dest): - print("already downloaded %s -> %s" % (url, dest)) - return dest - else: - uuid_tmp = "dl2_" + str(uuid.uuid4())[:6] - dest = uuid_tmp + os.path.basename(url) - - print("downloading %s to %s" % (url, dest)) - - if url.startswith("file://"): - from requests_file import FileAdapter - s = requests.Session() - s.mount('file://', FileAdapter()) - url_data = s.get(url, stream=True) - else: - url_data = requests.get(url, stream=True) - - if url_data.status_code != requests.codes.ok: - msg = "Cannot get url %s, code: %s, reason: %s" % ( - str(url), str(url_data.status_code), str(url_data.reason)) - raise requests.exceptions.RequestException(msg) - url_data.raw.decode_content = True - dirname = os.path.dirname(dest) - if dirname != "" and not os.path.isdir(dirname): - makedirs(os.path.dirname(dest), exist_ok=True) - uuid_tmp = "dl3_" + str(uuid.uuid4())[:6] - dest_tmp = dest + "_" + uuid_tmp + ".tmp" - with open(dest_tmp, 'wb') as f: - shutil.copyfileobj(url_data.raw, f) - try: - shutil.move(dest_tmp, dest) - except FileExistsError: - pass - remove(dest_tmp) - return dest - - -def get_url(x, from_str=False, short_name=False): - if not from_str: - source = x.metadata['source'] - else: - source = x - if short_name: - source_name = get_short_name(source) - else: - source_name = source - if source.startswith('http://') or source.startswith('https://'): - return """
    %s""" % ( - source, source_name) - else: - return """%s""" % ( - source, source_name) - - -def get_short_name(name, maxl=50): - if name is None: - return '' - length = len(name) - if length > maxl: - allow_length = maxl - 3 - half_allowed = max(1, int(allow_length / 2)) - name = name[0:half_allowed] + "..." + name[length - half_allowed:length] - return name - - -def cuda_vis_check(total_gpus): - """Helper function to count GPUs by environment variable - Stolen from Jon's h2o4gpu utils - """ - cudavis = os.getenv("CUDA_VISIBLE_DEVICES") - which_gpus = [] - if cudavis is not None: - # prune away white-space, non-numerics, - # except commas for simple checking - cudavis = "".join(cudavis.split()) - import re - cudavis = re.sub("[^0-9,]", "", cudavis) - - lencudavis = len(cudavis) - if lencudavis == 0: - total_gpus = 0 - else: - total_gpus = min( - total_gpus, - os.getenv("CUDA_VISIBLE_DEVICES").count(",") + 1) - which_gpus = os.getenv("CUDA_VISIBLE_DEVICES").split(",") - which_gpus = [int(x) for x in which_gpus] - else: - which_gpus = list(range(0, total_gpus)) - - return total_gpus, which_gpus - - -def get_ngpus_vis(raise_if_exception=True): - ngpus_vis1 = 0 - - shell = False - if shell: - cmd = "nvidia-smi -L 2> /dev/null" - else: - cmd = ["nvidia-smi", "-L"] - - try: - timeout = 5 * 3 - o = subprocess.check_output(cmd, shell=shell, timeout=timeout) - lines = o.decode("utf-8").splitlines() - ngpus_vis1 = 0 - for line in lines: - if 'Failed to initialize NVML' not in line: - ngpus_vis1 += 1 - except (FileNotFoundError, subprocess.CalledProcessError, OSError): - # GPU systems might not have nvidia-smi, so can't fail - pass - except subprocess.TimeoutExpired as e: - print('Failed get_ngpus_vis: %s' % str(e)) - if raise_if_exception: - raise - - ngpus_vis1, which_gpus = cuda_vis_check(ngpus_vis1) - return ngpus_vis1 - - -def get_mem_gpus(raise_if_exception=True, ngpus=None): - totalmem_gpus1 = 0 - usedmem_gpus1 = 0 - freemem_gpus1 = 0 - - if ngpus == 0: - return totalmem_gpus1, usedmem_gpus1, freemem_gpus1 - - try: - cmd = "nvidia-smi -q 2> /dev/null | grep -A 3 'FB Memory Usage'" - o = subprocess.check_output(cmd, shell=True, timeout=15) - lines = o.decode("utf-8").splitlines() - for line in lines: - if 'Total' in line: - totalmem_gpus1 += int(line.split()[2]) * 1024 ** 2 - if 'Used' in line: - usedmem_gpus1 += int(line.split()[2]) * 1024 ** 2 - if 'Free' in line: - freemem_gpus1 += int(line.split()[2]) * 1024 ** 2 - except (FileNotFoundError, subprocess.CalledProcessError, OSError): - # GPU systems might not have nvidia-smi, so can't fail - pass - except subprocess.TimeoutExpired as e: - print('Failed get_mem_gpus: %s' % str(e)) - if raise_if_exception: - raise - - return totalmem_gpus1, usedmem_gpus1, freemem_gpus1 - - -class ForkContext(threading.local): - """ - Set context for forking - Ensures state is returned once done - """ - - def __init__(self, args=None, kwargs=None, forkdata_capable=True): - """ - :param args: - :param kwargs: - :param forkdata_capable: whether fork is forkdata capable and will use copy-on-write forking of args/kwargs - """ - self.forkdata_capable = forkdata_capable - if self.forkdata_capable: - self.has_args = args is not None - self.has_kwargs = kwargs is not None - forkdatacontext.args = args - forkdatacontext.kwargs = kwargs - else: - self.has_args = False - self.has_kwargs = False - - def __enter__(self): - try: - # flush all outputs so doesn't happen during fork -- don't print/log inside ForkContext contexts! - sys.stdout.flush() - sys.stderr.flush() - except BaseException as e: - # exit not called if exception, and don't want to leave forkdatacontext filled in that case - print("ForkContext failure on enter: %s" % str(e)) - self.finally_act() - raise - return self - - def __exit__(self, exc_type, exc_value, exc_traceback): - self.finally_act() - - def finally_act(self): - """ - Done when exception hit or exit is reached in context - first reset forkdatacontext as crucial to have reset even if later 2 calls fail - :return: None - """ - if self.forkdata_capable and (self.has_args or self.has_kwargs): - forkdatacontext._reset() - - -class _ForkDataContext(threading.local): - def __init__( - self, - args=None, - kwargs=None, - ): - """ - Global context for fork to carry data to subprocess instead of relying upon copy/pickle/serialization - - :param args: args - :param kwargs: kwargs - """ - assert isinstance(args, (tuple, type(None))) - assert isinstance(kwargs, (dict, type(None))) - self.__args = args - self.__kwargs = kwargs - - @property - def args(self) -> Tuple: - """returns args""" - return self.__args - - @args.setter - def args(self, args): - if self.__args is not None: - raise AttributeError( - "args cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs)) - ) - - self.__args = args - - @property - def kwargs(self) -> Dict: - """returns kwargs""" - return self.__kwargs - - @kwargs.setter - def kwargs(self, kwargs): - if self.__kwargs is not None: - raise AttributeError( - "kwargs cannot be overwritten: %s %s" % (str(self.__args), str(self.__kwargs)) - ) - - self.__kwargs = kwargs - - def _reset(self): - """Reset fork arg-kwarg context to default values""" - self.__args = None - self.__kwargs = None - - def get_args_kwargs(self, func, args, kwargs) -> Tuple[Callable, Tuple, Dict]: - if self.__args: - args = self.__args[1:] - if not func: - assert len(self.__args) > 0, "if have no func, must have in args" - func = self.__args[0] # should always be there - if self.__kwargs: - kwargs = self.__kwargs - try: - return func, args, kwargs - finally: - forkdatacontext._reset() - - @staticmethod - def get_args_kwargs_for_traced_func(func, args, kwargs): - """ - Return args/kwargs out of forkdatacontext when using copy-on-write way of passing args/kwargs - :param func: actual function ran by _traced_func, which itself is directly what mppool treats as function - :param args: - :param kwargs: - :return: func, args, kwargs from forkdatacontext if used, else originals - """ - # first 3 lines are debug - func_was_None = func is None - args_was_None_or_empty = args is None or len(args) == 0 - kwargs_was_None_or_empty = kwargs is None or len(kwargs) == 0 - - forkdatacontext_args_was_None = forkdatacontext.args is None - forkdatacontext_kwargs_was_None = forkdatacontext.kwargs is None - func, args, kwargs = forkdatacontext.get_args_kwargs(func, args, kwargs) - using_forkdatacontext = func_was_None and func is not None # pulled func out of forkdatacontext.__args[0] - assert forkdatacontext.args is None, "forkdatacontext.args should be None after get_args_kwargs" - assert forkdatacontext.kwargs is None, "forkdatacontext.kwargs should be None after get_args_kwargs" - - proc_type = kwargs.get('proc_type', 'SUBPROCESS') - if using_forkdatacontext: - assert proc_type == "SUBPROCESS" or proc_type == "SUBPROCESS" - if proc_type == "NORMAL": - assert forkdatacontext_args_was_None, "if no fork, expect forkdatacontext.args None entering _traced_func" - assert forkdatacontext_kwargs_was_None, "if no fork, expect forkdatacontext.kwargs None entering _traced_func" - assert func is not None, "function should not be None, indicates original args[0] was None or args was None" - - return func, args, kwargs - - -forkdatacontext = _ForkDataContext() - - -def _traced_func(func, *args, **kwargs): - func, args, kwargs = forkdatacontext.get_args_kwargs_for_traced_func(func, args, kwargs) - return func(*args, **kwargs) - - -def call_subprocess_onetask(func, args=None, kwargs=None): - import platform - if platform.system() in ['Darwin', 'Windows']: - return func(*args, **kwargs) - if isinstance(args, list): - args = tuple(args) - if args is None: - args = () - if kwargs is None: - kwargs = {} - args = list(args) - args = [func] + args - args = tuple(args) - with ForkContext(args=args, kwargs=kwargs): - args = (None,) - kwargs = {} - with ProcessPoolExecutor(max_workers=1) as executor: - future = executor.submit(_traced_func, *args, **kwargs) - return future.result() - - -class ProgressParallel(Parallel): - def __init__(self, use_tqdm=True, total=None, *args, **kwargs): - self._use_tqdm = use_tqdm - self._total = total - super().__init__(*args, **kwargs) - - def __call__(self, *args, **kwargs): - with tqdm(disable=not self._use_tqdm, total=self._total) as self._pbar: - return Parallel.__call__(self, *args, **kwargs) - - def print_progress(self): - if self._total is None: - self._pbar.total = self.n_dispatched_tasks - self._pbar.n = self.n_completed_tasks - self._pbar.refresh() - - -def get_kwargs(func, exclude_names=None, **kwargs): - func_names = list(inspect.signature(func).parameters) - missing_kwargs = [x for x in func_names if x not in kwargs] - if exclude_names: - for k in exclude_names: - if k in missing_kwargs: - missing_kwargs.remove(k) - if k in func_names: - func_names.remove(k) - assert not missing_kwargs, "Missing %s" % missing_kwargs - kwargs = {k: v for k, v in kwargs.items() if k in func_names} - return kwargs - - -import pkg_resources - -have_faiss = False - -try: - assert pkg_resources.get_distribution('faiss') is not None - have_faiss = True -except (pkg_resources.DistributionNotFound, AssertionError): - pass -try: - assert pkg_resources.get_distribution('faiss_gpu') is not None - have_faiss = True -except (pkg_resources.DistributionNotFound, AssertionError): - pass -try: - assert pkg_resources.get_distribution('faiss_cpu') is not None - have_faiss = True -except (pkg_resources.DistributionNotFound, AssertionError): - pass - - -def hash_file(file): - try: - import hashlib - - # BUF_SIZE is totally arbitrary, change for your app! - BUF_SIZE = 65536 # lets read stuff in 64kb chunks! - - md5 = hashlib.md5() - # sha1 = hashlib.sha1() - - with open(file, 'rb') as f: - while True: - data = f.read(BUF_SIZE) - if not data: - break - md5.update(data) - # sha1.update(data) - except BaseException as e: - print("Cannot hash %s due to %s" % (file, str(e))) - traceback.print_exc() - md5 = None - return md5.hexdigest() - - -def start_faulthandler(): - # If hit server or any subprocess with signal SIGUSR1, it'll print out all threads stack trace, but wont't quit or coredump - # If more than one fork tries to write at same time, then looks corrupted. - import faulthandler - - # SIGUSR1 in h2oai/__init__.py as well - faulthandler.enable() - if hasattr(faulthandler, 'register'): - # windows/mac - import signal - faulthandler.register(signal.SIGUSR1) - - -def get_hf_server(inference_server): - inf_split = inference_server.split(" ") - assert len(inf_split) == 1 or len(inf_split) == 3 - inference_server = inf_split[0] - if len(inf_split) == 3: - headers = {"authorization": "%s %s" % (inf_split[1], inf_split[2])} - else: - headers = None - return inference_server, headers - - -class FakeTokenizer: - """ - 1) For keeping track of model_max_length - 2) For when model doesn't directly expose tokenizer but need to count tokens - """ - - def __init__(self, model_max_length=2048, encoding_name="cl100k_base"): - # dont' push limit, since if using fake tokenizer, only estimate, and seen underestimates by order 250 - self.model_max_length = model_max_length - 250 - self.encoding_name = encoding_name - # The first time this runs, it will require an internet connection to download. Later runs won't need an internet connection. - import tiktoken - self.encoding = tiktoken.get_encoding(self.encoding_name) - - def encode(self, x, *args, return_tensors="pt", **kwargs): - input_ids = self.encoding.encode(x, disallowed_special=()) - if return_tensors == 'pt' and isinstance(input_ids, list): - import torch - input_ids = torch.tensor(input_ids) - return dict(input_ids=input_ids) - - def decode(self, x, *args, **kwargs): - # input is input_ids[0] form - return self.encoding.decode(x) - - def num_tokens_from_string(self, prompt: str) -> int: - """Returns the number of tokens in a text string.""" - num_tokens = len(self.encoding.encode(prompt)) - return num_tokens - - def __call__(self, x, *args, **kwargs): - return self.encode(x, *args, **kwargs) - - -def get_local_ip(): - import socket - s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM) - try: - # doesn't even have to be reachable - s.connect(('10.255.255.255', 1)) - IP = s.getsockname()[0] - except Exception: - IP = '127.0.0.1' - finally: - s.close() - return IP - - -try: - assert pkg_resources.get_distribution('langchain') is not None - have_langchain = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_langchain = False - -import distutils.spawn - -have_tesseract = distutils.spawn.find_executable("tesseract") -have_libreoffice = distutils.spawn.find_executable("libreoffice") - -import pkg_resources - -try: - assert pkg_resources.get_distribution('arxiv') is not None - assert pkg_resources.get_distribution('pymupdf') is not None - have_arxiv = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_arxiv = False - -try: - assert pkg_resources.get_distribution('pymupdf') is not None - have_pymupdf = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_pymupdf = False - -try: - assert pkg_resources.get_distribution('selenium') is not None - have_selenium = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_selenium = False - -try: - assert pkg_resources.get_distribution('playwright') is not None - have_playwright = True -except (pkg_resources.DistributionNotFound, AssertionError): - have_playwright = False - -# disable, hangs too often -have_playwright = False - - -def set_openai(inference_server): - if inference_server.startswith('vllm'): - import openai_vllm - openai_vllm.api_key = "EMPTY" - inf_type = inference_server.split(':')[0] - ip_vllm = inference_server.split(':')[1] - port_vllm = inference_server.split(':')[2] - openai_vllm.api_base = f"http://{ip_vllm}:{port_vllm}/v1" - return openai_vllm, inf_type - else: - import openai - openai.api_key = os.getenv("OPENAI_API_KEY") - openai.api_base = os.environ.get("OPENAI_API_BASE", "https://api.openai.com/v1") - inf_type = inference_server - return openai, inf_type - - -visible_langchain_modes_file = 'visible_langchain_modes.pkl' - - -def save_collection_names(langchain_modes, visible_langchain_modes, langchain_mode_paths, LangChainMode, db1s): - """ - extra controls if UserData type of MyData type - """ - - # use first default MyData hash as general user hash to maintain file - # if user moves MyData from langchain modes, db will still survive, so can still use hash - scratch_collection_names = list(db1s.keys()) - user_hash = db1s.get(LangChainMode.MY_DATA.value, '')[1] - - llms = ['LLM', 'Disabled'] - - scratch_langchain_modes = [x for x in langchain_modes if x in scratch_collection_names] - scratch_visible_langchain_modes = [x for x in visible_langchain_modes if x in scratch_collection_names] - scratch_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if - k in scratch_collection_names and k not in llms} - - user_langchain_modes = [x for x in langchain_modes if x not in scratch_collection_names] - user_visible_langchain_modes = [x for x in visible_langchain_modes if x not in scratch_collection_names] - user_langchain_mode_paths = {k: v for k, v in langchain_mode_paths.items() if - k not in scratch_collection_names and k not in llms} - - base_path = 'locks' - makedirs(base_path) - - # user - extra = '' - file = "%s%s" % (visible_langchain_modes_file, extra) - with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)): - with open(file, 'wb') as f: - pickle.dump((user_langchain_modes, user_visible_langchain_modes, user_langchain_mode_paths), f) - - # scratch - extra = user_hash - file = "%s%s" % (visible_langchain_modes_file, extra) - with filelock.FileLock(os.path.join(base_path, "%s.lock" % file)): - with open(file, 'wb') as f: - pickle.dump((scratch_langchain_modes, scratch_visible_langchain_modes, scratch_langchain_mode_paths), f) - - -def load_collection_enum(extra): - """ - extra controls if UserData type of MyData type - """ - file = "%s%s" % (visible_langchain_modes_file, extra) - langchain_modes_from_file = [] - visible_langchain_modes_from_file = [] - langchain_mode_paths_from_file = {} - if os.path.isfile(visible_langchain_modes_file): - try: - with filelock.FileLock("%s.lock" % file): - with open(file, 'rb') as f: - langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file = pickle.load( - f) - except BaseException as e: - print("Cannot load %s, ignoring error: %s" % (file, str(e)), flush=True) - for k, v in langchain_mode_paths_from_file.items(): - if v is not None and not os.path.isdir(v) and isinstance(v, str): - # assume was deleted, but need to make again to avoid extra code elsewhere - makedirs(v) - return langchain_modes_from_file, visible_langchain_modes_from_file, langchain_mode_paths_from_file - - -def remove_collection_enum(): - remove(visible_langchain_modes_file) diff --git a/utils_langchain.py b/utils_langchain.py deleted file mode 100644 index d50110fa0dc664a95dc99b3fa47053287507b689..0000000000000000000000000000000000000000 --- a/utils_langchain.py +++ /dev/null @@ -1,64 +0,0 @@ -from typing import Any, Dict, List, Union, Optional -import time -import queue - -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult - - -class StreamingGradioCallbackHandler(BaseCallbackHandler): - """ - Similar to H2OTextIteratorStreamer that is for HF backend, but here LangChain backend - """ - def __init__(self, timeout: Optional[float] = None, block=True): - super().__init__() - self.text_queue = queue.SimpleQueue() - self.stop_signal = None - self.do_stop = False - self.timeout = timeout - self.block = block - - def on_llm_start( - self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any - ) -> None: - """Run when LLM starts running. Clean the queue.""" - while not self.text_queue.empty(): - try: - self.text_queue.get(block=False) - except queue.Empty: - continue - - def on_llm_new_token(self, token: str, **kwargs: Any) -> None: - """Run on new LLM token. Only available when streaming is enabled.""" - self.text_queue.put(token) - - def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None: - """Run when LLM ends running.""" - self.text_queue.put(self.stop_signal) - - def on_llm_error( - self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any - ) -> None: - """Run when LLM errors.""" - self.text_queue.put(self.stop_signal) - - def __iter__(self): - return self - - def __next__(self): - while True: - try: - value = self.stop_signal # value looks unused in pycharm, not true - if self.do_stop: - print("hit stop", flush=True) - # could raise or break, maybe best to raise and make parent see if any exception in thread - raise StopIteration() - # break - value = self.text_queue.get(block=self.block, timeout=self.timeout) - break - except queue.Empty: - time.sleep(0.01) - if value == self.stop_signal: - raise StopIteration() - else: - return value