Spaces:
Running
Running
| from typing import Type | |
| from api.baseline import BaselineAPI | |
| from api.fireworks import FireworksAPI | |
| from api.flux import FluxAPI | |
| from api.pruna import PrunaAPI | |
| from api.pruna_dev import PrunaDevAPI | |
| from api.replicate import ReplicateAPI | |
| from api.together import TogetherAPI | |
| from api.fal import FalAPI | |
| __all__ = [ | |
| 'create_api', | |
| 'FluxAPI', | |
| 'BaselineAPI', | |
| 'FireworksAPI', | |
| 'PrunaAPI', | |
| 'ReplicateAPI', | |
| 'TogetherAPI', | |
| 'FalAPI', | |
| 'PrunaDevAPI', | |
| ] | |
| def create_api(api_type: str) -> FluxAPI: | |
| """ | |
| Factory function to create API instances. | |
| Args: | |
| api_type (str): The type of API to create. Must be one of: | |
| - "baseline" | |
| - "fireworks" | |
| - "pruna_speed_mode" (where speed_mode is the desired speed mode) | |
| - "replicate" | |
| - "together" | |
| - "fal" | |
| Returns: | |
| FluxAPI: An instance of the requested API implementation | |
| Raises: | |
| ValueError: If an invalid API type is provided | |
| """ | |
| if api_type == "pruna_dev": | |
| return PrunaDevAPI() | |
| if api_type.startswith("pruna_"): | |
| speed_mode = api_type[6:] # Remove "pruna_" prefix | |
| return PrunaAPI(speed_mode) | |
| api_map: dict[str, Type[FluxAPI]] = { | |
| "baseline": BaselineAPI, | |
| "fireworks": FireworksAPI, | |
| "replicate": ReplicateAPI, | |
| "together": TogetherAPI, | |
| "fal": FalAPI, | |
| } | |
| if api_type not in api_map: | |
| raise ValueError(f"Invalid API type: {api_type}. Must be one of {list(api_map.keys())} or start with 'pruna_'") | |
| return api_map[api_type]() | |