Spaces:
Running
Running
| import argparse | |
| import json | |
| from pathlib import Path | |
| from typing import List | |
| from tqdm import tqdm | |
| from api import create_api | |
| from benchmark import create_benchmark | |
| def generate_images(api_type: str, benchmarks: List[str]): | |
| images_dir = Path("images") | |
| api = create_api(api_type) | |
| api_dir = images_dir / api_type | |
| api_dir.mkdir(parents=True, exist_ok=True) | |
| for benchmark_type in tqdm(benchmarks, desc="Processing benchmarks"): | |
| print(f"\nProcessing benchmark: {benchmark_type}") | |
| benchmark = create_benchmark(benchmark_type) | |
| if benchmark_type == "geneval": | |
| benchmark_dir = api_dir / benchmark_type | |
| benchmark_dir.mkdir(parents=True, exist_ok=True) | |
| metadata_file = benchmark_dir / "metadata.jsonl" | |
| existing_metadata = {} | |
| if metadata_file.exists(): | |
| with open(metadata_file, "r") as f: | |
| for line in f: | |
| entry = json.loads(line) | |
| existing_metadata[entry["filepath"]] = entry | |
| for metadata, folder_name in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False): | |
| sample_path = benchmark_dir / folder_name | |
| samples_path = sample_path / "samples" | |
| samples_path.mkdir(parents=True, exist_ok=True) | |
| image_path = samples_path / "0000.png" | |
| if image_path.exists(): | |
| continue | |
| try: | |
| inference_time = api.generate_image(metadata["prompt"], image_path) | |
| metadata_entry = { | |
| "filepath": str(image_path), | |
| "prompt": metadata["prompt"], | |
| "inference_time": inference_time | |
| } | |
| existing_metadata[str(image_path)] = metadata_entry | |
| except Exception as e: | |
| print(f"\nError generating image for prompt: {metadata['prompt']}") | |
| print(f"Error: {str(e)}") | |
| continue | |
| else: | |
| benchmark_dir = api_dir / benchmark_type | |
| benchmark_dir.mkdir(parents=True, exist_ok=True) | |
| metadata_file = benchmark_dir / "metadata.jsonl" | |
| existing_metadata = {} | |
| if metadata_file.exists(): | |
| with open(metadata_file, "r") as f: | |
| for line in f: | |
| entry = json.loads(line) | |
| existing_metadata[entry["filepath"]] = entry | |
| for prompt, image_path in tqdm(benchmark, desc=f"Generating images for {benchmark_type}", leave=False): | |
| full_image_path = benchmark_dir / image_path | |
| if full_image_path.exists(): | |
| continue | |
| try: | |
| inference_time = api.generate_image(prompt, full_image_path) | |
| metadata_entry = { | |
| "filepath": str(image_path), | |
| "prompt": prompt, | |
| "inference_time": inference_time | |
| } | |
| existing_metadata[str(image_path)] = metadata_entry | |
| except Exception as e: | |
| print(f"\nError generating image for prompt: {prompt}") | |
| print(f"Error: {str(e)}") | |
| continue | |
| with open(metadata_file, "w") as f: | |
| for entry in existing_metadata.values(): | |
| f.write(json.dumps(entry) + "\n") | |
| def main(): | |
| parser = argparse.ArgumentParser(description="Generate images for specified benchmarks using a given API") | |
| parser.add_argument("api_type", help="Type of API to use for image generation") | |
| parser.add_argument("benchmarks", nargs="+", help="List of benchmark types to run") | |
| args = parser.parse_args() | |
| generate_images(args.api_type, args.benchmarks) | |
| if __name__ == "__main__": | |
| main() | |