Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	File size: 7,606 Bytes
			
			a891a57  | 
								1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193  | 
								# coding: utf-8
"""
Benchmark the inference speed of each module in LivePortrait.
TODO: heavy GPT style, need to refactor
"""
import yaml
import torch
import time
import numpy as np
from src.utils.helper import load_model, concat_feat
from src.config.inference_config import InferenceConfig
def initialize_inputs(batch_size=1):
    """
    Generate random input tensors and move them to GPU
    """
    feature_3d = torch.randn(batch_size, 32, 16, 64, 64).cuda().half()
    kp_source = torch.randn(batch_size, 21, 3).cuda().half()
    kp_driving = torch.randn(batch_size, 21, 3).cuda().half()
    source_image = torch.randn(batch_size, 3, 256, 256).cuda().half()
    generator_input = torch.randn(batch_size, 256, 64, 64).cuda().half()
    eye_close_ratio = torch.randn(batch_size, 3).cuda().half()
    lip_close_ratio = torch.randn(batch_size, 2).cuda().half()
    feat_stitching = concat_feat(kp_source, kp_driving).half()
    feat_eye = concat_feat(kp_source, eye_close_ratio).half()
    feat_lip = concat_feat(kp_source, lip_close_ratio).half()
    inputs = {
        'feature_3d': feature_3d,
        'kp_source': kp_source,
        'kp_driving': kp_driving,
        'source_image': source_image,
        'generator_input': generator_input,
        'feat_stitching': feat_stitching,
        'feat_eye': feat_eye,
        'feat_lip': feat_lip
    }
    return inputs
def load_and_compile_models(cfg, model_config):
    """
    Load and compile models for inference
    """
    appearance_feature_extractor = load_model(cfg.checkpoint_F, model_config, cfg.device_id, 'appearance_feature_extractor')
    motion_extractor = load_model(cfg.checkpoint_M, model_config, cfg.device_id, 'motion_extractor')
    warping_module = load_model(cfg.checkpoint_W, model_config, cfg.device_id, 'warping_module')
    spade_generator = load_model(cfg.checkpoint_G, model_config, cfg.device_id, 'spade_generator')
    stitching_retargeting_module = load_model(cfg.checkpoint_S, model_config, cfg.device_id, 'stitching_retargeting_module')
    models_with_params = [
        ('Appearance Feature Extractor', appearance_feature_extractor),
        ('Motion Extractor', motion_extractor),
        ('Warping Network', warping_module),
        ('SPADE Decoder', spade_generator)
    ]
    compiled_models = {}
    for name, model in models_with_params:
        model = model.half()
        model = torch.compile(model, mode='max-autotune')  # Optimize for inference
        model.eval()  # Switch to evaluation mode
        compiled_models[name] = model
    retargeting_models = ['stitching', 'eye', 'lip']
    for retarget in retargeting_models:
        module = stitching_retargeting_module[retarget].half()
        module = torch.compile(module, mode='max-autotune')  # Optimize for inference
        module.eval()  # Switch to evaluation mode
        stitching_retargeting_module[retarget] = module
    return compiled_models, stitching_retargeting_module
def warm_up_models(compiled_models, stitching_retargeting_module, inputs):
    """
    Warm up models to prepare them for benchmarking
    """
    print("Warm up start!")
    with torch.no_grad():
        for _ in range(10):
            compiled_models['Appearance Feature Extractor'](inputs['source_image'])
            compiled_models['Motion Extractor'](inputs['source_image'])
            compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
            compiled_models['SPADE Decoder'](inputs['generator_input'])  # Adjust input as required
            stitching_retargeting_module['stitching'](inputs['feat_stitching'])
            stitching_retargeting_module['eye'](inputs['feat_eye'])
            stitching_retargeting_module['lip'](inputs['feat_lip'])
    print("Warm up end!")
def measure_inference_times(compiled_models, stitching_retargeting_module, inputs):
    """
    Measure inference times for each model
    """
    times = {name: [] for name in compiled_models.keys()}
    times['Retargeting Models'] = []
    overall_times = []
    with torch.no_grad():
        for _ in range(100):
            torch.cuda.synchronize()
            overall_start = time.time()
            start = time.time()
            compiled_models['Appearance Feature Extractor'](inputs['source_image'])
            torch.cuda.synchronize()
            times['Appearance Feature Extractor'].append(time.time() - start)
            start = time.time()
            compiled_models['Motion Extractor'](inputs['source_image'])
            torch.cuda.synchronize()
            times['Motion Extractor'].append(time.time() - start)
            start = time.time()
            compiled_models['Warping Network'](inputs['feature_3d'], inputs['kp_driving'], inputs['kp_source'])
            torch.cuda.synchronize()
            times['Warping Network'].append(time.time() - start)
            start = time.time()
            compiled_models['SPADE Decoder'](inputs['generator_input'])  # Adjust input as required
            torch.cuda.synchronize()
            times['SPADE Decoder'].append(time.time() - start)
            start = time.time()
            stitching_retargeting_module['stitching'](inputs['feat_stitching'])
            stitching_retargeting_module['eye'](inputs['feat_eye'])
            stitching_retargeting_module['lip'](inputs['feat_lip'])
            torch.cuda.synchronize()
            times['Retargeting Models'].append(time.time() - start)
            overall_times.append(time.time() - overall_start)
    return times, overall_times
def print_benchmark_results(compiled_models, stitching_retargeting_module, retargeting_models, times, overall_times):
    """
    Print benchmark results with average and standard deviation of inference times
    """
    average_times = {name: np.mean(times[name]) * 1000 for name in times.keys()}
    std_times = {name: np.std(times[name]) * 1000 for name in times.keys()}
    for name, model in compiled_models.items():
        num_params = sum(p.numel() for p in model.parameters())
        num_params_in_millions = num_params / 1e6
        print(f"Number of parameters for {name}: {num_params_in_millions:.2f} M")
    for index, retarget in enumerate(retargeting_models):
        num_params = sum(p.numel() for p in stitching_retargeting_module[retarget].parameters())
        num_params_in_millions = num_params / 1e6
        print(f"Number of parameters for part_{index} in Stitching and Retargeting Modules: {num_params_in_millions:.2f} M")
    for name, avg_time in average_times.items():
        std_time = std_times[name]
        print(f"Average inference time for {name} over 100 runs: {avg_time:.2f} ms (std: {std_time:.2f} ms)")
def main():
    """
    Main function to benchmark speed and model parameters
    """
    # Sample input tensors
    inputs = initialize_inputs()
    # Load configuration
    cfg = InferenceConfig(device_id=0)
    model_config_path = cfg.models_config
    with open(model_config_path, 'r') as file:
        model_config = yaml.safe_load(file)
    # Load and compile models
    compiled_models, stitching_retargeting_module = load_and_compile_models(cfg, model_config)
    # Warm up models
    warm_up_models(compiled_models, stitching_retargeting_module, inputs)
    # Measure inference times
    times, overall_times = measure_inference_times(compiled_models, stitching_retargeting_module, inputs)
    # Print benchmark results
    print_benchmark_results(compiled_models, stitching_retargeting_module, ['stitching', 'eye', 'lip'], times, overall_times)
if __name__ == "__main__":
    main()
 |