File size: 5,310 Bytes
7efee70
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
import argparse


def parse_args():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path", type=str, 
        default='', 
        help="Path to config file"
    )
    parser.add_argument(
        "--optimal_transport_method",
        type=str,
        default="exact",
        help="Use optimal transport in CFM training",
    )
    parser.add_argument(
        "--split_ratios",
        nargs=2,
        type=float,
        default=[0.9, 0.1],
        help="Split ratios for training/validation data in CFM training",
    )
    parser.add_argument(
        "--accelerator", type=str, default="cpu", help="Training accelerator"
    )
    parser.add_argument("--date", type=str)
    parser.add_argument("--seed", default=2, type=int)
    parser.add_argument("--device", default="cuda:1", type=str)
    parser.add_argument("--molecule", default="aldp", type=str)
    parser.add_argument('--wandb', action='store_true', default=False)
    parser.add_argument('--unseen', action='store_true', default=False)
    parser.add_argument('--run_name', default=None, type=str)
    # Logger Config
    parser.add_argument("--save_dir", default="", type=str)
    parser.add_argument("--root_dir", default="", type=str)
    # Policy Config
    parser.add_argument("--bias", default="force", type=str)
    # Sampling Config
    parser.add_argument("--start_state", default="c5", type=str)
    parser.add_argument("--end_state", default="c7ax", type=str)
    parser.add_argument("--num_steps", default=100, type=int)
    #parser.add_argument("--timestep", default=1, type=float)
    parser.add_argument("--sigma", default=0.1, type=float)
    parser.add_argument("--num_samples", default=16, type=int)
    parser.add_argument("--temperature", default=300, type=float)
    parser.add_argument("--friction", default=2.0, type=float)
    parser.add_argument("--rbf", action='store_true', default=False)
    parser.add_argument("--use_delta_to_target", action='store_true', default=False)
    parser.add_argument("--use_gnn", action='store_true', default=False)
    # Training Config
    parser.add_argument("--start_temperature", default=600, type=float)
    parser.add_argument("--end_temperature", default=300, type=float)
    parser.add_argument("--num_rollouts", default=1000, type=int)
    parser.add_argument("--trains_per_rollout", default=1000, type=int)
    parser.add_argument("--log_z_lr", default=1e-3, type=float)
    parser.add_argument("--policy_lr", default=1e-4, type=float)
    parser.add_argument("--batch_size", default=64, type=int)
    parser.add_argument("--buffer_size", default=1000, type=int)
    parser.add_argument("--max_grad_norm", default=1, type=int)
    parser.add_argument("--control_variate", default="global", type=str)
    parser.add_argument("--self_normalize", action='store_true', default=False)
    # path objective
    parser.add_argument("--objective", default="ce", type=str)
    parser.add_argument("--vel_conditioned", action='store_true', default=False)
    parser.add_argument("--dir_only", action='store_true', default=False)
    
    # cell experiment
    parser.add_argument("--num_particles", default=16, type=int)
    #parser.add_argument("--gene_dim", default=50, type=int)
    parser.add_argument("--kT", type=float, default=0.0)
    ######### DATASETS #################
    parser = datasets_parser(parser)

    ######### METRICS ##################
    parser = metric_parser(parser)

    return parser.parse_args()


def datasets_parser(parser):
    parser.add_argument("--dim", type=int, default=50, help="Dimension of data")

    parser.add_argument(
        "--data_type",
        type=str,
        default="tahoe",
        help="Type of data, now wither scrna or one of toys",
    )
    parser.add_argument(
        "--data_name",
        type=str,
        default="tahoe",
        help="Path to the dataset",
    )
    return parser


def metric_parser(parser):
    parser.add_argument(
        "--n_centers",
        type=int,
        default=300,
        help="Number of centers for RBF network",
    )
    parser.add_argument(
        "--kappa",
        type=float,
        default=1.5,
        help="Kappa parameter for RBF network",
    )
    parser.add_argument(
        "--rho",
        type=float,
        default=-2.75,
        help="Rho parameter in Riemanian Velocity Calculation",
    )
    parser.add_argument(
        "--velocity_metric",
        type=str,
        default="rbf",
        help="Metric for velocity calculation",
    )
    parser.add_argument(
        "--gamma",
        nargs="+",
        type=float,
        default=0.2,
        help="Gamma parameter in Riemanian Velocity Calculation",
    )
    parser.add_argument(
        "--metric_epochs",
        type=int,
        default=200,
        help="Number of epochs for metric learning",
    )
    parser.add_argument(
        "--metric_patience",
        type=int,
        default=25,
        help="Patience for metric learning",
    )
    parser.add_argument(
        "--metric_lr",
        type=float,
        default=1e-2,
        help="Learning rate for metric learning",
    )
    parser.add_argument(
        "--alpha_metric",
        type=float,
        default=1.0,
        help="Alpha parameter for metric learning",
    )
    return parser