Spaces:
Runtime error
Runtime error
Commit
·
49cde8e
1
Parent(s):
01ed12d
add: MedQAAssistant
Browse files
medrag_multi_modal/assistant/__init__.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
| 1 |
from .llm_client import LLMClient
|
|
|
|
| 2 |
|
| 3 |
-
__all__ = ["LLMClient"]
|
|
|
|
| 1 |
from .llm_client import LLMClient
|
| 2 |
+
from .medqa_assistant import MedQAAssistant
|
| 3 |
|
| 4 |
+
__all__ = ["LLMClient", "MedQAAssistant"]
|
medrag_multi_modal/assistant/llm_client.py
CHANGED
|
@@ -29,7 +29,7 @@ class LLMClient(weave.Model):
|
|
| 29 |
schema: Optional[Any] = None,
|
| 30 |
) -> Union[str, Any]:
|
| 31 |
import google.generativeai as genai
|
| 32 |
-
|
| 33 |
system_prompt = (
|
| 34 |
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
| 35 |
)
|
|
|
|
| 29 |
schema: Optional[Any] = None,
|
| 30 |
) -> Union[str, Any]:
|
| 31 |
import google.generativeai as genai
|
| 32 |
+
|
| 33 |
system_prompt = (
|
| 34 |
[system_prompt] if isinstance(system_prompt, str) else system_prompt
|
| 35 |
)
|
medrag_multi_modal/assistant/medqa_assistant.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Optional
|
| 2 |
+
|
| 3 |
+
import weave
|
| 4 |
+
from PIL import Image
|
| 5 |
+
|
| 6 |
+
from ..retrieval import SimilarityMetric
|
| 7 |
+
from .llm_client import LLMClient
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class MedQAAssistant(weave.Model):
|
| 11 |
+
llm_client: LLMClient
|
| 12 |
+
retriever: weave.Model
|
| 13 |
+
top_k_chunks: int = 2
|
| 14 |
+
retrieval_similarity_metric: SimilarityMetric = SimilarityMetric.COSINE
|
| 15 |
+
|
| 16 |
+
@weave.op()
|
| 17 |
+
def predict(self, query: str, image: Optional[Image.Image] = None) -> str:
|
| 18 |
+
_image = image
|
| 19 |
+
retrieved_chunks = self.retriever.predict(
|
| 20 |
+
query, top_k=self.top_k_chunks, metric=self.retrieval_similarity_metric
|
| 21 |
+
)
|
| 22 |
+
retrieved_chunks = [chunk["text"] for chunk in retrieved_chunks]
|
| 23 |
+
system_prompt = """
|
| 24 |
+
You are a medical expert. You are given a query and a list of chunks from a medical document.
|
| 25 |
+
"""
|
| 26 |
+
return self.llm_client.predict(
|
| 27 |
+
system_prompt=system_prompt, user_prompt=retrieved_chunks
|
| 28 |
+
)
|