Spaces:
Running
on
Zero
Running
on
Zero
| # Project EmbodiedGen | |
| # | |
| # Copyright (c) 2025 Horizon Robotics. All Rights Reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or | |
| # implied. See the License for the specific language governing | |
| # permissions and limitations under the License. | |
| import base64 | |
| import logging | |
| import os | |
| from io import BytesIO | |
| from typing import Optional | |
| import yaml | |
| from openai import AzureOpenAI, OpenAI # pip install openai | |
| from PIL import Image | |
| from tenacity import ( | |
| retry, | |
| stop_after_attempt, | |
| stop_after_delay, | |
| wait_random_exponential, | |
| ) | |
| from embodied_gen.utils.process_media import combine_images_to_base64 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| class GPTclient: | |
| """A client to interact with the GPT model via OpenAI or Azure API.""" | |
| def __init__( | |
| self, | |
| endpoint: str, | |
| api_key: str, | |
| model_name: str = "yfb-gpt-4o", | |
| api_version: str = None, | |
| verbose: bool = False, | |
| ): | |
| if api_version is not None: | |
| self.client = AzureOpenAI( | |
| azure_endpoint=endpoint, | |
| api_key=api_key, | |
| api_version=api_version, | |
| ) | |
| else: | |
| self.client = OpenAI( | |
| base_url=endpoint, | |
| api_key=api_key, | |
| ) | |
| self.endpoint = endpoint | |
| self.model_name = model_name | |
| self.image_formats = {".png", ".jpg", ".jpeg", ".webp", ".bmp", ".gif"} | |
| self.verbose = verbose | |
| logger.info(f"Using GPT model: {self.model_name}.") | |
| def completion_with_backoff(self, **kwargs): | |
| return self.client.chat.completions.create(**kwargs) | |
| def query( | |
| self, | |
| text_prompt: str, | |
| image_base64: Optional[list[str | Image.Image]] = None, | |
| system_role: Optional[str] = None, | |
| ) -> Optional[str]: | |
| """Queries the GPT model with a text and optional image prompts. | |
| Args: | |
| text_prompt (str): The main text input that the model responds to. | |
| image_base64 (Optional[List[str]]): A list of image base64 strings | |
| or local image paths or PIL.Image to accompany the text prompt. | |
| system_role (Optional[str]): Optional system-level instructions | |
| that specify the behavior of the assistant. | |
| Returns: | |
| Optional[str]: The response content generated by the model based on | |
| the prompt. Returns `None` if an error occurs. | |
| """ | |
| if system_role is None: | |
| system_role = "You are a highly knowledgeable assistant specializing in physics, engineering, and object properties." # noqa | |
| content_user = [ | |
| { | |
| "type": "text", | |
| "text": text_prompt, | |
| }, | |
| ] | |
| # Process images if provided | |
| if image_base64 is not None: | |
| image_base64 = ( | |
| image_base64 | |
| if isinstance(image_base64, list) | |
| else [image_base64] | |
| ) | |
| for img in image_base64: | |
| if isinstance(img, Image.Image): | |
| buffer = BytesIO() | |
| img.save(buffer, format=img.format or "PNG") | |
| buffer.seek(0) | |
| image_binary = buffer.read() | |
| img = base64.b64encode(image_binary).decode("utf-8") | |
| elif ( | |
| len(os.path.splitext(img)) > 1 | |
| and os.path.splitext(img)[-1].lower() in self.image_formats | |
| ): | |
| if not os.path.exists(img): | |
| raise FileNotFoundError(f"Image file not found: {img}") | |
| with open(img, "rb") as f: | |
| img = base64.b64encode(f.read()).decode("utf-8") | |
| content_user.append( | |
| { | |
| "type": "image_url", | |
| "image_url": {"url": f"data:image/png;base64,{img}"}, | |
| } | |
| ) | |
| payload = { | |
| "messages": [ | |
| {"role": "system", "content": system_role}, | |
| {"role": "user", "content": content_user}, | |
| ], | |
| "temperature": 0.1, | |
| "max_tokens": 500, | |
| "top_p": 0.1, | |
| "frequency_penalty": 0, | |
| "presence_penalty": 0, | |
| "stop": None, | |
| } | |
| payload.update({"model": self.model_name}) | |
| response = None | |
| try: | |
| response = self.completion_with_backoff(**payload) | |
| response = response.choices[0].message.content | |
| except Exception as e: | |
| logger.error(f"Error GPTclint {self.endpoint} API call: {e}") | |
| response = None | |
| if self.verbose: | |
| logger.info(f"Prompt: {text_prompt}") | |
| logger.info(f"Response: {response}") | |
| return response | |
| with open("embodied_gen/utils/gpt_config.yaml", "r") as f: | |
| config = yaml.safe_load(f) | |
| agent_type = config["agent_type"] | |
| agent_config = config.get(agent_type, {}) | |
| # Prefer environment variables, fallback to YAML config | |
| endpoint = os.environ.get("ENDPOINT", agent_config.get("endpoint")) | |
| api_key = os.environ.get("API_KEY", agent_config.get("api_key")) | |
| api_version = os.environ.get("API_VERSION", agent_config.get("api_version")) | |
| model_name = os.environ.get("MODEL_NAME", agent_config.get("model_name")) | |
| GPT_CLIENT = GPTclient( | |
| endpoint=endpoint, | |
| api_key=api_key, | |
| api_version=api_version, | |
| model_name=model_name, | |
| ) | |
| if __name__ == "__main__": | |
| if "openrouter" in GPT_CLIENT.endpoint: | |
| response = GPT_CLIENT.query( | |
| text_prompt="What is the content in each image?", | |
| image_base64=combine_images_to_base64( | |
| [ | |
| "apps/assets/example_image/sample_02.jpg", | |
| "apps/assets/example_image/sample_03.jpg", | |
| ] | |
| ), # input raw image_path if only one image | |
| ) | |
| print(response) | |
| else: | |
| response = GPT_CLIENT.query( | |
| text_prompt="What is the content in the images?", | |
| image_base64=[ | |
| Image.open("apps/assets/example_image/sample_02.jpg"), | |
| Image.open("apps/assets/example_image/sample_03.jpg"), | |
| ], | |
| ) | |
| print(response) | |
| # test2: text prompt | |
| response = GPT_CLIENT.query( | |
| text_prompt="What is the capital of China?" | |
| ) | |
| print(response) | |