Spaces:
Running
on
L40S
Running
on
L40S
| from thop import profile | |
| from thop import clever_format | |
| import torch | |
| from tqdm import tqdm | |
| import time | |
| import sys | |
| sys.path.append('./') | |
| def analyze_model(model, inputs): | |
| # model size | |
| num_trainable_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad) | |
| print("Num trainable parameters: {} M".format(num_trainable_parameters/1000./1000.)) | |
| # computation cost | |
| with torch.no_grad(): | |
| model.eval() | |
| macs, params = profile(model, inputs=inputs) | |
| macs, params = clever_format([macs, params], "%.3f") | |
| print("Macs: {}, Params: {}".format(macs, params)) | |
| run_times = 50 | |
| # eval forward 100 times | |
| with torch.no_grad(): | |
| model = model.eval().to('cuda') | |
| inputs = [i.to('cuda') if isinstance(i, torch.Tensor) else i for i in inputs] | |
| model.init_device_dtype(inputs[0].device, inputs[0].dtype) | |
| st = time.time() | |
| for i in tqdm(range(run_times)): | |
| _ = model(*inputs) | |
| et = time.time() | |
| print("Eval forward : {:.03f} secs/per iter".format((et-st)/float(run_times))) | |
| # train backward 100 times | |
| model = model.train().to('cuda') | |
| inputs = [i.to('cuda') if isinstance(i, torch.Tensor) else i for i in inputs] | |
| model.init_device_dtype(inputs[0].device, inputs[0].dtype) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=1e-3) | |
| optimizer.zero_grad() | |
| st = time.time() | |
| for i in tqdm(range(run_times)): | |
| inputs = [torch.rand_like(i) if isinstance(i, torch.cuda.FloatTensor) else i for i in inputs] | |
| out = model(*inputs) | |
| optimizer.zero_grad() | |
| out.mean().backward() | |
| optimizer.step() | |
| et = time.time() | |
| print("Train forward : {:.03f} secs/per iter".format((et-st)/float(run_times))) | |
| def fetch_model_v3_transformer(): | |
| # num params: 326M | |
| # macs (uncorrect): 261G/iter | |
| # infer: 0.32s/iter | |
| # train: 2.54s/iter | |
| from models_transformercond_winorm_ch16_everything_512 import PromptCondAudioDiffusion | |
| model = PromptCondAudioDiffusion( \ | |
| "configs/scheduler/stable_diffusion_2.1_largenoise.json", \ | |
| None, \ | |
| "configs/models/transformer2D.json" | |
| ) | |
| inputs = [ | |
| torch.rand(1,16,1024*3//8,32), | |
| torch.rand(1,7,512), | |
| torch.tensor([1,]), | |
| torch.tensor([0,]), | |
| False, | |
| ] | |
| return model, inputs | |
| def fetch_model_v3_unet(): | |
| # num params: 310M | |
| # infer: 0.10s/iter | |
| # train: 0.70s/iter | |
| from models_musicldm_winorm_ch16_everything_sepnorm import PromptCondAudioDiffusion | |
| model = PromptCondAudioDiffusion( \ | |
| "configs/scheduler/stable_diffusion_2.1_largenoise.json", \ | |
| None, \ | |
| "configs/diffusion_clapcond_model_config_ch16_everything.json" | |
| ) | |
| inputs = [ | |
| torch.rand(1,16,1024*3//8,32), | |
| torch.rand(1,7,512), | |
| torch.tensor([1,]), | |
| torch.tensor([0,]), | |
| False, | |
| ] | |
| return model, inputs | |
| if __name__=="__main__": | |
| model, inputs = fetch_model_v3_transformer() | |
| # model, inputs = fetch_model_v3_unet() | |
| analyze_model(model, inputs) | |