Spaces:
Paused
Paused
| import os | |
| from typing import List | |
| import chainlit as cl | |
| from llama_index.callbacks.base import CallbackManager | |
| from llama_index import ( | |
| ServiceContext, | |
| StorageContext, | |
| load_index_from_storage, | |
| ) | |
| from llama_index.llms import OpenAI | |
| from llama_index.postprocessor.cohere_rerank import CohereRerank | |
| from llama_index.tools import QueryEngineTool, ToolMetadata | |
| from llama_index.query_engine import SubQuestionQueryEngine | |
| from llama_index.embeddings import HuggingFaceEmbedding | |
| from chainlit.types import AskFileResponse | |
| from llama_index import download_loader | |
| from llama_index import VectorStoreIndex | |
| def process_file(file: AskFileResponse): | |
| import tempfile | |
| with tempfile.NamedTemporaryFile(mode="w", delete=False) as tempfile: | |
| with open(tempfile.name, "wb") as f: | |
| f.write(file.content) | |
| PDFReader = download_loader("PDFReader") | |
| loader = PDFReader() | |
| documents = loader.load_data(tempfile.name) | |
| return documents | |
| async def on_chat_start(): | |
| files = None | |
| # Wait for the user to upload a file | |
| while files == None: | |
| files = await cl.AskFileMessage( | |
| content="Please upload a PDF file to begin!", | |
| accept=["application/pdf"], | |
| max_size_mb=20, | |
| timeout=180, | |
| ).send() | |
| file = files[0] | |
| msg = cl.Message( | |
| content=f"Processing `{file.name}`...", disable_human_feedback=True | |
| ) | |
| await msg.send() | |
| # load the file | |
| documents = process_file(file) | |
| context = ServiceContext.from_defaults( | |
| embed_model=HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned") | |
| ) | |
| index = VectorStoreIndex.from_documents( | |
| documents=documents, context=context, show_progress=True | |
| ) | |
| llm = OpenAI(model="gpt-4-1106-preview", temperature=0) | |
| embed_model = HuggingFaceEmbedding(model_name="ai-maker-space/chatlgo-finetuned") | |
| service_context = ServiceContext.from_defaults( | |
| embed_model=embed_model, | |
| llm=llm, | |
| ) | |
| cohere_rerank = CohereRerank(top_n=5) | |
| query_engine = index.as_query_engine( | |
| similarity_top_k=10, | |
| node_postprocessors=[cohere_rerank], | |
| service_context=service_context, | |
| ) | |
| query_engine_tools = [ | |
| QueryEngineTool( | |
| query_engine=query_engine, | |
| metadata=ToolMetadata( | |
| name="mit_theses", | |
| description="A collection of MIT theses.", | |
| ), | |
| ), | |
| ] | |
| query_engine = SubQuestionQueryEngine.from_defaults( | |
| query_engine_tools=query_engine_tools, | |
| service_context=service_context, | |
| ) | |
| cl.user_session.set("query_engine", query_engine) | |
| async def main(message: cl.Message): | |
| query_engine = cl.user_session.get("query_engine") | |
| response = await cl.make_async(query_engine.query)(message.content) | |
| response_message = cl.Message(content=str(response)) | |
| await response_message.send() | |