File size: 6,233 Bytes
fcaa164 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ========= Copyright 2023-2024 @ CAMEL-AI.org. All Rights Reserved. =========
import base64
import io
import os
from typing import Any, Optional, Union
import requests
from PIL import Image
from camel.embeddings import BaseEmbedding
from camel.types.enums import EmbeddingModelType
from camel.utils import api_keys_required
class JinaEmbedding(BaseEmbedding[Union[str, Image.Image]]):
r"""Provides text and image embedding functionalities using Jina AI's API.
Args:
model_type (EmbeddingModelType, optional): The model to use for
embeddings. (default: :obj:`JINA_EMBEDDINGS_V3`)
api_key (Optional[str], optional): The API key for authenticating with
Jina AI. (default: :obj:`None`)
dimensions (Optional[int], optional): The dimension of the output
embeddings. (default: :obj:`None`)
task (Optional[str], optional): The type of task for text embeddings.
Options: retrieval.query, retrieval.passage, text-matching,
classification, separation. (default: :obj:`None`)
late_chunking (bool, optional): If true, concatenates all sentences in
input and treats as a single input. (default: :obj:`False`)
normalized (bool, optional): If true, embeddings are normalized to unit
L2 norm. (default: :obj:`False`)
"""
@api_keys_required([("api_key", 'JINA_API_KEY')])
def __init__(
self,
model_type: EmbeddingModelType = EmbeddingModelType.JINA_EMBEDDINGS_V3,
api_key: Optional[str] = None,
dimensions: Optional[int] = None,
embedding_type: Optional[str] = None,
task: Optional[str] = None,
late_chunking: bool = False,
normalized: bool = False,
) -> None:
if not model_type.is_jina:
raise ValueError(
f"Model type {model_type} is not a Jina model. "
"Please use a valid Jina model type."
)
self.model_type = model_type
if dimensions is None:
self.output_dim = model_type.output_dim
else:
self.output_dim = dimensions
self._api_key = api_key or os.environ.get("JINA_API_KEY")
self.embedding_type = embedding_type
self.task = task
self.late_chunking = late_chunking
self.normalized = normalized
self.url = 'https://api.jina.ai/v1/embeddings'
self.headers = {
'Content-Type': 'application/json',
'Accept': 'application/json',
'Authorization': f'Bearer {self._api_key}',
}
def embed_list(
self,
objs: list[Union[str, Image.Image]],
**kwargs: Any,
) -> list[list[float]]:
r"""Generates embeddings for the given texts or images.
Args:
objs (list[Union[str, Image.Image]]): The texts or images for which
to generate the embeddings.
**kwargs (Any): Extra kwargs passed to the embedding API. Not used
in this implementation.
Returns:
list[list[float]]: A list that represents the generated embedding
as a list of floating-point numbers.
Raises:
ValueError: If the input type is not supported.
RuntimeError: If the API request fails.
"""
input_data = []
for obj in objs:
if isinstance(obj, str):
if self.model_type == EmbeddingModelType.JINA_CLIP_V2:
input_data.append({"text": obj})
else:
input_data.append(obj) # type: ignore[arg-type]
elif isinstance(obj, Image.Image):
if self.model_type != EmbeddingModelType.JINA_CLIP_V2:
raise ValueError(
f"Model {self.model_type} does not support "
"image input. Use JINA_CLIP_V2 for image embeddings."
)
# Convert PIL Image to base64 string
buffered = io.BytesIO()
obj.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode()
input_data.append({"image": img_str})
else:
raise ValueError(
f"Input type {type(obj)} is not supported. "
"Must be either str or PIL.Image"
)
data = {
"model": self.model_type.value,
"input": input_data,
"embedding_type": "float",
}
if self.embedding_type is not None:
data["embedding_type"] = self.embedding_type
if self.task is not None:
data["task"] = self.task
if self.late_chunking:
data["late_chunking"] = self.late_chunking # type: ignore[assignment]
if self.normalized:
data["normalized"] = self.normalized # type: ignore[assignment]
try:
response = requests.post(
self.url, headers=self.headers, json=data, timeout=180
)
response.raise_for_status()
result = response.json()
return [data["embedding"] for data in result["data"]]
except requests.exceptions.RequestException as e:
raise RuntimeError(f"Failed to get embeddings from Jina AI: {e}")
def get_output_dim(self) -> int:
r"""Returns the output dimension of the embeddings.
Returns:
int: The dimensionality of the embedding for the current model.
"""
return self.output_dim
|