| import os | |
| import gradio as gr | |
| from openai import AzureOpenAI | |
| from langchain.text_splitter import RecursiveCharacterTextSplitter | |
| from langchain_community.document_loaders import PyPDFLoader | |
| from langchain_community.embeddings.sentence_transformer import SentenceTransformerEmbeddings | |
| from langchain_community.vectorstores import Chroma | |
| client = AzureOpenAI( | |
| azure_endpoint=os.environ['AZURE_OPENAI_ENDPOINT'], | |
| api_key=os.environ['AZURE_OPENAI_KEY'], | |
| api_version="2023-05-15" | |
| ) | |
| chat_model_deployment_name = "gpt-35-turbo" | |
| embedding_model = SentenceTransformerEmbeddings(model_name='thenlper/gte-small') | |
| text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( | |
| encoding_name='cl100k_base', | |
| chunk_size=512, | |
| chunk_overlap=16 | |
| ) | |
| pdf_file = "tsla-20221231-gen.pdf" | |
| pdf_loader = PyPDFLoader(pdf_file) | |
| tesla_10k_chunks_ada = pdf_loader.load_and_split(text_splitter) | |
| tesla_10k_collection = 'tesla-10k-2022' | |
| vectorstore = Chroma.from_documents( | |
| tesla_10k_chunks_ada, | |
| embedding_model, | |
| collection_name=tesla_10k_collection | |
| ) | |
| retriever = vectorstore.as_retriever( | |
| search_type='similarity', | |
| search_kwargs={'k': 5} | |
| ) | |
| qna_system_message = """ | |
| You are an assistant to a financial services firm who answers user queries on annual reports. | |
| Users will ask questions delimited by triple backticks, that is, ```. | |
| User input will have the context required by you to answer user questions. | |
| This context will begin with the token: ###Context. | |
| The context contains references to specific portions of a document relevant to the user query. | |
| Please answer only using the context provided in the input. | |
| If the answer is not found in the context, respond "I don't know". | |
| """ | |
| qna_user_message_template = """ | |
| ###Context | |
| Here are some documents that are relevant to the question. | |
| {context} | |
| ``` | |
| {question} | |
| ``` | |
| """ | |
| def predict(user_input): | |
| relevant_document_chunks = retriever.get_relevant_documents(user_input) | |
| context_list = [d.page_content for d in relevant_document_chunks] | |
| context_for_query = ".".join(context_list) | |
| prompt = [ | |
| {'role':'system', 'content': qna_system_message}, | |
| {'role': 'user', 'content': qna_user_message_template.format( | |
| context=context_for_query, | |
| question=user_input | |
| ) | |
| } | |
| ] | |
| try: | |
| response = client.chat.completions.create( | |
| model=chat_model_deployment_name, | |
| messages=prompt, | |
| temperature=0 | |
| ) | |
| prediction = response.choices[0].message.content | |
| except Exception as e: | |
| prediction = e | |
| return prediction | |
| textbox = gr.Textbox(placeholder="Enter your query here", lines=6) | |
| interface = gr.Interface( | |
| inputs=textbox, fn=predict, outputs="text", | |
| title="AMA on Tesla 2022 10-K", | |
| description="This web API presents an interface to ask questions on contents of the Tesla 2022 10-K report.", | |
| article="Note that questions that are not relevent to the Tesla 10-K report will not be answered.", | |
| allow_flagging="manual", flagging_options=["Useful", "Not Useful"] | |
| ) | |
| with gr.Blocks() as demo: | |
| interface.launch() | |
| demo.queue(concurrency_count=16) | |
| demo.launch() |