Spaces:
Runtime error
Runtime error
| import weave | |
| from ..retrieval import SimilarityMetric | |
| from .figure_annotation import FigureAnnotatorFromPageImage | |
| from .llm_client import LLMClient | |
| class MedQAAssistant(weave.Model): | |
| """ | |
| `MedQAAssistant` is a class designed to assist with medical queries by leveraging a | |
| language model client, a retriever model, and a figure annotator. | |
| !!! example "Usage Example" | |
| ```python | |
| import weave | |
| from dotenv import load_dotenv | |
| from medrag_multi_modal.assistant import ( | |
| FigureAnnotatorFromPageImage, | |
| LLMClient, | |
| MedQAAssistant, | |
| ) | |
| from medrag_multi_modal.retrieval import MedCPTRetriever | |
| load_dotenv() | |
| weave.init(project_name="ml-colabs/medrag-multi-modal") | |
| llm_client = LLMClient(model_name="gemini-1.5-flash") | |
| retriever=MedCPTRetriever.from_wandb_artifact( | |
| chunk_dataset_name="grays-anatomy-chunks:v0", | |
| index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0", | |
| ) | |
| figure_annotator=FigureAnnotatorFromPageImage( | |
| figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"), | |
| structured_output_llm_client=LLMClient(model_name="gpt-4o"), | |
| image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6", | |
| ) | |
| medqa_assistant = MedQAAssistant( | |
| llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator | |
| ) | |
| medqa_assistant.predict(query="What is ribosome?") | |
| ``` | |
| Args: | |
| llm_client (LLMClient): The language model client used to generate responses. | |
| retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document. | |
| figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages. | |
| top_k_chunks (int): The number of top chunks to retrieve based on similarity metric. | |
| retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval. | |
| """ | |
| llm_client: LLMClient | |
| retriever: weave.Model | |
| figure_annotator: FigureAnnotatorFromPageImage | |
| top_k_chunks: int = 2 | |
| retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE | |
| def predict(self, query: str) -> str: | |
| """ | |
| Generates a response to a medical query by retrieving relevant text chunks and figure descriptions | |
| from a medical document and using a language model to generate the final response. | |
| This function performs the following steps: | |
| 1. Retrieves relevant text chunks from the medical document based on the query using the retriever model. | |
| 2. Extracts the text and page indices from the retrieved chunks. | |
| 3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator. | |
| 4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions. | |
| 5. Uses the language model client to generate a response based on the constructed prompts. | |
| 6. Appends the source information (page numbers) to the generated response. | |
| Args: | |
| query (str): The medical query to be answered. | |
| Returns: | |
| str: The generated response to the query, including source information. | |
| """ | |
| retrieved_chunks = self.retriever.predict( | |
| query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric | |
| ) | |
| retrieved_chunk_texts = [] | |
| page_indices = set() | |
| for chunk in retrieved_chunks: | |
| retrieved_chunk_texts.append(chunk["text"]) | |
| page_indices.add(int(chunk["page_idx"])) | |
| figure_descriptions = [] | |
| for page_idx in page_indices: | |
| figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[ | |
| page_idx | |
| ] | |
| figure_descriptions += [ | |
| item["figure_description"] for item in figure_annotations | |
| ] | |
| system_prompt = """ | |
| You are an expert in medical science. You are given a query and a list of chunks from a medical document. | |
| """ | |
| response = self.llm_client.predict( | |
| system_prompt=system_prompt, | |
| user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions], | |
| ) | |
| page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices]) | |
| response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy" | |
| return response | |