Spaces:
Runtime error
Runtime error
Commit
·
ceaeef3
1
Parent(s):
7934a8e
update: FigureAnnotatorFromPageImage
Browse files
medrag_multi_modal/assistant/__init__.py
CHANGED
|
@@ -1,5 +1,5 @@
|
|
| 1 |
-
from .figure_annotation import
|
| 2 |
from .llm_client import ClientType, LLMClient
|
| 3 |
from .medqa_assistant import MedQAAssistant
|
| 4 |
|
| 5 |
-
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "
|
|
|
|
| 1 |
+
from .figure_annotation import FigureAnnotatorFromPageImage
|
| 2 |
from .llm_client import ClientType, LLMClient
|
| 3 |
from .medqa_assistant import MedQAAssistant
|
| 4 |
|
| 5 |
+
__all__ = ["LLMClient", "ClientType", "MedQAAssistant", "FigureAnnotatorFromPageImage"]
|
medrag_multi_modal/assistant/figure_annotation.py
CHANGED
|
@@ -10,7 +10,7 @@ from ..utils import get_wandb_artifact, read_jsonl_file
|
|
| 10 |
from .llm_client import LLMClient
|
| 11 |
|
| 12 |
|
| 13 |
-
class
|
| 14 |
llm_client: LLMClient
|
| 15 |
|
| 16 |
@weave.op()
|
|
@@ -24,6 +24,7 @@ You are presented with a page from a scientific textbook.
|
|
| 24 |
You are to first identify the number of figures in the image.
|
| 25 |
Then you are to identify the figure IDs associated with each figure in the image.
|
| 26 |
Then, you are to extract the exact figure descriptions from the image.
|
|
|
|
| 27 |
|
| 28 |
Here are some clues you need to follow:
|
| 29 |
1. Figure IDs are unique identifiers for each figure in the image.
|
|
@@ -33,6 +34,8 @@ Here are some clues you need to follow:
|
|
| 33 |
5. The text in the image is written in English and is present in a two-column format.
|
| 34 |
6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
|
| 35 |
7. There might be multiple figures present in the image.
|
|
|
|
|
|
|
| 36 |
""",
|
| 37 |
user_prompt=[page_image],
|
| 38 |
)
|
|
|
|
| 10 |
from .llm_client import LLMClient
|
| 11 |
|
| 12 |
|
| 13 |
+
class FigureAnnotatorFromPageImage(weave.Model):
|
| 14 |
llm_client: LLMClient
|
| 15 |
|
| 16 |
@weave.op()
|
|
|
|
| 24 |
You are to first identify the number of figures in the image.
|
| 25 |
Then you are to identify the figure IDs associated with each figure in the image.
|
| 26 |
Then, you are to extract the exact figure descriptions from the image.
|
| 27 |
+
You need to output the figure IDs and descriptions in a structured manner as a JSON object.
|
| 28 |
|
| 29 |
Here are some clues you need to follow:
|
| 30 |
1. Figure IDs are unique identifiers for each figure in the image.
|
|
|
|
| 34 |
5. The text in the image is written in English and is present in a two-column format.
|
| 35 |
6. There is a clear distinction between the figure caption and the regular text in the image in the form of extra white space.
|
| 36 |
7. There might be multiple figures present in the image.
|
| 37 |
+
8. The figures may or may not have a distinct border against a white background.
|
| 38 |
+
9. There might be multiple figures present in the image. You are to carefully identify all the figures in the image.
|
| 39 |
""",
|
| 40 |
user_prompt=[page_image],
|
| 41 |
)
|
medrag_multi_modal/assistant/llm_client.py
CHANGED
|
@@ -14,11 +14,59 @@ class ClientType(str, Enum):
|
|
| 14 |
MISTRAL = "mistral"
|
| 15 |
|
| 16 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
class LLMClient(weave.Model):
|
| 18 |
model_name: str
|
| 19 |
-
client_type: ClientType
|
| 20 |
|
| 21 |
-
def __init__(self, model_name: str, client_type: ClientType):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 22 |
super().__init__(model_name=model_name, client_type=client_type)
|
| 23 |
|
| 24 |
@weave.op()
|
|
|
|
| 14 |
MISTRAL = "mistral"
|
| 15 |
|
| 16 |
|
| 17 |
+
GOOGLE_MODELS = [
|
| 18 |
+
"gemini-1.0-pro-latest",
|
| 19 |
+
"gemini-1.0-pro",
|
| 20 |
+
"gemini-pro",
|
| 21 |
+
"gemini-1.0-pro-001",
|
| 22 |
+
"gemini-1.0-pro-vision-latest",
|
| 23 |
+
"gemini-pro-vision",
|
| 24 |
+
"gemini-1.5-pro-latest",
|
| 25 |
+
"gemini-1.5-pro-001",
|
| 26 |
+
"gemini-1.5-pro-002",
|
| 27 |
+
"gemini-1.5-pro",
|
| 28 |
+
"gemini-1.5-pro-exp-0801",
|
| 29 |
+
"gemini-1.5-pro-exp-0827",
|
| 30 |
+
"gemini-1.5-flash-latest",
|
| 31 |
+
"gemini-1.5-flash-001",
|
| 32 |
+
"gemini-1.5-flash-001-tuning",
|
| 33 |
+
"gemini-1.5-flash",
|
| 34 |
+
"gemini-1.5-flash-exp-0827",
|
| 35 |
+
"gemini-1.5-flash-002",
|
| 36 |
+
"gemini-1.5-flash-8b",
|
| 37 |
+
"gemini-1.5-flash-8b-001",
|
| 38 |
+
"gemini-1.5-flash-8b-latest",
|
| 39 |
+
"gemini-1.5-flash-8b-exp-0827",
|
| 40 |
+
"gemini-1.5-flash-8b-exp-0924",
|
| 41 |
+
]
|
| 42 |
+
|
| 43 |
+
MISTRAL_MODELS = [
|
| 44 |
+
"ministral-3b-latest",
|
| 45 |
+
"ministral-8b-latest",
|
| 46 |
+
"mistral-large-latest",
|
| 47 |
+
"mistral-small-latest",
|
| 48 |
+
"codestral-latest",
|
| 49 |
+
"pixtral-12b-2409",
|
| 50 |
+
"open-mistral-nemo",
|
| 51 |
+
"open-codestral-mamba",
|
| 52 |
+
"open-mistral-7b",
|
| 53 |
+
"open-mixtral-8x7b",
|
| 54 |
+
"open-mixtral-8x22b",
|
| 55 |
+
]
|
| 56 |
+
|
| 57 |
+
|
| 58 |
class LLMClient(weave.Model):
|
| 59 |
model_name: str
|
| 60 |
+
client_type: Optional[ClientType]
|
| 61 |
|
| 62 |
+
def __init__(self, model_name: str, client_type: Optional[ClientType] = None):
|
| 63 |
+
if client_type is None:
|
| 64 |
+
if model_name in GOOGLE_MODELS:
|
| 65 |
+
client_type = ClientType.GEMINI
|
| 66 |
+
elif model_name in MISTRAL_MODELS:
|
| 67 |
+
client_type = ClientType.MISTRAL
|
| 68 |
+
else:
|
| 69 |
+
raise ValueError(f"Invalid model name: {model_name}")
|
| 70 |
super().__init__(model_name=model_name, client_type=client_type)
|
| 71 |
|
| 72 |
@weave.op()
|