Spaces:
Sleeping
Sleeping
| import json | |
| import logging | |
| import os | |
| import blobfile as bf | |
| import torch | |
| import gc | |
| from datasets import load_dataset | |
| from pytorch_lightning import seed_everything | |
| from tqdm import tqdm | |
| from arguments import parse_args | |
| from models import get_model, get_multi_apply_fn | |
| from rewards import get_reward_losses | |
| from training import LatentNoiseTrainer, get_optimizer | |
| def find_and_move_object_to_cpu(): | |
| for obj in gc.get_objects(): | |
| try: | |
| # Check if the object is a PyTorch model | |
| if isinstance(obj, torch.nn.Module): | |
| # Check if any parameter of the model is on CUDA | |
| if any(param.is_cuda for param in obj.parameters()): | |
| print(f"Found PyTorch model on CUDA: {type(obj).__name__}") | |
| # Move the model to CPU | |
| obj.to('cpu') | |
| print(f"Moved {type(obj).__name__} to CPU.") | |
| # Optionally check if buffers are on CUDA | |
| if any(buf.is_cuda for buf in obj.buffers()): | |
| print(f"Found buffer on CUDA in {type(obj).__name__}") | |
| obj.to('cpu') | |
| print(f"Moved buffers of {type(obj).__name__} to CPU.") | |
| except Exception as e: | |
| # Handle any exceptions if obj is not a torch model | |
| pass | |
| def clear_gpu(): | |
| """Clear GPU memory by removing tensors, freeing cache, and moving data to CPU.""" | |
| # List memory usage before clearing | |
| print(f"Memory allocated before clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB") | |
| print(f"Memory reserved before clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB") | |
| # Move any bound tensors back to CPU if needed | |
| if torch.cuda.is_available(): | |
| torch.cuda.empty_cache() | |
| torch.cuda.synchronize() # Ensure that all operations are completed | |
| print("GPU memory cleared.") | |
| print(f"Memory allocated after clearing: {torch.cuda.memory_allocated() / (1024 ** 2)} MB") | |
| print(f"Memory reserved after clearing: {torch.cuda.memory_reserved() / (1024 ** 2)} MB") | |
| def unload_previous_model_if_needed(loaded_model_setup): | |
| # Check if any GPU memory is being used even when loaded_model_setup is None | |
| if loaded_model_setup is None: | |
| if torch.cuda.is_available() and torch.cuda.memory_allocated() > 0: | |
| print("Unknown model or tensors are still loaded on the GPU. Clearing GPU memory.") | |
| # Call the function to find and move object to CPU | |
| find_and_move_object_to_cpu() | |
| return | |
| """Unload the current model from the GPU and free resources if a new model is being loaded.""" | |
| print("Unloading previous model from GPU to free memory.") | |
| """ | |
| previous_model = loaded_model_setup[7] # Assuming pipe is at position [7] in the setup | |
| # If the model is 'hyper-sd', ensure its components are moved to CPU before deletion | |
| if loaded_model_setup[0].model == "hyper-sd": | |
| if previous_model.device == torch.device('cuda'): | |
| if hasattr(previous_model, 'unet'): | |
| print("Moving UNet back to CPU.") | |
| previous_model.unet.to('cpu') # Move unet to CPU | |
| print("Moving entire pipeline back to CPU.") | |
| previous_model.to('cpu') # Move the entire pipeline (pipe) to CPU | |
| # For other models, use a generic 'to' function if available | |
| elif hasattr(previous_model, 'to') and loaded_model_setup[0].model != "flux": | |
| if previous_model.device == torch.device('cuda'): | |
| print("Moving previous model back to CPU.") | |
| previous_model.to('cpu') # Move model to CPU to free GPU memory | |
| # Delete the reference to the model to allow garbage collection | |
| del previous_model | |
| """ | |
| # Call the function to find and move object to CPU | |
| find_and_move_object_to_cpu() | |
| # Clear GPU memory | |
| clear_gpu() # Ensure that this function properly clears memory (e.g., torch.cuda.empty_cache()) | |
| def setup(args, loaded_model_setup=None): | |
| seed_everything(args.seed) | |
| bf.makedirs(f"{args.save_dir}/logs/{args.task}") | |
| # Set up logging and name settings | |
| logger = logging.getLogger() | |
| logger.handlers.clear() # Clear existing handlers | |
| settings = ( | |
| f"{args.model}{'_' + args.prompt if args.task == 't2i-compbench' else ''}" | |
| f"{'_no-optim' if args.no_optim else ''}_{args.seed if args.task != 'geneval' else ''}" | |
| f"_lr{args.lr}_gc{args.grad_clip}_iter{args.n_iters}" | |
| f"_reg{args.reg_weight if args.enable_reg else '0'}" | |
| f"{'_pickscore' + str(args.pickscore_weighting) if args.enable_pickscore else ''}" | |
| f"{'_clip' + str(args.clip_weighting) if args.enable_clip else ''}" | |
| f"{'_hps' + str(args.hps_weighting) if args.enable_hps else ''}" | |
| f"{'_imagereward' + str(args.imagereward_weighting) if args.enable_imagereward else ''}" | |
| f"{'_aesthetic' + str(args.aesthetic_weighting) if args.enable_aesthetic else ''}" | |
| ) | |
| file_stream = open(f"{args.save_dir}/logs/{args.task}/{settings}.txt", "w") | |
| handler = logging.StreamHandler(file_stream) | |
| formatter = logging.Formatter("%(asctime)s - %(message)s") | |
| handler.setFormatter(formatter) | |
| logger.addHandler(handler) | |
| logger.setLevel("INFO") | |
| consoleHandler = logging.StreamHandler() | |
| consoleHandler.setFormatter(formatter) | |
| logger.addHandler(consoleHandler) | |
| logging.info(args) | |
| if args.device_id is not None: | |
| logging.info(f"Using CUDA device {args.device_id}") | |
| os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" | |
| os.environ["CUDA_VISIBLE_DEVICES"] = args.device_id | |
| device = torch.device("cuda") | |
| dtype = torch.float16 if args.dtype == "float16" else torch.float32 | |
| # If args.model is the same as the one in loaded_model_setup, reuse the trainer and pipe | |
| if loaded_model_setup and args.model == loaded_model_setup[0].model: | |
| print(f"Reusing model {args.model} from loaded setup.") | |
| trainer = loaded_model_setup[1] # Trainer is at position 1 in loaded_model_setup | |
| # Update trainer with the new arguments | |
| trainer.n_iters = args.n_iters | |
| trainer.n_inference_steps = args.n_inference_steps | |
| trainer.seed = args.seed | |
| trainer.save_all_images = args.save_all_images | |
| trainer.no_optim = args.no_optim | |
| trainer.regularize = args.enable_reg | |
| trainer.regularization_weight = args.reg_weight | |
| trainer.grad_clip = args.grad_clip | |
| trainer.log_metrics = args.task == "single" or not args.no_optim | |
| trainer.imageselect = args.imageselect | |
| # Get latents (this step is still required) | |
| if args.model == "flux": | |
| shape = (1, 16 * 64, 64) | |
| elif args.model != "pixart": | |
| height = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor | |
| width = trainer.model.unet.config.sample_size * trainer.model.vae_scale_factor | |
| shape = ( | |
| 1, | |
| trainer.model.unet.in_channels, | |
| height // trainer.model.vae_scale_factor, | |
| width // trainer.model.vae_scale_factor, | |
| ) | |
| else: | |
| height = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor | |
| width = trainer.model.transformer.config.sample_size * trainer.model.vae_scale_factor | |
| shape = ( | |
| 1, | |
| trainer.model.transformer.config.in_channels, | |
| height // trainer.model.vae_scale_factor, | |
| width // trainer.model.vae_scale_factor, | |
| ) | |
| pipe = loaded_model_setup[7] | |
| enable_grad = not args.no_optim | |
| return args, trainer, device, dtype, shape, enable_grad, settings, pipe | |
| # Unload previous model and clear GPU resources | |
| unload_previous_model_if_needed(loaded_model_setup) | |
| # Proceed with full model loading if args.model is different | |
| print(f"Loading new model: {args.model}") | |
| # Get reward losses | |
| reward_losses = get_reward_losses(args, dtype, device, args.cache_dir) | |
| # Get model and noise trainer | |
| pipe = get_model( | |
| args.model, dtype, device, args.cache_dir, args.memsave, args.cpu_offloading | |
| ) | |
| # Final memory cleanup after model loading | |
| torch.cuda.empty_cache() | |
| trainer = LatentNoiseTrainer( | |
| reward_losses=reward_losses, | |
| model=pipe, | |
| n_iters=args.n_iters, | |
| n_inference_steps=args.n_inference_steps, | |
| seed=args.seed, | |
| save_all_images=args.save_all_images, | |
| device=device if not args.cpu_offloading else 'cpu', # Use CPU if offloading is enabled | |
| no_optim=args.no_optim, | |
| regularize=args.enable_reg, | |
| regularization_weight=args.reg_weight, | |
| grad_clip=args.grad_clip, | |
| log_metrics=args.task == "single" or not args.no_optim, | |
| imageselect=args.imageselect, | |
| ) | |
| # Create latents | |
| if args.model == "flux": | |
| shape = (1, 16 * 64, 64) | |
| elif args.model != "pixart": | |
| height = pipe.unet.config.sample_size * pipe.vae_scale_factor | |
| width = pipe.unet.config.sample_size * pipe.vae_scale_factor | |
| shape = ( | |
| 1, | |
| pipe.unet.in_channels, | |
| height // pipe.vae_scale_factor, | |
| width // pipe.vae_scale_factor, | |
| ) | |
| else: | |
| height = pipe.transformer.config.sample_size * pipe.vae_scale_factor | |
| width = pipe.transformer.config.sample_size * pipe.vae_scale_factor | |
| shape = ( | |
| 1, | |
| pipe.transformer.config.in_channels, | |
| height // pipe.vae_scale_factor, | |
| width // pipe.vae_scale_factor, | |
| ) | |
| enable_grad = not args.no_optim | |
| # Final memory cleanup | |
| torch.cuda.empty_cache() # Free up cached memory | |
| return args, trainer, device, dtype, shape, enable_grad, settings, pipe | |
| def execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe, progress_callback=None): | |
| if args.task == "single": | |
| # Attempt to move the model to GPU if model is not Flux | |
| if args.model != "flux": | |
| if args.model == "hyper-sd": | |
| if pipe.device != torch.device('cuda'): | |
| # Transfer UNet to GPU | |
| pipe.unet = pipe.unet.to(device, dtype) | |
| # Transfer the whole pipe to GPU, if required (optional) | |
| pipe = pipe.to(device, dtype) | |
| # upcast vae | |
| pipe.vae = pipe.vae.to(dtype=torch.float32) | |
| elif args.model == "pixart": | |
| if pipe.device != torch.device('cuda'): | |
| pipe.to(device) | |
| else: | |
| if pipe.device != torch.device('cuda'): | |
| pipe.to(device, dtype) | |
| else: | |
| if args.cpu_offloading: | |
| pipe.enable_sequential_cpu_offload() | |
| pipe.vae.enable_slicing() | |
| pipe.vae.enable_tiling() | |
| pipe.to(torch.float16) # casting here instead of in the pipeline constructor because doing so in the constructor loads all models into CPU memory at once | |
| if args.enable_multi_apply: | |
| multi_apply_fn = get_multi_apply_fn( | |
| model_type=args.multi_step_model, | |
| seed=args.seed, | |
| pipe=pipe, | |
| cache_dir=args.cache_dir, | |
| device=device if not args.cpu_offloading else 'cpu', | |
| dtype=dtype, | |
| ) | |
| else: | |
| multi_apply_fn = None | |
| torch.cuda.empty_cache() # Free up cached memory | |
| print(f"PIPE:{pipe}") | |
| init_latents = torch.randn(shape, device=device, dtype=dtype) | |
| latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) | |
| optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) | |
| save_dir = f"{args.save_dir}/{args.task}/{settings}/{args.prompt[:150]}" | |
| os.makedirs(f"{save_dir}", exist_ok=True) | |
| init_image, best_image, total_init_rewards, total_best_rewards = trainer.train( | |
| latents, args.prompt, optimizer, save_dir, multi_apply_fn, progress_callback=progress_callback | |
| ) | |
| best_image.save(f"{save_dir}/best_image.png") | |
| #init_image.save(f"{save_dir}/init_image.png") | |
| clear_gpu() | |
| elif args.task == "example-prompts": | |
| fo = open("assets/example_prompts.txt", "r") | |
| prompts = fo.readlines() | |
| fo.close() | |
| for i, prompt in tqdm(enumerate(prompts)): | |
| # Get new latents and optimizer | |
| init_latents = torch.randn(shape, device=device, dtype=dtype) | |
| latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) | |
| optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) | |
| prompt = prompt.strip() | |
| name = f"{i:03d}_{prompt[:150]}.png" | |
| save_dir = f"{args.save_dir}/{args.task}/{settings}/{name}" | |
| os.makedirs(save_dir, exist_ok=True) | |
| init_image, best_image, init_rewards, best_rewards = trainer.train( | |
| latents, prompt, optimizer, save_dir, multi_apply_fn | |
| ) | |
| if i == 0: | |
| total_best_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| total_init_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| for k in best_rewards.keys(): | |
| total_best_rewards[k] += best_rewards[k] | |
| total_init_rewards[k] += init_rewards[k] | |
| best_image.save(f"{save_dir}/best_image.png") | |
| init_image.save(f"{save_dir}/init_image.png") | |
| logging.info(f"Initial rewards: {init_rewards}") | |
| logging.info(f"Best rewards: {best_rewards}") | |
| for k in total_best_rewards.keys(): | |
| total_best_rewards[k] /= len(prompts) | |
| total_init_rewards[k] /= len(prompts) | |
| # save results to directory | |
| with open(f"{args.save_dir}/example-prompts/{settings}/results.txt", "w") as f: | |
| f.write( | |
| f"Mean initial all rewards: {total_init_rewards}\n" | |
| f"Mean best all rewards: {total_best_rewards}\n" | |
| ) | |
| elif args.task == "t2i-compbench": | |
| prompt_list_file = f"../T2I-CompBench/examples/dataset/{args.prompt}.txt" | |
| fo = open(prompt_list_file, "r") | |
| prompts = fo.readlines() | |
| fo.close() | |
| os.makedirs(f"{args.save_dir}/{args.task}/{settings}/samples", exist_ok=True) | |
| for i, prompt in tqdm(enumerate(prompts)): | |
| # Get new latents and optimizer | |
| init_latents = torch.randn(shape, device=device, dtype=dtype) | |
| latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) | |
| optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) | |
| prompt = prompt.strip() | |
| init_image, best_image, init_rewards, best_rewards = trainer.train( | |
| latents, prompt, optimizer, None, multi_apply_fn | |
| ) | |
| if i == 0: | |
| total_best_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| total_init_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| for k in best_rewards.keys(): | |
| total_best_rewards[k] += best_rewards[k] | |
| total_init_rewards[k] += init_rewards[k] | |
| name = f"{prompt}_{i:06d}.png" | |
| best_image.save(f"{args.save_dir}/{args.task}/{settings}/samples/{name}") | |
| logging.info(f"Initial rewards: {init_rewards}") | |
| logging.info(f"Best rewards: {best_rewards}") | |
| for k in total_best_rewards.keys(): | |
| total_best_rewards[k] /= len(prompts) | |
| total_init_rewards[k] /= len(prompts) | |
| elif args.task == "parti-prompts": | |
| parti_dataset = load_dataset("nateraw/parti-prompts", split="train") | |
| total_reward_diff = 0.0 | |
| total_best_reward = 0.0 | |
| total_init_reward = 0.0 | |
| total_improved_samples = 0 | |
| for index, sample in enumerate(parti_dataset): | |
| os.makedirs( | |
| f"{args.save_dir}/{args.task}/{settings}/{index}", exist_ok=True | |
| ) | |
| prompt = sample["Prompt"] | |
| init_image, best_image, init_rewards, best_rewards = trainer.train( | |
| latents, prompt, optimizer, multi_apply_fn | |
| ) | |
| best_image.save( | |
| f"{args.save_dir}/{args.task}/{settings}/{index}/best_image.png" | |
| ) | |
| open( | |
| f"{args.save_dir}/{args.task}/{settings}/{index}/prompt.txt", "w" | |
| ).write( | |
| f"{prompt} \n Initial Rewards: {init_rewards} \n Best Rewards: {best_rewards}" | |
| ) | |
| logging.info(f"Initial rewards: {init_rewards}") | |
| logging.info(f"Best rewards: {best_rewards}") | |
| initial_reward = init_rewards[args.benchmark_reward] | |
| best_reward = best_rewards[args.benchmark_reward] | |
| total_reward_diff += best_reward - initial_reward | |
| total_best_reward += best_reward | |
| total_init_reward += initial_reward | |
| if best_reward < initial_reward: | |
| total_improved_samples += 1 | |
| if i == 0: | |
| total_best_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| total_init_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| for k in best_rewards.keys(): | |
| total_best_rewards[k] += best_rewards[k] | |
| total_init_rewards[k] += init_rewards[k] | |
| # Get new latents and optimizer | |
| init_latents = torch.randn(shape, device=device, dtype=dtype) | |
| latents = torch.nn.Parameter(init_latents, requires_grad=enable_grad) | |
| optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) | |
| improvement_percentage = total_improved_samples / parti_dataset.num_rows | |
| mean_best_reward = total_best_reward / parti_dataset.num_rows | |
| mean_init_reward = total_init_reward / parti_dataset.num_rows | |
| mean_reward_diff = total_reward_diff / parti_dataset.num_rows | |
| logging.info( | |
| f"Improvement percentage: {improvement_percentage:.4f}, " | |
| f"mean initial reward: {mean_init_reward:.4f}, " | |
| f"mean best reward: {mean_best_reward:.4f}, " | |
| f"mean reward diff: {mean_reward_diff:.4f}" | |
| ) | |
| for k in total_best_rewards.keys(): | |
| total_best_rewards[k] /= len(parti_dataset) | |
| total_init_rewards[k] /= len(parti_dataset) | |
| # save results | |
| os.makedirs(f"{args.save_dir}/parti-prompts/{settings}", exist_ok=True) | |
| with open(f"{args.save_dir}/parti-prompts/{settings}/results.txt", "w") as f: | |
| f.write( | |
| f"Mean improvement: {improvement_percentage:.4f}, " | |
| f"mean initial reward: {mean_init_reward:.4f}, " | |
| f"mean best reward: {mean_best_reward:.4f}, " | |
| f"mean reward diff: {mean_reward_diff:.4f}\n" | |
| f"Mean initial all rewards: {total_init_rewards}\n" | |
| f"Mean best all rewards: {total_best_rewards}" | |
| ) | |
| elif args.task == "geneval": | |
| prompt_list_file = "../geneval/prompts/evaluation_metadata.jsonl" | |
| with open(prompt_list_file) as fp: | |
| metadatas = [json.loads(line) for line in fp] | |
| outdir = f"{args.save_dir}/{args.task}/{settings}" | |
| for index, metadata in enumerate(metadatas): | |
| # Get new latents and optimizer | |
| init_latents = torch.randn(shape, device=device, dtype=dtype) | |
| latents = torch.nn.Parameter(init_latents, requires_grad=True) | |
| optimizer = get_optimizer(args.optim, latents, args.lr, args.nesterov) | |
| prompt = metadata["prompt"] | |
| init_image, best_image, init_rewards, best_rewards = trainer.train( | |
| latents, prompt, optimizer, None, multi_apply_fn | |
| ) | |
| logging.info(f"Initial rewards: {init_rewards}") | |
| logging.info(f"Best rewards: {best_rewards}") | |
| outpath = f"{outdir}/{index:0>5}" | |
| os.makedirs(f"{outpath}/samples", exist_ok=True) | |
| with open(f"{outpath}/metadata.jsonl", "w") as fp: | |
| json.dump(metadata, fp) | |
| best_image.save(f"{outpath}/samples/{args.seed:05}.png") | |
| if i == 0: | |
| total_best_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| total_init_rewards = {k: 0.0 for k in best_rewards.keys()} | |
| for k in best_rewards.keys(): | |
| total_best_rewards[k] += best_rewards[k] | |
| total_init_rewards[k] += init_rewards[k] | |
| for k in total_best_rewards.keys(): | |
| total_best_rewards[k] /= len(parti_dataset) | |
| total_init_rewards[k] /= len(parti_dataset) | |
| else: | |
| raise ValueError(f"Unknown task {args.task}") | |
| # log total rewards | |
| logging.info(f"Mean initial rewards: {total_init_rewards}") | |
| logging.info(f"Mean best rewards: {total_best_rewards}") | |
| def main(): | |
| args = parse_args() | |
| args, trainer, device, dtype, shape, enable_grad, settings, pipe = setup(args, loaded_model_setup=None) | |
| execute_task(args, trainer, device, dtype, shape, enable_grad, settings, pipe) | |
| if __name__ == "__main__": | |
| main() |