File size: 5,310 Bytes
7efee70 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
import argparse
def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path", type=str,
default='',
help="Path to config file"
)
parser.add_argument(
"--optimal_transport_method",
type=str,
default="exact",
help="Use optimal transport in CFM training",
)
parser.add_argument(
"--split_ratios",
nargs=2,
type=float,
default=[0.9, 0.1],
help="Split ratios for training/validation data in CFM training",
)
parser.add_argument(
"--accelerator", type=str, default="cpu", help="Training accelerator"
)
parser.add_argument("--date", type=str)
parser.add_argument("--seed", default=2, type=int)
parser.add_argument("--device", default="cuda:1", type=str)
parser.add_argument("--molecule", default="aldp", type=str)
parser.add_argument('--wandb', action='store_true', default=False)
parser.add_argument('--unseen', action='store_true', default=False)
parser.add_argument('--run_name', default=None, type=str)
# Logger Config
parser.add_argument("--save_dir", default="", type=str)
parser.add_argument("--root_dir", default="", type=str)
# Policy Config
parser.add_argument("--bias", default="force", type=str)
# Sampling Config
parser.add_argument("--start_state", default="c5", type=str)
parser.add_argument("--end_state", default="c7ax", type=str)
parser.add_argument("--num_steps", default=100, type=int)
#parser.add_argument("--timestep", default=1, type=float)
parser.add_argument("--sigma", default=0.1, type=float)
parser.add_argument("--num_samples", default=16, type=int)
parser.add_argument("--temperature", default=300, type=float)
parser.add_argument("--friction", default=2.0, type=float)
parser.add_argument("--rbf", action='store_true', default=False)
parser.add_argument("--use_delta_to_target", action='store_true', default=False)
parser.add_argument("--use_gnn", action='store_true', default=False)
# Training Config
parser.add_argument("--start_temperature", default=600, type=float)
parser.add_argument("--end_temperature", default=300, type=float)
parser.add_argument("--num_rollouts", default=1000, type=int)
parser.add_argument("--trains_per_rollout", default=1000, type=int)
parser.add_argument("--log_z_lr", default=1e-3, type=float)
parser.add_argument("--policy_lr", default=1e-4, type=float)
parser.add_argument("--batch_size", default=64, type=int)
parser.add_argument("--buffer_size", default=1000, type=int)
parser.add_argument("--max_grad_norm", default=1, type=int)
parser.add_argument("--control_variate", default="global", type=str)
parser.add_argument("--self_normalize", action='store_true', default=False)
# path objective
parser.add_argument("--objective", default="ce", type=str)
parser.add_argument("--vel_conditioned", action='store_true', default=False)
parser.add_argument("--dir_only", action='store_true', default=False)
# cell experiment
parser.add_argument("--num_particles", default=16, type=int)
#parser.add_argument("--gene_dim", default=50, type=int)
parser.add_argument("--kT", type=float, default=0.0)
######### DATASETS #################
parser = datasets_parser(parser)
######### METRICS ##################
parser = metric_parser(parser)
return parser.parse_args()
def datasets_parser(parser):
parser.add_argument("--dim", type=int, default=50, help="Dimension of data")
parser.add_argument(
"--data_type",
type=str,
default="tahoe",
help="Type of data, now wither scrna or one of toys",
)
parser.add_argument(
"--data_name",
type=str,
default="tahoe",
help="Path to the dataset",
)
return parser
def metric_parser(parser):
parser.add_argument(
"--n_centers",
type=int,
default=300,
help="Number of centers for RBF network",
)
parser.add_argument(
"--kappa",
type=float,
default=1.5,
help="Kappa parameter for RBF network",
)
parser.add_argument(
"--rho",
type=float,
default=-2.75,
help="Rho parameter in Riemanian Velocity Calculation",
)
parser.add_argument(
"--velocity_metric",
type=str,
default="rbf",
help="Metric for velocity calculation",
)
parser.add_argument(
"--gamma",
nargs="+",
type=float,
default=0.2,
help="Gamma parameter in Riemanian Velocity Calculation",
)
parser.add_argument(
"--metric_epochs",
type=int,
default=200,
help="Number of epochs for metric learning",
)
parser.add_argument(
"--metric_patience",
type=int,
default=25,
help="Patience for metric learning",
)
parser.add_argument(
"--metric_lr",
type=float,
default=1e-2,
help="Learning rate for metric learning",
)
parser.add_argument(
"--alpha_metric",
type=float,
default=1.0,
help="Alpha parameter for metric learning",
)
return parser
|