Spaces:
Runtime error
Runtime error
| import glob | |
| import inspect | |
| import os | |
| import pathlib | |
| import pickle | |
| import queue | |
| import random | |
| import shutil | |
| import subprocess | |
| import sys | |
| import tempfile | |
| import traceback | |
| import uuid | |
| import zipfile | |
| from collections import defaultdict | |
| from datetime import datetime | |
| from functools import reduce | |
| from operator import concat | |
| from joblib import Parallel, delayed | |
| from langchain.embeddings import HuggingFaceInstructEmbeddings | |
| from tqdm import tqdm | |
| from enums import DocumentChoices | |
| from generate import gen_hyper | |
| from prompter import non_hf_types, PromptType | |
| from utils import wrapped_partial, EThread, import_matplotlib, sanitize_filename, makedirs, get_url, flatten_list, \ | |
| get_device, ProgressParallel, remove, hash_file, clear_torch_cache | |
| import_matplotlib() | |
| import numpy as np | |
| import pandas as pd | |
| import requests | |
| from langchain.chains.qa_with_sources import load_qa_with_sources_chain | |
| # , GCSDirectoryLoader, GCSFileLoader | |
| # , OutlookMessageLoader # GPL3 | |
| # ImageCaptionLoader, # use our own wrapper | |
| # ReadTheDocsLoader, # no special file, some path, so have to give as special option | |
| from langchain.document_loaders import PyPDFLoader, TextLoader, CSVLoader, PythonLoader, TomlLoader, \ | |
| UnstructuredURLLoader, UnstructuredHTMLLoader, UnstructuredWordDocumentLoader, UnstructuredMarkdownLoader, \ | |
| EverNoteLoader, UnstructuredEmailLoader, UnstructuredODTLoader, UnstructuredPowerPointLoader, \ | |
| UnstructuredEPubLoader, UnstructuredImageLoader, UnstructuredRTFLoader, ArxivLoader | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain.chains.question_answering import load_qa_chain | |
| from langchain.docstore.document import Document | |
| from langchain import PromptTemplate | |
| from langchain.vectorstores import Chroma | |
| def get_db(sources, use_openai_embedding=False, db_type='faiss', | |
| persist_directory="db_dir", load_db_if_exists=True, | |
| langchain_mode='notset', | |
| collection_name=None, | |
| hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
| if not sources: | |
| return None | |
| # get embedding model | |
| embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) | |
| assert collection_name is not None or langchain_mode != 'notset' | |
| if collection_name is None: | |
| collection_name = langchain_mode.replace(' ', '_') | |
| # Create vector database | |
| if db_type == 'faiss': | |
| from langchain.vectorstores import FAISS | |
| db = FAISS.from_documents(sources, embedding) | |
| elif db_type == 'weaviate': | |
| import weaviate | |
| from weaviate.embedded import EmbeddedOptions | |
| from langchain.vectorstores import Weaviate | |
| if os.getenv('WEAVIATE_URL', None): | |
| client = _create_local_weaviate_client() | |
| else: | |
| client = weaviate.Client( | |
| embedded_options=EmbeddedOptions() | |
| ) | |
| index_name = collection_name.capitalize() | |
| db = Weaviate.from_documents(documents=sources, embedding=embedding, client=client, by_text=False, | |
| index_name=index_name) | |
| elif db_type == 'chroma': | |
| assert persist_directory is not None | |
| os.makedirs(persist_directory, exist_ok=True) | |
| # see if already actually have persistent db, and deal with possible changes in embedding | |
| db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, | |
| hf_embedding_model, verbose=False) | |
| if db is None: | |
| db = Chroma.from_documents(documents=sources, | |
| embedding=embedding, | |
| persist_directory=persist_directory, | |
| collection_name=collection_name, | |
| anonymized_telemetry=False) | |
| db.persist() | |
| clear_embedding(db) | |
| save_embed(db, use_openai_embedding, hf_embedding_model) | |
| else: | |
| # then just add | |
| db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, | |
| use_openai_embedding=use_openai_embedding, | |
| hf_embedding_model=hf_embedding_model) | |
| else: | |
| raise RuntimeError("No such db_type=%s" % db_type) | |
| return db | |
| def _get_unique_sources_in_weaviate(db): | |
| batch_size = 100 | |
| id_source_list = [] | |
| result = db._client.data_object.get(class_name=db._index_name, limit=batch_size) | |
| while result['objects']: | |
| id_source_list += [(obj['id'], obj['properties']['source']) for obj in result['objects']] | |
| last_id = id_source_list[-1][0] | |
| result = db._client.data_object.get(class_name=db._index_name, limit=batch_size, after=last_id) | |
| unique_sources = {source for _, source in id_source_list} | |
| return unique_sources | |
| def add_to_db(db, sources, db_type='faiss', | |
| avoid_dup_by_file=False, | |
| avoid_dup_by_content=True, | |
| use_openai_embedding=False, | |
| hf_embedding_model=None): | |
| assert hf_embedding_model is not None | |
| num_new_sources = len(sources) | |
| if not sources: | |
| return db, num_new_sources, [] | |
| if db_type == 'faiss': | |
| db.add_documents(sources) | |
| elif db_type == 'weaviate': | |
| # FIXME: only control by file name, not hash yet | |
| if avoid_dup_by_file or avoid_dup_by_content: | |
| unique_sources = _get_unique_sources_in_weaviate(db) | |
| sources = [x for x in sources if x.metadata['source'] not in unique_sources] | |
| num_new_sources = len(sources) | |
| if num_new_sources == 0: | |
| return db, num_new_sources, [] | |
| db.add_documents(documents=sources) | |
| elif db_type == 'chroma': | |
| collection = db.get() | |
| # files we already have: | |
| metadata_files = set([x['source'] for x in collection['metadatas']]) | |
| if avoid_dup_by_file: | |
| # Too weak in case file changed content, assume parent shouldn't pass true for this for now | |
| raise RuntimeError("Not desired code path") | |
| sources = [x for x in sources if x.metadata['source'] not in metadata_files] | |
| if avoid_dup_by_content: | |
| # look at hash, instead of page_content | |
| # migration: If no hash previously, avoid updating, | |
| # since don't know if need to update and may be expensive to redo all unhashed files | |
| metadata_hash_ids = set( | |
| [x['hashid'] for x in collection['metadatas'] if 'hashid' in x and x['hashid'] not in ["None", None]]) | |
| # avoid sources with same hash | |
| sources = [x for x in sources if x.metadata.get('hashid') not in metadata_hash_ids] | |
| # get new file names that match existing file names. delete existing files we are overridding | |
| dup_metadata_files = set([x.metadata['source'] for x in sources if x.metadata['source'] in metadata_files]) | |
| print("Removing %s duplicate files from db because ingesting those as new documents" % len( | |
| dup_metadata_files), flush=True) | |
| client_collection = db._client.get_collection(name=db._collection.name, | |
| embedding_function=db._collection._embedding_function) | |
| for dup_file in dup_metadata_files: | |
| dup_file_meta = dict(source=dup_file) | |
| try: | |
| client_collection.delete(where=dup_file_meta) | |
| except KeyError: | |
| pass | |
| num_new_sources = len(sources) | |
| if num_new_sources == 0: | |
| return db, num_new_sources, [] | |
| db.add_documents(documents=sources) | |
| db.persist() | |
| clear_embedding(db) | |
| save_embed(db, use_openai_embedding, hf_embedding_model) | |
| else: | |
| raise RuntimeError("No such db_type=%s" % db_type) | |
| new_sources_metadata = [x.metadata for x in sources] | |
| return db, num_new_sources, new_sources_metadata | |
| def create_or_update_db(db_type, persist_directory, collection_name, | |
| sources, use_openai_embedding, add_if_exists, verbose, hf_embedding_model): | |
| if db_type == 'weaviate': | |
| import weaviate | |
| from weaviate.embedded import EmbeddedOptions | |
| if os.getenv('WEAVIATE_URL', None): | |
| client = _create_local_weaviate_client() | |
| else: | |
| client = weaviate.Client( | |
| embedded_options=EmbeddedOptions() | |
| ) | |
| index_name = collection_name.replace(' ', '_').capitalize() | |
| if client.schema.exists(index_name) and not add_if_exists: | |
| client.schema.delete_class(index_name) | |
| if verbose: | |
| print("Removing %s" % index_name, flush=True) | |
| elif db_type == 'chroma': | |
| if not os.path.isdir(persist_directory) or not add_if_exists: | |
| if os.path.isdir(persist_directory): | |
| if verbose: | |
| print("Removing %s" % persist_directory, flush=True) | |
| remove(persist_directory) | |
| if verbose: | |
| print("Generating db", flush=True) | |
| if not add_if_exists: | |
| if verbose: | |
| print("Generating db", flush=True) | |
| else: | |
| if verbose: | |
| print("Loading and updating db", flush=True) | |
| db = get_db(sources, | |
| use_openai_embedding=use_openai_embedding, | |
| db_type=db_type, | |
| persist_directory=persist_directory, | |
| langchain_mode=collection_name, | |
| hf_embedding_model=hf_embedding_model) | |
| return db | |
| def get_embedding(use_openai_embedding, hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2"): | |
| # Get embedding model | |
| if use_openai_embedding: | |
| assert os.getenv("OPENAI_API_KEY") is not None, "Set ENV OPENAI_API_KEY" | |
| from langchain.embeddings import OpenAIEmbeddings | |
| embedding = OpenAIEmbeddings() | |
| else: | |
| # to ensure can fork without deadlock | |
| from langchain.embeddings import HuggingFaceEmbeddings | |
| device, torch_dtype, context_class = get_device_dtype() | |
| model_kwargs = dict(device=device) | |
| if 'instructor' in hf_embedding_model: | |
| encode_kwargs = {'normalize_embeddings': True} | |
| embedding = HuggingFaceInstructEmbeddings(model_name=hf_embedding_model, | |
| model_kwargs=model_kwargs, | |
| encode_kwargs=encode_kwargs) | |
| else: | |
| embedding = HuggingFaceEmbeddings(model_name=hf_embedding_model, model_kwargs=model_kwargs) | |
| return embedding | |
| def get_answer_from_sources(chain, sources, question): | |
| return chain( | |
| { | |
| "input_documents": sources, | |
| "question": question, | |
| }, | |
| return_only_outputs=True, | |
| )["output_text"] | |
| def get_llm(use_openai_model=False, model_name=None, model=None, | |
| tokenizer=None, stream_output=False, | |
| do_sample=False, | |
| temperature=0.1, | |
| top_k=40, | |
| top_p=0.7, | |
| num_beams=1, | |
| max_new_tokens=256, | |
| min_new_tokens=1, | |
| early_stopping=False, | |
| max_time=180, | |
| repetition_penalty=1.0, | |
| num_return_sequences=1, | |
| prompt_type=None, | |
| prompt_dict=None, | |
| prompter=None, | |
| verbose=False, | |
| ): | |
| if use_openai_model: | |
| from langchain.llms import OpenAI | |
| llm = OpenAI(temperature=0) | |
| model_name = 'openai' | |
| streamer = None | |
| prompt_type = 'plain' | |
| elif model_name in non_hf_types: | |
| from gpt4all_llm import get_llm_gpt4all | |
| llm = get_llm_gpt4all(model_name, model=model, max_new_tokens=max_new_tokens, | |
| temperature=temperature, | |
| repetition_penalty=repetition_penalty, | |
| top_k=top_k, | |
| top_p=top_p, | |
| verbose=verbose, | |
| ) | |
| streamer = None | |
| prompt_type = 'plain' | |
| else: | |
| from transformers import AutoTokenizer, AutoModelForCausalLM | |
| if model is None: | |
| # only used if didn't pass model in | |
| assert tokenizer is None | |
| prompt_type = 'human_bot' | |
| model_name = 'h2oai/h2ogpt-oasst1-512-12b' | |
| # model_name = 'h2oai/h2ogpt-oig-oasst1-512-6_9b' | |
| # model_name = 'h2oai/h2ogpt-oasst1-512-20b' | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| device, torch_dtype, context_class = get_device_dtype() | |
| with context_class(device): | |
| load_8bit = True | |
| # FIXME: for now not to spread across hetero GPUs | |
| # device_map={"": 0} if load_8bit and device == 'cuda' else "auto" | |
| device_map = {"": 0} if device == 'cuda' else "auto" | |
| model = AutoModelForCausalLM.from_pretrained(model_name, | |
| device_map=device_map, | |
| torch_dtype=torch_dtype, | |
| load_in_8bit=load_8bit) | |
| max_max_tokens = tokenizer.model_max_length | |
| gen_kwargs = dict(do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| early_stopping=early_stopping, | |
| max_time=max_time, | |
| repetition_penalty=repetition_penalty, | |
| num_return_sequences=num_return_sequences, | |
| return_full_text=True, | |
| handle_long_generation='hole') | |
| assert len(set(gen_hyper).difference(gen_kwargs.keys())) == 0 | |
| if stream_output: | |
| skip_prompt = False | |
| from generate import H2OTextIteratorStreamer | |
| decoder_kwargs = {} | |
| streamer = H2OTextIteratorStreamer(tokenizer, skip_prompt=skip_prompt, block=False, **decoder_kwargs) | |
| gen_kwargs.update(dict(streamer=streamer)) | |
| else: | |
| streamer = None | |
| from h2oai_pipeline import H2OTextGenerationPipeline | |
| pipe = H2OTextGenerationPipeline(model=model, use_prompter=True, | |
| prompter=prompter, | |
| prompt_type=prompt_type, | |
| prompt_dict=prompt_dict, | |
| sanitize_bot_response=True, | |
| chat=False, stream_output=stream_output, | |
| tokenizer=tokenizer, | |
| max_input_tokens=max_max_tokens - max_new_tokens, | |
| **gen_kwargs) | |
| # pipe.task = "text-generation" | |
| # below makes it listen only to our prompt removal, | |
| # not built in prompt removal that is less general and not specific for our model | |
| pipe.task = "text2text-generation" | |
| from langchain.llms import HuggingFacePipeline | |
| llm = HuggingFacePipeline(pipeline=pipe) | |
| return llm, model_name, streamer, prompt_type | |
| def get_device_dtype(): | |
| # torch.device("cuda") leads to cuda:x cuda:y mismatches for multi-GPU consistently | |
| import torch | |
| n_gpus = torch.cuda.device_count() if torch.cuda.is_available else 0 | |
| device = 'cpu' if n_gpus == 0 else 'cuda' | |
| # from utils import NullContext | |
| # context_class = NullContext if n_gpus > 1 or n_gpus == 0 else context_class | |
| context_class = torch.device | |
| torch_dtype = torch.float16 if device == 'cuda' else torch.float32 | |
| return device, torch_dtype, context_class | |
| def get_wiki_data(title, first_paragraph_only, text_limit=None, take_head=True): | |
| """ | |
| Get wikipedia data from online | |
| :param title: | |
| :param first_paragraph_only: | |
| :param text_limit: | |
| :param take_head: | |
| :return: | |
| """ | |
| filename = 'wiki_%s_%s_%s_%s.data' % (first_paragraph_only, title, text_limit, take_head) | |
| url = f"https://en.wikipedia.org/w/api.php?format=json&action=query&prop=extracts&explaintext=1&titles={title}" | |
| if first_paragraph_only: | |
| url += "&exintro=1" | |
| import json | |
| if not os.path.isfile(filename): | |
| data = requests.get(url).json() | |
| json.dump(data, open(filename, 'wt')) | |
| else: | |
| data = json.load(open(filename, "rt")) | |
| page_content = list(data["query"]["pages"].values())[0]["extract"] | |
| if take_head is not None and text_limit is not None: | |
| page_content = page_content[:text_limit] if take_head else page_content[:-text_limit] | |
| title_url = str(title).replace(' ', '_') | |
| return Document( | |
| page_content=page_content, | |
| metadata={"source": f"https://en.wikipedia.org/wiki/{title_url}"}, | |
| ) | |
| def get_wiki_sources(first_para=True, text_limit=None): | |
| """ | |
| Get specific named sources from wikipedia | |
| :param first_para: | |
| :param text_limit: | |
| :return: | |
| """ | |
| default_wiki_sources = ['Unix', 'Microsoft_Windows', 'Linux'] | |
| wiki_sources = list(os.getenv('WIKI_SOURCES', default_wiki_sources)) | |
| return [get_wiki_data(x, first_para, text_limit=text_limit) for x in wiki_sources] | |
| def get_github_docs(repo_owner, repo_name): | |
| """ | |
| Access github from specific repo | |
| :param repo_owner: | |
| :param repo_name: | |
| :return: | |
| """ | |
| with tempfile.TemporaryDirectory() as d: | |
| subprocess.check_call( | |
| f"git clone --depth 1 https://github.com/{repo_owner}/{repo_name}.git .", | |
| cwd=d, | |
| shell=True, | |
| ) | |
| git_sha = ( | |
| subprocess.check_output("git rev-parse HEAD", shell=True, cwd=d) | |
| .decode("utf-8") | |
| .strip() | |
| ) | |
| repo_path = pathlib.Path(d) | |
| markdown_files = list(repo_path.glob("*/*.md")) + list( | |
| repo_path.glob("*/*.mdx") | |
| ) | |
| for markdown_file in markdown_files: | |
| with open(markdown_file, "r") as f: | |
| relative_path = markdown_file.relative_to(repo_path) | |
| github_url = f"https://github.com/{repo_owner}/{repo_name}/blob/{git_sha}/{relative_path}" | |
| yield Document(page_content=f.read(), metadata={"source": github_url}) | |
| def get_dai_pickle(dest="."): | |
| from huggingface_hub import hf_hub_download | |
| # True for case when locally already logged in with correct token, so don't have to set key | |
| token = os.getenv('HUGGINGFACE_API_TOKEN', True) | |
| path_to_zip_file = hf_hub_download('h2oai/dai_docs', 'dai_docs.pickle', token=token, repo_type='dataset') | |
| shutil.copy(path_to_zip_file, dest) | |
| def get_dai_docs(from_hf=False, get_pickle=True): | |
| """ | |
| Consume DAI documentation, or consume from public pickle | |
| :param from_hf: get DAI docs from HF, then generate pickle for later use by LangChain | |
| :param get_pickle: Avoid raw DAI docs, just get pickle directly from HF | |
| :return: | |
| """ | |
| import pickle | |
| if get_pickle: | |
| get_dai_pickle() | |
| dai_store = 'dai_docs.pickle' | |
| dst = "working_dir_docs" | |
| if not os.path.isfile(dai_store): | |
| from create_data import setup_dai_docs | |
| dst = setup_dai_docs(dst=dst, from_hf=from_hf) | |
| import glob | |
| files = list(glob.glob(os.path.join(dst, '*rst'), recursive=True)) | |
| basedir = os.path.abspath(os.getcwd()) | |
| from create_data import rst_to_outputs | |
| new_outputs = rst_to_outputs(files) | |
| os.chdir(basedir) | |
| pickle.dump(new_outputs, open(dai_store, 'wb')) | |
| else: | |
| new_outputs = pickle.load(open(dai_store, 'rb')) | |
| sources = [] | |
| for line, file in new_outputs: | |
| # gradio requires any linked file to be with app.py | |
| sym_src = os.path.abspath(os.path.join(dst, file)) | |
| sym_dst = os.path.abspath(os.path.join(os.getcwd(), file)) | |
| if os.path.lexists(sym_dst): | |
| os.remove(sym_dst) | |
| os.symlink(sym_src, sym_dst) | |
| itm = Document(page_content=line, metadata={"source": file}) | |
| # NOTE: yield has issues when going into db, loses metadata | |
| # yield itm | |
| sources.append(itm) | |
| return sources | |
| import distutils.spawn | |
| have_tesseract = distutils.spawn.find_executable("tesseract") | |
| have_libreoffice = distutils.spawn.find_executable("libreoffice") | |
| import pkg_resources | |
| try: | |
| assert pkg_resources.get_distribution('arxiv') is not None | |
| assert pkg_resources.get_distribution('pymupdf') is not None | |
| have_arxiv = True | |
| except (pkg_resources.DistributionNotFound, AssertionError): | |
| have_arxiv = False | |
| try: | |
| assert pkg_resources.get_distribution('pymupdf') is not None | |
| have_pymupdf = True | |
| except (pkg_resources.DistributionNotFound, AssertionError): | |
| have_pymupdf = False | |
| image_types = ["png", "jpg", "jpeg"] | |
| non_image_types = ["pdf", "txt", "csv", "toml", "py", "rst", "rtf", | |
| "md", "html", | |
| "enex", "eml", "epub", "odt", "pptx", "ppt", | |
| "zip", "urls", | |
| ] | |
| # "msg", GPL3 | |
| if have_libreoffice: | |
| non_image_types.extend(["docx", "doc"]) | |
| file_types = non_image_types + image_types | |
| def add_meta(docs1, file): | |
| file_extension = pathlib.Path(file).suffix | |
| hashid = hash_file(file) | |
| if not isinstance(docs1, list): | |
| docs1 = [docs1] | |
| [x.metadata.update(dict(input_type=file_extension, date=str(datetime.now), hashid=hashid)) for x in docs1] | |
| def file_to_doc(file, base_path=None, verbose=False, fail_any_exception=False, | |
| chunk=True, chunk_size=512, | |
| is_url=False, is_txt=False, | |
| enable_captions=True, | |
| captions_model=None, | |
| enable_ocr=False, caption_loader=None, | |
| headsize=50): | |
| if file is None: | |
| if fail_any_exception: | |
| raise RuntimeError("Unexpected None file") | |
| else: | |
| return [] | |
| doc1 = [] # in case no support, or disabled support | |
| if base_path is None and not is_txt and not is_url: | |
| # then assume want to persist but don't care which path used | |
| # can't be in base_path | |
| dir_name = os.path.dirname(file) | |
| base_name = os.path.basename(file) | |
| # if from gradio, will have its own temp uuid too, but that's ok | |
| base_name = sanitize_filename(base_name) + "_" + str(uuid.uuid4())[:10] | |
| base_path = os.path.join(dir_name, base_name) | |
| if is_url: | |
| if file.lower().startswith('arxiv:'): | |
| query = file.lower().split('arxiv:') | |
| if len(query) == 2 and have_arxiv: | |
| query = query[1] | |
| docs1 = ArxivLoader(query=query, load_max_docs=20, load_all_available_meta=True).load() | |
| # ensure string, sometimes None | |
| [[x.metadata.update({k: str(v)}) for k, v in x.metadata.items()] for x in docs1] | |
| query_url = f"https://arxiv.org/abs/{query}" | |
| [x.metadata.update( | |
| dict(source=x.metadata.get('entry_id', query_url), query=query_url, | |
| input_type='arxiv', head=x.metadata.get('Title', ''), date=str(datetime.now))) for x in | |
| docs1] | |
| else: | |
| docs1 = [] | |
| else: | |
| docs1 = UnstructuredURLLoader(urls=[file]).load() | |
| [x.metadata.update(dict(input_type='url', date=str(datetime.now))) for x in docs1] | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif is_txt: | |
| base_path = "user_paste" | |
| source_file = os.path.join(base_path, "_%s" % str(uuid.uuid4())[:10]) | |
| makedirs(os.path.dirname(source_file), exist_ok=True) | |
| with open(source_file, "wt") as f: | |
| f.write(file) | |
| metadata = dict(source=source_file, date=str(datetime.now()), input_type='pasted txt') | |
| doc1 = Document(page_content=file, metadata=metadata) | |
| elif file.lower().endswith('.html') or file.lower().endswith('.mhtml'): | |
| docs1 = UnstructuredHTMLLoader(file_path=file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif (file.lower().endswith('.docx') or file.lower().endswith('.doc')) and have_libreoffice: | |
| docs1 = UnstructuredWordDocumentLoader(file_path=file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.odt'): | |
| docs1 = UnstructuredODTLoader(file_path=file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('pptx') or file.lower().endswith('ppt'): | |
| docs1 = UnstructuredPowerPointLoader(file_path=file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.txt'): | |
| # use UnstructuredFileLoader ? | |
| docs1 = TextLoader(file, encoding="utf8", autodetect_encoding=True).load() | |
| # makes just one, but big one | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| add_meta(doc1, file) | |
| elif file.lower().endswith('.rtf'): | |
| docs1 = UnstructuredRTFLoader(file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.md'): | |
| docs1 = UnstructuredMarkdownLoader(file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.enex'): | |
| docs1 = EverNoteLoader(file).load() | |
| add_meta(doc1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.epub'): | |
| docs1 = UnstructuredEPubLoader(file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.jpeg') or file.lower().endswith('.jpg') or file.lower().endswith('.png'): | |
| docs1 = [] | |
| if have_tesseract and enable_ocr: | |
| # OCR, somewhat works, but not great | |
| docs1.extend(UnstructuredImageLoader(file).load()) | |
| add_meta(docs1, file) | |
| if enable_captions: | |
| # BLIP | |
| if caption_loader is not None and not isinstance(caption_loader, (str, bool)): | |
| # assumes didn't fork into this process with joblib, else can deadlock | |
| caption_loader.set_image_paths([file]) | |
| docs1c = caption_loader.load() | |
| add_meta(docs1c, file) | |
| [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c] | |
| docs1.extend(docs1c) | |
| else: | |
| from image_captions import H2OImageCaptionLoader | |
| caption_loader = H2OImageCaptionLoader(caption_gpu=caption_loader == 'gpu', | |
| blip_model=captions_model, | |
| blip_processor=captions_model) | |
| caption_loader.set_image_paths([file]) | |
| docs1c = caption_loader.load() | |
| add_meta(docs1c, file) | |
| [x.metadata.update(dict(head=x.page_content[:headsize].strip())) for x in docs1c] | |
| docs1.extend(docs1c) | |
| for doci in docs1: | |
| doci.metadata['source'] = doci.metadata['image_path'] | |
| doci.metadata['hash'] = hash_file(doci.metadata['source']) | |
| if docs1: | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.msg'): | |
| raise RuntimeError("Not supported, GPL3 license") | |
| # docs1 = OutlookMessageLoader(file).load() | |
| # docs1[0].metadata['source'] = file | |
| elif file.lower().endswith('.eml'): | |
| try: | |
| docs1 = UnstructuredEmailLoader(file).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| except ValueError as e: | |
| if 'text/html content not found in email' in str(e): | |
| # e.g. plain/text dict key exists, but not | |
| # doc1 = TextLoader(file, encoding="utf8").load() | |
| docs1 = UnstructuredEmailLoader(file, content_source="text/plain").load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| else: | |
| raise | |
| # elif file.lower().endswith('.gcsdir'): | |
| # doc1 = GCSDirectoryLoader(project_name, bucket, prefix).load() | |
| # elif file.lower().endswith('.gcsfile'): | |
| # doc1 = GCSFileLoader(project_name, bucket, blob).load() | |
| elif file.lower().endswith('.rst'): | |
| with open(file, "r") as f: | |
| doc1 = Document(page_content=f.read(), metadata={"source": file}) | |
| add_meta(doc1, file) | |
| elif file.lower().endswith('.pdf'): | |
| env_gpt4all_file = ".env_gpt4all" | |
| from dotenv import dotenv_values | |
| env_kwargs = dotenv_values(env_gpt4all_file) | |
| pdf_class_name = env_kwargs.get('PDF_CLASS_NAME', 'PyMuPDFParser') | |
| if have_pymupdf and pdf_class_name == 'PyMuPDFParser': | |
| # GPL, only use if installed | |
| from langchain.document_loaders import PyMuPDFLoader | |
| # load() still chunks by pages, but every page has title at start to help | |
| doc1 = PyMuPDFLoader(file).load() | |
| else: | |
| # open-source fallback | |
| # load() still chunks by pages, but every page has title at start to help | |
| doc1 = PyPDFLoader(file).load() | |
| # Some PDFs return nothing or junk from PDFMinerLoader | |
| add_meta(doc1, file) | |
| elif file.lower().endswith('.csv'): | |
| doc1 = CSVLoader(file).load() | |
| add_meta(doc1, file) | |
| elif file.lower().endswith('.py'): | |
| doc1 = PythonLoader(file).load() | |
| add_meta(doc1, file) | |
| elif file.lower().endswith('.toml'): | |
| doc1 = TomlLoader(file).load() | |
| add_meta(doc1, file) | |
| elif file.lower().endswith('.urls'): | |
| with open(file, "r") as f: | |
| docs1 = UnstructuredURLLoader(urls=f.readlines()).load() | |
| add_meta(docs1, file) | |
| doc1 = chunk_sources(docs1, chunk=chunk, chunk_size=chunk_size) | |
| elif file.lower().endswith('.zip'): | |
| with zipfile.ZipFile(file, 'r') as zip_ref: | |
| # don't put into temporary path, since want to keep references to docs inside zip | |
| # so just extract in path where | |
| zip_ref.extractall(base_path) | |
| # recurse | |
| doc1 = path_to_docs(base_path, verbose=verbose, fail_any_exception=fail_any_exception) | |
| else: | |
| raise RuntimeError("No file handler for %s" % os.path.basename(file)) | |
| # allow doc1 to be list or not. If not list, did not chunk yet, so chunk now | |
| # if list of length one, don't trust and chunk it | |
| if not isinstance(doc1, list): | |
| if chunk: | |
| docs = chunk_sources([doc1], chunk=chunk, chunk_size=chunk_size) | |
| else: | |
| docs = [doc1] | |
| elif isinstance(doc1, list) and len(doc1) == 1: | |
| if chunk: | |
| docs = chunk_sources(doc1, chunk=chunk, chunk_size=chunk_size) | |
| else: | |
| docs = doc1 | |
| else: | |
| docs = doc1 | |
| assert isinstance(docs, list) | |
| return docs | |
| def path_to_doc1(file, verbose=False, fail_any_exception=False, return_file=True, | |
| chunk=True, chunk_size=512, | |
| is_url=False, is_txt=False, | |
| enable_captions=True, | |
| captions_model=None, | |
| enable_ocr=False, caption_loader=None): | |
| if verbose: | |
| if is_url: | |
| print("Ingesting URL: %s" % file, flush=True) | |
| elif is_txt: | |
| print("Ingesting Text: %s" % file, flush=True) | |
| else: | |
| print("Ingesting file: %s" % file, flush=True) | |
| res = None | |
| try: | |
| # don't pass base_path=path, would infinitely recurse | |
| res = file_to_doc(file, base_path=None, verbose=verbose, fail_any_exception=fail_any_exception, | |
| chunk=chunk, chunk_size=chunk_size, | |
| is_url=is_url, is_txt=is_txt, | |
| enable_captions=enable_captions, | |
| captions_model=captions_model, | |
| enable_ocr=enable_ocr, | |
| caption_loader=caption_loader) | |
| except BaseException as e: | |
| print("Failed to ingest %s due to %s" % (file, traceback.format_exc())) | |
| if fail_any_exception: | |
| raise | |
| else: | |
| exception_doc = Document( | |
| page_content='', | |
| metadata={"source": file, "exception": str(e), "traceback": traceback.format_exc()}) | |
| res = [exception_doc] | |
| if return_file: | |
| base_tmp = "temp_path_to_doc1" | |
| if not os.path.isdir(base_tmp): | |
| os.makedirs(base_tmp, exist_ok=True) | |
| filename = os.path.join(base_tmp, str(uuid.uuid4()) + ".tmp.pickle") | |
| with open(filename, 'wb') as f: | |
| pickle.dump(res, f) | |
| return filename | |
| return res | |
| def path_to_docs(path_or_paths, verbose=False, fail_any_exception=False, n_jobs=-1, | |
| chunk=True, chunk_size=512, | |
| url=None, text=None, | |
| enable_captions=True, | |
| captions_model=None, | |
| caption_loader=None, | |
| enable_ocr=False, | |
| existing_files=[], | |
| existing_hash_ids={}, | |
| ): | |
| globs_image_types = [] | |
| globs_non_image_types = [] | |
| if not path_or_paths and not url and not text: | |
| return [] | |
| elif url: | |
| globs_non_image_types = [url] | |
| elif text: | |
| globs_non_image_types = [text] | |
| elif isinstance(path_or_paths, str): | |
| # single path, only consume allowed files | |
| path = path_or_paths | |
| # Below globs should match patterns in file_to_doc() | |
| [globs_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) | |
| for ftype in image_types] | |
| [globs_non_image_types.extend(glob.glob(os.path.join(path, "./**/*.%s" % ftype), recursive=True)) | |
| for ftype in non_image_types] | |
| else: | |
| # list/tuple of files (consume what can, and exception those that selected but cannot consume so user knows) | |
| assert isinstance(path_or_paths, (list, tuple)), "Wrong type for path_or_paths: %s" % type(path_or_paths) | |
| # reform out of allowed types | |
| globs_image_types.extend(flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in image_types])) | |
| # could do below: | |
| # globs_non_image_types = flatten_list([[x for x in path_or_paths if x.endswith(y)] for y in non_image_types]) | |
| # But instead, allow fail so can collect unsupported too | |
| set_globs_image_types = set(globs_image_types) | |
| globs_non_image_types.extend([x for x in path_or_paths if x not in set_globs_image_types]) | |
| # filter out any files to skip (e.g. if already processed them) | |
| # this is easy, but too aggressive in case a file changed, so parent probably passed existing_files=[] | |
| assert not existing_files, "DEV: assume not using this approach" | |
| if existing_files: | |
| set_skip_files = set(existing_files) | |
| globs_image_types = [x for x in globs_image_types if x not in set_skip_files] | |
| globs_non_image_types = [x for x in globs_non_image_types if x not in set_skip_files] | |
| if existing_hash_ids: | |
| # assume consistent with add_meta() use of hash_file(file) | |
| # also assume consistent with get_existing_hash_ids for dict creation | |
| # assume hashable values | |
| existing_hash_ids_set = set(existing_hash_ids.items()) | |
| hash_ids_all_image = set({x: hash_file(x) for x in globs_image_types}.items()) | |
| hash_ids_all_non_image = set({x: hash_file(x) for x in globs_non_image_types}.items()) | |
| # don't use symmetric diff. If file is gone, ignore and don't remove or something | |
| # just consider existing files (key) having new hash or not (value) | |
| new_files_image = set(dict(hash_ids_all_image - existing_hash_ids_set).keys()) | |
| new_files_non_image = set(dict(hash_ids_all_non_image - existing_hash_ids_set).keys()) | |
| globs_image_types = [x for x in globs_image_types if x in new_files_image] | |
| globs_non_image_types = [x for x in globs_non_image_types if x in new_files_non_image] | |
| # could use generator, but messes up metadata handling in recursive case | |
| if caption_loader and not isinstance(caption_loader, (bool, str)) and \ | |
| caption_loader.device != 'cpu' or \ | |
| get_device() == 'cuda': | |
| # to avoid deadlocks, presume was preloaded and so can't fork due to cuda context | |
| n_jobs_image = 1 | |
| else: | |
| n_jobs_image = n_jobs | |
| return_file = True # local choice | |
| is_url = url is not None | |
| is_txt = text is not None | |
| kwargs = dict(verbose=verbose, fail_any_exception=fail_any_exception, | |
| return_file=return_file, | |
| chunk=chunk, chunk_size=chunk_size, | |
| is_url=is_url, | |
| is_txt=is_txt, | |
| enable_captions=enable_captions, | |
| captions_model=captions_model, | |
| caption_loader=caption_loader, | |
| enable_ocr=enable_ocr, | |
| ) | |
| if n_jobs != 1 and len(globs_non_image_types) > 1: | |
| # avoid nesting, e.g. upload 1 zip and then inside many files | |
| # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib | |
| documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( | |
| delayed(path_to_doc1)(file, **kwargs) for file in globs_non_image_types | |
| ) | |
| else: | |
| documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_non_image_types)] | |
| # do images separately since can't fork after cuda in parent, so can't be parallel | |
| if n_jobs_image != 1 and len(globs_image_types) > 1: | |
| # avoid nesting, e.g. upload 1 zip and then inside many files | |
| # harder to handle if upload many zips with many files, inner parallel one will be disabled by joblib | |
| image_documents = ProgressParallel(n_jobs=n_jobs, verbose=10 if verbose else 0, backend='multiprocessing')( | |
| delayed(path_to_doc1)(file, **kwargs) for file in globs_image_types | |
| ) | |
| else: | |
| image_documents = [path_to_doc1(file, **kwargs) for file in tqdm(globs_image_types)] | |
| # add image docs in | |
| documents += image_documents | |
| if return_file: | |
| # then documents really are files | |
| files = documents.copy() | |
| documents = [] | |
| for fil in files: | |
| with open(fil, 'rb') as f: | |
| documents.extend(pickle.load(f)) | |
| # remove temp pickle | |
| os.remove(fil) | |
| else: | |
| documents = reduce(concat, documents) | |
| return documents | |
| def prep_langchain(persist_directory, | |
| load_db_if_exists, | |
| db_type, use_openai_embedding, langchain_mode, user_path, | |
| hf_embedding_model, n_jobs=-1, kwargs_make_db={}): | |
| """ | |
| do prep first time, involving downloads | |
| # FIXME: Add github caching then add here | |
| :return: | |
| """ | |
| assert langchain_mode not in ['MyData'], "Should not prep scratch data" | |
| db_dir_exists = os.path.isdir(persist_directory) | |
| if db_dir_exists and user_path is None: | |
| print("Prep: persist_directory=%s exists, using" % persist_directory, flush=True) | |
| db = get_existing_db(None, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, | |
| hf_embedding_model) | |
| else: | |
| if db_dir_exists and user_path is not None: | |
| print("Prep: persist_directory=%s exists, user_path=%s passed, adding any changed or new documents" % ( | |
| persist_directory, user_path), flush=True) | |
| elif not db_dir_exists: | |
| print("Prep: persist_directory=%s does not exist, regenerating" % persist_directory, flush=True) | |
| db = None | |
| if langchain_mode in ['All', 'DriverlessAI docs']: | |
| # FIXME: Could also just use dai_docs.pickle directly and upload that | |
| get_dai_docs(from_hf=True) | |
| if langchain_mode in ['All', 'wiki']: | |
| get_wiki_sources(first_para=kwargs_make_db['first_para'], text_limit=kwargs_make_db['text_limit']) | |
| langchain_kwargs = kwargs_make_db.copy() | |
| langchain_kwargs.update(locals()) | |
| db, num_new_sources, new_sources_metadata = make_db(**langchain_kwargs) | |
| return db | |
| import posthog | |
| posthog.disabled = True | |
| class FakeConsumer(object): | |
| def __init__(self, *args, **kwargs): | |
| pass | |
| def run(self): | |
| pass | |
| def pause(self): | |
| pass | |
| def upload(self): | |
| pass | |
| def next(self): | |
| pass | |
| def request(self, batch): | |
| pass | |
| posthog.Consumer = FakeConsumer | |
| def check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, langchain_mode): | |
| changed_db = False | |
| if load_embed(db) != (use_openai_embedding, hf_embedding_model): | |
| print("Detected new embedding, updating db: %s" % langchain_mode, flush=True) | |
| # handle embedding changes | |
| db_get = db.get() | |
| sources = [Document(page_content=result[0], metadata=result[1] or {}) | |
| for result in zip(db_get['documents'], db_get['metadatas'])] | |
| # delete index, has to be redone | |
| persist_directory = db._persist_directory | |
| shutil.move(persist_directory, persist_directory + "_" + str(uuid.uuid4()) + ".bak") | |
| db_type = 'chroma' | |
| load_db_if_exists = False | |
| db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, | |
| persist_directory=persist_directory, load_db_if_exists=load_db_if_exists, | |
| langchain_mode=langchain_mode, | |
| collection_name=None, | |
| hf_embedding_model=hf_embedding_model) | |
| if False: | |
| # below doesn't work if db already in memory, so have to switch to new db as above | |
| # upsert does new embedding, but if index already in memory, complains about size mismatch etc. | |
| client_collection = db._client.get_collection(name=db._collection.name, | |
| embedding_function=db._collection._embedding_function) | |
| client_collection.upsert(ids=db_get['ids'], metadatas=db_get['metadatas'], documents=db_get['documents']) | |
| changed_db = True | |
| print("Done updating db for new embedding: %s" % langchain_mode, flush=True) | |
| return db, changed_db | |
| def get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, | |
| hf_embedding_model, verbose=False, check_embedding=True): | |
| if load_db_if_exists and db_type == 'chroma' and os.path.isdir(persist_directory) and os.path.isdir( | |
| os.path.join(persist_directory, 'index')): | |
| if db is None: | |
| if verbose: | |
| print("DO Loading db: %s" % langchain_mode, flush=True) | |
| embedding = get_embedding(use_openai_embedding, hf_embedding_model=hf_embedding_model) | |
| from chromadb.config import Settings | |
| client_settings = Settings(anonymized_telemetry=False, | |
| chroma_db_impl="duckdb+parquet", | |
| persist_directory=persist_directory) | |
| db = Chroma(persist_directory=persist_directory, embedding_function=embedding, | |
| collection_name=langchain_mode.replace(' ', '_'), | |
| client_settings=client_settings) | |
| if verbose: | |
| print("DONE Loading db: %s" % langchain_mode, flush=True) | |
| else: | |
| if verbose: | |
| print("USING already-loaded db: %s" % langchain_mode, flush=True) | |
| if check_embedding: | |
| db_trial, changed_db = check_update_chroma_embedding(db, use_openai_embedding, hf_embedding_model, | |
| langchain_mode) | |
| if changed_db: | |
| db = db_trial | |
| # only call persist if really changed db, else takes too long for large db | |
| db.persist() | |
| clear_embedding(db) | |
| save_embed(db, use_openai_embedding, hf_embedding_model) | |
| return db | |
| return None | |
| def clear_embedding(db): | |
| # don't keep on GPU, wastes memory, push back onto CPU and only put back on GPU once again embed | |
| db._embedding_function.client.cpu() | |
| clear_torch_cache() | |
| def make_db(**langchain_kwargs): | |
| func_names = list(inspect.signature(_make_db).parameters) | |
| missing_kwargs = [x for x in func_names if x not in langchain_kwargs] | |
| defaults_db = {k: v.default for k, v in dict(inspect.signature(run_qa_db).parameters).items()} | |
| for k in missing_kwargs: | |
| if k in defaults_db: | |
| langchain_kwargs[k] = defaults_db[k] | |
| # final check for missing | |
| missing_kwargs = [x for x in func_names if x not in langchain_kwargs] | |
| assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs | |
| # only keep actual used | |
| langchain_kwargs = {k: v for k, v in langchain_kwargs.items() if k in func_names} | |
| return _make_db(**langchain_kwargs) | |
| def save_embed(db, use_openai_embedding, hf_embedding_model): | |
| embed_info_file = os.path.join(db._persist_directory, 'embed_info') | |
| with open(embed_info_file, 'wb') as f: | |
| pickle.dump((use_openai_embedding, hf_embedding_model), f) | |
| return use_openai_embedding, hf_embedding_model | |
| def load_embed(db): | |
| embed_info_file = os.path.join(db._persist_directory, 'embed_info') | |
| if os.path.isfile(embed_info_file): | |
| with open(embed_info_file, 'rb') as f: | |
| use_openai_embedding, hf_embedding_model = pickle.load(f) | |
| else: | |
| # migration, assume defaults | |
| use_openai_embedding, hf_embedding_model = False, "sentence-transformers/all-MiniLM-L6-v2" | |
| return use_openai_embedding, hf_embedding_model | |
| def get_persist_directory(langchain_mode): | |
| return 'db_dir_%s' % langchain_mode # single place, no special names for each case | |
| def _make_db(use_openai_embedding=False, | |
| hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", | |
| first_para=False, text_limit=None, | |
| chunk=True, chunk_size=512, | |
| langchain_mode=None, | |
| user_path=None, | |
| db_type='faiss', | |
| load_db_if_exists=True, | |
| db=None, | |
| n_jobs=-1, | |
| verbose=False): | |
| persist_directory = get_persist_directory(langchain_mode) | |
| # see if can get persistent chroma db | |
| db_trial = get_existing_db(db, persist_directory, load_db_if_exists, db_type, use_openai_embedding, langchain_mode, | |
| hf_embedding_model, verbose=verbose) | |
| if db_trial is not None: | |
| db = db_trial | |
| sources = [] | |
| if not db and langchain_mode not in ['MyData'] or \ | |
| user_path is not None and \ | |
| langchain_mode in ['UserData']: | |
| # Should not make MyData db this way, why avoided, only upload from UI | |
| assert langchain_mode not in ['MyData'], "Should not make MyData db this way" | |
| if verbose: | |
| if langchain_mode in ['UserData']: | |
| if user_path is not None: | |
| print("Checking if changed or new sources in %s, and generating sources them" % user_path, | |
| flush=True) | |
| elif db is None: | |
| print("user_path not passed and no db, no sources", flush=True) | |
| else: | |
| print("user_path not passed, using only existing db, no new sources", flush=True) | |
| else: | |
| print("Generating %s sources" % langchain_mode, flush=True) | |
| if langchain_mode in ['wiki_full', 'All', "'All'"]: | |
| from read_wiki_full import get_all_documents | |
| small_test = None | |
| print("Generating new wiki", flush=True) | |
| sources1 = get_all_documents(small_test=small_test, n_jobs=os.cpu_count() // 2) | |
| print("Got new wiki", flush=True) | |
| if chunk: | |
| sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) | |
| print("Chunked new wiki", flush=True) | |
| sources.extend(sources1) | |
| if langchain_mode in ['wiki', 'All', "'All'"]: | |
| sources1 = get_wiki_sources(first_para=first_para, text_limit=text_limit) | |
| if chunk: | |
| sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) | |
| sources.extend(sources1) | |
| if langchain_mode in ['github h2oGPT', 'All', "'All'"]: | |
| # sources = get_github_docs("dagster-io", "dagster") | |
| sources1 = get_github_docs("h2oai", "h2ogpt") | |
| # FIXME: always chunk for now | |
| sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) | |
| sources.extend(sources1) | |
| if langchain_mode in ['DriverlessAI docs', 'All', "'All'"]: | |
| sources1 = get_dai_docs(from_hf=True) | |
| if chunk and False: # FIXME: DAI docs are already chunked well, should only chunk more if over limit | |
| sources1 = chunk_sources(sources1, chunk=chunk, chunk_size=chunk_size) | |
| sources.extend(sources1) | |
| if langchain_mode in ['All', 'UserData']: | |
| if user_path: | |
| if db is not None: | |
| # NOTE: Ignore file names for now, only go by hash ids | |
| # existing_files = get_existing_files(db) | |
| existing_files = [] | |
| existing_hash_ids = get_existing_hash_ids(db) | |
| else: | |
| # pretend no existing files so won't filter | |
| existing_files = [] | |
| existing_hash_ids = [] | |
| # chunk internally for speed over multiple docs | |
| # FIXME: If first had old Hash=None and switch embeddings, | |
| # then re-embed, and then hit here and reload so have hash, and then re-embed. | |
| sources1 = path_to_docs(user_path, n_jobs=n_jobs, chunk=chunk, chunk_size=chunk_size, | |
| existing_files=existing_files, existing_hash_ids=existing_hash_ids) | |
| new_metadata_sources = set([x.metadata['source'] for x in sources1]) | |
| if new_metadata_sources: | |
| print("Loaded %s new files as sources to add to UserData" % len(new_metadata_sources), flush=True) | |
| if verbose: | |
| print("Files added: %s" % '\n'.join(new_metadata_sources), flush=True) | |
| sources.extend(sources1) | |
| print("Loaded %s sources for potentially adding to UserData" % len(sources), flush=True) | |
| else: | |
| print("Chose UserData but user_path is empty/None", flush=True) | |
| if False and langchain_mode in ['urls', 'All', "'All'"]: | |
| # from langchain.document_loaders import UnstructuredURLLoader | |
| # loader = UnstructuredURLLoader(urls=urls) | |
| urls = ["https://www.birdsongsf.com/who-we-are/"] | |
| from langchain.document_loaders import PlaywrightURLLoader | |
| loader = PlaywrightURLLoader(urls=urls, remove_selectors=["header", "footer"]) | |
| sources1 = loader.load() | |
| sources.extend(sources1) | |
| if not sources: | |
| if verbose: | |
| if db is not None: | |
| print("langchain_mode %s has no new sources, nothing to add to db" % langchain_mode, flush=True) | |
| else: | |
| print("langchain_mode %s has no sources, not making new db" % langchain_mode, flush=True) | |
| return db, 0, [] | |
| if verbose: | |
| if db is not None: | |
| print("Generating db", flush=True) | |
| else: | |
| print("Adding to db", flush=True) | |
| if not db: | |
| if sources: | |
| db = get_db(sources, use_openai_embedding=use_openai_embedding, db_type=db_type, | |
| persist_directory=persist_directory, langchain_mode=langchain_mode, | |
| hf_embedding_model=hf_embedding_model) | |
| if verbose: | |
| print("Generated db", flush=True) | |
| else: | |
| print("Did not generate db since no sources", flush=True) | |
| new_sources_metadata = [x.metadata for x in sources] | |
| elif user_path is not None and langchain_mode in ['UserData']: | |
| print("Existing db, potentially adding %s sources from user_path=%s" % (len(sources), user_path), flush=True) | |
| db, num_new_sources, new_sources_metadata = add_to_db(db, sources, db_type=db_type, | |
| use_openai_embedding=use_openai_embedding, | |
| hf_embedding_model=hf_embedding_model) | |
| print("Existing db, added %s new sources from user_path=%s" % (num_new_sources, user_path), flush=True) | |
| else: | |
| new_sources_metadata = [x.metadata for x in sources] | |
| return db, len(new_sources_metadata), new_sources_metadata | |
| def get_existing_files(db): | |
| collection = db.get() | |
| metadata_sources = set([x['source'] for x in collection['metadatas']]) | |
| return metadata_sources | |
| def get_existing_hash_ids(db): | |
| collection = db.get() | |
| # assume consistency, that any prior hashed source was single hashed file at the time among all source chunks | |
| metadata_hash_ids = {x['source']: x.get('hashid') for x in collection['metadatas']} | |
| return metadata_hash_ids | |
| source_prefix = "Sources [Score | Link]:" | |
| source_postfix = "End Sources<p>" | |
| def run_qa_db(**kwargs): | |
| func_names = list(inspect.signature(_run_qa_db).parameters) | |
| # hard-coded defaults | |
| kwargs['answer_with_sources'] = True | |
| kwargs['sanitize_bot_response'] = True | |
| kwargs['show_rank'] = False | |
| missing_kwargs = [x for x in func_names if x not in kwargs] | |
| assert not missing_kwargs, "Missing kwargs: %s" % missing_kwargs | |
| # only keep actual used | |
| kwargs = {k: v for k, v in kwargs.items() if k in func_names} | |
| try: | |
| return _run_qa_db(**kwargs) | |
| finally: | |
| clear_torch_cache() | |
| def _run_qa_db(query=None, | |
| use_openai_model=False, use_openai_embedding=False, | |
| first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, | |
| user_path=None, | |
| detect_user_path_changes_every_query=False, | |
| db_type='faiss', | |
| model_name=None, model=None, tokenizer=None, | |
| hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", | |
| stream_output=False, | |
| prompter=None, | |
| prompt_type=None, | |
| prompt_dict=None, | |
| answer_with_sources=True, | |
| cut_distanct=1.1, | |
| sanitize_bot_response=True, | |
| show_rank=False, | |
| load_db_if_exists=False, | |
| db=None, | |
| do_sample=False, | |
| temperature=0.1, | |
| top_k=40, | |
| top_p=0.7, | |
| num_beams=1, | |
| max_new_tokens=256, | |
| min_new_tokens=1, | |
| early_stopping=False, | |
| max_time=180, | |
| repetition_penalty=1.0, | |
| num_return_sequences=1, | |
| langchain_mode=None, | |
| document_choice=[DocumentChoices.All_Relevant.name], | |
| n_jobs=-1, | |
| verbose=False, | |
| cli=False): | |
| """ | |
| :param query: | |
| :param use_openai_model: | |
| :param use_openai_embedding: | |
| :param first_para: | |
| :param text_limit: | |
| :param k: | |
| :param chunk: | |
| :param chunk_size: | |
| :param user_path: user path to glob recursively from | |
| :param db_type: 'faiss' for in-memory db or 'chroma' or 'weaviate' for persistent db | |
| :param model_name: model name, used to switch behaviors | |
| :param model: pre-initialized model, else will make new one | |
| :param tokenizer: pre-initialized tokenizer, else will make new one. Required not None if model is not None | |
| :param answer_with_sources | |
| :return: | |
| """ | |
| assert query is not None | |
| assert prompter is not None or prompt_type is not None or model is None # if model is None, then will generate | |
| if prompter is not None: | |
| prompt_type = prompter.prompt_type | |
| prompt_dict = prompter.prompt_dict | |
| if model is not None: | |
| assert prompt_type is not None | |
| if prompt_type == PromptType.custom.name: | |
| assert prompt_dict is not None # should at least be {} or '' | |
| else: | |
| prompt_dict = '' | |
| assert len(set(gen_hyper).difference(inspect.signature(get_llm).parameters)) == 0 | |
| llm, model_name, streamer, prompt_type_out = get_llm(use_openai_model=use_openai_model, model_name=model_name, | |
| model=model, tokenizer=tokenizer, | |
| stream_output=stream_output, | |
| do_sample=do_sample, | |
| temperature=temperature, | |
| top_k=top_k, | |
| top_p=top_p, | |
| num_beams=num_beams, | |
| max_new_tokens=max_new_tokens, | |
| min_new_tokens=min_new_tokens, | |
| early_stopping=early_stopping, | |
| max_time=max_time, | |
| repetition_penalty=repetition_penalty, | |
| num_return_sequences=num_return_sequences, | |
| prompt_type=prompt_type, | |
| prompt_dict=prompt_dict, | |
| prompter=prompter, | |
| verbose=verbose, | |
| ) | |
| if model_name in non_hf_types: | |
| # FIXME: for now, streams to stdout/stderr currently | |
| stream_output = False | |
| use_context = False | |
| scores = [] | |
| chain = None | |
| if isinstance(document_choice, str): | |
| # support string as well | |
| document_choice = [document_choice] | |
| # get first DocumentChoices as command to use, ignore others | |
| doc_choices_set = set([x.name for x in list(DocumentChoices)]) | |
| cmd = [x for x in document_choice if x in doc_choices_set] | |
| cmd = None if len(cmd) == 0 else cmd[0] | |
| # now have cmd, filter out for only docs | |
| document_choice = [x for x in document_choice if x not in doc_choices_set] | |
| func_names = list(inspect.signature(get_similarity_chain).parameters) | |
| sim_kwargs = {k: v for k, v in locals().items() if k in func_names} | |
| missing_kwargs = [x for x in func_names if x not in sim_kwargs] | |
| assert not missing_kwargs, "Missing: %s" % missing_kwargs | |
| docs, chain, scores, use_context = get_similarity_chain(**sim_kwargs) | |
| if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]: | |
| formatted_doc_chunks = '\n\n'.join([get_url(x) + '\n\n' + x.page_content for x in docs]) | |
| yield formatted_doc_chunks, '' | |
| return | |
| if chain is None and model_name not in non_hf_types: | |
| # can only return if HF type | |
| return | |
| if stream_output: | |
| answer = None | |
| assert streamer is not None | |
| import queue | |
| bucket = queue.Queue() | |
| thread = EThread(target=chain, streamer=streamer, bucket=bucket) | |
| thread.start() | |
| outputs = "" | |
| prompt = None # FIXME | |
| try: | |
| for new_text in streamer: | |
| # print("new_text: %s" % new_text, flush=True) | |
| if bucket.qsize() > 0 or thread.exc: | |
| thread.join() | |
| outputs += new_text | |
| if prompter: # and False: # FIXME: pipeline can already use prompter | |
| output1 = prompter.get_response(outputs, prompt=prompt, | |
| sanitize_bot_response=sanitize_bot_response) | |
| yield output1, '' | |
| else: | |
| yield outputs, '' | |
| except BaseException: | |
| # if any exception, raise that exception if was from thread, first | |
| if thread.exc: | |
| raise thread.exc | |
| raise | |
| finally: | |
| # in case no exception and didn't join with thread yet, then join | |
| if not thread.exc: | |
| answer = thread.join() | |
| # in case raise StopIteration or broke queue loop in streamer, but still have exception | |
| if thread.exc: | |
| raise thread.exc | |
| # FIXME: answer is not string outputs from streamer. How to get actual final output? | |
| # answer = outputs | |
| else: | |
| answer = chain() | |
| if not use_context: | |
| ret = answer['output_text'] | |
| extra = '' | |
| yield ret, extra | |
| elif answer is not None: | |
| ret, extra = get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=verbose) | |
| yield ret, extra | |
| return | |
| def get_similarity_chain(query=None, | |
| use_openai_model=False, use_openai_embedding=False, | |
| first_para=False, text_limit=None, top_k_docs=4, chunk=True, chunk_size=512, | |
| user_path=None, | |
| detect_user_path_changes_every_query=False, | |
| db_type='faiss', | |
| model_name=None, | |
| hf_embedding_model="sentence-transformers/all-MiniLM-L6-v2", | |
| prompt_type=None, | |
| prompt_dict=None, | |
| cut_distanct=1.1, | |
| load_db_if_exists=False, | |
| db=None, | |
| langchain_mode=None, | |
| document_choice=[DocumentChoices.All_Relevant.name], | |
| n_jobs=-1, | |
| # beyond run_db_query: | |
| llm=None, | |
| verbose=False, | |
| cmd=None, | |
| ): | |
| # determine whether use of context out of docs is planned | |
| if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types: | |
| if langchain_mode in ['Disabled', 'ChatLLM', 'LLM']: | |
| use_context = False | |
| else: | |
| use_context = True | |
| else: | |
| use_context = True | |
| # https://github.com/hwchase17/langchain/issues/1946 | |
| # FIXME: Seems to way to get size of chroma db to limit top_k_docs to avoid | |
| # Chroma collection MyData contains fewer than 4 elements. | |
| # type logger error | |
| k_db = 1000 if db_type == 'chroma' else top_k_docs # top_k_docs=100 works ok too for | |
| # FIXME: For All just go over all dbs instead of a separate db for All | |
| if not detect_user_path_changes_every_query and db is not None: | |
| # avoid looking at user_path during similarity search db handling, | |
| # if already have db and not updating from user_path every query | |
| # but if db is None, no db yet loaded (e.g. from prep), so allow user_path to be whatever it was | |
| user_path = None | |
| db, num_new_sources, new_sources_metadata = make_db(use_openai_embedding=use_openai_embedding, | |
| hf_embedding_model=hf_embedding_model, | |
| first_para=first_para, text_limit=text_limit, | |
| chunk=chunk, | |
| chunk_size=chunk_size, | |
| langchain_mode=langchain_mode, | |
| user_path=user_path, | |
| db_type=db_type, | |
| load_db_if_exists=load_db_if_exists, | |
| db=db, | |
| n_jobs=n_jobs, | |
| verbose=verbose) | |
| if db and use_context: | |
| if not isinstance(db, Chroma): | |
| # only chroma supports filtering | |
| filter_kwargs = {} | |
| else: | |
| # if here then some cmd + documents selected or just documents selected | |
| if len(document_choice) >= 2: | |
| or_filter = [{"source": {"$eq": x}} for x in document_choice] | |
| filter_kwargs = dict(filter={"$or": or_filter}) | |
| elif len(document_choice) == 1: | |
| # degenerate UX bug in chroma | |
| one_filter = [{"source": {"$eq": x}} for x in document_choice][0] | |
| filter_kwargs = dict(filter=one_filter) | |
| else: | |
| # shouldn't reach | |
| filter_kwargs = {} | |
| if cmd == DocumentChoices.Just_LLM.name: | |
| docs = [] | |
| scores = [] | |
| elif cmd == DocumentChoices.Only_All_Sources.name: | |
| if isinstance(db, Chroma): | |
| db_get = db._collection.get(where=filter_kwargs.get('filter')) | |
| else: | |
| db_get = db.get() | |
| # similar to langchain's chroma's _results_to_docs_and_scores | |
| docs_with_score = [(Document(page_content=result[0], metadata=result[1] or {}), 0) | |
| for result in zip(db_get['documents'], db_get['metadatas'])][:top_k_docs] | |
| docs = [x[0] for x in docs_with_score] | |
| scores = [x[1] for x in docs_with_score] | |
| else: | |
| docs_with_score = db.similarity_search_with_score(query, k=k_db, **filter_kwargs)[:top_k_docs] | |
| # cut off so no high distance docs/sources considered | |
| docs = [x[0] for x in docs_with_score if x[1] < cut_distanct] | |
| scores = [x[1] for x in docs_with_score if x[1] < cut_distanct] | |
| if len(scores) > 0 and verbose: | |
| print("Distance: min: %s max: %s mean: %s median: %s" % | |
| (scores[0], scores[-1], np.mean(scores), np.median(scores)), flush=True) | |
| else: | |
| docs = [] | |
| scores = [] | |
| if not docs and use_context and model_name not in non_hf_types: | |
| # if HF type and have no docs, can bail out | |
| return docs, None, [], False | |
| if cmd in [DocumentChoices.All_Relevant_Only_Sources.name, DocumentChoices.Only_All_Sources.name]: | |
| # no LLM use | |
| return docs, None, [], False | |
| common_words_file = "data/NGSL_1.2_stats.csv.zip" | |
| if os.path.isfile(common_words_file): | |
| df = pd.read_csv("data/NGSL_1.2_stats.csv.zip") | |
| import string | |
| reduced_query = query.translate(str.maketrans(string.punctuation, ' ' * len(string.punctuation))).strip() | |
| reduced_query_words = reduced_query.split(' ') | |
| set_common = set(df['Lemma'].values.tolist()) | |
| num_common = len([x.lower() in set_common for x in reduced_query_words]) | |
| frac_common = num_common / len(reduced_query) if reduced_query else 0 | |
| # FIXME: report to user bad query that uses too many common words | |
| if verbose: | |
| print("frac_common: %s" % frac_common, flush=True) | |
| if len(docs) == 0: | |
| # avoid context == in prompt then | |
| use_context = False | |
| if not use_openai_model and prompt_type not in ['plain'] or model_name in non_hf_types: | |
| # instruct-like, rather than few-shot prompt_type='plain' as default | |
| # but then sources confuse the model with how inserted among rest of text, so avoid | |
| prefix = "" | |
| if langchain_mode in ['Disabled', 'ChatLLM', 'LLM'] or not use_context: | |
| template = """%s{context}{question}""" % prefix | |
| else: | |
| template = """%s | |
| == | |
| {context} | |
| == | |
| {question}""" % prefix | |
| prompt = PromptTemplate( | |
| # input_variables=["summaries", "question"], | |
| input_variables=["context", "question"], | |
| template=template, | |
| ) | |
| chain = load_qa_chain(llm, prompt=prompt) | |
| else: | |
| chain = load_qa_with_sources_chain(llm) | |
| if not use_context: | |
| chain_kwargs = dict(input_documents=[], question=query) | |
| else: | |
| chain_kwargs = dict(input_documents=docs, question=query) | |
| target = wrapped_partial(chain, chain_kwargs) | |
| return docs, target, scores, use_context | |
| def get_sources_answer(query, answer, scores, show_rank, answer_with_sources, verbose=False): | |
| if verbose: | |
| print("query: %s" % query, flush=True) | |
| print("answer: %s" % answer['output_text'], flush=True) | |
| if len(answer['input_documents']) == 0: | |
| extra = '' | |
| ret = answer['output_text'] + extra | |
| return ret, extra | |
| # link | |
| answer_sources = [(max(0.0, 1.5 - score) / 1.5, get_url(doc)) for score, doc in | |
| zip(scores, answer['input_documents'])] | |
| answer_sources_dict = defaultdict(list) | |
| [answer_sources_dict[url].append(score) for score, url in answer_sources] | |
| answers_dict = {} | |
| for url, scores_url in answer_sources_dict.items(): | |
| answers_dict[url] = np.max(scores_url) | |
| answer_sources = [(score, url) for url, score in answers_dict.items()] | |
| answer_sources.sort(key=lambda x: x[0], reverse=True) | |
| if show_rank: | |
| # answer_sources = ['%d | %s' % (1 + rank, url) for rank, (score, url) in enumerate(answer_sources)] | |
| # sorted_sources_urls = "Sources [Rank | Link]:<br>" + "<br>".join(answer_sources) | |
| answer_sources = ['%s' % url for rank, (score, url) in enumerate(answer_sources)] | |
| sorted_sources_urls = "Ranked Sources:<br>" + "<br>".join(answer_sources) | |
| else: | |
| answer_sources = ['<li>%.2g | %s</li>' % (score, url) for score, url in answer_sources] | |
| sorted_sources_urls = f"{source_prefix}<p><ul>" + "<p>".join(answer_sources) | |
| sorted_sources_urls += f"</ul></p>{source_postfix}" | |
| if not answer['output_text'].endswith('\n'): | |
| answer['output_text'] += '\n' | |
| if answer_with_sources: | |
| extra = '\n' + sorted_sources_urls | |
| else: | |
| extra = '' | |
| ret = answer['output_text'] + extra | |
| return ret, extra | |
| def chunk_sources(sources, chunk=True, chunk_size=512): | |
| if not chunk: | |
| return sources | |
| source_chunks = [] | |
| # Below for known separator | |
| # splitter = CharacterTextSplitter(separator=" ", chunk_size=chunk_size, chunk_overlap=0) | |
| splitter = RecursiveCharacterTextSplitter(chunk_size=chunk_size, chunk_overlap=0) | |
| for source in sources: | |
| # print(source.metadata['source'], flush=True) | |
| for chunky in splitter.split_text(source.page_content): | |
| source_chunks.append(Document(page_content=chunky, metadata=source.metadata)) | |
| return source_chunks | |
| def get_db_from_hf(dest=".", db_dir='db_dir_DriverlessAI_docs.zip'): | |
| from huggingface_hub import hf_hub_download | |
| # True for case when locally already logged in with correct token, so don't have to set key | |
| token = os.getenv('HUGGINGFACE_API_TOKEN', True) | |
| path_to_zip_file = hf_hub_download('h2oai/db_dirs', db_dir, token=token, repo_type='dataset') | |
| import zipfile | |
| with zipfile.ZipFile(path_to_zip_file, 'r') as zip_ref: | |
| persist_directory = os.path.dirname(zip_ref.namelist()[0]) | |
| remove(persist_directory) | |
| zip_ref.extractall(dest) | |
| return path_to_zip_file | |
| # Note dir has space in some cases, while zip does not | |
| some_db_zips = [['db_dir_DriverlessAI_docs.zip', 'db_dir_DriverlessAI docs', 'CC-BY-NC license'], | |
| ['db_dir_UserData.zip', 'db_dir_UserData', 'CC-BY license for ArXiv'], | |
| ['db_dir_github_h2oGPT.zip', 'db_dir_github h2oGPT', 'ApacheV2 license'], | |
| ['db_dir_wiki.zip', 'db_dir_wiki', 'CC-BY-SA Wikipedia license'], | |
| # ['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], | |
| ] | |
| all_db_zips = some_db_zips + \ | |
| [['db_dir_wiki_full.zip', 'db_dir_wiki_full.zip', '23GB, 05/04/2023 CC-BY-SA Wiki license'], | |
| ] | |
| def get_some_dbs_from_hf(dest='.', db_zips=None): | |
| if db_zips is None: | |
| db_zips = some_db_zips | |
| for db_dir, dir_expected, license1 in db_zips: | |
| path_to_zip_file = get_db_from_hf(dest=dest, db_dir=db_dir) | |
| assert os.path.isfile(path_to_zip_file), "Missing zip in %s" % path_to_zip_file | |
| if dir_expected: | |
| assert os.path.isdir(os.path.join(dest, dir_expected)), "Missing path for %s" % dir_expected | |
| assert os.path.isdir(os.path.join(dest, dir_expected, 'index')), "Missing index in %s" % dir_expected | |
| def _create_local_weaviate_client(): | |
| WEAVIATE_URL = os.getenv('WEAVIATE_URL', "http://localhost:8080") | |
| WEAVIATE_USERNAME = os.getenv('WEAVIATE_USERNAME') | |
| WEAVIATE_PASSWORD = os.getenv('WEAVIATE_PASSWORD') | |
| WEAVIATE_SCOPE = os.getenv('WEAVIATE_SCOPE', "offline_access") | |
| resource_owner_config = None | |
| if WEAVIATE_USERNAME is not None and WEAVIATE_PASSWORD is not None: | |
| resource_owner_config = weaviate.AuthClientPassword( | |
| username=WEAVIATE_USERNAME, | |
| password=WEAVIATE_PASSWORD, | |
| scope=WEAVIATE_SCOPE | |
| ) | |
| try: | |
| client = weaviate.Client(WEAVIATE_URL, auth_client_secret=resource_owner_config) | |
| except Exception as e: | |
| print(f"Failed to create Weaviate client: {e}") | |
| return None | |
| if __name__ == '__main__': | |
| pass | |