Spaces:
Build error
Build error
| # main.py | |
| import spaces | |
| from torch.nn import DataParallel | |
| from torch import Tensor | |
| from transformers import AutoTokenizer, AutoModel | |
| from huggingface_hub import InferenceClient | |
| from openai import OpenAI | |
| from langchain_community.document_loaders import UnstructuredFileLoader | |
| from langchain_chroma import Chroma | |
| from chromadb import Documents, EmbeddingFunction, Embeddings | |
| from chromadb.config import Settings | |
| from chromadb import HttpClient | |
| import os | |
| import re | |
| import uuid | |
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from dotenv import load_dotenv | |
| from utils import load_env_variables, parse_and_route | |
| from globalvars import API_BASE, intention_prompt, tasks, system_message, model_name , metadata_prompt | |
| load_dotenv() | |
| os.environ['PYTORCH_CUDA_ALLOC_CONF'] = 'max_split_size_mb:30' | |
| os.environ['CUDA_LAUNCH_BLOCKING'] = '1' | |
| os.environ['CUDA_CACHE_DISABLE'] = '1' | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| ### Utils | |
| hf_token, yi_token = load_env_variables() | |
| def clear_cuda_cache(): | |
| torch.cuda.empty_cache() | |
| client = OpenAI(api_key=yi_token, base_url=API_BASE) | |
| class EmbeddingGenerator: | |
| def __init__(self, model_name: str, token: str, intention_client): | |
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| self.tokenizer = AutoTokenizer.from_pretrained(model_name, token=token, trust_remote_code=True) | |
| self.model = AutoModel.from_pretrained(model_name, token=token, trust_remote_code=True).to(self.device) | |
| self.intention_client = intention_client | |
| def clear_cuda_cache(self): | |
| torch.cuda.empty_cache() | |
| def compute_embeddings(self, input_text: str): | |
| # Get the intention | |
| intention_completion = self.intention_client.chat.completions.create( | |
| model="yi-large", | |
| messages=[ | |
| {"role": "system", "content": intention_prompt}, | |
| {"role": "user", "content": input_text} | |
| ] | |
| ) | |
| intention_output = intention_completion.choices[0].message['content'] | |
| # Parse and route the intention | |
| parsed_task = parse_and_route(intention_output) | |
| selected_task = list(parsed_task.keys())[0] | |
| # Construct the prompt | |
| try: | |
| task_description = tasks[selected_task] | |
| except KeyError: | |
| print(f"Selected task not found: {selected_task}") | |
| return f"Error: Task '{selected_task}' not found. Please select a valid task." | |
| query_prefix = f"Instruct: {task_description}\nQuery: " | |
| queries = [input_text] | |
| # Get the metadata | |
| metadata_completion = self.intention_client.chat.completions.create( | |
| model="yi-large", | |
| messages=[ | |
| {"role": "system", "content": metadata_prompt}, | |
| {"role": "user", "content": input_text} | |
| ] | |
| ) | |
| metadata_output = metadata_completion.choices[0].message['content'] | |
| metadata = self.extract_metadata(metadata_output) | |
| # Get the embeddings | |
| with torch.no_grad(): | |
| inputs = self.tokenizer(queries, return_tensors='pt', padding=True, truncation=True, max_length=4096).to(self.device) | |
| outputs = self.model(**inputs) | |
| query_embeddings = outputs.last_hidden_state.mean(dim=1) | |
| # Normalize embeddings | |
| query_embeddings = F.normalize(query_embeddings, p=2, dim=1) | |
| embeddings_list = query_embeddings.detach().cpu().numpy().tolist() | |
| # Include metadata in the embeddings | |
| embeddings_with_metadata = [{"embedding": emb, "metadata": metadata} for emb in embeddings_list] | |
| self.clear_cuda_cache() | |
| return embeddings_with_metadata | |
| def extract_metadata(self, metadata_output: str): | |
| # Regex pattern to extract key-value pairs | |
| pattern = re.compile(r'\"(\w+)\": \"([^\"]+)\"') | |
| matches = pattern.findall(metadata_output) | |
| metadata = {key: value for key, value in matches} | |
| return metadata | |
| class MyEmbeddingFunction(EmbeddingFunction): | |
| def __init__(self, embedding_generator: EmbeddingGenerator): | |
| self.embedding_generator = embedding_generator | |
| def __call__(self, input: Documents) -> Embeddings: | |
| embeddings = [self.embedding_generator.compute_embeddings(doc) for doc in input] | |
| embeddings = [item for sublist in embeddings for item in sublist] | |
| return embeddings | |
| def load_documents(file_path: str, mode: str = "elements"): | |
| loader = UnstructuredFileLoader(file_path, mode=mode) | |
| docs = loader.load() | |
| return [doc.page_content for doc in docs] | |
| def initialize_chroma(collection_name: str, embedding_function: MyEmbeddingFunction): | |
| client = HttpClient(host='localhost', port=8000, settings = Settings(allow_reset=True, anonymized_telemetry=False)) | |
| client.reset() # resets the database | |
| collection = client.create_collection(collection_name) | |
| return client, collection | |
| def add_documents_to_chroma(client, collection, documents: list, embedding_function: MyEmbeddingFunction): | |
| for doc in documents: | |
| collection.add(ids=[str(uuid.uuid1())], documents=[doc], embeddings=embedding_function([doc])) | |
| def query_chroma(client, collection_name: str, query_text: str, embedding_function: MyEmbeddingFunction): | |
| db = Chroma(client=client, collection_name=collection_name, embedding_function=embedding_function) | |
| result_docs = db.similarity_search(query_text) | |
| return result_docs | |
| # Initialize clients | |
| intention_client = OpenAI(api_key=yi_token, base_url=API_BASE) | |
| embedding_generator = EmbeddingGenerator(model_name=model_name, token=hf_token, intention_client=intention_client) | |
| embedding_function = MyEmbeddingFunction(embedding_generator=embedding_generator) | |
| chroma_client, chroma_collection = initialize_chroma(collection_name="Tonic-instruct", embedding_function=embedding_function) | |
| def respond( | |
| message, | |
| history: list[tuple[str, str]], | |
| system_message, | |
| max_tokens, | |
| temperature, | |
| top_p, | |
| ): | |
| retrieved_text = query_documents(message) | |
| messages = [{"role": "system", "content": system_message}] | |
| for val in history: | |
| if val[0]: | |
| messages.append({"role": "user", "content": val[0]}) | |
| if val[1]: | |
| messages.append({"role": "assistant", "content": val[1]}) | |
| messages.append({"role": "user", "content": f"{retrieved_text}\n\n{message}"}) | |
| response = "" | |
| for message in intention_client.chat_completion( | |
| messages, | |
| max_tokens=max_tokens, | |
| stream=True, | |
| temperature=temperature, | |
| top_p=top_p, | |
| ): | |
| token = message.choices[0].delta.content | |
| response += token | |
| yield response | |
| def upload_documents(files): | |
| for file in files: | |
| loader = UnstructuredFileLoader(file.name) | |
| documents = loader.load_documents() | |
| add_documents_to_chroma(documents) | |
| return "Documents uploaded and processed successfully!" | |
| def query_documents(query): | |
| results = query_chroma(query) | |
| return "\n\n".join([result.content for result in results]) | |
| with gr.Blocks() as demo: | |
| with gr.Tab("Upload Documents"): | |
| with gr.Row(): | |
| document_upload = gr.File(file_count="multiple", file_types=["document"]) | |
| upload_button = gr.Button("Upload and Process") | |
| upload_button.click(upload_documents, inputs=document_upload, outputs=gr.Text()) | |
| with gr.Tab("Ask Questions"): | |
| with gr.Row(): | |
| chat_interface = gr.ChatInterface( | |
| respond, | |
| additional_inputs=[ | |
| gr.Textbox(value="You are a friendly Chatbot.", label="System message"), | |
| gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"), | |
| gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"), | |
| gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)"), | |
| ], | |
| ) | |
| query_input = gr.Textbox(label="Query") | |
| query_button = gr.Button("Query") | |
| query_output = gr.Textbox() | |
| query_button.click(query_documents, inputs=query_input, outputs=query_output) | |
| if __name__ == "__main__": | |
| demo.launch() | |