Spaces:
Running
on
Zero
Running
on
Zero
| import io | |
| import logging | |
| import base64 | |
| import requests | |
| import torch | |
| from typing import Optional | |
| from comfy.comfy_types.node_typing import IO, ComfyNodeABC | |
| from comfy_api.input_impl.video_types import VideoFromFile | |
| from comfy_api_nodes.apis import ( | |
| Veo2GenVidRequest, | |
| Veo2GenVidResponse, | |
| Veo2GenVidPollRequest, | |
| Veo2GenVidPollResponse | |
| ) | |
| from comfy_api_nodes.apis.client import ( | |
| ApiEndpoint, | |
| HttpMethod, | |
| SynchronousOperation, | |
| PollingOperation, | |
| ) | |
| from comfy_api_nodes.apinode_utils import ( | |
| downscale_image_tensor, | |
| tensor_to_base64_string | |
| ) | |
| AVERAGE_DURATION_VIDEO_GEN = 32 | |
| def convert_image_to_base64(image: torch.Tensor): | |
| if image is None: | |
| return None | |
| scaled_image = downscale_image_tensor(image, total_pixels=2048*2048) | |
| return tensor_to_base64_string(scaled_image) | |
| def get_video_url_from_response(poll_response: Veo2GenVidPollResponse) -> Optional[str]: | |
| if ( | |
| poll_response.response | |
| and hasattr(poll_response.response, "videos") | |
| and poll_response.response.videos | |
| and len(poll_response.response.videos) > 0 | |
| ): | |
| video = poll_response.response.videos[0] | |
| else: | |
| return None | |
| if hasattr(video, "gcsUri") and video.gcsUri: | |
| return str(video.gcsUri) | |
| return None | |
| class VeoVideoGenerationNode(ComfyNodeABC): | |
| """ | |
| Generates videos from text prompts using Google's Veo API. | |
| This node can create videos from text descriptions and optional image inputs, | |
| with control over parameters like aspect ratio, duration, and more. | |
| """ | |
| def INPUT_TYPES(s): | |
| return { | |
| "required": { | |
| "prompt": ( | |
| IO.STRING, | |
| { | |
| "multiline": True, | |
| "default": "", | |
| "tooltip": "Text description of the video", | |
| }, | |
| ), | |
| "aspect_ratio": ( | |
| IO.COMBO, | |
| { | |
| "options": ["16:9", "9:16"], | |
| "default": "16:9", | |
| "tooltip": "Aspect ratio of the output video", | |
| }, | |
| ), | |
| }, | |
| "optional": { | |
| "negative_prompt": ( | |
| IO.STRING, | |
| { | |
| "multiline": True, | |
| "default": "", | |
| "tooltip": "Negative text prompt to guide what to avoid in the video", | |
| }, | |
| ), | |
| "duration_seconds": ( | |
| IO.INT, | |
| { | |
| "default": 5, | |
| "min": 5, | |
| "max": 8, | |
| "step": 1, | |
| "display": "number", | |
| "tooltip": "Duration of the output video in seconds", | |
| }, | |
| ), | |
| "enhance_prompt": ( | |
| IO.BOOLEAN, | |
| { | |
| "default": True, | |
| "tooltip": "Whether to enhance the prompt with AI assistance", | |
| } | |
| ), | |
| "person_generation": ( | |
| IO.COMBO, | |
| { | |
| "options": ["ALLOW", "BLOCK"], | |
| "default": "ALLOW", | |
| "tooltip": "Whether to allow generating people in the video", | |
| }, | |
| ), | |
| "seed": ( | |
| IO.INT, | |
| { | |
| "default": 0, | |
| "min": 0, | |
| "max": 0xFFFFFFFF, | |
| "step": 1, | |
| "display": "number", | |
| "control_after_generate": True, | |
| "tooltip": "Seed for video generation (0 for random)", | |
| }, | |
| ), | |
| "image": (IO.IMAGE, { | |
| "default": None, | |
| "tooltip": "Optional reference image to guide video generation", | |
| }), | |
| }, | |
| "hidden": { | |
| "auth_token": "AUTH_TOKEN_COMFY_ORG", | |
| "comfy_api_key": "API_KEY_COMFY_ORG", | |
| "unique_id": "UNIQUE_ID", | |
| }, | |
| } | |
| RETURN_TYPES = (IO.VIDEO,) | |
| FUNCTION = "generate_video" | |
| CATEGORY = "api node/video/Veo" | |
| DESCRIPTION = "Generates videos from text prompts using Google's Veo API" | |
| API_NODE = True | |
| def generate_video( | |
| self, | |
| prompt, | |
| aspect_ratio="16:9", | |
| negative_prompt="", | |
| duration_seconds=5, | |
| enhance_prompt=True, | |
| person_generation="ALLOW", | |
| seed=0, | |
| image=None, | |
| unique_id: Optional[str] = None, | |
| **kwargs, | |
| ): | |
| # Prepare the instances for the request | |
| instances = [] | |
| instance = { | |
| "prompt": prompt | |
| } | |
| # Add image if provided | |
| if image is not None: | |
| image_base64 = convert_image_to_base64(image) | |
| if image_base64: | |
| instance["image"] = { | |
| "bytesBase64Encoded": image_base64, | |
| "mimeType": "image/png" | |
| } | |
| instances.append(instance) | |
| # Create parameters dictionary | |
| parameters = { | |
| "aspectRatio": aspect_ratio, | |
| "personGeneration": person_generation, | |
| "durationSeconds": duration_seconds, | |
| "enhancePrompt": enhance_prompt, | |
| } | |
| # Add optional parameters if provided | |
| if negative_prompt: | |
| parameters["negativePrompt"] = negative_prompt | |
| if seed > 0: | |
| parameters["seed"] = seed | |
| # Initial request to start video generation | |
| initial_operation = SynchronousOperation( | |
| endpoint=ApiEndpoint( | |
| path="/proxy/veo/generate", | |
| method=HttpMethod.POST, | |
| request_model=Veo2GenVidRequest, | |
| response_model=Veo2GenVidResponse | |
| ), | |
| request=Veo2GenVidRequest( | |
| instances=instances, | |
| parameters=parameters | |
| ), | |
| auth_kwargs=kwargs, | |
| ) | |
| initial_response = initial_operation.execute() | |
| operation_name = initial_response.name | |
| logging.info(f"Veo generation started with operation name: {operation_name}") | |
| # Define status extractor function | |
| def status_extractor(response): | |
| # Only return "completed" if the operation is done, regardless of success or failure | |
| # We'll check for errors after polling completes | |
| return "completed" if response.done else "pending" | |
| # Define progress extractor function | |
| def progress_extractor(response): | |
| # Could be enhanced if the API provides progress information | |
| return None | |
| # Define the polling operation | |
| poll_operation = PollingOperation( | |
| poll_endpoint=ApiEndpoint( | |
| path="/proxy/veo/poll", | |
| method=HttpMethod.POST, | |
| request_model=Veo2GenVidPollRequest, | |
| response_model=Veo2GenVidPollResponse | |
| ), | |
| completed_statuses=["completed"], | |
| failed_statuses=[], # No failed statuses, we'll handle errors after polling | |
| status_extractor=status_extractor, | |
| progress_extractor=progress_extractor, | |
| request=Veo2GenVidPollRequest( | |
| operationName=operation_name | |
| ), | |
| auth_kwargs=kwargs, | |
| poll_interval=5.0, | |
| result_url_extractor=get_video_url_from_response, | |
| node_id=unique_id, | |
| estimated_duration=AVERAGE_DURATION_VIDEO_GEN, | |
| ) | |
| # Execute the polling operation | |
| poll_response = poll_operation.execute() | |
| # Now check for errors in the final response | |
| # Check for error in poll response | |
| if hasattr(poll_response, 'error') and poll_response.error: | |
| error_message = f"Veo API error: {poll_response.error.message} (code: {poll_response.error.code})" | |
| logging.error(error_message) | |
| raise Exception(error_message) | |
| # Check for RAI filtered content | |
| if (hasattr(poll_response.response, 'raiMediaFilteredCount') and | |
| poll_response.response.raiMediaFilteredCount > 0): | |
| # Extract reason message if available | |
| if (hasattr(poll_response.response, 'raiMediaFilteredReasons') and | |
| poll_response.response.raiMediaFilteredReasons): | |
| reason = poll_response.response.raiMediaFilteredReasons[0] | |
| error_message = f"Content filtered by Google's Responsible AI practices: {reason} ({poll_response.response.raiMediaFilteredCount} videos filtered.)" | |
| else: | |
| error_message = f"Content filtered by Google's Responsible AI practices ({poll_response.response.raiMediaFilteredCount} videos filtered.)" | |
| logging.error(error_message) | |
| raise Exception(error_message) | |
| # Extract video data | |
| video_data = None | |
| if poll_response.response and hasattr(poll_response.response, 'videos') and poll_response.response.videos and len(poll_response.response.videos) > 0: | |
| video = poll_response.response.videos[0] | |
| # Check if video is provided as base64 or URL | |
| if hasattr(video, 'bytesBase64Encoded') and video.bytesBase64Encoded: | |
| # Decode base64 string to bytes | |
| video_data = base64.b64decode(video.bytesBase64Encoded) | |
| elif hasattr(video, 'gcsUri') and video.gcsUri: | |
| # Download from URL | |
| video_url = video.gcsUri | |
| video_response = requests.get(video_url) | |
| video_data = video_response.content | |
| else: | |
| raise Exception("Video returned but no data or URL was provided") | |
| else: | |
| raise Exception("Video generation completed but no video was returned") | |
| if not video_data: | |
| raise Exception("No video data was returned") | |
| logging.info("Video generation completed successfully") | |
| # Convert video data to BytesIO object | |
| video_io = io.BytesIO(video_data) | |
| # Return VideoFromFile object | |
| return (VideoFromFile(video_io),) | |
| # Register the node | |
| NODE_CLASS_MAPPINGS = { | |
| "VeoVideoGenerationNode": VeoVideoGenerationNode, | |
| } | |
| NODE_DISPLAY_NAME_MAPPINGS = { | |
| "VeoVideoGenerationNode": "Google Veo2 Video Generation", | |
| } | |