Spaces:
Runtime error
Runtime error
Commit
·
70d9de4
1
Parent(s):
96bca50
add: utility for torch backend
Browse files
medrag_multi_modal/retrieval/contriever_retrieval.py
CHANGED
|
@@ -15,7 +15,7 @@ from transformers import (
|
|
| 15 |
|
| 16 |
import wandb
|
| 17 |
|
| 18 |
-
from ..utils import get_wandb_artifact
|
| 19 |
from .common import SimilarityMetric, argsort_scores, mean_pooling
|
| 20 |
|
| 21 |
|
|
@@ -150,7 +150,7 @@ class ContrieverRetriever(weave.Model):
|
|
| 150 |
os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
|
| 151 |
) as f:
|
| 152 |
vector_index = f.get_tensor("vector_index")
|
| 153 |
-
device = torch.device(
|
| 154 |
vector_index = vector_index.to(device)
|
| 155 |
chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
|
| 156 |
return cls(
|
|
@@ -199,7 +199,7 @@ class ContrieverRetriever(weave.Model):
|
|
| 199 |
list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
|
| 200 |
"""
|
| 201 |
query = [query]
|
| 202 |
-
device = torch.device(
|
| 203 |
with torch.no_grad():
|
| 204 |
query_embedding = self.encode(query).to(device)
|
| 205 |
if metric == SimilarityMetric.EUCLIDEAN:
|
|
|
|
| 15 |
|
| 16 |
import wandb
|
| 17 |
|
| 18 |
+
from ..utils import get_wandb_artifact, get_torch_backend
|
| 19 |
from .common import SimilarityMetric, argsort_scores, mean_pooling
|
| 20 |
|
| 21 |
|
|
|
|
| 150 |
os.path.join(artifact_dir, "vector_index.safetensors"), framework="pt"
|
| 151 |
) as f:
|
| 152 |
vector_index = f.get_tensor("vector_index")
|
| 153 |
+
device = torch.device(get_torch_backend())
|
| 154 |
vector_index = vector_index.to(device)
|
| 155 |
chunk_dataset = [dict(row) for row in weave.ref(chunk_dataset_name).get().rows]
|
| 156 |
return cls(
|
|
|
|
| 199 |
list: A list of dictionaries, each containing a retrieved chunk and its relevance score.
|
| 200 |
"""
|
| 201 |
query = [query]
|
| 202 |
+
device = torch.device(get_torch_backend())
|
| 203 |
with torch.no_grad():
|
| 204 |
query_embedding = self.encode(query).to(device)
|
| 205 |
if metric == SimilarityMetric.EUCLIDEAN:
|
medrag_multi_modal/utils.py
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
import wandb
|
| 2 |
|
| 3 |
|
|
@@ -14,3 +15,13 @@ def get_wandb_artifact(
|
|
| 14 |
if get_metadata:
|
| 15 |
return artifact_dir, artifact.metadata
|
| 16 |
return artifact_dir
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
import wandb
|
| 3 |
|
| 4 |
|
|
|
|
| 15 |
if get_metadata:
|
| 16 |
return artifact_dir, artifact.metadata
|
| 17 |
return artifact_dir
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def get_torch_backend():
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
return "cuda"
|
| 23 |
+
if torch.backends.mps.is_available():
|
| 24 |
+
if torch.backends.mps.is_built():
|
| 25 |
+
return "mps"
|
| 26 |
+
return "cpu"
|
| 27 |
+
return "cpu"
|