Spaces:
Paused
Paused
| from typing import Optional, Union, Any | |
| import types, time, json | |
| import httpx | |
| from .base import BaseLLM | |
| from litellm.utils import ( | |
| ModelResponse, | |
| Choices, | |
| Message, | |
| CustomStreamWrapper, | |
| convert_to_model_response_object, | |
| Usage, | |
| ) | |
| from typing import Callable, Optional | |
| import aiohttp, requests | |
| import litellm | |
| from .prompt_templates.factory import prompt_factory, custom_prompt | |
| from openai import OpenAI, AsyncOpenAI | |
| class OpenAIError(Exception): | |
| def __init__( | |
| self, | |
| status_code, | |
| message, | |
| request: Optional[httpx.Request] = None, | |
| response: Optional[httpx.Response] = None, | |
| ): | |
| self.status_code = status_code | |
| self.message = message | |
| if request: | |
| self.request = request | |
| else: | |
| self.request = httpx.Request(method="POST", url="https://api.openai.com/v1") | |
| if response: | |
| self.response = response | |
| else: | |
| self.response = httpx.Response( | |
| status_code=status_code, request=self.request | |
| ) | |
| super().__init__( | |
| self.message | |
| ) # Call the base class constructor with the parameters it needs | |
| class OpenAIConfig: | |
| """ | |
| Reference: https://platform.openai.com/docs/api-reference/chat/create | |
| The class `OpenAIConfig` provides configuration for the OpenAI's Chat API interface. Below are the parameters: | |
| - `frequency_penalty` (number or null): Defaults to 0. Allows a value between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, thereby minimizing repetition. | |
| - `function_call` (string or object): This optional parameter controls how the model calls functions. | |
| - `functions` (array): An optional parameter. It is a list of functions for which the model may generate JSON inputs. | |
| - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. | |
| - `max_tokens` (integer or null): This optional parameter helps to set the maximum number of tokens to generate in the chat completion. | |
| - `n` (integer or null): This optional parameter helps to set how many chat completion choices to generate for each input message. | |
| - `presence_penalty` (number or null): Defaults to 0. It penalizes new tokens based on if they appear in the text so far, hence increasing the model's likelihood to talk about new topics. | |
| - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. | |
| - `temperature` (number or null): Defines the sampling temperature to use, varying between 0 and 2. | |
| - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. | |
| """ | |
| frequency_penalty: Optional[int] = None | |
| function_call: Optional[Union[str, dict]] = None | |
| functions: Optional[list] = None | |
| logit_bias: Optional[dict] = None | |
| max_tokens: Optional[int] = None | |
| n: Optional[int] = None | |
| presence_penalty: Optional[int] = None | |
| stop: Optional[Union[str, list]] = None | |
| temperature: Optional[int] = None | |
| top_p: Optional[int] = None | |
| def __init__( | |
| self, | |
| frequency_penalty: Optional[int] = None, | |
| function_call: Optional[Union[str, dict]] = None, | |
| functions: Optional[list] = None, | |
| logit_bias: Optional[dict] = None, | |
| max_tokens: Optional[int] = None, | |
| n: Optional[int] = None, | |
| presence_penalty: Optional[int] = None, | |
| stop: Optional[Union[str, list]] = None, | |
| temperature: Optional[int] = None, | |
| top_p: Optional[int] = None, | |
| ) -> None: | |
| locals_ = locals() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return { | |
| k: v | |
| for k, v in cls.__dict__.items() | |
| if not k.startswith("__") | |
| and not isinstance( | |
| v, | |
| ( | |
| types.FunctionType, | |
| types.BuiltinFunctionType, | |
| classmethod, | |
| staticmethod, | |
| ), | |
| ) | |
| and v is not None | |
| } | |
| class OpenAITextCompletionConfig: | |
| """ | |
| Reference: https://platform.openai.com/docs/api-reference/completions/create | |
| The class `OpenAITextCompletionConfig` provides configuration for the OpenAI's text completion API interface. Below are the parameters: | |
| - `best_of` (integer or null): This optional parameter generates server-side completions and returns the one with the highest log probability per token. | |
| - `echo` (boolean or null): This optional parameter will echo back the prompt in addition to the completion. | |
| - `frequency_penalty` (number or null): Defaults to 0. It is a numbers from -2.0 to 2.0, where positive values decrease the model's likelihood to repeat the same line. | |
| - `logit_bias` (map): This optional parameter modifies the likelihood of specified tokens appearing in the completion. | |
| - `logprobs` (integer or null): This optional parameter includes the log probabilities on the most likely tokens as well as the chosen tokens. | |
| - `max_tokens` (integer or null): This optional parameter sets the maximum number of tokens to generate in the completion. | |
| - `n` (integer or null): This optional parameter sets how many completions to generate for each prompt. | |
| - `presence_penalty` (number or null): Defaults to 0 and can be between -2.0 and 2.0. Positive values increase the model's likelihood to talk about new topics. | |
| - `stop` (string / array / null): Specifies up to 4 sequences where the API will stop generating further tokens. | |
| - `suffix` (string or null): Defines the suffix that comes after a completion of inserted text. | |
| - `temperature` (number or null): This optional parameter defines the sampling temperature to use. | |
| - `top_p` (number or null): An alternative to sampling with temperature, used for nucleus sampling. | |
| """ | |
| best_of: Optional[int] = None | |
| echo: Optional[bool] = None | |
| frequency_penalty: Optional[int] = None | |
| logit_bias: Optional[dict] = None | |
| logprobs: Optional[int] = None | |
| max_tokens: Optional[int] = None | |
| n: Optional[int] = None | |
| presence_penalty: Optional[int] = None | |
| stop: Optional[Union[str, list]] = None | |
| suffix: Optional[str] = None | |
| temperature: Optional[float] = None | |
| top_p: Optional[float] = None | |
| def __init__( | |
| self, | |
| best_of: Optional[int] = None, | |
| echo: Optional[bool] = None, | |
| frequency_penalty: Optional[int] = None, | |
| logit_bias: Optional[dict] = None, | |
| logprobs: Optional[int] = None, | |
| max_tokens: Optional[int] = None, | |
| n: Optional[int] = None, | |
| presence_penalty: Optional[int] = None, | |
| stop: Optional[Union[str, list]] = None, | |
| suffix: Optional[str] = None, | |
| temperature: Optional[float] = None, | |
| top_p: Optional[float] = None, | |
| ) -> None: | |
| locals_ = locals() | |
| for key, value in locals_.items(): | |
| if key != "self" and value is not None: | |
| setattr(self.__class__, key, value) | |
| def get_config(cls): | |
| return { | |
| k: v | |
| for k, v in cls.__dict__.items() | |
| if not k.startswith("__") | |
| and not isinstance( | |
| v, | |
| ( | |
| types.FunctionType, | |
| types.BuiltinFunctionType, | |
| classmethod, | |
| staticmethod, | |
| ), | |
| ) | |
| and v is not None | |
| } | |
| class OpenAIChatCompletion(BaseLLM): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def completion( | |
| self, | |
| model_response: ModelResponse, | |
| timeout: float, | |
| model: Optional[str] = None, | |
| messages: Optional[list] = None, | |
| print_verbose: Optional[Callable] = None, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| acompletion: bool = False, | |
| logging_obj=None, | |
| optional_params=None, | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers: Optional[dict] = None, | |
| custom_prompt_dict: dict = {}, | |
| client=None, | |
| ): | |
| super().completion() | |
| exception_mapping_worked = False | |
| try: | |
| if headers: | |
| optional_params["extra_headers"] = headers | |
| if model is None or messages is None: | |
| raise OpenAIError(status_code=422, message=f"Missing model or messages") | |
| if not isinstance(timeout, float): | |
| raise OpenAIError( | |
| status_code=422, message=f"Timeout needs to be a float" | |
| ) | |
| for _ in range( | |
| 2 | |
| ): # if call fails due to alternating messages, retry with reformatted message | |
| data = {"model": model, "messages": messages, **optional_params} | |
| try: | |
| max_retries = data.pop("max_retries", 2) | |
| if acompletion is True: | |
| if optional_params.get("stream", False): | |
| return self.async_streaming( | |
| logging_obj=logging_obj, | |
| headers=headers, | |
| data=data, | |
| model=model, | |
| api_base=api_base, | |
| api_key=api_key, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| return self.acompletion( | |
| data=data, | |
| headers=headers, | |
| logging_obj=logging_obj, | |
| model_response=model_response, | |
| api_base=api_base, | |
| api_key=api_key, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| ) | |
| elif optional_params.get("stream", False): | |
| return self.streaming( | |
| logging_obj=logging_obj, | |
| headers=headers, | |
| data=data, | |
| model=model, | |
| api_base=api_base, | |
| api_key=api_key, | |
| timeout=timeout, | |
| client=client, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError( | |
| status_code=422, message="max retries must be an int" | |
| ) | |
| if client is None: | |
| openai_client = OpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.client_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_client = client | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=messages, | |
| api_key=openai_client.api_key, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": openai_client._base_url._uri_reference, | |
| "acompletion": acompletion, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| response = openai_client.chat.completions.create(**data, timeout=timeout) # type: ignore | |
| stringified_response = response.model_dump() | |
| logging_obj.post_call( | |
| input=messages, | |
| api_key=api_key, | |
| original_response=stringified_response, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| return convert_to_model_response_object( | |
| response_object=stringified_response, | |
| model_response_object=model_response, | |
| ) | |
| except Exception as e: | |
| if "Conversation roles must alternate user/assistant" in str( | |
| e | |
| ) or "user and assistant roles should be alternating" in str(e): | |
| # reformat messages to ensure user/assistant are alternating, if there's either 2 consecutive 'user' messages or 2 consecutive 'assistant' message, add a blank 'user' or 'assistant' message to ensure compatibility | |
| new_messages = [] | |
| for i in range(len(messages) - 1): | |
| new_messages.append(messages[i]) | |
| if messages[i]["role"] == messages[i + 1]["role"]: | |
| if messages[i]["role"] == "user": | |
| new_messages.append( | |
| {"role": "assistant", "content": ""} | |
| ) | |
| else: | |
| new_messages.append({"role": "user", "content": ""}) | |
| new_messages.append(messages[-1]) | |
| messages = new_messages | |
| elif "Last message must have role `user`" in str(e): | |
| new_messages = messages | |
| new_messages.append({"role": "user", "content": ""}) | |
| messages = new_messages | |
| else: | |
| raise e | |
| except OpenAIError as e: | |
| exception_mapping_worked = True | |
| raise e | |
| except Exception as e: | |
| if hasattr(e, "status_code"): | |
| raise OpenAIError(status_code=e.status_code, message=str(e)) | |
| else: | |
| raise OpenAIError(status_code=500, message=str(e)) | |
| async def acompletion( | |
| self, | |
| data: dict, | |
| model_response: ModelResponse, | |
| timeout: float, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| logging_obj=None, | |
| headers=None, | |
| ): | |
| response = None | |
| try: | |
| if client is None: | |
| openai_aclient = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.aclient_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_aclient = client | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=data["messages"], | |
| api_key=openai_aclient.api_key, | |
| additional_args={ | |
| "headers": {"Authorization": f"Bearer {openai_aclient.api_key}"}, | |
| "api_base": openai_aclient._base_url._uri_reference, | |
| "acompletion": True, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| response = await openai_aclient.chat.completions.create( | |
| **data, timeout=timeout | |
| ) | |
| stringified_response = response.model_dump() | |
| logging_obj.post_call( | |
| input=data["messages"], | |
| api_key=api_key, | |
| original_response=stringified_response, | |
| additional_args={"complete_input_dict": data}, | |
| ) | |
| return convert_to_model_response_object( | |
| response_object=stringified_response, | |
| model_response_object=model_response, | |
| ) | |
| except Exception as e: | |
| raise e | |
| def streaming( | |
| self, | |
| logging_obj, | |
| timeout: float, | |
| data: dict, | |
| model: str, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| headers=None, | |
| ): | |
| if client is None: | |
| openai_client = OpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.client_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_client = client | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=data["messages"], | |
| api_key=api_key, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": api_base, | |
| "acompletion": False, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| response = openai_client.chat.completions.create(**data, timeout=timeout) | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=response, | |
| model=model, | |
| custom_llm_provider="openai", | |
| logging_obj=logging_obj, | |
| ) | |
| return streamwrapper | |
| async def async_streaming( | |
| self, | |
| logging_obj, | |
| timeout: float, | |
| data: dict, | |
| model: str, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| headers=None, | |
| ): | |
| response = None | |
| try: | |
| if client is None: | |
| openai_aclient = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.aclient_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_aclient = client | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=data["messages"], | |
| api_key=api_key, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": api_base, | |
| "acompletion": True, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| response = await openai_aclient.chat.completions.create( | |
| **data, timeout=timeout | |
| ) | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=response, | |
| model=model, | |
| custom_llm_provider="openai", | |
| logging_obj=logging_obj, | |
| ) | |
| return streamwrapper | |
| except ( | |
| Exception | |
| ) as e: # need to exception handle here. async exceptions don't get caught in sync functions. | |
| if response is not None and hasattr(response, "text"): | |
| raise OpenAIError( | |
| status_code=500, | |
| message=f"{str(e)}\n\nOriginal Response: {response.text}", | |
| ) | |
| else: | |
| if type(e).__name__ == "ReadTimeout": | |
| raise OpenAIError(status_code=408, message=f"{type(e).__name__}") | |
| elif hasattr(e, "status_code"): | |
| raise OpenAIError(status_code=e.status_code, message=str(e)) | |
| else: | |
| raise OpenAIError(status_code=500, message=f"{str(e)}") | |
| async def aembedding( | |
| self, | |
| input: list, | |
| data: dict, | |
| model_response: ModelResponse, | |
| timeout: float, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| logging_obj=None, | |
| ): | |
| response = None | |
| try: | |
| if client is None: | |
| openai_aclient = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.aclient_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_aclient = client | |
| response = await openai_aclient.embeddings.create(**data, timeout=timeout) # type: ignore | |
| stringified_response = response.model_dump() | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=stringified_response, | |
| ) | |
| return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="embedding") # type: ignore | |
| except Exception as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| original_response=str(e), | |
| ) | |
| raise e | |
| def embedding( | |
| self, | |
| model: str, | |
| input: list, | |
| timeout: float, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| model_response: Optional[litellm.utils.EmbeddingResponse] = None, | |
| logging_obj=None, | |
| optional_params=None, | |
| client=None, | |
| aembedding=None, | |
| ): | |
| super().embedding() | |
| exception_mapping_worked = False | |
| try: | |
| model = model | |
| data = {"model": model, "input": input, **optional_params} | |
| max_retries = data.pop("max_retries", 2) | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError(status_code=422, message="max retries must be an int") | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data, "api_base": api_base}, | |
| ) | |
| if aembedding == True: | |
| response = self.aembedding(data=data, input=input, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore | |
| return response | |
| if client is None: | |
| openai_client = OpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.client_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_client = client | |
| ## COMPLETION CALL | |
| response = openai_client.embeddings.create(**data, timeout=timeout) # type: ignore | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=response, | |
| ) | |
| return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="embedding") # type: ignore | |
| except OpenAIError as e: | |
| exception_mapping_worked = True | |
| raise e | |
| except Exception as e: | |
| if hasattr(e, "status_code"): | |
| raise OpenAIError(status_code=e.status_code, message=str(e)) | |
| else: | |
| raise OpenAIError(status_code=500, message=str(e)) | |
| async def aimage_generation( | |
| self, | |
| prompt: str, | |
| data: dict, | |
| model_response: ModelResponse, | |
| timeout: float, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| client=None, | |
| max_retries=None, | |
| logging_obj=None, | |
| ): | |
| response = None | |
| try: | |
| if client is None: | |
| openai_aclient = AsyncOpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.aclient_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_aclient = client | |
| response = await openai_aclient.images.generate(**data, timeout=timeout) # type: ignore | |
| stringified_response = response.model_dump() | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=stringified_response, | |
| ) | |
| return convert_to_model_response_object(response_object=stringified_response, model_response_object=model_response, response_type="image_generation") # type: ignore | |
| except Exception as e: | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| original_response=str(e), | |
| ) | |
| raise e | |
| def image_generation( | |
| self, | |
| model: Optional[str], | |
| prompt: str, | |
| timeout: float, | |
| api_key: Optional[str] = None, | |
| api_base: Optional[str] = None, | |
| model_response: Optional[litellm.utils.ImageResponse] = None, | |
| logging_obj=None, | |
| optional_params=None, | |
| client=None, | |
| aimg_generation=None, | |
| ): | |
| exception_mapping_worked = False | |
| try: | |
| model = model | |
| data = {"model": model, "prompt": prompt, **optional_params} | |
| max_retries = data.pop("max_retries", 2) | |
| if not isinstance(max_retries, int): | |
| raise OpenAIError(status_code=422, message="max retries must be an int") | |
| if aimg_generation == True: | |
| response = self.aimage_generation(data=data, prompt=prompt, logging_obj=logging_obj, model_response=model_response, api_base=api_base, api_key=api_key, timeout=timeout, client=client, max_retries=max_retries) # type: ignore | |
| return response | |
| if client is None: | |
| openai_client = OpenAI( | |
| api_key=api_key, | |
| base_url=api_base, | |
| http_client=litellm.client_session, | |
| timeout=timeout, | |
| max_retries=max_retries, | |
| ) | |
| else: | |
| openai_client = client | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=prompt, | |
| api_key=openai_client.api_key, | |
| additional_args={ | |
| "headers": {"Authorization": f"Bearer {openai_client.api_key}"}, | |
| "api_base": openai_client._base_url._uri_reference, | |
| "acompletion": True, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| ## COMPLETION CALL | |
| response = openai_client.images.generate(**data, timeout=timeout) # type: ignore | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=input, | |
| api_key=api_key, | |
| additional_args={"complete_input_dict": data}, | |
| original_response=response, | |
| ) | |
| # return response | |
| return convert_to_model_response_object(response_object=response.model_dump(), model_response_object=model_response, response_type="image_generation") # type: ignore | |
| except OpenAIError as e: | |
| exception_mapping_worked = True | |
| raise e | |
| except Exception as e: | |
| if hasattr(e, "status_code"): | |
| raise OpenAIError(status_code=e.status_code, message=str(e)) | |
| else: | |
| raise OpenAIError(status_code=500, message=str(e)) | |
| async def ahealth_check( | |
| self, | |
| model: Optional[str], | |
| api_key: str, | |
| timeout: float, | |
| mode: str, | |
| messages: Optional[list] = None, | |
| input: Optional[list] = None, | |
| prompt: Optional[str] = None, | |
| ): | |
| client = AsyncOpenAI(api_key=api_key, timeout=timeout) | |
| if model is None and mode != "image_generation": | |
| raise Exception("model is not set") | |
| completion = None | |
| if mode == "completion": | |
| completion = await client.completions.with_raw_response.create( | |
| model=model, # type: ignore | |
| prompt=prompt, # type: ignore | |
| ) | |
| elif mode == "chat": | |
| if messages is None: | |
| raise Exception("messages is not set") | |
| completion = await client.chat.completions.with_raw_response.create( | |
| model=model, # type: ignore | |
| messages=messages, # type: ignore | |
| ) | |
| elif mode == "embedding": | |
| if input is None: | |
| raise Exception("input is not set") | |
| completion = await client.embeddings.with_raw_response.create( | |
| model=model, # type: ignore | |
| input=input, # type: ignore | |
| ) | |
| elif mode == "image_generation": | |
| if prompt is None: | |
| raise Exception("prompt is not set") | |
| completion = await client.images.with_raw_response.generate( | |
| model=model, # type: ignore | |
| prompt=prompt, # type: ignore | |
| ) | |
| else: | |
| raise Exception("mode not set") | |
| response = {} | |
| if completion is None or not hasattr(completion, "headers"): | |
| raise Exception("invalid completion response") | |
| if ( | |
| completion.headers.get("x-ratelimit-remaining-requests", None) is not None | |
| ): # not provided for dall-e requests | |
| response["x-ratelimit-remaining-requests"] = completion.headers[ | |
| "x-ratelimit-remaining-requests" | |
| ] | |
| if completion.headers.get("x-ratelimit-remaining-tokens", None) is not None: | |
| response["x-ratelimit-remaining-tokens"] = completion.headers[ | |
| "x-ratelimit-remaining-tokens" | |
| ] | |
| return response | |
| class OpenAITextCompletion(BaseLLM): | |
| _client_session: httpx.Client | |
| def __init__(self) -> None: | |
| super().__init__() | |
| self._client_session = self.create_client_session() | |
| def validate_environment(self, api_key): | |
| headers = { | |
| "content-type": "application/json", | |
| } | |
| if api_key: | |
| headers["Authorization"] = f"Bearer {api_key}" | |
| return headers | |
| def convert_to_model_response_object( | |
| self, | |
| response_object: Optional[dict] = None, | |
| model_response_object: Optional[ModelResponse] = None, | |
| ): | |
| try: | |
| ## RESPONSE OBJECT | |
| if response_object is None or model_response_object is None: | |
| raise ValueError("Error in response object format") | |
| choice_list = [] | |
| for idx, choice in enumerate(response_object["choices"]): | |
| message = Message(content=choice["text"], role="assistant") | |
| choice = Choices( | |
| finish_reason=choice["finish_reason"], index=idx, message=message | |
| ) | |
| choice_list.append(choice) | |
| model_response_object.choices = choice_list | |
| if "usage" in response_object: | |
| model_response_object.usage = response_object["usage"] | |
| if "id" in response_object: | |
| model_response_object.id = response_object["id"] | |
| if "model" in response_object: | |
| model_response_object.model = response_object["model"] | |
| model_response_object._hidden_params[ | |
| "original_response" | |
| ] = response_object # track original response, if users make a litellm.text_completion() request, we can return the original response | |
| return model_response_object | |
| except Exception as e: | |
| raise e | |
| def completion( | |
| self, | |
| model_response: ModelResponse, | |
| api_key: str, | |
| model: str, | |
| messages: list, | |
| timeout: float, | |
| print_verbose: Optional[Callable] = None, | |
| api_base: Optional[str] = None, | |
| logging_obj=None, | |
| acompletion: bool = False, | |
| optional_params=None, | |
| litellm_params=None, | |
| logger_fn=None, | |
| headers: Optional[dict] = None, | |
| ): | |
| super().completion() | |
| exception_mapping_worked = False | |
| try: | |
| if headers is None: | |
| headers = self.validate_environment(api_key=api_key) | |
| if model is None or messages is None: | |
| raise OpenAIError(status_code=422, message=f"Missing model or messages") | |
| api_base = f"{api_base}/completions" | |
| if ( | |
| len(messages) > 0 | |
| and "content" in messages[0] | |
| and type(messages[0]["content"]) == list | |
| ): | |
| prompt = messages[0]["content"] | |
| else: | |
| prompt = " ".join([message["content"] for message in messages]) # type: ignore | |
| # don't send max retries to the api, if set | |
| optional_params.pop("max_retries", None) | |
| data = {"model": model, "prompt": prompt, **optional_params} | |
| ## LOGGING | |
| logging_obj.pre_call( | |
| input=messages, | |
| api_key=api_key, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": api_base, | |
| "complete_input_dict": data, | |
| }, | |
| ) | |
| if acompletion == True: | |
| if optional_params.get("stream", False): | |
| return self.async_streaming( | |
| logging_obj=logging_obj, | |
| api_base=api_base, | |
| data=data, | |
| headers=headers, | |
| model_response=model_response, | |
| model=model, | |
| timeout=timeout, | |
| ) | |
| else: | |
| return self.acompletion(api_base=api_base, data=data, headers=headers, model_response=model_response, prompt=prompt, api_key=api_key, logging_obj=logging_obj, model=model, timeout=timeout) # type: ignore | |
| elif optional_params.get("stream", False): | |
| return self.streaming( | |
| logging_obj=logging_obj, | |
| api_base=api_base, | |
| data=data, | |
| headers=headers, | |
| model_response=model_response, | |
| model=model, | |
| timeout=timeout, | |
| ) | |
| else: | |
| response = httpx.post( | |
| url=f"{api_base}", json=data, headers=headers, timeout=timeout | |
| ) | |
| if response.status_code != 200: | |
| raise OpenAIError( | |
| status_code=response.status_code, message=response.text | |
| ) | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| original_response=response, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": api_base, | |
| }, | |
| ) | |
| ## RESPONSE OBJECT | |
| return self.convert_to_model_response_object( | |
| response_object=response.json(), | |
| model_response_object=model_response, | |
| ) | |
| except Exception as e: | |
| raise e | |
| async def acompletion( | |
| self, | |
| logging_obj, | |
| api_base: str, | |
| data: dict, | |
| headers: dict, | |
| model_response: ModelResponse, | |
| prompt: str, | |
| api_key: str, | |
| model: str, | |
| timeout: float, | |
| ): | |
| async with httpx.AsyncClient(timeout=timeout) as client: | |
| try: | |
| response = await client.post( | |
| api_base, | |
| json=data, | |
| headers=headers, | |
| timeout=litellm.request_timeout, | |
| ) | |
| response_json = response.json() | |
| if response.status_code != 200: | |
| raise OpenAIError( | |
| status_code=response.status_code, message=response.text | |
| ) | |
| ## LOGGING | |
| logging_obj.post_call( | |
| input=prompt, | |
| api_key=api_key, | |
| original_response=response, | |
| additional_args={ | |
| "headers": headers, | |
| "api_base": api_base, | |
| }, | |
| ) | |
| ## RESPONSE OBJECT | |
| return self.convert_to_model_response_object( | |
| response_object=response_json, model_response_object=model_response | |
| ) | |
| except Exception as e: | |
| raise e | |
| def streaming( | |
| self, | |
| logging_obj, | |
| api_base: str, | |
| data: dict, | |
| headers: dict, | |
| model_response: ModelResponse, | |
| model: str, | |
| timeout: float, | |
| ): | |
| with httpx.stream( | |
| url=f"{api_base}", | |
| json=data, | |
| headers=headers, | |
| method="POST", | |
| timeout=timeout, | |
| ) as response: | |
| if response.status_code != 200: | |
| raise OpenAIError( | |
| status_code=response.status_code, message=response.text | |
| ) | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=response.iter_lines(), | |
| model=model, | |
| custom_llm_provider="text-completion-openai", | |
| logging_obj=logging_obj, | |
| ) | |
| for transformed_chunk in streamwrapper: | |
| yield transformed_chunk | |
| async def async_streaming( | |
| self, | |
| logging_obj, | |
| api_base: str, | |
| data: dict, | |
| headers: dict, | |
| model_response: ModelResponse, | |
| model: str, | |
| timeout: float, | |
| ): | |
| client = httpx.AsyncClient() | |
| async with client.stream( | |
| url=f"{api_base}", | |
| json=data, | |
| headers=headers, | |
| method="POST", | |
| timeout=timeout, | |
| ) as response: | |
| try: | |
| if response.status_code != 200: | |
| raise OpenAIError( | |
| status_code=response.status_code, message=response.text | |
| ) | |
| streamwrapper = CustomStreamWrapper( | |
| completion_stream=response.aiter_lines(), | |
| model=model, | |
| custom_llm_provider="text-completion-openai", | |
| logging_obj=logging_obj, | |
| ) | |
| async for transformed_chunk in streamwrapper: | |
| yield transformed_chunk | |
| except Exception as e: | |
| raise e | |