Spaces:
Runtime error
Runtime error
| """This file should be imported only and only if you want to run the UI locally.""" | |
| import itertools | |
| import logging | |
| from pathlib import Path | |
| from typing import Any | |
| import gradio as gr | |
| from fastapi import FastAPI | |
| from gradio.themes.utils.colors import slate | |
| from llama_index.llms import ChatMessage, MessageRole | |
| from app._config import settings | |
| from app.components.embedding.component import EmbeddingComponent | |
| from app.components.llm.component import LLMComponent | |
| from app.components.node_store.component import NodeStoreComponent | |
| from app.components.vector_store.component import VectorStoreComponent | |
| from app.enums import PROJECT_ROOT_PATH | |
| from app.server.chat.service import ChatService | |
| from app.server.ingest.service import IngestService | |
| from app.ui.schemas import Source | |
| logger = logging.getLogger(__name__) | |
| THIS_DIRECTORY_RELATIVE = Path(__file__).parent.relative_to(PROJECT_ROOT_PATH) | |
| AVATAR_BOT = THIS_DIRECTORY_RELATIVE / "dodge_ava.jpg" | |
| UI_TAB_TITLE = "Agriculture Chatbot" | |
| SOURCES_SEPARATOR = "\n\n Sources: \n" | |
| class PrivateGptUi: | |
| def __init__( | |
| self, | |
| ingest_service: IngestService, | |
| chat_service: ChatService, | |
| ) -> None: | |
| self._ingest_service = ingest_service | |
| self._chat_service = chat_service | |
| # Cache the UI blocks | |
| self._ui_block = None | |
| # Initialize system prompt | |
| self._system_prompt = self._get_default_system_prompt() | |
| def _chat(self, message: str, history: list[list[str]], *_: Any) -> Any: | |
| def build_history() -> list[ChatMessage]: | |
| history_messages: list[ChatMessage] = list( | |
| itertools.chain( | |
| *[ | |
| [ | |
| ChatMessage(content=interaction[0], role=MessageRole.USER), | |
| ChatMessage( | |
| # Remove from history content the Sources information | |
| content=interaction[1].split(SOURCES_SEPARATOR)[0], | |
| role=MessageRole.ASSISTANT, | |
| ), | |
| ] | |
| for interaction in history | |
| ] | |
| ) | |
| ) | |
| # max 20 messages to try to avoid context overflow | |
| return history_messages[:20] | |
| new_message = ChatMessage(content=message, role=MessageRole.USER) | |
| all_messages = [*build_history(), new_message] | |
| # If a system prompt is set, add it as a system message | |
| if self._system_prompt: | |
| all_messages.insert( | |
| 0, | |
| ChatMessage( | |
| content=self._system_prompt, | |
| role=MessageRole.SYSTEM, | |
| ), | |
| ) | |
| completion = self._chat_service.chat(messages=all_messages) | |
| full_response = completion.response | |
| if completion.sources: | |
| full_response += SOURCES_SEPARATOR | |
| curated_sources = Source.curate_sources(completion.sources) | |
| sources_text = "\n\n\n".join( | |
| f"{index}. {source.file} (page {source.page})" | |
| for index, source in enumerate(curated_sources, start=1) | |
| ) | |
| full_response += sources_text | |
| return full_response | |
| # On initialization this function set the system prompt | |
| # to the default prompt based on settings. | |
| def _get_default_system_prompt() -> str: | |
| return settings.DEFAULT_QUERY_SYSTEM_PROMPT | |
| def _set_system_prompt(self, system_prompt_input: str) -> None: | |
| logger.info(f"Setting system prompt to: {system_prompt_input}") | |
| self._system_prompt = system_prompt_input | |
| def _list_ingested_files(self) -> list[list[str]]: | |
| files = set() | |
| for ingested_document in self._ingest_service.list_ingested(): | |
| if ingested_document.doc_metadata is None: | |
| # Skipping documents without metadata | |
| continue | |
| file_name = ingested_document.doc_metadata.get( | |
| "file_name", "[FILE NAME MISSING]" | |
| ) | |
| files.add(file_name) | |
| return [[row] for row in files] | |
| def _upload_file(self, files: list[str]) -> None: | |
| logger.debug("Loading count=%s files", len(files)) | |
| paths = [Path(file) for file in files] | |
| self._ingest_service.bulk_ingest([(str(path.name), path) for path in paths]) | |
| def _build_ui_blocks(self) -> gr.Blocks: | |
| logger.debug("Creating the UI blocks") | |
| with gr.Blocks( | |
| title=UI_TAB_TITLE, | |
| theme=gr.themes.Soft(primary_hue=slate), | |
| css=".logo { " | |
| "display:flex;" | |
| "height: 80px;" | |
| "border-radius: 8px;" | |
| "align-content: center;" | |
| "justify-content: center;" | |
| "align-items: center;" | |
| "}" | |
| ".logo img { height: 25% }" | |
| ".contain { display: flex !important; flex-direction: column !important; }" | |
| "#component-0, #component-3, #component-10, #component-8 { height: 100% !important; }" | |
| "#chatbot { flex-grow: 1 !important; overflow: auto !important;}" | |
| "#col { height: calc(100vh - 112px - 16px) !important; }", | |
| ) as blocks: | |
| with gr.Row(): | |
| gr.HTML(f"<div class='logo'/><h1>{UI_TAB_TITLE}</h1></div") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=3): | |
| upload_button = gr.components.UploadButton( | |
| "Upload File(s)", | |
| type="filepath", | |
| file_count="multiple", | |
| size="sm", | |
| ) | |
| ingested_dataset = gr.List( | |
| self._list_ingested_files, | |
| headers=["File name"], | |
| label="Ingested Files", | |
| interactive=False, | |
| render=False, # Rendered under the button | |
| ) | |
| upload_button.upload( | |
| self._upload_file, | |
| inputs=upload_button, | |
| outputs=ingested_dataset, | |
| ) | |
| ingested_dataset.change( | |
| self._list_ingested_files, | |
| outputs=ingested_dataset, | |
| ) | |
| ingested_dataset.render() | |
| system_prompt_input = gr.Textbox( | |
| placeholder=self._system_prompt, | |
| label="System Prompt", | |
| lines=2, | |
| interactive=True, | |
| render=False, | |
| ) | |
| # On blur, set system prompt to use in queries | |
| system_prompt_input.blur( | |
| self._set_system_prompt, | |
| inputs=system_prompt_input, | |
| ) | |
| with gr.Column(scale=7, elem_id="col"): | |
| _ = gr.ChatInterface( | |
| self._chat, | |
| chatbot=gr.Chatbot( | |
| label=f"LLM: {settings.LLM_MODE}", | |
| show_copy_button=True, | |
| elem_id="chatbot", | |
| render=False, | |
| avatar_images=( | |
| None, | |
| AVATAR_BOT, | |
| ), | |
| ), | |
| additional_inputs=[upload_button, system_prompt_input], | |
| ) | |
| return blocks | |
| def get_ui_blocks(self) -> gr.Blocks: | |
| if self._ui_block is None: | |
| self._ui_block = self._build_ui_blocks() | |
| return self._ui_block | |
| def mount_in_app(self, app: FastAPI, path: str) -> None: | |
| blocks = self.get_ui_blocks() | |
| blocks.queue() | |
| logger.info("Mounting the gradio UI, at path=%s", path) | |
| gr.mount_gradio_app(app, blocks, path=path) | |
| if __name__ == "__main__": | |
| llm_component = LLMComponent() | |
| vector_store_component = VectorStoreComponent() | |
| embedding_component = EmbeddingComponent() | |
| node_store_component = NodeStoreComponent() | |
| ingest_service = IngestService( | |
| llm_component, vector_store_component, embedding_component, node_store_component | |
| ) | |
| chat_service = ChatService( | |
| llm_component, vector_store_component, embedding_component, node_store_component | |
| ) | |
| ui = PrivateGptUi(ingest_service, chat_service) | |
| _blocks = ui.get_ui_blocks() | |
| _blocks.queue() | |
| _blocks.launch(debug=False, show_api=False) | |