Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| from datasets import load_dataset | |
| import os | |
| import spaces | |
| from transformers import AutoModelForCausalLM, AutoTokenizer, TextIteratorStreamer | |
| import torch | |
| from threading import Thread | |
| from sentence_transformers import SentenceTransformer | |
| from datasets import load_dataset | |
| token = os.environ["HF_TOKEN"] | |
| model = AutoModelForCausalLM.from_pretrained( | |
| "google/gemma-7b-it", | |
| # torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
| torch_dtype=torch.float16, | |
| token=token, | |
| ) | |
| tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b-it", token=token) | |
| device = torch.device("cuda") | |
| model = model.to(device) | |
| RAG = SentenceTransformer("mixedbread-ai/mxbai-embed-large-v1") | |
| TOP_K = 3 | |
| # prepare data | |
| # since data is too big we will only select the first 3K lines | |
| data = load_dataset("not-lain/wikipedia-small-3000-embedded", split="train") | |
| # index dataset | |
| data.add_faiss_index("embedding") | |
| def search(query: str, k: int = TOP_K): | |
| embedded_query = RAG.encode(query) | |
| scores, retrieved_examples = data.get_nearest_examples( | |
| "embedding", embedded_query, k=k | |
| ) | |
| return retrieved_examples | |
| def prepare_prompt(query, retrieved_examples): | |
| prompt = ( | |
| f"Query: {query}\nContinue to answer the query by using the Search Results:\n" | |
| ) | |
| urls = [] | |
| titles = retrieved_examples["title"][::-1] | |
| texts = retrieved_examples["text"][::-1] | |
| urls = retrieved_examples["url"][::-1] | |
| titles = titles[::-1] | |
| for i in range(TOP_K): | |
| prompt += f"Title: {titles[i]}, Text: {texts[i]}\n" | |
| return prompt, (titles, urls) | |
| def talk(message, history): | |
| retrieved_examples = search(message) | |
| message, metadata = prepare_prompt(message, retrieved_examples) | |
| resources = "\nRESOURCES:\n" | |
| for title, url in metadata: | |
| resources += f"[{title}]({url}), " | |
| chat = [] | |
| for item in history: | |
| chat.append({"role": "user", "content": item[0]}) | |
| if item[1] is not None: | |
| cleaned_past = item[1].split("\nRESOURCES:\n")[0] | |
| chat.append({"role": "assistant", "content": cleaned_past}) | |
| chat.append({"role": "user", "content": message}) | |
| messages = tokenizer.apply_chat_template( | |
| chat, tokenize=False, add_generation_prompt=True | |
| ) | |
| # Tokenize the messages string | |
| model_inputs = tokenizer([messages], return_tensors="pt").to(device) | |
| streamer = TextIteratorStreamer( | |
| tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True | |
| ) | |
| generate_kwargs = dict( | |
| model_inputs, | |
| streamer=streamer, | |
| max_new_tokens=1024, | |
| do_sample=True, | |
| top_p=0.95, | |
| top_k=1000, | |
| temperature=0.75, | |
| num_beams=1, | |
| ) | |
| t = Thread(target=model.generate, kwargs=generate_kwargs) | |
| t.start() | |
| # Initialize an empty string to store the generated text | |
| partial_text = "" | |
| for new_text in streamer: | |
| partial_text += new_text | |
| yield partial_text | |
| partial_text += resources | |
| yield partial_text | |
| TITLE = "RAG" | |
| DESCRIPTION = """ | |
| A rag pipeline with a chatbot feature | |
| Resources used to build this project : | |
| * embedding model : https://huggingface.co/mixedbread-ai/mxbai-embed-large-v1 | |
| * dataset : https://huggingface.co/datasets/not-lain/wikipedia-small-3000-embedded (used mxbai-colbert-large-v1 to create the embedding column ) | |
| * faiss docs : https://huggingface.co/docs/datasets/v2.18.0/en/package_reference/main_classes#datasets.Dataset.add_faiss_index | |
| * chatbot : google/gemma-7b-it | |
| If you want to support my work please click on the heart react button β€οΈπ€ | |
| <sub><sup><sub><sup>psst, I am still open for work if please reach me out at https://not-lain.github.io/</sup></sub></sup></sub> | |
| """ | |
| demo = gr.ChatInterface( | |
| fn=talk, | |
| chatbot=gr.Chatbot( | |
| show_label=True, | |
| show_share_button=True, | |
| show_copy_button=True, | |
| likeable=True, | |
| layout="bubble", | |
| bubble_full_width=False, | |
| ), | |
| theme="Soft", | |
| examples=[["what is machine learning"]], | |
| title=TITLE, | |
| description=DESCRIPTION, | |
| ) | |
| demo.launch() | |