Spaces:
Sleeping
Sleeping
| import streamlit as st | |
| from medrag_multi_modal.assistant import LLMClient, MedQAAssistant | |
| from medrag_multi_modal.retrieval.text_retrieval import ( | |
| BM25sRetriever, | |
| ContrieverRetriever, | |
| MedCPTRetriever, | |
| NVEmbed2Retriever, | |
| ) | |
| # Define constants | |
| ALL_AVAILABLE_MODELS = [ | |
| "gemini-1.5-flash-latest", | |
| "gemini-1.5-pro-latest", | |
| "gpt-4o", | |
| "gpt-4o-mini", | |
| ] | |
| # Sidebar for configuration settings | |
| st.sidebar.title("Configuration Settings") | |
| project_name = st.sidebar.text_input( | |
| label="Project Name", | |
| value="ml-colabs/medrag-multi-modal", | |
| placeholder="wandb project name", | |
| help="format: wandb_username/wandb_project_name", | |
| ) | |
| chunk_dataset_id = st.sidebar.selectbox( | |
| label="Chunk Dataset ID", | |
| options=["ashwiniai/medrag-text-corpus-chunks"], | |
| ) | |
| llm_model = st.sidebar.selectbox( | |
| label="LLM Model", | |
| options=ALL_AVAILABLE_MODELS, | |
| ) | |
| top_k_chunks_for_query = st.sidebar.slider( | |
| label="Top K Chunks for Query", | |
| min_value=1, | |
| max_value=20, | |
| value=5, | |
| ) | |
| top_k_chunks_for_options = st.sidebar.slider( | |
| label="Top K Chunks for Options", | |
| min_value=1, | |
| max_value=20, | |
| value=3, | |
| ) | |
| rely_only_on_context = st.sidebar.checkbox( | |
| label="Rely Only on Context", | |
| value=False, | |
| ) | |
| retriever_type = st.sidebar.selectbox( | |
| label="Retriever Type", | |
| options=[ | |
| "", | |
| "BM25S", | |
| "Contriever", | |
| "MedCPT", | |
| "NV-Embed-v2", | |
| ], | |
| ) | |
| if retriever_type != "": | |
| llm_model = LLMClient(model_name=llm_model) | |
| retriever = None | |
| if retriever_type == "BM25S": | |
| retriever = BM25sRetriever.from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-bm25s" | |
| ) | |
| elif retriever_type == "Contriever": | |
| retriever = ContrieverRetriever.from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-contriever", | |
| chunk_dataset=chunk_dataset_id, | |
| ) | |
| elif retriever_type == "MedCPT": | |
| retriever = MedCPTRetriever.from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-medcpt", | |
| chunk_dataset=chunk_dataset_id, | |
| ) | |
| elif retriever_type == "NV-Embed-v2": | |
| retriever = NVEmbed2Retriever.from_index( | |
| index_repo_id="ashwiniai/medrag-text-corpus-chunks-nv-embed-2", | |
| chunk_dataset=chunk_dataset_id, | |
| ) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_model, | |
| retriever=retriever, | |
| top_k_chunks_for_query=top_k_chunks_for_query, | |
| top_k_chunks_for_options=top_k_chunks_for_options, | |
| ) | |
| with st.chat_message("assistant"): | |
| st.markdown( | |
| """ | |
| Hi! I am Medrag, your medical assistant. You can ask me any questions about the medical and the life sciences. | |
| I am currently a work-in-progress, so please bear with my stupidity and overall lack of knowledge. | |
| **Note:** that I am not a medical professional, so please do not rely on my answers for medical decisions. | |
| Please consult a medical professional for any medical advice. | |
| In order to learn more about how I am being developed, please visit [soumik12345/medrag-multi-modal](https://github.com/soumik12345/medrag-multi-modal). | |
| """, | |
| unsafe_allow_html=True, | |
| ) | |
| query = st.chat_input("Enter your question here") | |
| if query: | |
| with st.chat_message("user"): | |
| st.markdown(query) | |
| response = medqa_assistant.predict(query=query) | |
| with st.chat_message("assistant"): | |
| st.markdown(response.response) | |