Spaces:
Running
Running
| """ | |
| Text-to-video functionality handler for AI-Inferoxy AI Hub. | |
| Handles text-to-video generation with multiple providers. | |
| """ | |
| import os | |
| import gradio as gr | |
| import tempfile | |
| import io | |
| from concurrent.futures import ThreadPoolExecutor, TimeoutError as FutureTimeoutError | |
| from huggingface_hub import InferenceClient | |
| from huggingface_hub.errors import HfHubHTTPError | |
| from requests.exceptions import ConnectionError | |
| from hf_token_utils import get_proxy_token, report_token_status | |
| from utils import ( | |
| validate_proxy_key, | |
| format_error_message, | |
| format_success_message, | |
| ) | |
| # Timeout configuration for video generation | |
| VIDEO_GENERATION_TIMEOUT = 600 # up to 10 minutes, videos can be slow | |
| def generate_video( | |
| prompt: str, | |
| model_name: str, | |
| provider: str, | |
| num_inference_steps: int | None = None, | |
| guidance_scale: float | None = None, | |
| seed: int | None = None, | |
| client_name: str | None = None, | |
| ): | |
| """ | |
| Generate a video using the specified model and provider through AI-Inferoxy. | |
| Returns (video_bytes_or_url, status_message) | |
| """ | |
| # Validate proxy API key | |
| is_valid, error_msg = validate_proxy_key() | |
| if not is_valid: | |
| return None, error_msg | |
| proxy_api_key = os.getenv("PROXY_KEY") | |
| token_id = None | |
| try: | |
| # Get token from AI-Inferoxy proxy server with timeout handling | |
| print(f"π Video: Requesting token from proxy...") | |
| token, token_id = get_proxy_token(api_key=proxy_api_key) | |
| print(f"β Video: Got token: {token_id}") | |
| print(f"π¬ Video: Using model='{model_name}', provider='{provider}'") | |
| # Create client with specified provider | |
| client = InferenceClient( | |
| provider=provider, | |
| api_key=token | |
| ) | |
| # Prepare generation parameters | |
| generation_params: dict = { | |
| "model": model_name, | |
| "prompt": prompt, | |
| } | |
| if num_inference_steps is not None: | |
| generation_params["num_inference_steps"] = num_inference_steps | |
| if guidance_scale is not None: | |
| generation_params["guidance_scale"] = guidance_scale | |
| if seed is not None and seed != -1: | |
| generation_params["seed"] = seed | |
| print(f"π‘ Video: Making generation request with {VIDEO_GENERATION_TIMEOUT}s timeout...") | |
| # Create generation function for timeout handling | |
| def generate_video_task(): | |
| return client.text_to_video(**generation_params) | |
| # Execute with timeout using ThreadPoolExecutor | |
| with ThreadPoolExecutor(max_workers=1) as executor: | |
| future = executor.submit(generate_video_task) | |
| try: | |
| video = future.result(timeout=VIDEO_GENERATION_TIMEOUT) | |
| except FutureTimeoutError: | |
| future.cancel() | |
| raise TimeoutError(f"Video generation timed out after {VIDEO_GENERATION_TIMEOUT} seconds") | |
| print(f"ποΈ Video: Generation completed! Type: {type(video)}") | |
| # Convert output to a path or URL Gradio can handle | |
| video_output = _coerce_video_output(video) | |
| # Report successful token usage | |
| if token_id: | |
| report_token_status(token_id, "success", api_key=proxy_api_key, client_name=client_name) | |
| return video_output, format_success_message("Video generated", f"using {model_name} on {provider}") | |
| except ConnectionError as e: | |
| error_msg = f"Cannot connect to AI-Inferoxy server: {str(e)}" | |
| print(f"π Video connection error: {error_msg}") | |
| if token_id: | |
| report_token_status(token_id, "error", error_msg, api_key=proxy_api_key, client_name=client_name) | |
| return None, format_error_message("Connection Error", "Unable to connect to the proxy server. Please check if it's running.") | |
| except TimeoutError as e: | |
| error_msg = f"Video generation timed out: {str(e)}" | |
| print(f"β° Video timeout: {error_msg}") | |
| if token_id: | |
| report_token_status(token_id, "error", error_msg, api_key=proxy_api_key, client_name=client_name) | |
| return None, format_error_message("Timeout Error", f"Video generation took too long (>{VIDEO_GENERATION_TIMEOUT//60} minutes). Try a shorter prompt.") | |
| except HfHubHTTPError as e: | |
| error_msg = str(e) | |
| print(f"π€ Video HF error: {error_msg}") | |
| if token_id: | |
| report_token_status(token_id, "error", error_msg, api_key=proxy_api_key, client_name=client_name) | |
| if "401" in error_msg: | |
| return None, format_error_message("Authentication Error", "Invalid or expired API token. The proxy will provide a new token on retry.") | |
| elif "402" in error_msg: | |
| return None, format_error_message("Quota Exceeded", "API quota exceeded. The proxy will try alternative providers.") | |
| elif "429" in error_msg: | |
| return None, format_error_message("Rate Limited", "Too many requests. Please wait a moment and try again.") | |
| else: | |
| return None, format_error_message("HuggingFace API Error", error_msg) | |
| except Exception as e: | |
| error_msg = str(e) | |
| print(f"β Video unexpected error: {error_msg}") | |
| if token_id: | |
| report_token_status(token_id, "error", error_msg, api_key=proxy_api_key) | |
| return None, format_error_message("Unexpected Error", f"An unexpected error occurred: {error_msg}") | |
| def handle_video_generation(prompt_val, model_val, provider_val, steps_val, guidance_val, seed_val, hf_token: gr.OAuthToken = None, hf_profile: gr.OAuthProfile = None): | |
| """ | |
| Handle text-to-video generation request with validation and org access. | |
| """ | |
| if not prompt_val or not prompt_val.strip(): | |
| return None, format_error_message("Validation Error", "Please enter a prompt for video generation") | |
| access_token = getattr(hf_token, "token", None) if hf_token is not None else None | |
| username = getattr(hf_profile, "username", None) if hf_profile is not None else None | |
| if not access_token: | |
| return None, format_error_message("Access Required", "Please sign in with Hugging Face (sidebar Login button).") | |
| return generate_video( | |
| prompt=prompt_val.strip(), | |
| model_name=model_val, | |
| provider=provider_val, | |
| num_inference_steps=steps_val if steps_val is not None else None, | |
| guidance_scale=guidance_val if guidance_val is not None else None, | |
| seed=seed_val if seed_val is not None else None, | |
| client_name=username, | |
| ) | |
| def _coerce_video_output(value): | |
| """Coerce various return types (bytes, str path/URL, BytesIO) into a filepath/URL for gr.Video.""" | |
| # Case 1: Direct URL or existing file path | |
| if isinstance(value, str): | |
| if value.startswith("http://") or value.startswith("https://"): | |
| return value | |
| if os.path.exists(value): | |
| return value | |
| # Unknown string; fall through to save as file | |
| # Case 2: Bytes-like content | |
| if isinstance(value, (bytes, bytearray)): | |
| data = bytes(value) | |
| suffix = _guess_video_suffix(data) | |
| return _write_temp_video(data, suffix) | |
| # Case 3: File-like object | |
| if isinstance(value, io.IOBase) or hasattr(value, "read"): | |
| try: | |
| data = value.read() | |
| if isinstance(data, (bytes, bytearray)): | |
| suffix = _guess_video_suffix(data) | |
| return _write_temp_video(bytes(data), suffix) | |
| except Exception: | |
| pass | |
| # Fallback: save string representation for debugging | |
| debug_bytes = str(type(value)).encode("utf-8") | |
| return _write_temp_video(debug_bytes, ".mp4") | |
| def _write_temp_video(data: bytes, suffix: str) -> str: | |
| tmp = tempfile.NamedTemporaryFile(delete=False, suffix=suffix) | |
| try: | |
| tmp.write(data) | |
| tmp.flush() | |
| finally: | |
| tmp.close() | |
| return tmp.name | |
| def _guess_video_suffix(data: bytes) -> str: | |
| header = data[:64] | |
| # MP4 often contains 'ftyp' box near start | |
| if b"ftyp" in header: | |
| return ".mp4" | |
| # WebM/Matroska magic number starts with 0x1A45DFA3 and often contains 'webm' | |
| if header.startswith(b"\x1aE\xdf\xa3") or b"webm" in header.lower(): | |
| return ".webm" | |
| # Default to mp4 | |
| return ".mp4" | |