Spaces:
Runtime error
Runtime error
Commit
·
6c6905f
1
Parent(s):
e197ad0
update: MedQAAssistant + FigureAnnotatorFromPageImage
Browse files
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
|
@@ -1,12 +1,11 @@
|
|
| 1 |
import os
|
| 2 |
from glob import glob
|
| 3 |
-
from typing import Union
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
import weave
|
| 7 |
from PIL import Image
|
| 8 |
from pydantic import BaseModel
|
| 9 |
-
from rich.progress import track
|
| 10 |
|
| 11 |
from ..utils import get_wandb_artifact, read_jsonl_file
|
| 12 |
from .llm_client import LLMClient
|
|
@@ -23,7 +22,8 @@ class FigureAnnotations(BaseModel):
|
|
| 23 |
|
| 24 |
class FigureAnnotatorFromPageImage(weave.Model):
|
| 25 |
"""
|
| 26 |
-
`FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
|
|
|
|
| 27 |
|
| 28 |
!!! example "Example Usage"
|
| 29 |
```python
|
|
@@ -39,19 +39,35 @@ class FigureAnnotatorFromPageImage(weave.Model):
|
|
| 39 |
figure_annotator = FigureAnnotatorFromPageImage(
|
| 40 |
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
| 41 |
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
|
|
|
| 42 |
)
|
| 43 |
-
annotations = figure_annotator.predict(
|
| 44 |
-
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6"
|
| 45 |
-
)
|
| 46 |
```
|
| 47 |
|
| 48 |
-
|
| 49 |
-
figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
|
| 50 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 51 |
"""
|
| 52 |
|
| 53 |
figure_extraction_llm_client: LLMClient
|
| 54 |
structured_output_llm_client: LLMClient
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 55 |
|
| 56 |
@weave.op()
|
| 57 |
def annotate_figures(
|
|
@@ -92,7 +108,7 @@ Here are some clues you need to follow:
|
|
| 92 |
)
|
| 93 |
|
| 94 |
@weave.op()
|
| 95 |
-
def predict(self, page_idx: int,
|
| 96 |
"""
|
| 97 |
Predicts figure annotations for a specific page in a document.
|
| 98 |
|
|
@@ -105,22 +121,23 @@ Here are some clues you need to follow:
|
|
| 105 |
|
| 106 |
Args:
|
| 107 |
page_idx (int): The index of the page to annotate.
|
| 108 |
-
image_artifact_address (str): The address of the image artifact containing the
|
|
|
|
| 109 |
|
| 110 |
Returns:
|
| 111 |
-
dict: A dictionary containing the page index as the key and the extracted figure
|
| 112 |
-
|
| 113 |
"""
|
| 114 |
-
|
| 115 |
-
metadata = read_jsonl_file(os.path.join(
|
| 116 |
annotations = {}
|
| 117 |
-
for item in
|
| 118 |
if item["page_idx"] == page_idx:
|
| 119 |
page_image_file = os.path.join(
|
| 120 |
-
|
| 121 |
)
|
| 122 |
figure_image_files = glob(
|
| 123 |
-
os.path.join(
|
| 124 |
)
|
| 125 |
if len(figure_image_files) > 0:
|
| 126 |
page_image = cv2.imread(page_image_file)
|
|
|
|
| 1 |
import os
|
| 2 |
from glob import glob
|
| 3 |
+
from typing import Optional, Union
|
| 4 |
|
| 5 |
import cv2
|
| 6 |
import weave
|
| 7 |
from PIL import Image
|
| 8 |
from pydantic import BaseModel
|
|
|
|
| 9 |
|
| 10 |
from ..utils import get_wandb_artifact, read_jsonl_file
|
| 11 |
from .llm_client import LLMClient
|
|
|
|
| 22 |
|
| 23 |
class FigureAnnotatorFromPageImage(weave.Model):
|
| 24 |
"""
|
| 25 |
+
`FigureAnnotatorFromPageImage` is a class that leverages two LLM clients to annotate
|
| 26 |
+
figures from a page image of a scientific textbook.
|
| 27 |
|
| 28 |
!!! example "Example Usage"
|
| 29 |
```python
|
|
|
|
| 39 |
figure_annotator = FigureAnnotatorFromPageImage(
|
| 40 |
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
| 41 |
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
| 42 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
| 43 |
)
|
| 44 |
+
annotations = figure_annotator.predict(page_idx=34)
|
|
|
|
|
|
|
| 45 |
```
|
| 46 |
|
| 47 |
+
Args:
|
| 48 |
+
figure_extraction_llm_client (LLMClient): An LLM client used to extract figure annotations
|
| 49 |
+
from the page image.
|
| 50 |
+
structured_output_llm_client (LLMClient): An LLM client used to convert the extracted
|
| 51 |
+
annotations into a structured format.
|
| 52 |
+
image_artifact_address (Optional[str]): The address of the image artifact containing the
|
| 53 |
+
page images.
|
| 54 |
"""
|
| 55 |
|
| 56 |
figure_extraction_llm_client: LLMClient
|
| 57 |
structured_output_llm_client: LLMClient
|
| 58 |
+
_artifact_dir: str
|
| 59 |
+
|
| 60 |
+
def __init__(
|
| 61 |
+
self,
|
| 62 |
+
figure_extraction_llm_client: LLMClient,
|
| 63 |
+
structured_output_llm_client: LLMClient,
|
| 64 |
+
image_artifact_address: Optional[str] = None,
|
| 65 |
+
):
|
| 66 |
+
super().__init__(
|
| 67 |
+
figure_extraction_llm_client=figure_extraction_llm_client,
|
| 68 |
+
structured_output_llm_client=structured_output_llm_client,
|
| 69 |
+
)
|
| 70 |
+
self._artifact_dir = get_wandb_artifact(image_artifact_address, "dataset")
|
| 71 |
|
| 72 |
@weave.op()
|
| 73 |
def annotate_figures(
|
|
|
|
| 108 |
)
|
| 109 |
|
| 110 |
@weave.op()
|
| 111 |
+
def predict(self, page_idx: int) -> dict[int, list[FigureAnnotation]]:
|
| 112 |
"""
|
| 113 |
Predicts figure annotations for a specific page in a document.
|
| 114 |
|
|
|
|
| 121 |
|
| 122 |
Args:
|
| 123 |
page_idx (int): The index of the page to annotate.
|
| 124 |
+
image_artifact_address (str): The address of the image artifact containing the
|
| 125 |
+
page images.
|
| 126 |
|
| 127 |
Returns:
|
| 128 |
+
dict: A dictionary containing the page index as the key and the extracted figure
|
| 129 |
+
annotations as the value.
|
| 130 |
"""
|
| 131 |
+
|
| 132 |
+
metadata = read_jsonl_file(os.path.join(self._artifact_dir, "metadata.jsonl"))
|
| 133 |
annotations = {}
|
| 134 |
+
for item in metadata:
|
| 135 |
if item["page_idx"] == page_idx:
|
| 136 |
page_image_file = os.path.join(
|
| 137 |
+
self._artifact_dir, f"page{item['page_idx']}.png"
|
| 138 |
)
|
| 139 |
figure_image_files = glob(
|
| 140 |
+
os.path.join(self._artifact_dir, f"page{item['page_idx']}_fig*.png")
|
| 141 |
)
|
| 142 |
if len(figure_image_files) > 0:
|
| 143 |
page_image = cv2.imread(page_image_file)
|
medrag_multi_modal/assistant/medqa_assistant.py
CHANGED
|
@@ -1,5 +1,3 @@
|
|
| 1 |
-
from typing import Optional
|
| 2 |
-
|
| 3 |
import weave
|
| 4 |
|
| 5 |
from ..retrieval import SimilarityMetric
|
|
@@ -8,7 +6,50 @@ from .llm_client import LLMClient
|
|
| 8 |
|
| 9 |
|
| 10 |
class MedQAAssistant(weave.Model):
|
| 11 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
llm_client: LLMClient
|
| 14 |
retriever: weave.Model
|
|
@@ -17,7 +58,25 @@ class MedQAAssistant(weave.Model):
|
|
| 17 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
| 18 |
|
| 19 |
@weave.op()
|
| 20 |
-
def predict(self, query: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
retrieved_chunks = self.retriever.predict(
|
| 22 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
| 23 |
)
|
|
@@ -29,14 +88,13 @@ class MedQAAssistant(weave.Model):
|
|
| 29 |
page_indices.add(int(chunk["page_idx"]))
|
| 30 |
|
| 31 |
figure_descriptions = []
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
|
| 39 |
-
]
|
| 40 |
|
| 41 |
system_prompt = """
|
| 42 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
|
@@ -46,5 +104,5 @@ class MedQAAssistant(weave.Model):
|
|
| 46 |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
| 47 |
)
|
| 48 |
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
| 49 |
-
response += f"\n\n**Source:** {'Pages' if len(
|
| 50 |
return response
|
|
|
|
|
|
|
|
|
|
| 1 |
import weave
|
| 2 |
|
| 3 |
from ..retrieval import SimilarityMetric
|
|
|
|
| 6 |
|
| 7 |
|
| 8 |
class MedQAAssistant(weave.Model):
|
| 9 |
+
"""
|
| 10 |
+
`MedQAAssistant` is a class designed to assist with medical queries by leveraging a
|
| 11 |
+
language model client, a retriever model, and a figure annotator.
|
| 12 |
+
|
| 13 |
+
!!! example "Usage Example"
|
| 14 |
+
```python
|
| 15 |
+
import weave
|
| 16 |
+
from dotenv import load_dotenv
|
| 17 |
+
|
| 18 |
+
from medrag_multi_modal.assistant import (
|
| 19 |
+
FigureAnnotatorFromPageImage,
|
| 20 |
+
LLMClient,
|
| 21 |
+
MedQAAssistant,
|
| 22 |
+
)
|
| 23 |
+
from medrag_multi_modal.retrieval import MedCPTRetriever
|
| 24 |
+
|
| 25 |
+
load_dotenv()
|
| 26 |
+
weave.init(project_name="ml-colabs/medrag-multi-modal")
|
| 27 |
+
|
| 28 |
+
llm_client = LLMClient(model_name="gemini-1.5-flash")
|
| 29 |
+
|
| 30 |
+
retriever=MedCPTRetriever.from_wandb_artifact(
|
| 31 |
+
chunk_dataset_name="grays-anatomy-chunks:v0",
|
| 32 |
+
index_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-medcpt:v0",
|
| 33 |
+
)
|
| 34 |
+
|
| 35 |
+
figure_annotator=FigureAnnotatorFromPageImage(
|
| 36 |
+
figure_extraction_llm_client=LLMClient(model_name="pixtral-12b-2409"),
|
| 37 |
+
structured_output_llm_client=LLMClient(model_name="gpt-4o"),
|
| 38 |
+
image_artifact_address="ml-colabs/medrag-multi-modal/grays-anatomy-images-marker:v6",
|
| 39 |
+
)
|
| 40 |
+
medqa_assistant = MedQAAssistant(
|
| 41 |
+
llm_client=llm_client, retriever=retriever, figure_annotator=figure_annotator
|
| 42 |
+
)
|
| 43 |
+
medqa_assistant.predict(query="What is ribosome?")
|
| 44 |
+
```
|
| 45 |
+
|
| 46 |
+
Args:
|
| 47 |
+
llm_client (LLMClient): The language model client used to generate responses.
|
| 48 |
+
retriever (weave.Model): The model used to retrieve relevant chunks of text from a medical document.
|
| 49 |
+
figure_annotator (FigureAnnotatorFromPageImage): The annotator used to extract figure descriptions from pages.
|
| 50 |
+
top_k_chunks (int): The number of top chunks to retrieve based on similarity metric.
|
| 51 |
+
retrieval_similarity_metric (SimilarityMetric): The metric used to measure similarity for retrieval.
|
| 52 |
+
"""
|
| 53 |
|
| 54 |
llm_client: LLMClient
|
| 55 |
retriever: weave.Model
|
|
|
|
| 58 |
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
| 59 |
|
| 60 |
@weave.op()
|
| 61 |
+
def predict(self, query: str) -> str:
|
| 62 |
+
"""
|
| 63 |
+
Generates a response to a medical query by retrieving relevant text chunks and figure descriptions
|
| 64 |
+
from a medical document and using a language model to generate the final response.
|
| 65 |
+
|
| 66 |
+
This function performs the following steps:
|
| 67 |
+
1. Retrieves relevant text chunks from the medical document based on the query using the retriever model.
|
| 68 |
+
2. Extracts the text and page indices from the retrieved chunks.
|
| 69 |
+
3. Retrieves figure descriptions from the pages identified in the previous step using the figure annotator.
|
| 70 |
+
4. Constructs a system prompt and user prompt combining the query, retrieved text chunks, and figure descriptions.
|
| 71 |
+
5. Uses the language model client to generate a response based on the constructed prompts.
|
| 72 |
+
6. Appends the source information (page numbers) to the generated response.
|
| 73 |
+
|
| 74 |
+
Args:
|
| 75 |
+
query (str): The medical query to be answered.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
str: The generated response to the query, including source information.
|
| 79 |
+
"""
|
| 80 |
retrieved_chunks = self.retriever.predict(
|
| 81 |
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
| 82 |
)
|
|
|
|
| 88 |
page_indices.add(int(chunk["page_idx"]))
|
| 89 |
|
| 90 |
figure_descriptions = []
|
| 91 |
+
for page_idx in page_indices:
|
| 92 |
+
figure_annotations = self.figure_annotator.predict(page_idx=page_idx)[
|
| 93 |
+
page_idx
|
| 94 |
+
]
|
| 95 |
+
figure_descriptions += [
|
| 96 |
+
item["figure_description"] for item in figure_annotations
|
| 97 |
+
]
|
|
|
|
| 98 |
|
| 99 |
system_prompt = """
|
| 100 |
You are an expert in medical science. You are given a query and a list of chunks from a medical document.
|
|
|
|
| 104 |
user_prompt=[query, *retrieved_chunk_texts, *figure_descriptions],
|
| 105 |
)
|
| 106 |
page_numbers = ", ".join([str(int(page_idx) + 1) for page_idx in page_indices])
|
| 107 |
+
response += f"\n\n**Source:** {'Pages' if len(page_indices) > 1 else 'Page'} {page_numbers} from Gray's Anatomy"
|
| 108 |
return response
|