Spaces:
Sleeping
Sleeping
| import argparse | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Process Reward Optimization.") | |
| # update paths here! | |
| parser.add_argument( | |
| "--cache_dir", | |
| type=str, | |
| help="HF cache directory", | |
| default="/shared-local/aoq951/HF_CACHE/", | |
| ) | |
| parser.add_argument( | |
| "--save_dir", | |
| type=str, | |
| help="Directory to save images", | |
| default="/shared-local/aoq951/ReNO/outputs", | |
| ) | |
| # model and optim | |
| parser.add_argument("--model", type=str, help="Model to use", default="sdxl-turbo") | |
| parser.add_argument("--lr", type=float, help="Learning rate", default=5.0) | |
| parser.add_argument("--n_iters", type=int, help="Number of iterations", default=50) | |
| parser.add_argument( | |
| "--n_inference_steps", type=int, help="Number of iterations", default=1 | |
| ) | |
| parser.add_argument( | |
| "--optim", | |
| choices=["sgd", "adam", "lbfgs"], | |
| default="sgd", | |
| help="Optimizer to be used", | |
| ) | |
| parser.add_argument("--nesterov", default=True, action="store_false") | |
| parser.add_argument( | |
| "--grad_clip", type=float, help="Gradient clipping", default=0.1 | |
| ) | |
| parser.add_argument("--seed", type=int, help="Seed to use", default=0) | |
| # reward losses | |
| parser.add_argument( | |
| "--enable_hps", default=False, action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--hps_weighting", type=float, help="Weighting for HPS", default=5.0 | |
| ) | |
| parser.add_argument( | |
| "--enable_imagereward", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--imagereward_weighting", | |
| type=float, | |
| help="Weighting for ImageReward", | |
| default=1.0, | |
| ) | |
| parser.add_argument( | |
| "--enable_clip", default=False, action="store_true" | |
| ) | |
| parser.add_argument( | |
| "--clip_weighting", type=float, help="Weighting for CLIP", default=0.01 | |
| ) | |
| parser.add_argument( | |
| "--enable_pickscore", | |
| default=False, | |
| action="store_true", | |
| ) | |
| parser.add_argument( | |
| "--pickscore_weighting", | |
| type=float, | |
| help="Weighting for PickScore", | |
| default=0.05, | |
| ) | |
| parser.add_argument( | |
| "--disable_aesthetic", | |
| default=False, | |
| action="store_false", | |
| dest="enable_aesthetic", | |
| ) | |
| parser.add_argument( | |
| "--aesthetic_weighting", | |
| type=float, | |
| help="Weighting for Aesthetic", | |
| default=0.0, | |
| ) | |
| parser.add_argument( | |
| "--disable_reg", default=True, action="store_false", dest="enable_reg" | |
| ) | |
| parser.add_argument( | |
| "--reg_weight", type=float, help="Regularization weight", default=0.01 | |
| ) | |
| # task specific | |
| parser.add_argument( | |
| "--task", | |
| type=str, | |
| help="Task to run", | |
| default="single", | |
| choices=[ | |
| "t2i-compbench", | |
| "single", | |
| "parti-prompts", | |
| "geneval", | |
| "example-prompts", | |
| ], | |
| ) | |
| parser.add_argument( | |
| "--prompt", | |
| type=str, | |
| help="Prompt to run", | |
| default="A red dog and a green cat", | |
| ) | |
| parser.add_argument( | |
| "--benchmark_reward", | |
| help="Reward to benchmark on", | |
| default="total", | |
| choices=["ImageReward", "PickScore", "HPS", "CLIP", "total"], | |
| ) | |
| # general | |
| parser.add_argument("--save_all_images", default=False, action="store_true") | |
| parser.add_argument("--no_optim", default=False, action="store_true") | |
| parser.add_argument("--imageselect", default=False, action="store_true") | |
| parser.add_argument("--memsave", default=False, action="store_true") | |
| parser.add_argument("--dtype", type=str, help="Data type to use", default="float16") | |
| parser.add_argument("--device_id", type=str, help="Device ID to use", default=None) | |
| parser.add_argument( | |
| "--cpu_offloading", | |
| help="Enable CPU offloading", | |
| default=False, | |
| action="store_true", | |
| ) | |
| # optional multi-step model | |
| parser.add_argument("--enable_multi_apply", default=False, action="store_true") | |
| parser.add_argument( | |
| "--multi_step_model", type=str, help="Model to use", default="flux" | |
| ) | |
| args = parser.parse_args() | |
| return args |