Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Upload folder using huggingface_hub
Browse files- transport/__init__.py +70 -0
- transport/__pycache__/__init__.cpython-310.pyc +0 -0
- transport/__pycache__/dpm_solver.cpython-310.pyc +0 -0
- transport/__pycache__/integrators.cpython-310.pyc +0 -0
- transport/__pycache__/path.cpython-310.pyc +0 -0
- transport/__pycache__/transport.cpython-310.pyc +0 -0
- transport/__pycache__/utils.cpython-310.pyc +0 -0
- transport/dpm_solver.py +1386 -0
- transport/integrators.py +122 -0
- transport/path.py +201 -0
- transport/transport.py +490 -0
- transport/utils.py +56 -0
    	
        transport/__init__.py
    ADDED
    
    | @@ -0,0 +1,70 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .transport import ModelType, PathType, Sampler, Transport, WeightType
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            def create_transport(
         | 
| 5 | 
            +
                path_type="Linear",
         | 
| 6 | 
            +
                prediction="velocity",
         | 
| 7 | 
            +
                loss_weight=None,
         | 
| 8 | 
            +
                train_eps=None,
         | 
| 9 | 
            +
                sample_eps=None,
         | 
| 10 | 
            +
                snr_type="uniform",
         | 
| 11 | 
            +
                do_shift=True,
         | 
| 12 | 
            +
                seq_len=1024,  # corresponding to 512x512
         | 
| 13 | 
            +
            ):
         | 
| 14 | 
            +
                """function for creating Transport object
         | 
| 15 | 
            +
                **Note**: model prediction defaults to velocity
         | 
| 16 | 
            +
                Args:
         | 
| 17 | 
            +
                - path_type: type of path to use; default to linear
         | 
| 18 | 
            +
                - learn_score: set model prediction to score
         | 
| 19 | 
            +
                - learn_noise: set model prediction to noise
         | 
| 20 | 
            +
                - velocity_weighted: weight loss by velocity weight
         | 
| 21 | 
            +
                - likelihood_weighted: weight loss by likelihood weight
         | 
| 22 | 
            +
                - train_eps: small epsilon for avoiding instability during training
         | 
| 23 | 
            +
                - sample_eps: small epsilon for avoiding instability during sampling
         | 
| 24 | 
            +
                """
         | 
| 25 | 
            +
             | 
| 26 | 
            +
                if prediction == "noise":
         | 
| 27 | 
            +
                    model_type = ModelType.NOISE
         | 
| 28 | 
            +
                elif prediction == "score":
         | 
| 29 | 
            +
                    model_type = ModelType.SCORE
         | 
| 30 | 
            +
                else:
         | 
| 31 | 
            +
                    model_type = ModelType.VELOCITY
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                if loss_weight == "velocity":
         | 
| 34 | 
            +
                    loss_type = WeightType.VELOCITY
         | 
| 35 | 
            +
                elif loss_weight == "likelihood":
         | 
| 36 | 
            +
                    loss_type = WeightType.LIKELIHOOD
         | 
| 37 | 
            +
                else:
         | 
| 38 | 
            +
                    loss_type = WeightType.NONE
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                path_choice = {
         | 
| 41 | 
            +
                    "Linear": PathType.LINEAR,
         | 
| 42 | 
            +
                    "GVP": PathType.GVP,
         | 
| 43 | 
            +
                    "VP": PathType.VP,
         | 
| 44 | 
            +
                }
         | 
| 45 | 
            +
             | 
| 46 | 
            +
                path_type = path_choice[path_type]
         | 
| 47 | 
            +
             | 
| 48 | 
            +
                if path_type in [PathType.VP]:
         | 
| 49 | 
            +
                    train_eps = 1e-5 if train_eps is None else train_eps
         | 
| 50 | 
            +
                    sample_eps = 1e-3 if train_eps is None else sample_eps
         | 
| 51 | 
            +
                elif path_type in [PathType.GVP, PathType.LINEAR] and model_type != ModelType.VELOCITY:
         | 
| 52 | 
            +
                    train_eps = 1e-3 if train_eps is None else train_eps
         | 
| 53 | 
            +
                    sample_eps = 1e-3 if train_eps is None else sample_eps
         | 
| 54 | 
            +
                else:  # velocity & [GVP, LINEAR] is stable everywhere
         | 
| 55 | 
            +
                    train_eps = 0
         | 
| 56 | 
            +
                    sample_eps = 0
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                # create flow state
         | 
| 59 | 
            +
                state = Transport(
         | 
| 60 | 
            +
                    model_type=model_type,
         | 
| 61 | 
            +
                    path_type=path_type,
         | 
| 62 | 
            +
                    loss_type=loss_type,
         | 
| 63 | 
            +
                    train_eps=train_eps,
         | 
| 64 | 
            +
                    sample_eps=sample_eps,
         | 
| 65 | 
            +
                    snr_type=snr_type,
         | 
| 66 | 
            +
                    do_shift=do_shift,
         | 
| 67 | 
            +
                    seq_len=seq_len,
         | 
| 68 | 
            +
                )
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                return state
         | 
    	
        transport/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (1.63 kB). View file | 
|  | 
    	
        transport/__pycache__/dpm_solver.cpython-310.pyc
    ADDED
    
    | Binary file (50.9 kB). View file | 
|  | 
    	
        transport/__pycache__/integrators.cpython-310.pyc
    ADDED
    
    | Binary file (3.78 kB). View file | 
|  | 
    	
        transport/__pycache__/path.cpython-310.pyc
    ADDED
    
    | Binary file (8.35 kB). View file | 
|  | 
    	
        transport/__pycache__/transport.cpython-310.pyc
    ADDED
    
    | Binary file (15.6 kB). View file | 
|  | 
    	
        transport/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (2.26 kB). View file | 
|  | 
    	
        transport/dpm_solver.py
    ADDED
    
    | @@ -0,0 +1,1386 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            # Copyright 2024 NVIDIA CORPORATION & AFFILIATES
         | 
| 2 | 
            +
            #
         | 
| 3 | 
            +
            # Licensed under the Apache License, Version 2.0 (the "License");
         | 
| 4 | 
            +
            # you may not use this file except in compliance with the License.
         | 
| 5 | 
            +
            # You may obtain a copy of the License at
         | 
| 6 | 
            +
            #
         | 
| 7 | 
            +
            #     http://www.apache.org/licenses/LICENSE-2.0
         | 
| 8 | 
            +
            #
         | 
| 9 | 
            +
            # Unless required by applicable law or agreed to in writing, software
         | 
| 10 | 
            +
            # distributed under the License is distributed on an "AS IS" BASIS,
         | 
| 11 | 
            +
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         | 
| 12 | 
            +
            # See the License for the specific language governing permissions and
         | 
| 13 | 
            +
            # limitations under the License.
         | 
| 14 | 
            +
            #
         | 
| 15 | 
            +
            # SPDX-License-Identifier: Apache-2.0
         | 
| 16 | 
            +
             | 
| 17 | 
            +
            # This file is modified from https://github.com/PixArt-alpha/PixArt-sigma
         | 
| 18 | 
            +
            import os
         | 
| 19 | 
            +
             | 
| 20 | 
            +
            import torch
         | 
| 21 | 
            +
            from tqdm import tqdm
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class NoiseScheduleFlow:
         | 
| 25 | 
            +
                def __init__(
         | 
| 26 | 
            +
                    self,
         | 
| 27 | 
            +
                    schedule="discrete_flow",
         | 
| 28 | 
            +
                ):
         | 
| 29 | 
            +
                    """Create a wrapper class for the forward SDE (EDM type)."""
         | 
| 30 | 
            +
                    self.T = 1
         | 
| 31 | 
            +
                    self.t0 = 0.001
         | 
| 32 | 
            +
                    self.schedule = schedule  # ['continuous', 'discrete_flow']
         | 
| 33 | 
            +
                    self.total_N = 1000
         | 
| 34 | 
            +
             | 
| 35 | 
            +
                def marginal_log_mean_coeff(self, t):
         | 
| 36 | 
            +
                    """
         | 
| 37 | 
            +
                    Compute log(alpha_t) of a given continuous-time label t in [0, T].
         | 
| 38 | 
            +
                    """
         | 
| 39 | 
            +
                    return torch.log(self.marginal_alpha(t))
         | 
| 40 | 
            +
             | 
| 41 | 
            +
                def marginal_alpha(self, t):
         | 
| 42 | 
            +
                    """
         | 
| 43 | 
            +
                    Compute alpha_t of a given continuous-time label t in [0, T].
         | 
| 44 | 
            +
                    """
         | 
| 45 | 
            +
                    return 1 - t
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                @staticmethod
         | 
| 48 | 
            +
                def marginal_std(t):
         | 
| 49 | 
            +
                    """
         | 
| 50 | 
            +
                    Compute sigma_t of a given continuous-time label t in [0, T].
         | 
| 51 | 
            +
                    """
         | 
| 52 | 
            +
                    return t
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                def marginal_lambda(self, t):
         | 
| 55 | 
            +
                    """
         | 
| 56 | 
            +
                    Compute lambda_t = log(alpha_t) - log(sigma_t) of a given continuous-time label t in [0, T].
         | 
| 57 | 
            +
                    """
         | 
| 58 | 
            +
                    log_mean_coeff = self.marginal_log_mean_coeff(t)
         | 
| 59 | 
            +
                    log_std = torch.log(self.marginal_std(t))
         | 
| 60 | 
            +
                    return log_mean_coeff - log_std
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                @staticmethod
         | 
| 63 | 
            +
                def inverse_lambda(lamb):
         | 
| 64 | 
            +
                    """
         | 
| 65 | 
            +
                    Compute the continuous-time label t in [0, T] of a given half-logSNR lambda_t.
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    return torch.exp(-lamb)
         | 
| 68 | 
            +
             | 
| 69 | 
            +
             | 
| 70 | 
            +
            def model_wrapper(
         | 
| 71 | 
            +
                model,
         | 
| 72 | 
            +
                noise_schedule,
         | 
| 73 | 
            +
                model_type="noise",
         | 
| 74 | 
            +
                model_kwargs={},
         | 
| 75 | 
            +
                guidance_type="uncond",
         | 
| 76 | 
            +
                condition=None,
         | 
| 77 | 
            +
                unconditional_condition=None,
         | 
| 78 | 
            +
                guidance_scale=1.0,
         | 
| 79 | 
            +
                interval_guidance=[0, 1.0],
         | 
| 80 | 
            +
                classifier_fn=None,
         | 
| 81 | 
            +
                classifier_kwargs={},
         | 
| 82 | 
            +
            ):
         | 
| 83 | 
            +
                """Create a wrapper function for the noise prediction model.
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                DPM-Solver needs to solve the continuous-time diffusion ODEs. For DPMs trained on discrete-time labels, we need to
         | 
| 86 | 
            +
                firstly wrap the model function to a noise prediction model that accepts the continuous time as the input.
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                We support four types of the diffusion model by setting `model_type`:
         | 
| 89 | 
            +
             | 
| 90 | 
            +
                    1. "noise": noise prediction model. (Trained by predicting noise).
         | 
| 91 | 
            +
             | 
| 92 | 
            +
                    2. "x_start": data prediction model. (Trained by predicting the data x_0 at time 0).
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    3. "v": velocity prediction model. (Trained by predicting the velocity).
         | 
| 95 | 
            +
                        The "v" prediction is derivation detailed in Appendix D of [1], and is used in Imagen-Video [2].
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                        [1] Salimans, Tim, and Jonathan Ho. "Progressive distillation for fast sampling of diffusion models."
         | 
| 98 | 
            +
                            arXiv preprint arXiv:2202.00512 (2022).
         | 
| 99 | 
            +
                        [2] Ho, Jonathan, et al. "Imagen Video: High Definition Video Generation with Diffusion Models."
         | 
| 100 | 
            +
                            arXiv preprint arXiv:2210.02303 (2022).
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                    4. "score": marginal score function. (Trained by denoising score matching).
         | 
| 103 | 
            +
                        Note that the score function and the noise prediction model follows a simple relationship:
         | 
| 104 | 
            +
                        ```
         | 
| 105 | 
            +
                            noise(x_t, t) = -sigma_t * score(x_t, t)
         | 
| 106 | 
            +
                        ```
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                We support three types of guided sampling by DPMs by setting `guidance_type`:
         | 
| 109 | 
            +
                    1. "uncond": unconditional sampling by DPMs.
         | 
| 110 | 
            +
                        The input `model` has the following format:
         | 
| 111 | 
            +
                        ``
         | 
| 112 | 
            +
                            model(x, t_input, **model_kwargs) -> noise | x_start | v | score
         | 
| 113 | 
            +
                        ``
         | 
| 114 | 
            +
             | 
| 115 | 
            +
                    2. "classifier": classifier guidance sampling [3] by DPMs and another classifier.
         | 
| 116 | 
            +
                        The input `model` has the following format:
         | 
| 117 | 
            +
                        ``
         | 
| 118 | 
            +
                            model(x, t_input, **model_kwargs) -> noise | x_start | v | score
         | 
| 119 | 
            +
                        ``
         | 
| 120 | 
            +
             | 
| 121 | 
            +
                        The input `classifier_fn` has the following format:
         | 
| 122 | 
            +
                        ``
         | 
| 123 | 
            +
                            classifier_fn(x, t_input, cond, **classifier_kwargs) -> logits(x, t_input, cond)
         | 
| 124 | 
            +
                        ``
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                        [3] P. Dhariwal and A. Q. Nichol, "Diffusion models beat GANs on image synthesis,"
         | 
| 127 | 
            +
                            in Advances in Neural Information Processing Systems, vol. 34, 2021, pp. 8780-8794.
         | 
| 128 | 
            +
             | 
| 129 | 
            +
                    3. "classifier-free": classifier-free guidance sampling by conditional DPMs.
         | 
| 130 | 
            +
                        The input `model` has the following format:
         | 
| 131 | 
            +
                        ``
         | 
| 132 | 
            +
                            model(x, t_input, cond, **model_kwargs) -> noise | x_start | v | score
         | 
| 133 | 
            +
                        ``
         | 
| 134 | 
            +
                        And if cond == `unconditional_condition`, the model output is the unconditional DPM output.
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                        [4] Ho, Jonathan, and Tim Salimans. "Classifier-free diffusion guidance."
         | 
| 137 | 
            +
                            arXiv preprint arXiv:2207.12598 (2022).
         | 
| 138 | 
            +
             | 
| 139 | 
            +
             | 
| 140 | 
            +
                The `t_input` is the time label of the model, which may be discrete-time labels (i.e. 0 to 999)
         | 
| 141 | 
            +
                or continuous-time labels (i.e. epsilon to T).
         | 
| 142 | 
            +
             | 
| 143 | 
            +
                We wrap the model function to accept only `x` and `t_continuous` as inputs, and outputs the predicted noise:
         | 
| 144 | 
            +
                ``
         | 
| 145 | 
            +
                    def model_fn(x, t_continuous) -> noise:
         | 
| 146 | 
            +
                        t_input = get_model_input_time(t_continuous)
         | 
| 147 | 
            +
                        return noise_pred(model, x, t_input, **model_kwargs)
         | 
| 148 | 
            +
                ``
         | 
| 149 | 
            +
                where `t_continuous` is the continuous time labels (i.e. epsilon to T). And we use `model_fn` for DPM-Solver.
         | 
| 150 | 
            +
             | 
| 151 | 
            +
                ===============================================================
         | 
| 152 | 
            +
             | 
| 153 | 
            +
                Args:
         | 
| 154 | 
            +
                    model: A diffusion model with the corresponding format described above.
         | 
| 155 | 
            +
                    noise_schedule: A noise schedule object, such as NoiseScheduleVP.
         | 
| 156 | 
            +
                    model_type: A `str`. The parameterization type of the diffusion model.
         | 
| 157 | 
            +
                                "noise" or "x_start" or "v" or "score".
         | 
| 158 | 
            +
                    model_kwargs: A `dict`. A dict for the other inputs of the model function.
         | 
| 159 | 
            +
                    guidance_type: A `str`. The type of the guidance for sampling.
         | 
| 160 | 
            +
                                "uncond" or "classifier" or "classifier-free".
         | 
| 161 | 
            +
                    condition: A pytorch tensor. The condition for the guided sampling.
         | 
| 162 | 
            +
                                Only used for "classifier" or "classifier-free" guidance type.
         | 
| 163 | 
            +
                    unconditional_condition: A pytorch tensor. The condition for the unconditional sampling.
         | 
| 164 | 
            +
                                Only used for "classifier-free" guidance type.
         | 
| 165 | 
            +
                    guidance_scale: A `float`. The scale for the guided sampling.
         | 
| 166 | 
            +
                    classifier_fn: A classifier function. Only used for the classifier guidance.
         | 
| 167 | 
            +
                    classifier_kwargs: A `dict`. A dict for the other inputs of the classifier function.
         | 
| 168 | 
            +
                Returns:
         | 
| 169 | 
            +
                    A noise prediction model that accepts the noised data and the continuous time as the inputs.
         | 
| 170 | 
            +
                """
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def get_model_input_time(t_continuous):
         | 
| 173 | 
            +
                    """
         | 
| 174 | 
            +
                    Convert the continuous-time `t_continuous` (in [epsilon, T]) to the model input time.
         | 
| 175 | 
            +
                    For discrete-time DPMs, we convert `t_continuous` in [1 / N, 1] to `t_input` in [0, 1000 * (N - 1) / N].
         | 
| 176 | 
            +
                    For continuous-time DPMs, we just use `t_continuous`.
         | 
| 177 | 
            +
                    """
         | 
| 178 | 
            +
                    if noise_schedule.schedule == "discrete":
         | 
| 179 | 
            +
                        return (t_continuous - 1.0 / noise_schedule.total_N) * noise_schedule.total_N
         | 
| 180 | 
            +
                    elif noise_schedule.schedule == "discrete_flow":
         | 
| 181 | 
            +
                        return t_continuous * noise_schedule.total_N
         | 
| 182 | 
            +
                    else:
         | 
| 183 | 
            +
                        return t_continuous
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                def noise_pred_fn(x, t_continuous, cond=None):
         | 
| 186 | 
            +
                    t_input = get_model_input_time(t_continuous)
         | 
| 187 | 
            +
                    if cond is None:
         | 
| 188 | 
            +
                        output = model(x, t_input, **model_kwargs)
         | 
| 189 | 
            +
                    else:
         | 
| 190 | 
            +
                        output = model(x, t_input, cond, **model_kwargs)
         | 
| 191 | 
            +
                    if model_type == "noise":
         | 
| 192 | 
            +
                        return output
         | 
| 193 | 
            +
                    elif model_type == "x_start":
         | 
| 194 | 
            +
                        alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
         | 
| 195 | 
            +
                        return (x - expand_dims(alpha_t, x.dim()) * output) / expand_dims(sigma_t, x.dim())
         | 
| 196 | 
            +
                    elif model_type == "v":
         | 
| 197 | 
            +
                        alpha_t, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
         | 
| 198 | 
            +
                        return expand_dims(alpha_t, x.dim()) * output + expand_dims(sigma_t, x.dim()) * x
         | 
| 199 | 
            +
                    elif model_type == "score":
         | 
| 200 | 
            +
                        sigma_t = noise_schedule.marginal_std(t_continuous)
         | 
| 201 | 
            +
                        return -expand_dims(sigma_t, x.dim()) * output
         | 
| 202 | 
            +
                    elif model_type == "flow":
         | 
| 203 | 
            +
                        _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
         | 
| 204 | 
            +
                        try:
         | 
| 205 | 
            +
                            noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output + x
         | 
| 206 | 
            +
                        except:
         | 
| 207 | 
            +
                            noise = (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0] + x
         | 
| 208 | 
            +
                        return noise
         | 
| 209 | 
            +
             | 
| 210 | 
            +
                def cond_grad_fn(x, t_input):
         | 
| 211 | 
            +
                    """
         | 
| 212 | 
            +
                    Compute the gradient of the classifier, i.e. nabla_{x} log p_t(cond | x_t).
         | 
| 213 | 
            +
                    """
         | 
| 214 | 
            +
                    with torch.enable_grad():
         | 
| 215 | 
            +
                        x_in = x.detach().requires_grad_(True)
         | 
| 216 | 
            +
                        log_prob = classifier_fn(x_in, t_input, condition, **classifier_kwargs)
         | 
| 217 | 
            +
                        return torch.autograd.grad(log_prob.sum(), x_in)[0]
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def model_fn(x, t_continuous):
         | 
| 220 | 
            +
                    """
         | 
| 221 | 
            +
                    The noise predicition model function that is used for DPM-Solver.
         | 
| 222 | 
            +
                    """
         | 
| 223 | 
            +
                    guidance_tp = guidance_type
         | 
| 224 | 
            +
                    if guidance_tp == "uncond":
         | 
| 225 | 
            +
                        return noise_pred_fn(x, t_continuous)
         | 
| 226 | 
            +
                    elif guidance_tp == "classifier":
         | 
| 227 | 
            +
                        assert classifier_fn is not None
         | 
| 228 | 
            +
                        t_input = get_model_input_time(t_continuous)
         | 
| 229 | 
            +
                        cond_grad = cond_grad_fn(x, t_input)
         | 
| 230 | 
            +
                        sigma_t = noise_schedule.marginal_std(t_continuous)
         | 
| 231 | 
            +
                        noise = noise_pred_fn(x, t_continuous)
         | 
| 232 | 
            +
                        return noise - guidance_scale * expand_dims(sigma_t, x.dim()) * cond_grad
         | 
| 233 | 
            +
                    elif guidance_tp == "classifier-free":
         | 
| 234 | 
            +
                        if (
         | 
| 235 | 
            +
                            guidance_scale == 1.0
         | 
| 236 | 
            +
                            or unconditional_condition is None
         | 
| 237 | 
            +
                            or not (interval_guidance[0] < t_continuous[0] < interval_guidance[1])
         | 
| 238 | 
            +
                        ):
         | 
| 239 | 
            +
                            return noise_pred_fn(x, t_continuous, cond=condition)
         | 
| 240 | 
            +
                        else:
         | 
| 241 | 
            +
                            x_in = torch.cat([x] * 2)
         | 
| 242 | 
            +
                            t_in = torch.cat([t_continuous] * 2)
         | 
| 243 | 
            +
                            c_in = torch.cat([unconditional_condition, condition])
         | 
| 244 | 
            +
                            try:
         | 
| 245 | 
            +
                                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in).chunk(2)
         | 
| 246 | 
            +
                            except:
         | 
| 247 | 
            +
                                noise_uncond, noise = noise_pred_fn(x_in, t_in, cond=c_in)[0].chunk(2)
         | 
| 248 | 
            +
                            return noise_uncond + guidance_scale * (noise - noise_uncond)
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                assert model_type in ["noise", "x_start", "v", "score", "flow"]
         | 
| 251 | 
            +
                assert guidance_type in [
         | 
| 252 | 
            +
                    "uncond",
         | 
| 253 | 
            +
                    "classifier",
         | 
| 254 | 
            +
                    "classifier-free",
         | 
| 255 | 
            +
                ]
         | 
| 256 | 
            +
                return model_fn
         | 
| 257 | 
            +
             | 
| 258 | 
            +
             | 
| 259 | 
            +
            class DPM_Solver:
         | 
| 260 | 
            +
                def __init__(
         | 
| 261 | 
            +
                    self,
         | 
| 262 | 
            +
                    model_fn,
         | 
| 263 | 
            +
                    noise_schedule,
         | 
| 264 | 
            +
                    algorithm_type="dpmsolver++",
         | 
| 265 | 
            +
                    correcting_x0_fn=None,
         | 
| 266 | 
            +
                    correcting_xt_fn=None,
         | 
| 267 | 
            +
                    thresholding_max_val=1.0,
         | 
| 268 | 
            +
                    dynamic_thresholding_ratio=0.995,
         | 
| 269 | 
            +
                ):
         | 
| 270 | 
            +
                    """Construct a DPM-Solver.
         | 
| 271 | 
            +
             | 
| 272 | 
            +
                    We support both DPM-Solver (`algorithm_type="dpmsolver"`) and DPM-Solver++ (`algorithm_type="dpmsolver++"`).
         | 
| 273 | 
            +
             | 
| 274 | 
            +
                    We also support the "dynamic thresholding" method in Imagen[1]. For pixel-space diffusion models, you
         | 
| 275 | 
            +
                    can set both `algorithm_type="dpmsolver++"` and `correcting_x0_fn="dynamic_thresholding"` to use the
         | 
| 276 | 
            +
                    dynamic thresholding. The "dynamic thresholding" can greatly improve the sample quality for pixel-space
         | 
| 277 | 
            +
                    DPMs with large guidance scales. Note that the thresholding method is **unsuitable** for latent-space
         | 
| 278 | 
            +
                    DPMs (such as stable-diffusion).
         | 
| 279 | 
            +
             | 
| 280 | 
            +
                    To support advanced algorithms in image-to-image applications, we also support corrector functions for
         | 
| 281 | 
            +
                    both x0 and xt.
         | 
| 282 | 
            +
             | 
| 283 | 
            +
                    Args:
         | 
| 284 | 
            +
                        model_fn: A noise prediction model function which accepts the continuous-time input (t in [epsilon, T]):
         | 
| 285 | 
            +
                            ``
         | 
| 286 | 
            +
                            def model_fn(x, t_continuous):
         | 
| 287 | 
            +
                                return noise
         | 
| 288 | 
            +
                            ``
         | 
| 289 | 
            +
                            The shape of `x` is `(batch_size, **shape)`, and the shape of `t_continuous` is `(batch_size,)`.
         | 
| 290 | 
            +
                        noise_schedule: A noise schedule object, such as NoiseScheduleVP.
         | 
| 291 | 
            +
                        algorithm_type: A `str`. Either "dpmsolver" or "dpmsolver++".
         | 
| 292 | 
            +
                        correcting_x0_fn: A `str` or a function with the following format:
         | 
| 293 | 
            +
                            ```
         | 
| 294 | 
            +
                            def correcting_x0_fn(x0, t):
         | 
| 295 | 
            +
                                x0_new = ...
         | 
| 296 | 
            +
                                return x0_new
         | 
| 297 | 
            +
                            ```
         | 
| 298 | 
            +
                            This function is to correct the outputs of the data prediction model at each sampling step. e.g.,
         | 
| 299 | 
            +
                            ```
         | 
| 300 | 
            +
                            x0_pred = data_pred_model(xt, t)
         | 
| 301 | 
            +
                            if correcting_x0_fn is not None:
         | 
| 302 | 
            +
                                x0_pred = correcting_x0_fn(x0_pred, t)
         | 
| 303 | 
            +
                            xt_1 = update(x0_pred, xt, t)
         | 
| 304 | 
            +
                            ```
         | 
| 305 | 
            +
                            If `correcting_x0_fn="dynamic_thresholding"`, we use the dynamic thresholding proposed in Imagen[1].
         | 
| 306 | 
            +
                        correcting_xt_fn: A function with the following format:
         | 
| 307 | 
            +
                            ```
         | 
| 308 | 
            +
                            def correcting_xt_fn(xt, t, step):
         | 
| 309 | 
            +
                                x_new = ...
         | 
| 310 | 
            +
                                return x_new
         | 
| 311 | 
            +
                            ```
         | 
| 312 | 
            +
                            This function is to correct the intermediate samples xt at each sampling step. e.g.,
         | 
| 313 | 
            +
                            ```
         | 
| 314 | 
            +
                            xt = ...
         | 
| 315 | 
            +
                            xt = correcting_xt_fn(xt, t, step)
         | 
| 316 | 
            +
                            ```
         | 
| 317 | 
            +
                        thresholding_max_val: A `float`. The max value for thresholding.
         | 
| 318 | 
            +
                            Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
         | 
| 319 | 
            +
                        dynamic_thresholding_ratio: A `float`. The ratio for dynamic thresholding (see Imagen[1] for details).
         | 
| 320 | 
            +
                            Valid only when use `dpmsolver++` and `correcting_x0_fn="dynamic_thresholding"`.
         | 
| 321 | 
            +
             | 
| 322 | 
            +
                    [1] Chitwan Saharia, William Chan, Saurabh Saxena, Lala Li, Jay Whang, Emily Denton, Seyed Kamyar Seyed Ghasemipour,
         | 
| 323 | 
            +
                        Burcu Karagol Ayan, S Sara Mahdavi, Rapha Gontijo Lopes, et al. Photorealistic text-to-image diffusion models
         | 
| 324 | 
            +
                        with deep language understanding. arXiv preprint arXiv:2205.11487, 2022b.
         | 
| 325 | 
            +
                    """
         | 
| 326 | 
            +
                    self.model = lambda x, t: model_fn(x, t.expand(x.shape[0]))
         | 
| 327 | 
            +
                    self.noise_schedule = noise_schedule
         | 
| 328 | 
            +
                    assert algorithm_type in ["dpmsolver", "dpmsolver++"]
         | 
| 329 | 
            +
                    self.algorithm_type = algorithm_type
         | 
| 330 | 
            +
                    if correcting_x0_fn == "dynamic_thresholding":
         | 
| 331 | 
            +
                        self.correcting_x0_fn = self.dynamic_thresholding_fn
         | 
| 332 | 
            +
                    else:
         | 
| 333 | 
            +
                        self.correcting_x0_fn = correcting_x0_fn
         | 
| 334 | 
            +
                    self.correcting_xt_fn = correcting_xt_fn
         | 
| 335 | 
            +
                    self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
         | 
| 336 | 
            +
                    self.thresholding_max_val = thresholding_max_val
         | 
| 337 | 
            +
                    self.register_progress_bar()
         | 
| 338 | 
            +
             | 
| 339 | 
            +
                def register_progress_bar(self, progress_fn=None):
         | 
| 340 | 
            +
                    """
         | 
| 341 | 
            +
                    Register a progress bar callback function
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    Args:
         | 
| 344 | 
            +
                        progress_fn: Callback function that takes current step and total steps as parameters
         | 
| 345 | 
            +
                    """
         | 
| 346 | 
            +
                    self.progress_fn = progress_fn if progress_fn is not None else lambda step, total: None
         | 
| 347 | 
            +
             | 
| 348 | 
            +
                def update_progress(self, step, total_steps):
         | 
| 349 | 
            +
                    """
         | 
| 350 | 
            +
                    Update sampling progress
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    Args:
         | 
| 353 | 
            +
                        step: Current step number
         | 
| 354 | 
            +
                        total_steps: Total number of steps
         | 
| 355 | 
            +
                    """
         | 
| 356 | 
            +
                    if hasattr(self, "progress_fn"):
         | 
| 357 | 
            +
                        try:
         | 
| 358 | 
            +
                            self.progress_fn(step / total_steps, desc=f"Generating {step}/{total_steps}")
         | 
| 359 | 
            +
                        except:
         | 
| 360 | 
            +
                            self.progress_fn(step, total_steps)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                    else:
         | 
| 363 | 
            +
                        # If no progress_fn registered, use default empty function
         | 
| 364 | 
            +
                        pass
         | 
| 365 | 
            +
             | 
| 366 | 
            +
                def dynamic_thresholding_fn(self, x0, t):
         | 
| 367 | 
            +
                    """
         | 
| 368 | 
            +
                    The dynamic thresholding method.
         | 
| 369 | 
            +
                    """
         | 
| 370 | 
            +
                    dims = x0.dim()
         | 
| 371 | 
            +
                    p = self.dynamic_thresholding_ratio
         | 
| 372 | 
            +
                    s = torch.quantile(torch.abs(x0).reshape((x0.shape[0], -1)), p, dim=1)
         | 
| 373 | 
            +
                    s = expand_dims(torch.maximum(s, self.thresholding_max_val * torch.ones_like(s).to(s.device)), dims)
         | 
| 374 | 
            +
                    x0 = torch.clamp(x0, -s, s) / s
         | 
| 375 | 
            +
                    return x0
         | 
| 376 | 
            +
             | 
| 377 | 
            +
                def noise_prediction_fn(self, x, t):
         | 
| 378 | 
            +
                    """
         | 
| 379 | 
            +
                    Return the noise prediction model.
         | 
| 380 | 
            +
                    """
         | 
| 381 | 
            +
                    return self.model(x, t)
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                def data_prediction_fn(self, x, t):
         | 
| 384 | 
            +
                    """
         | 
| 385 | 
            +
                    Return the data prediction model (with corrector).
         | 
| 386 | 
            +
                    """
         | 
| 387 | 
            +
                    noise = self.noise_prediction_fn(x, t)
         | 
| 388 | 
            +
                    alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
         | 
| 389 | 
            +
                    x0 = (x - sigma_t * noise) / alpha_t
         | 
| 390 | 
            +
                    if self.correcting_x0_fn is not None:
         | 
| 391 | 
            +
                        x0 = self.correcting_x0_fn(x0, t)
         | 
| 392 | 
            +
                    return x0
         | 
| 393 | 
            +
             | 
| 394 | 
            +
                def model_fn(self, x, t):
         | 
| 395 | 
            +
                    """
         | 
| 396 | 
            +
                    Convert the model to the noise prediction model or the data prediction model.
         | 
| 397 | 
            +
                    """
         | 
| 398 | 
            +
                    if self.algorithm_type == "dpmsolver++":
         | 
| 399 | 
            +
                        return self.data_prediction_fn(x, t)
         | 
| 400 | 
            +
                    else:
         | 
| 401 | 
            +
                        return self.noise_prediction_fn(x, t)
         | 
| 402 | 
            +
             | 
| 403 | 
            +
                def get_time_steps(self, skip_type, t_T, t_0, N, device, shift=1.0):
         | 
| 404 | 
            +
                    """Compute the intermediate time steps for sampling.
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                    Args:
         | 
| 407 | 
            +
                        skip_type: A `str`. The type for the spacing of the time steps. We support three types:
         | 
| 408 | 
            +
                            - 'logSNR': uniform logSNR for the time steps.
         | 
| 409 | 
            +
                            - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
         | 
| 410 | 
            +
                            - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
         | 
| 411 | 
            +
                        t_T: A `float`. The starting time of the sampling (default is T).
         | 
| 412 | 
            +
                        t_0: A `float`. The ending time of the sampling (default is epsilon).
         | 
| 413 | 
            +
                        N: A `int`. The total number of the spacing of the time steps.
         | 
| 414 | 
            +
                        device: A torch device.
         | 
| 415 | 
            +
                    Returns:
         | 
| 416 | 
            +
                        A pytorch tensor of the time steps, with the shape (N + 1,).
         | 
| 417 | 
            +
                    """
         | 
| 418 | 
            +
                    if skip_type == "logSNR":
         | 
| 419 | 
            +
                        lambda_T = self.noise_schedule.marginal_lambda(torch.tensor(t_T).to(device))
         | 
| 420 | 
            +
                        lambda_0 = self.noise_schedule.marginal_lambda(torch.tensor(t_0).to(device))
         | 
| 421 | 
            +
                        logSNR_steps = torch.linspace(lambda_T.cpu().item(), lambda_0.cpu().item(), N + 1).to(device)
         | 
| 422 | 
            +
                        return self.noise_schedule.inverse_lambda(logSNR_steps)
         | 
| 423 | 
            +
                    elif skip_type == "time_uniform":
         | 
| 424 | 
            +
                        return torch.linspace(t_T, t_0, N + 1).to(device)
         | 
| 425 | 
            +
                    elif skip_type == "time_quadratic":
         | 
| 426 | 
            +
                        t_order = 2
         | 
| 427 | 
            +
                        t = torch.linspace(t_T ** (1.0 / t_order), t_0 ** (1.0 / t_order), N + 1).pow(t_order).to(device)
         | 
| 428 | 
            +
                        return t
         | 
| 429 | 
            +
                    elif skip_type == "time_uniform_flow":
         | 
| 430 | 
            +
                        betas = torch.linspace(t_T, t_0, N + 1).to(device)
         | 
| 431 | 
            +
                        sigmas = 1.0 - betas
         | 
| 432 | 
            +
                        sigmas = (shift * sigmas / (1 + (shift - 1) * sigmas)).flip(dims=[0])
         | 
| 433 | 
            +
                        return sigmas
         | 
| 434 | 
            +
                    else:
         | 
| 435 | 
            +
                        raise ValueError(
         | 
| 436 | 
            +
                            f"Unsupported skip_type {skip_type}, need to be 'logSNR' or 'time_uniform' or 'time_quadratic'"
         | 
| 437 | 
            +
                        )
         | 
| 438 | 
            +
             | 
| 439 | 
            +
                def get_orders_and_timesteps_for_singlestep_solver(self, steps, order, skip_type, t_T, t_0, device):
         | 
| 440 | 
            +
                    """
         | 
| 441 | 
            +
                    Get the order of each step for sampling by the singlestep DPM-Solver.
         | 
| 442 | 
            +
             | 
| 443 | 
            +
                    We combine both DPM-Solver-1,2,3 to use all the function evaluations, which is named as "DPM-Solver-fast".
         | 
| 444 | 
            +
                    Given a fixed number of function evaluations by `steps`, the sampling procedure by DPM-Solver-fast is:
         | 
| 445 | 
            +
                        - If order == 1:
         | 
| 446 | 
            +
                            We take `steps` of DPM-Solver-1 (i.e. DDIM).
         | 
| 447 | 
            +
                        - If order == 2:
         | 
| 448 | 
            +
                            - Denote K = (steps // 2). We take K or (K + 1) intermediate time steps for sampling.
         | 
| 449 | 
            +
                            - If steps % 2 == 0, we use K steps of DPM-Solver-2.
         | 
| 450 | 
            +
                            - If steps % 2 == 1, we use K steps of DPM-Solver-2 and 1 step of DPM-Solver-1.
         | 
| 451 | 
            +
                        - If order == 3:
         | 
| 452 | 
            +
                            - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
         | 
| 453 | 
            +
                            - If steps % 3 == 0, we use (K - 2) steps of DPM-Solver-3, and 1 step of DPM-Solver-2 and 1 step of DPM-Solver-1.
         | 
| 454 | 
            +
                            - If steps % 3 == 1, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-1.
         | 
| 455 | 
            +
                            - If steps % 3 == 2, we use (K - 1) steps of DPM-Solver-3 and 1 step of DPM-Solver-2.
         | 
| 456 | 
            +
             | 
| 457 | 
            +
                    ============================================
         | 
| 458 | 
            +
                    Args:
         | 
| 459 | 
            +
                        order: A `int`. The max order for the solver (2 or 3).
         | 
| 460 | 
            +
                        steps: A `int`. The total number of function evaluations (NFE).
         | 
| 461 | 
            +
                        skip_type: A `str`. The type for the spacing of the time steps. We support three types:
         | 
| 462 | 
            +
                            - 'logSNR': uniform logSNR for the time steps.
         | 
| 463 | 
            +
                            - 'time_uniform': uniform time for the time steps. (**Recommended for high-resolutional data**.)
         | 
| 464 | 
            +
                            - 'time_quadratic': quadratic time for the time steps. (Used in DDIM for low-resolutional data.)
         | 
| 465 | 
            +
                        t_T: A `float`. The starting time of the sampling (default is T).
         | 
| 466 | 
            +
                        t_0: A `float`. The ending time of the sampling (default is epsilon).
         | 
| 467 | 
            +
                        device: A torch device.
         | 
| 468 | 
            +
                    Returns:
         | 
| 469 | 
            +
                        orders: A list of the solver order of each step.
         | 
| 470 | 
            +
                    """
         | 
| 471 | 
            +
                    if order == 3:
         | 
| 472 | 
            +
                        K = steps // 3 + 1
         | 
| 473 | 
            +
                        if steps % 3 == 0:
         | 
| 474 | 
            +
                            orders = [3,] * (
         | 
| 475 | 
            +
                                K - 2
         | 
| 476 | 
            +
                            ) + [2, 1]
         | 
| 477 | 
            +
                        elif steps % 3 == 1:
         | 
| 478 | 
            +
                            orders = [3,] * (
         | 
| 479 | 
            +
                                K - 1
         | 
| 480 | 
            +
                            ) + [1]
         | 
| 481 | 
            +
                        else:
         | 
| 482 | 
            +
                            orders = [3,] * (
         | 
| 483 | 
            +
                                K - 1
         | 
| 484 | 
            +
                            ) + [2]
         | 
| 485 | 
            +
                    elif order == 2:
         | 
| 486 | 
            +
                        if steps % 2 == 0:
         | 
| 487 | 
            +
                            K = steps // 2
         | 
| 488 | 
            +
                            orders = [
         | 
| 489 | 
            +
                                2,
         | 
| 490 | 
            +
                            ] * K
         | 
| 491 | 
            +
                        else:
         | 
| 492 | 
            +
                            K = steps // 2 + 1
         | 
| 493 | 
            +
                            orders = [2,] * (
         | 
| 494 | 
            +
                                K - 1
         | 
| 495 | 
            +
                            ) + [1]
         | 
| 496 | 
            +
                    elif order == 1:
         | 
| 497 | 
            +
                        K = 1
         | 
| 498 | 
            +
                        orders = [
         | 
| 499 | 
            +
                            1,
         | 
| 500 | 
            +
                        ] * steps
         | 
| 501 | 
            +
                    else:
         | 
| 502 | 
            +
                        raise ValueError("'order' must be '1' or '2' or '3'.")
         | 
| 503 | 
            +
                    if skip_type == "logSNR":
         | 
| 504 | 
            +
                        # To reproduce the results in DPM-Solver paper
         | 
| 505 | 
            +
                        timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, K, device)
         | 
| 506 | 
            +
                    else:
         | 
| 507 | 
            +
                        timesteps_outer = self.get_time_steps(skip_type, t_T, t_0, steps, device)[
         | 
| 508 | 
            +
                            torch.cumsum(
         | 
| 509 | 
            +
                                torch.tensor(
         | 
| 510 | 
            +
                                    [
         | 
| 511 | 
            +
                                        0,
         | 
| 512 | 
            +
                                    ]
         | 
| 513 | 
            +
                                    + orders
         | 
| 514 | 
            +
                                ),
         | 
| 515 | 
            +
                                0,
         | 
| 516 | 
            +
                            ).to(device)
         | 
| 517 | 
            +
                        ]
         | 
| 518 | 
            +
                    return timesteps_outer, orders
         | 
| 519 | 
            +
             | 
| 520 | 
            +
                def denoise_to_zero_fn(self, x, s):
         | 
| 521 | 
            +
                    """
         | 
| 522 | 
            +
                    Denoise at the final step, which is equivalent to solve the ODE from lambda_s to infty by first-order discretization.
         | 
| 523 | 
            +
                    """
         | 
| 524 | 
            +
                    return self.data_prediction_fn(x, s)
         | 
| 525 | 
            +
             | 
| 526 | 
            +
                def dpm_solver_first_update(self, x, s, t, model_s=None, return_intermediate=False):
         | 
| 527 | 
            +
                    """
         | 
| 528 | 
            +
                    DPM-Solver-1 (equivalent to DDIM) from time `s` to time `t`.
         | 
| 529 | 
            +
             | 
| 530 | 
            +
                    Args:
         | 
| 531 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 532 | 
            +
                        s: A pytorch tensor. The starting time, with the shape (1,).
         | 
| 533 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 534 | 
            +
                        model_s: A pytorch tensor. The model function evaluated at time `s`.
         | 
| 535 | 
            +
                            If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
         | 
| 536 | 
            +
                        return_intermediate: A `bool`. If true, also return the model value at time `s`.
         | 
| 537 | 
            +
                    Returns:
         | 
| 538 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 539 | 
            +
                    """
         | 
| 540 | 
            +
                    ns = self.noise_schedule
         | 
| 541 | 
            +
                    dims = x.dim()
         | 
| 542 | 
            +
                    lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
         | 
| 543 | 
            +
                    h = lambda_t - lambda_s
         | 
| 544 | 
            +
                    log_alpha_s, log_alpha_t = ns.marginal_log_mean_coeff(s), ns.marginal_log_mean_coeff(t)
         | 
| 545 | 
            +
                    sigma_s, sigma_t = ns.marginal_std(s), ns.marginal_std(t)
         | 
| 546 | 
            +
                    alpha_t = torch.exp(log_alpha_t)
         | 
| 547 | 
            +
             | 
| 548 | 
            +
                    if self.algorithm_type == "dpmsolver++":
         | 
| 549 | 
            +
                        phi_1 = torch.expm1(-h)
         | 
| 550 | 
            +
                        if model_s is None:
         | 
| 551 | 
            +
                            model_s = self.model_fn(x, s)
         | 
| 552 | 
            +
                        x_t = sigma_t / sigma_s * x - alpha_t * phi_1 * model_s
         | 
| 553 | 
            +
                        if return_intermediate:
         | 
| 554 | 
            +
                            return x_t, {"model_s": model_s}
         | 
| 555 | 
            +
                        else:
         | 
| 556 | 
            +
                            return x_t
         | 
| 557 | 
            +
                    else:
         | 
| 558 | 
            +
                        phi_1 = torch.expm1(h)
         | 
| 559 | 
            +
                        if model_s is None:
         | 
| 560 | 
            +
                            model_s = self.model_fn(x, s)
         | 
| 561 | 
            +
                        x_t = torch.exp(log_alpha_t - log_alpha_s) * x - (sigma_t * phi_1) * model_s
         | 
| 562 | 
            +
                        if return_intermediate:
         | 
| 563 | 
            +
                            return x_t, {"model_s": model_s}
         | 
| 564 | 
            +
                        else:
         | 
| 565 | 
            +
                            return x_t
         | 
| 566 | 
            +
             | 
| 567 | 
            +
                def singlestep_dpm_solver_second_update(
         | 
| 568 | 
            +
                    self, x, s, t, r1=0.5, model_s=None, return_intermediate=False, solver_type="dpmsolver"
         | 
| 569 | 
            +
                ):
         | 
| 570 | 
            +
                    """
         | 
| 571 | 
            +
                    Singlestep solver DPM-Solver-2 from time `s` to time `t`.
         | 
| 572 | 
            +
             | 
| 573 | 
            +
                    Args:
         | 
| 574 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 575 | 
            +
                        s: A pytorch tensor. The starting time, with the shape (1,).
         | 
| 576 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 577 | 
            +
                        r1: A `float`. The hyperparameter of the second-order solver.
         | 
| 578 | 
            +
                        model_s: A pytorch tensor. The model function evaluated at time `s`.
         | 
| 579 | 
            +
                            If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
         | 
| 580 | 
            +
                        return_intermediate: A `bool`. If true, also return the model value at time `s` and `s1` (the intermediate time).
         | 
| 581 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 582 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 583 | 
            +
                    Returns:
         | 
| 584 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 585 | 
            +
                    """
         | 
| 586 | 
            +
                    if solver_type not in ["dpmsolver", "taylor"]:
         | 
| 587 | 
            +
                        raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
         | 
| 588 | 
            +
                    if r1 is None:
         | 
| 589 | 
            +
                        r1 = 0.5
         | 
| 590 | 
            +
                    ns = self.noise_schedule
         | 
| 591 | 
            +
                    lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
         | 
| 592 | 
            +
                    h = lambda_t - lambda_s
         | 
| 593 | 
            +
                    lambda_s1 = lambda_s + r1 * h
         | 
| 594 | 
            +
                    s1 = ns.inverse_lambda(lambda_s1)
         | 
| 595 | 
            +
                    log_alpha_s, log_alpha_s1, log_alpha_t = (
         | 
| 596 | 
            +
                        ns.marginal_log_mean_coeff(s),
         | 
| 597 | 
            +
                        ns.marginal_log_mean_coeff(s1),
         | 
| 598 | 
            +
                        ns.marginal_log_mean_coeff(t),
         | 
| 599 | 
            +
                    )
         | 
| 600 | 
            +
                    sigma_s, sigma_s1, sigma_t = ns.marginal_std(s), ns.marginal_std(s1), ns.marginal_std(t)
         | 
| 601 | 
            +
                    alpha_s1, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_t)
         | 
| 602 | 
            +
             | 
| 603 | 
            +
                    if self.algorithm_type == "dpmsolver++":
         | 
| 604 | 
            +
                        phi_11 = torch.expm1(-r1 * h)
         | 
| 605 | 
            +
                        phi_1 = torch.expm1(-h)
         | 
| 606 | 
            +
             | 
| 607 | 
            +
                        if model_s is None:
         | 
| 608 | 
            +
                            model_s = self.model_fn(x, s)
         | 
| 609 | 
            +
                        x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
         | 
| 610 | 
            +
                        model_s1 = self.model_fn(x_s1, s1)
         | 
| 611 | 
            +
                        if solver_type == "dpmsolver":
         | 
| 612 | 
            +
                            x_t = (
         | 
| 613 | 
            +
                                (sigma_t / sigma_s) * x
         | 
| 614 | 
            +
                                - (alpha_t * phi_1) * model_s
         | 
| 615 | 
            +
                                - (0.5 / r1) * (alpha_t * phi_1) * (model_s1 - model_s)
         | 
| 616 | 
            +
                            )
         | 
| 617 | 
            +
                        elif solver_type == "taylor":
         | 
| 618 | 
            +
                            x_t = (
         | 
| 619 | 
            +
                                (sigma_t / sigma_s) * x
         | 
| 620 | 
            +
                                - (alpha_t * phi_1) * model_s
         | 
| 621 | 
            +
                                + (1.0 / r1) * (alpha_t * (phi_1 / h + 1.0)) * (model_s1 - model_s)
         | 
| 622 | 
            +
                            )
         | 
| 623 | 
            +
                    else:
         | 
| 624 | 
            +
                        phi_11 = torch.expm1(r1 * h)
         | 
| 625 | 
            +
                        phi_1 = torch.expm1(h)
         | 
| 626 | 
            +
             | 
| 627 | 
            +
                        if model_s is None:
         | 
| 628 | 
            +
                            model_s = self.model_fn(x, s)
         | 
| 629 | 
            +
                        x_s1 = torch.exp(log_alpha_s1 - log_alpha_s) * x - (sigma_s1 * phi_11) * model_s
         | 
| 630 | 
            +
                        model_s1 = self.model_fn(x_s1, s1)
         | 
| 631 | 
            +
                        if solver_type == "dpmsolver":
         | 
| 632 | 
            +
                            x_t = (
         | 
| 633 | 
            +
                                torch.exp(log_alpha_t - log_alpha_s) * x
         | 
| 634 | 
            +
                                - (sigma_t * phi_1) * model_s
         | 
| 635 | 
            +
                                - (0.5 / r1) * (sigma_t * phi_1) * (model_s1 - model_s)
         | 
| 636 | 
            +
                            )
         | 
| 637 | 
            +
                        elif solver_type == "taylor":
         | 
| 638 | 
            +
                            x_t = (
         | 
| 639 | 
            +
                                torch.exp(log_alpha_t - log_alpha_s) * x
         | 
| 640 | 
            +
                                - (sigma_t * phi_1) * model_s
         | 
| 641 | 
            +
                                - (1.0 / r1) * (sigma_t * (phi_1 / h - 1.0)) * (model_s1 - model_s)
         | 
| 642 | 
            +
                            )
         | 
| 643 | 
            +
                    if return_intermediate:
         | 
| 644 | 
            +
                        return x_t, {"model_s": model_s, "model_s1": model_s1}
         | 
| 645 | 
            +
                    else:
         | 
| 646 | 
            +
                        return x_t
         | 
| 647 | 
            +
             | 
| 648 | 
            +
                def singlestep_dpm_solver_third_update(
         | 
| 649 | 
            +
                    self,
         | 
| 650 | 
            +
                    x,
         | 
| 651 | 
            +
                    s,
         | 
| 652 | 
            +
                    t,
         | 
| 653 | 
            +
                    r1=1.0 / 3.0,
         | 
| 654 | 
            +
                    r2=2.0 / 3.0,
         | 
| 655 | 
            +
                    model_s=None,
         | 
| 656 | 
            +
                    model_s1=None,
         | 
| 657 | 
            +
                    return_intermediate=False,
         | 
| 658 | 
            +
                    solver_type="dpmsolver",
         | 
| 659 | 
            +
                ):
         | 
| 660 | 
            +
                    """
         | 
| 661 | 
            +
                    Singlestep solver DPM-Solver-3 from time `s` to time `t`.
         | 
| 662 | 
            +
             | 
| 663 | 
            +
                    Args:
         | 
| 664 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 665 | 
            +
                        s: A pytorch tensor. The starting time, with the shape (1,).
         | 
| 666 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 667 | 
            +
                        r1: A `float`. The hyperparameter of the third-order solver.
         | 
| 668 | 
            +
                        r2: A `float`. The hyperparameter of the third-order solver.
         | 
| 669 | 
            +
                        model_s: A pytorch tensor. The model function evaluated at time `s`.
         | 
| 670 | 
            +
                            If `model_s` is None, we evaluate the model by `x` and `s`; otherwise we directly use it.
         | 
| 671 | 
            +
                        model_s1: A pytorch tensor. The model function evaluated at time `s1` (the intermediate time given by `r1`).
         | 
| 672 | 
            +
                            If `model_s1` is None, we evaluate the model at `s1`; otherwise we directly use it.
         | 
| 673 | 
            +
                        return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
         | 
| 674 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 675 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 676 | 
            +
                    Returns:
         | 
| 677 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 678 | 
            +
                    """
         | 
| 679 | 
            +
                    if solver_type not in ["dpmsolver", "taylor"]:
         | 
| 680 | 
            +
                        raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
         | 
| 681 | 
            +
                    if r1 is None:
         | 
| 682 | 
            +
                        r1 = 1.0 / 3.0
         | 
| 683 | 
            +
                    if r2 is None:
         | 
| 684 | 
            +
                        r2 = 2.0 / 3.0
         | 
| 685 | 
            +
                    ns = self.noise_schedule
         | 
| 686 | 
            +
                    lambda_s, lambda_t = ns.marginal_lambda(s), ns.marginal_lambda(t)
         | 
| 687 | 
            +
                    h = lambda_t - lambda_s
         | 
| 688 | 
            +
                    lambda_s1 = lambda_s + r1 * h
         | 
| 689 | 
            +
                    lambda_s2 = lambda_s + r2 * h
         | 
| 690 | 
            +
                    s1 = ns.inverse_lambda(lambda_s1)
         | 
| 691 | 
            +
                    s2 = ns.inverse_lambda(lambda_s2)
         | 
| 692 | 
            +
                    log_alpha_s, log_alpha_s1, log_alpha_s2, log_alpha_t = (
         | 
| 693 | 
            +
                        ns.marginal_log_mean_coeff(s),
         | 
| 694 | 
            +
                        ns.marginal_log_mean_coeff(s1),
         | 
| 695 | 
            +
                        ns.marginal_log_mean_coeff(s2),
         | 
| 696 | 
            +
                        ns.marginal_log_mean_coeff(t),
         | 
| 697 | 
            +
                    )
         | 
| 698 | 
            +
                    sigma_s, sigma_s1, sigma_s2, sigma_t = (
         | 
| 699 | 
            +
                        ns.marginal_std(s),
         | 
| 700 | 
            +
                        ns.marginal_std(s1),
         | 
| 701 | 
            +
                        ns.marginal_std(s2),
         | 
| 702 | 
            +
                        ns.marginal_std(t),
         | 
| 703 | 
            +
                    )
         | 
| 704 | 
            +
                    alpha_s1, alpha_s2, alpha_t = torch.exp(log_alpha_s1), torch.exp(log_alpha_s2), torch.exp(log_alpha_t)
         | 
| 705 | 
            +
             | 
| 706 | 
            +
                    if self.algorithm_type == "dpmsolver++":
         | 
| 707 | 
            +
                        phi_11 = torch.expm1(-r1 * h)
         | 
| 708 | 
            +
                        phi_12 = torch.expm1(-r2 * h)
         | 
| 709 | 
            +
                        phi_1 = torch.expm1(-h)
         | 
| 710 | 
            +
                        phi_22 = torch.expm1(-r2 * h) / (r2 * h) + 1.0
         | 
| 711 | 
            +
                        phi_2 = phi_1 / h + 1.0
         | 
| 712 | 
            +
                        phi_3 = phi_2 / h - 0.5
         | 
| 713 | 
            +
             | 
| 714 | 
            +
                        if model_s is None:
         | 
| 715 | 
            +
                            model_s = self.model_fn(x, s)
         | 
| 716 | 
            +
                        if model_s1 is None:
         | 
| 717 | 
            +
                            x_s1 = (sigma_s1 / sigma_s) * x - (alpha_s1 * phi_11) * model_s
         | 
| 718 | 
            +
                            model_s1 = self.model_fn(x_s1, s1)
         | 
| 719 | 
            +
                        x_s2 = (
         | 
| 720 | 
            +
                            (sigma_s2 / sigma_s) * x
         | 
| 721 | 
            +
                            - (alpha_s2 * phi_12) * model_s
         | 
| 722 | 
            +
                            + r2 / r1 * (alpha_s2 * phi_22) * (model_s1 - model_s)
         | 
| 723 | 
            +
                        )
         | 
| 724 | 
            +
                        model_s2 = self.model_fn(x_s2, s2)
         | 
| 725 | 
            +
                        if solver_type == "dpmsolver":
         | 
| 726 | 
            +
                            x_t = (
         | 
| 727 | 
            +
                                (sigma_t / sigma_s) * x
         | 
| 728 | 
            +
                                - (alpha_t * phi_1) * model_s
         | 
| 729 | 
            +
                                + (1.0 / r2) * (alpha_t * phi_2) * (model_s2 - model_s)
         | 
| 730 | 
            +
                            )
         | 
| 731 | 
            +
                        elif solver_type == "taylor":
         | 
| 732 | 
            +
                            D1_0 = (1.0 / r1) * (model_s1 - model_s)
         | 
| 733 | 
            +
                            D1_1 = (1.0 / r2) * (model_s2 - model_s)
         | 
| 734 | 
            +
                            D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
         | 
| 735 | 
            +
                            D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
         | 
| 736 | 
            +
                            x_t = (
         | 
| 737 | 
            +
                                (sigma_t / sigma_s) * x
         | 
| 738 | 
            +
                                - (alpha_t * phi_1) * model_s
         | 
| 739 | 
            +
                                + (alpha_t * phi_2) * D1
         | 
| 740 | 
            +
                                - (alpha_t * phi_3) * D2
         | 
| 741 | 
            +
                            )
         | 
| 742 | 
            +
                    else:
         | 
| 743 | 
            +
                        phi_11 = torch.expm1(r1 * h)
         | 
| 744 | 
            +
                        phi_12 = torch.expm1(r2 * h)
         | 
| 745 | 
            +
                        phi_1 = torch.expm1(h)
         | 
| 746 | 
            +
                        phi_22 = torch.expm1(r2 * h) / (r2 * h) - 1.0
         | 
| 747 | 
            +
                        phi_2 = phi_1 / h - 1.0
         | 
| 748 | 
            +
                        phi_3 = phi_2 / h - 0.5
         | 
| 749 | 
            +
             | 
| 750 | 
            +
                        if model_s is None:
         | 
| 751 | 
            +
                            model_s = self.model_fn(x, s)
         | 
| 752 | 
            +
                        if model_s1 is None:
         | 
| 753 | 
            +
                            x_s1 = (torch.exp(log_alpha_s1 - log_alpha_s)) * x - (sigma_s1 * phi_11) * model_s
         | 
| 754 | 
            +
                            model_s1 = self.model_fn(x_s1, s1)
         | 
| 755 | 
            +
                        x_s2 = (
         | 
| 756 | 
            +
                            (torch.exp(log_alpha_s2 - log_alpha_s)) * x
         | 
| 757 | 
            +
                            - (sigma_s2 * phi_12) * model_s
         | 
| 758 | 
            +
                            - r2 / r1 * (sigma_s2 * phi_22) * (model_s1 - model_s)
         | 
| 759 | 
            +
                        )
         | 
| 760 | 
            +
                        model_s2 = self.model_fn(x_s2, s2)
         | 
| 761 | 
            +
                        if solver_type == "dpmsolver":
         | 
| 762 | 
            +
                            x_t = (
         | 
| 763 | 
            +
                                (torch.exp(log_alpha_t - log_alpha_s)) * x
         | 
| 764 | 
            +
                                - (sigma_t * phi_1) * model_s
         | 
| 765 | 
            +
                                - (1.0 / r2) * (sigma_t * phi_2) * (model_s2 - model_s)
         | 
| 766 | 
            +
                            )
         | 
| 767 | 
            +
                        elif solver_type == "taylor":
         | 
| 768 | 
            +
                            D1_0 = (1.0 / r1) * (model_s1 - model_s)
         | 
| 769 | 
            +
                            D1_1 = (1.0 / r2) * (model_s2 - model_s)
         | 
| 770 | 
            +
                            D1 = (r2 * D1_0 - r1 * D1_1) / (r2 - r1)
         | 
| 771 | 
            +
                            D2 = 2.0 * (D1_1 - D1_0) / (r2 - r1)
         | 
| 772 | 
            +
                            x_t = (
         | 
| 773 | 
            +
                                (torch.exp(log_alpha_t - log_alpha_s)) * x
         | 
| 774 | 
            +
                                - (sigma_t * phi_1) * model_s
         | 
| 775 | 
            +
                                - (sigma_t * phi_2) * D1
         | 
| 776 | 
            +
                                - (sigma_t * phi_3) * D2
         | 
| 777 | 
            +
                            )
         | 
| 778 | 
            +
             | 
| 779 | 
            +
                    if return_intermediate:
         | 
| 780 | 
            +
                        return x_t, {"model_s": model_s, "model_s1": model_s1, "model_s2": model_s2}
         | 
| 781 | 
            +
                    else:
         | 
| 782 | 
            +
                        return x_t
         | 
| 783 | 
            +
             | 
| 784 | 
            +
                def multistep_dpm_solver_second_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
         | 
| 785 | 
            +
                    """
         | 
| 786 | 
            +
                    Multistep solver DPM-Solver-2 from time `t_prev_list[-1]` to time `t`.
         | 
| 787 | 
            +
             | 
| 788 | 
            +
                    Args:
         | 
| 789 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 790 | 
            +
                        model_prev_list: A list of pytorch tensor. The previous computed model values.
         | 
| 791 | 
            +
                        t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
         | 
| 792 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 793 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 794 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 795 | 
            +
                    Returns:
         | 
| 796 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 797 | 
            +
                    """
         | 
| 798 | 
            +
                    if solver_type not in ["dpmsolver", "taylor"]:
         | 
| 799 | 
            +
                        raise ValueError(f"'solver_type' must be either 'dpmsolver' or 'taylor', got {solver_type}")
         | 
| 800 | 
            +
                    ns = self.noise_schedule
         | 
| 801 | 
            +
                    model_prev_1, model_prev_0 = model_prev_list[-2], model_prev_list[-1]
         | 
| 802 | 
            +
                    t_prev_1, t_prev_0 = t_prev_list[-2], t_prev_list[-1]
         | 
| 803 | 
            +
                    lambda_prev_1, lambda_prev_0, lambda_t = (
         | 
| 804 | 
            +
                        ns.marginal_lambda(t_prev_1),
         | 
| 805 | 
            +
                        ns.marginal_lambda(t_prev_0),
         | 
| 806 | 
            +
                        ns.marginal_lambda(t),
         | 
| 807 | 
            +
                    )
         | 
| 808 | 
            +
                    log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
         | 
| 809 | 
            +
                    sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
         | 
| 810 | 
            +
                    alpha_t = torch.exp(log_alpha_t)
         | 
| 811 | 
            +
             | 
| 812 | 
            +
                    h_0 = lambda_prev_0 - lambda_prev_1
         | 
| 813 | 
            +
                    h = lambda_t - lambda_prev_0
         | 
| 814 | 
            +
                    r0 = h_0 / h
         | 
| 815 | 
            +
                    D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
         | 
| 816 | 
            +
                    if self.algorithm_type == "dpmsolver++":
         | 
| 817 | 
            +
                        phi_1 = torch.expm1(-h)
         | 
| 818 | 
            +
                        if solver_type == "dpmsolver":
         | 
| 819 | 
            +
                            x_t = (sigma_t / sigma_prev_0) * x - (alpha_t * phi_1) * model_prev_0 - 0.5 * (alpha_t * phi_1) * D1_0
         | 
| 820 | 
            +
                        elif solver_type == "taylor":
         | 
| 821 | 
            +
                            x_t = (
         | 
| 822 | 
            +
                                (sigma_t / sigma_prev_0) * x
         | 
| 823 | 
            +
                                - (alpha_t * phi_1) * model_prev_0
         | 
| 824 | 
            +
                                + (alpha_t * (phi_1 / h + 1.0)) * D1_0
         | 
| 825 | 
            +
                            )
         | 
| 826 | 
            +
                    else:
         | 
| 827 | 
            +
                        phi_1 = torch.expm1(h)
         | 
| 828 | 
            +
                        if solver_type == "dpmsolver":
         | 
| 829 | 
            +
                            x_t = (
         | 
| 830 | 
            +
                                (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
         | 
| 831 | 
            +
                                - (sigma_t * phi_1) * model_prev_0
         | 
| 832 | 
            +
                                - 0.5 * (sigma_t * phi_1) * D1_0
         | 
| 833 | 
            +
                            )
         | 
| 834 | 
            +
                        elif solver_type == "taylor":
         | 
| 835 | 
            +
                            x_t = (
         | 
| 836 | 
            +
                                (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
         | 
| 837 | 
            +
                                - (sigma_t * phi_1) * model_prev_0
         | 
| 838 | 
            +
                                - (sigma_t * (phi_1 / h - 1.0)) * D1_0
         | 
| 839 | 
            +
                            )
         | 
| 840 | 
            +
                    return x_t
         | 
| 841 | 
            +
             | 
| 842 | 
            +
                def multistep_dpm_solver_third_update(self, x, model_prev_list, t_prev_list, t, solver_type="dpmsolver"):
         | 
| 843 | 
            +
                    """
         | 
| 844 | 
            +
                    Multistep solver DPM-Solver-3 from time `t_prev_list[-1]` to time `t`.
         | 
| 845 | 
            +
             | 
| 846 | 
            +
                    Args:
         | 
| 847 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 848 | 
            +
                        model_prev_list: A list of pytorch tensor. The previous computed model values.
         | 
| 849 | 
            +
                        t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
         | 
| 850 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 851 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 852 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 853 | 
            +
                    Returns:
         | 
| 854 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 855 | 
            +
                    """
         | 
| 856 | 
            +
                    ns = self.noise_schedule
         | 
| 857 | 
            +
                    model_prev_2, model_prev_1, model_prev_0 = model_prev_list
         | 
| 858 | 
            +
                    t_prev_2, t_prev_1, t_prev_0 = t_prev_list
         | 
| 859 | 
            +
                    lambda_prev_2, lambda_prev_1, lambda_prev_0, lambda_t = (
         | 
| 860 | 
            +
                        ns.marginal_lambda(t_prev_2),
         | 
| 861 | 
            +
                        ns.marginal_lambda(t_prev_1),
         | 
| 862 | 
            +
                        ns.marginal_lambda(t_prev_0),
         | 
| 863 | 
            +
                        ns.marginal_lambda(t),
         | 
| 864 | 
            +
                    )
         | 
| 865 | 
            +
                    log_alpha_prev_0, log_alpha_t = ns.marginal_log_mean_coeff(t_prev_0), ns.marginal_log_mean_coeff(t)
         | 
| 866 | 
            +
                    sigma_prev_0, sigma_t = ns.marginal_std(t_prev_0), ns.marginal_std(t)
         | 
| 867 | 
            +
                    alpha_t = torch.exp(log_alpha_t)
         | 
| 868 | 
            +
             | 
| 869 | 
            +
                    h_1 = lambda_prev_1 - lambda_prev_2
         | 
| 870 | 
            +
                    h_0 = lambda_prev_0 - lambda_prev_1
         | 
| 871 | 
            +
                    h = lambda_t - lambda_prev_0
         | 
| 872 | 
            +
                    r0, r1 = h_0 / h, h_1 / h
         | 
| 873 | 
            +
                    D1_0 = (1.0 / r0) * (model_prev_0 - model_prev_1)
         | 
| 874 | 
            +
                    D1_1 = (1.0 / r1) * (model_prev_1 - model_prev_2)
         | 
| 875 | 
            +
                    D1 = D1_0 + (r0 / (r0 + r1)) * (D1_0 - D1_1)
         | 
| 876 | 
            +
                    D2 = (1.0 / (r0 + r1)) * (D1_0 - D1_1)
         | 
| 877 | 
            +
                    if self.algorithm_type == "dpmsolver++":
         | 
| 878 | 
            +
                        phi_1 = torch.expm1(-h)
         | 
| 879 | 
            +
                        phi_2 = phi_1 / h + 1.0
         | 
| 880 | 
            +
                        phi_3 = phi_2 / h - 0.5
         | 
| 881 | 
            +
                        x_t = (
         | 
| 882 | 
            +
                            (sigma_t / sigma_prev_0) * x
         | 
| 883 | 
            +
                            - (alpha_t * phi_1) * model_prev_0
         | 
| 884 | 
            +
                            + (alpha_t * phi_2) * D1
         | 
| 885 | 
            +
                            - (alpha_t * phi_3) * D2
         | 
| 886 | 
            +
                        )
         | 
| 887 | 
            +
                    else:
         | 
| 888 | 
            +
                        phi_1 = torch.expm1(h)
         | 
| 889 | 
            +
                        phi_2 = phi_1 / h - 1.0
         | 
| 890 | 
            +
                        phi_3 = phi_2 / h - 0.5
         | 
| 891 | 
            +
                        x_t = (
         | 
| 892 | 
            +
                            (torch.exp(log_alpha_t - log_alpha_prev_0)) * x
         | 
| 893 | 
            +
                            - (sigma_t * phi_1) * model_prev_0
         | 
| 894 | 
            +
                            - (sigma_t * phi_2) * D1
         | 
| 895 | 
            +
                            - (sigma_t * phi_3) * D2
         | 
| 896 | 
            +
                        )
         | 
| 897 | 
            +
                    return x_t
         | 
| 898 | 
            +
             | 
| 899 | 
            +
                def singlestep_dpm_solver_update(
         | 
| 900 | 
            +
                    self, x, s, t, order, return_intermediate=False, solver_type="dpmsolver", r1=None, r2=None
         | 
| 901 | 
            +
                ):
         | 
| 902 | 
            +
                    """
         | 
| 903 | 
            +
                    Singlestep DPM-Solver with the order `order` from time `s` to time `t`.
         | 
| 904 | 
            +
             | 
| 905 | 
            +
                    Args:
         | 
| 906 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 907 | 
            +
                        s: A pytorch tensor. The starting time, with the shape (1,).
         | 
| 908 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 909 | 
            +
                        order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
         | 
| 910 | 
            +
                        return_intermediate: A `bool`. If true, also return the model value at time `s`, `s1` and `s2` (the intermediate times).
         | 
| 911 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 912 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 913 | 
            +
                        r1: A `float`. The hyperparameter of the second-order or third-order solver.
         | 
| 914 | 
            +
                        r2: A `float`. The hyperparameter of the third-order solver.
         | 
| 915 | 
            +
                    Returns:
         | 
| 916 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 917 | 
            +
                    """
         | 
| 918 | 
            +
                    if order == 1:
         | 
| 919 | 
            +
                        return self.dpm_solver_first_update(x, s, t, return_intermediate=return_intermediate)
         | 
| 920 | 
            +
                    elif order == 2:
         | 
| 921 | 
            +
                        return self.singlestep_dpm_solver_second_update(
         | 
| 922 | 
            +
                            x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1
         | 
| 923 | 
            +
                        )
         | 
| 924 | 
            +
                    elif order == 3:
         | 
| 925 | 
            +
                        return self.singlestep_dpm_solver_third_update(
         | 
| 926 | 
            +
                            x, s, t, return_intermediate=return_intermediate, solver_type=solver_type, r1=r1, r2=r2
         | 
| 927 | 
            +
                        )
         | 
| 928 | 
            +
                    else:
         | 
| 929 | 
            +
                        raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
         | 
| 930 | 
            +
             | 
| 931 | 
            +
                def multistep_dpm_solver_update(self, x, model_prev_list, t_prev_list, t, order, solver_type="dpmsolver"):
         | 
| 932 | 
            +
                    """
         | 
| 933 | 
            +
                    Multistep DPM-Solver with the order `order` from time `t_prev_list[-1]` to time `t`.
         | 
| 934 | 
            +
             | 
| 935 | 
            +
                    Args:
         | 
| 936 | 
            +
                        x: A pytorch tensor. The initial value at time `s`.
         | 
| 937 | 
            +
                        model_prev_list: A list of pytorch tensor. The previous computed model values.
         | 
| 938 | 
            +
                        t_prev_list: A list of pytorch tensor. The previous times, each time has the shape (1,)
         | 
| 939 | 
            +
                        t: A pytorch tensor. The ending time, with the shape (1,).
         | 
| 940 | 
            +
                        order: A `int`. The order of DPM-Solver. We only support order == 1 or 2 or 3.
         | 
| 941 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 942 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 943 | 
            +
                    Returns:
         | 
| 944 | 
            +
                        x_t: A pytorch tensor. The approximated solution at time `t`.
         | 
| 945 | 
            +
                    """
         | 
| 946 | 
            +
                    if order == 1:
         | 
| 947 | 
            +
                        return self.dpm_solver_first_update(x, t_prev_list[-1], t, model_s=model_prev_list[-1])
         | 
| 948 | 
            +
                    elif order == 2:
         | 
| 949 | 
            +
                        return self.multistep_dpm_solver_second_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
         | 
| 950 | 
            +
                    elif order == 3:
         | 
| 951 | 
            +
                        return self.multistep_dpm_solver_third_update(x, model_prev_list, t_prev_list, t, solver_type=solver_type)
         | 
| 952 | 
            +
                    else:
         | 
| 953 | 
            +
                        raise ValueError(f"Solver order must be 1 or 2 or 3, got {order}")
         | 
| 954 | 
            +
             | 
| 955 | 
            +
                def dpm_solver_adaptive(
         | 
| 956 | 
            +
                    self, x, order, t_T, t_0, h_init=0.05, atol=0.0078, rtol=0.05, theta=0.9, t_err=1e-5, solver_type="dpmsolver"
         | 
| 957 | 
            +
                ):
         | 
| 958 | 
            +
                    """
         | 
| 959 | 
            +
                    The adaptive step size solver based on singlestep DPM-Solver.
         | 
| 960 | 
            +
             | 
| 961 | 
            +
                    Args:
         | 
| 962 | 
            +
                        x: A pytorch tensor. The initial value at time `t_T`.
         | 
| 963 | 
            +
                        order: A `int`. The (higher) order of the solver. We only support order == 2 or 3.
         | 
| 964 | 
            +
                        t_T: A `float`. The starting time of the sampling (default is T).
         | 
| 965 | 
            +
                        t_0: A `float`. The ending time of the sampling (default is epsilon).
         | 
| 966 | 
            +
                        h_init: A `float`. The initial step size (for logSNR).
         | 
| 967 | 
            +
                        atol: A `float`. The absolute tolerance of the solver. For image data, the default setting is 0.0078, followed [1].
         | 
| 968 | 
            +
                        rtol: A `float`. The relative tolerance of the solver. The default setting is 0.05.
         | 
| 969 | 
            +
                        theta: A `float`. The safety hyperparameter for adapting the step size. The default setting is 0.9, followed [1].
         | 
| 970 | 
            +
                        t_err: A `float`. The tolerance for the time. We solve the diffusion ODE until the absolute error between the
         | 
| 971 | 
            +
                            current time and `t_0` is less than `t_err`. The default setting is 1e-5.
         | 
| 972 | 
            +
                        solver_type: either 'dpmsolver' or 'taylor'. The type for the high-order solvers.
         | 
| 973 | 
            +
                            The type slightly impacts the performance. We recommend to use 'dpmsolver' type.
         | 
| 974 | 
            +
                    Returns:
         | 
| 975 | 
            +
                        x_0: A pytorch tensor. The approximated solution at time `t_0`.
         | 
| 976 | 
            +
             | 
| 977 | 
            +
                    [1] A. Jolicoeur-Martineau, K. Li, R. Piché-Taillefer, T. Kachman, and I. Mitliagkas, "Gotta go fast when generating data with score-based models," arXiv preprint arXiv:2105.14080, 2021.
         | 
| 978 | 
            +
                    """
         | 
| 979 | 
            +
                    ns = self.noise_schedule
         | 
| 980 | 
            +
                    s = t_T * torch.ones((1,)).to(x)
         | 
| 981 | 
            +
                    lambda_s = ns.marginal_lambda(s)
         | 
| 982 | 
            +
                    lambda_0 = ns.marginal_lambda(t_0 * torch.ones_like(s).to(x))
         | 
| 983 | 
            +
                    h = h_init * torch.ones_like(s).to(x)
         | 
| 984 | 
            +
                    x_prev = x
         | 
| 985 | 
            +
                    nfe = 0
         | 
| 986 | 
            +
                    if order == 2:
         | 
| 987 | 
            +
                        r1 = 0.5
         | 
| 988 | 
            +
                        lower_update = lambda x, s, t: self.dpm_solver_first_update(x, s, t, return_intermediate=True)
         | 
| 989 | 
            +
                        higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_second_update(
         | 
| 990 | 
            +
                            x, s, t, r1=r1, solver_type=solver_type, **kwargs
         | 
| 991 | 
            +
                        )
         | 
| 992 | 
            +
                    elif order == 3:
         | 
| 993 | 
            +
                        r1, r2 = 1.0 / 3.0, 2.0 / 3.0
         | 
| 994 | 
            +
                        lower_update = lambda x, s, t: self.singlestep_dpm_solver_second_update(
         | 
| 995 | 
            +
                            x, s, t, r1=r1, return_intermediate=True, solver_type=solver_type
         | 
| 996 | 
            +
                        )
         | 
| 997 | 
            +
                        higher_update = lambda x, s, t, **kwargs: self.singlestep_dpm_solver_third_update(
         | 
| 998 | 
            +
                            x, s, t, r1=r1, r2=r2, solver_type=solver_type, **kwargs
         | 
| 999 | 
            +
                        )
         | 
| 1000 | 
            +
                    else:
         | 
| 1001 | 
            +
                        raise ValueError(f"For adaptive step size solver, order must be 2 or 3, got {order}")
         | 
| 1002 | 
            +
                    while torch.abs(s - t_0).mean() > t_err:
         | 
| 1003 | 
            +
                        t = ns.inverse_lambda(lambda_s + h)
         | 
| 1004 | 
            +
                        x_lower, lower_noise_kwargs = lower_update(x, s, t)
         | 
| 1005 | 
            +
                        x_higher = higher_update(x, s, t, **lower_noise_kwargs)
         | 
| 1006 | 
            +
                        delta = torch.max(torch.ones_like(x).to(x) * atol, rtol * torch.max(torch.abs(x_lower), torch.abs(x_prev)))
         | 
| 1007 | 
            +
                        norm_fn = lambda v: torch.sqrt(torch.square(v.reshape((v.shape[0], -1))).mean(dim=-1, keepdim=True))
         | 
| 1008 | 
            +
                        E = norm_fn((x_higher - x_lower) / delta).max()
         | 
| 1009 | 
            +
                        if torch.all(E <= 1.0):
         | 
| 1010 | 
            +
                            x = x_higher
         | 
| 1011 | 
            +
                            s = t
         | 
| 1012 | 
            +
                            x_prev = x_lower
         | 
| 1013 | 
            +
                            lambda_s = ns.marginal_lambda(s)
         | 
| 1014 | 
            +
                        h = torch.min(theta * h * torch.float_power(E, -1.0 / order).float(), lambda_0 - lambda_s)
         | 
| 1015 | 
            +
                        nfe += order
         | 
| 1016 | 
            +
                    print("adaptive solver nfe", nfe)
         | 
| 1017 | 
            +
                    return x
         | 
| 1018 | 
            +
             | 
| 1019 | 
            +
                def add_noise(self, x, t, noise=None):
         | 
| 1020 | 
            +
                    """
         | 
| 1021 | 
            +
                    Compute the noised input xt = alpha_t * x + sigma_t * noise.
         | 
| 1022 | 
            +
             | 
| 1023 | 
            +
                    Args:
         | 
| 1024 | 
            +
                        x: A `torch.Tensor` with shape `(batch_size, *shape)`.
         | 
| 1025 | 
            +
                        t: A `torch.Tensor` with shape `(t_size,)`.
         | 
| 1026 | 
            +
                    Returns:
         | 
| 1027 | 
            +
                        xt with shape `(t_size, batch_size, *shape)`.
         | 
| 1028 | 
            +
                    """
         | 
| 1029 | 
            +
                    alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
         | 
| 1030 | 
            +
                    if noise is None:
         | 
| 1031 | 
            +
                        noise = torch.randn((t.shape[0], *x.shape), device=x.device)
         | 
| 1032 | 
            +
                    x = x.reshape((-1, *x.shape))
         | 
| 1033 | 
            +
                    xt = expand_dims(alpha_t, x.dim()) * x + expand_dims(sigma_t, x.dim()) * noise
         | 
| 1034 | 
            +
                    if t.shape[0] == 1:
         | 
| 1035 | 
            +
                        return xt.squeeze(0)
         | 
| 1036 | 
            +
                    else:
         | 
| 1037 | 
            +
                        return xt
         | 
| 1038 | 
            +
             | 
| 1039 | 
            +
                def inverse(
         | 
| 1040 | 
            +
                    self,
         | 
| 1041 | 
            +
                    x,
         | 
| 1042 | 
            +
                    steps=20,
         | 
| 1043 | 
            +
                    t_start=None,
         | 
| 1044 | 
            +
                    t_end=None,
         | 
| 1045 | 
            +
                    order=2,
         | 
| 1046 | 
            +
                    skip_type="time_uniform",
         | 
| 1047 | 
            +
                    method="multistep",
         | 
| 1048 | 
            +
                    lower_order_final=True,
         | 
| 1049 | 
            +
                    denoise_to_zero=False,
         | 
| 1050 | 
            +
                    solver_type="dpmsolver",
         | 
| 1051 | 
            +
                    atol=0.0078,
         | 
| 1052 | 
            +
                    rtol=0.05,
         | 
| 1053 | 
            +
                    return_intermediate=False,
         | 
| 1054 | 
            +
                ):
         | 
| 1055 | 
            +
                    """
         | 
| 1056 | 
            +
                    Inverse the sample `x` from time `t_start` to `t_end` by DPM-Solver.
         | 
| 1057 | 
            +
                    For discrete-time DPMs, we use `t_start=1/N`, where `N` is the total time steps during training.
         | 
| 1058 | 
            +
                    """
         | 
| 1059 | 
            +
                    t_0 = 1.0 / self.noise_schedule.total_N if t_start is None else t_start
         | 
| 1060 | 
            +
                    t_T = self.noise_schedule.T if t_end is None else t_end
         | 
| 1061 | 
            +
                    assert (
         | 
| 1062 | 
            +
                        t_0 > 0 and t_T > 0
         | 
| 1063 | 
            +
                    ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
         | 
| 1064 | 
            +
                    return self.sample(
         | 
| 1065 | 
            +
                        x,
         | 
| 1066 | 
            +
                        steps=steps,
         | 
| 1067 | 
            +
                        t_start=t_0,
         | 
| 1068 | 
            +
                        t_end=t_T,
         | 
| 1069 | 
            +
                        order=order,
         | 
| 1070 | 
            +
                        skip_type=skip_type,
         | 
| 1071 | 
            +
                        method=method,
         | 
| 1072 | 
            +
                        lower_order_final=lower_order_final,
         | 
| 1073 | 
            +
                        denoise_to_zero=denoise_to_zero,
         | 
| 1074 | 
            +
                        solver_type=solver_type,
         | 
| 1075 | 
            +
                        atol=atol,
         | 
| 1076 | 
            +
                        rtol=rtol,
         | 
| 1077 | 
            +
                        return_intermediate=return_intermediate,
         | 
| 1078 | 
            +
                    )
         | 
| 1079 | 
            +
             | 
| 1080 | 
            +
                def sample(
         | 
| 1081 | 
            +
                    self,
         | 
| 1082 | 
            +
                    x,
         | 
| 1083 | 
            +
                    steps=20,
         | 
| 1084 | 
            +
                    t_start=None,
         | 
| 1085 | 
            +
                    t_end=None,
         | 
| 1086 | 
            +
                    order=2,
         | 
| 1087 | 
            +
                    skip_type="time_uniform",
         | 
| 1088 | 
            +
                    method="multistep",
         | 
| 1089 | 
            +
                    lower_order_final=True,
         | 
| 1090 | 
            +
                    denoise_to_zero=False,
         | 
| 1091 | 
            +
                    solver_type="dpmsolver",
         | 
| 1092 | 
            +
                    atol=0.0078,
         | 
| 1093 | 
            +
                    rtol=0.05,
         | 
| 1094 | 
            +
                    return_intermediate=False,
         | 
| 1095 | 
            +
                    flow_shift=1.0,
         | 
| 1096 | 
            +
                ):
         | 
| 1097 | 
            +
                    """
         | 
| 1098 | 
            +
                    Compute the sample at time `t_end` by DPM-Solver, given the initial `x` at time `t_start`.
         | 
| 1099 | 
            +
             | 
| 1100 | 
            +
                    =====================================================
         | 
| 1101 | 
            +
             | 
| 1102 | 
            +
                    We support the following algorithms for both noise prediction model and data prediction model:
         | 
| 1103 | 
            +
                        - 'singlestep':
         | 
| 1104 | 
            +
                            Singlestep DPM-Solver (i.e. "DPM-Solver-fast" in the paper), which combines different orders of singlestep DPM-Solver.
         | 
| 1105 | 
            +
                            We combine all the singlestep solvers with order <= `order` to use up all the function evaluations (steps).
         | 
| 1106 | 
            +
                            The total number of function evaluations (NFE) == `steps`.
         | 
| 1107 | 
            +
                            Given a fixed NFE == `steps`, the sampling procedure is:
         | 
| 1108 | 
            +
                                - If `order` == 1:
         | 
| 1109 | 
            +
                                    - Denote K = steps. We use K steps of DPM-Solver-1 (i.e. DDIM).
         | 
| 1110 | 
            +
                                - If `order` == 2:
         | 
| 1111 | 
            +
                                    - Denote K = (steps // 2) + (steps % 2). We take K intermediate time steps for sampling.
         | 
| 1112 | 
            +
                                    - If steps % 2 == 0, we use K steps of singlestep DPM-Solver-2.
         | 
| 1113 | 
            +
                                    - If steps % 2 == 1, we use (K - 1) steps of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
         | 
| 1114 | 
            +
                                - If `order` == 3:
         | 
| 1115 | 
            +
                                    - Denote K = (steps // 3 + 1). We take K intermediate time steps for sampling.
         | 
| 1116 | 
            +
                                    - If steps % 3 == 0, we use (K - 2) steps of singlestep DPM-Solver-3, and 1 step of singlestep DPM-Solver-2 and 1 step of DPM-Solver-1.
         | 
| 1117 | 
            +
                                    - If steps % 3 == 1, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of DPM-Solver-1.
         | 
| 1118 | 
            +
                                    - If steps % 3 == 2, we use (K - 1) steps of singlestep DPM-Solver-3 and 1 step of singlestep DPM-Solver-2.
         | 
| 1119 | 
            +
                        - 'multistep':
         | 
| 1120 | 
            +
                            Multistep DPM-Solver with the order of `order`. The total number of function evaluations (NFE) == `steps`.
         | 
| 1121 | 
            +
                            We initialize the first `order` values by lower order multistep solvers.
         | 
| 1122 | 
            +
                            Given a fixed NFE == `steps`, the sampling procedure is:
         | 
| 1123 | 
            +
                                Denote K = steps.
         | 
| 1124 | 
            +
                                - If `order` == 1:
         | 
| 1125 | 
            +
                                    - We use K steps of DPM-Solver-1 (i.e. DDIM).
         | 
| 1126 | 
            +
                                - If `order` == 2:
         | 
| 1127 | 
            +
                                    - We firstly use 1 step of DPM-Solver-1, then use (K - 1) step of multistep DPM-Solver-2.
         | 
| 1128 | 
            +
                                - If `order` == 3:
         | 
| 1129 | 
            +
                                    - We firstly use 1 step of DPM-Solver-1, then 1 step of multistep DPM-Solver-2, then (K - 2) step of multistep DPM-Solver-3.
         | 
| 1130 | 
            +
                        - 'singlestep_fixed':
         | 
| 1131 | 
            +
                            Fixed order singlestep DPM-Solver (i.e. DPM-Solver-1 or singlestep DPM-Solver-2 or singlestep DPM-Solver-3).
         | 
| 1132 | 
            +
                            We use singlestep DPM-Solver-`order` for `order`=1 or 2 or 3, with total [`steps` // `order`] * `order` NFE.
         | 
| 1133 | 
            +
                        - 'adaptive':
         | 
| 1134 | 
            +
                            Adaptive step size DPM-Solver (i.e. "DPM-Solver-12" and "DPM-Solver-23" in the paper).
         | 
| 1135 | 
            +
                            We ignore `steps` and use adaptive step size DPM-Solver with a higher order of `order`.
         | 
| 1136 | 
            +
                            You can adjust the absolute tolerance `atol` and the relative tolerance `rtol` to balance the computatation costs
         | 
| 1137 | 
            +
                            (NFE) and the sample quality.
         | 
| 1138 | 
            +
                                - If `order` == 2, we use DPM-Solver-12 which combines DPM-Solver-1 and singlestep DPM-Solver-2.
         | 
| 1139 | 
            +
                                - If `order` == 3, we use DPM-Solver-23 which combines singlestep DPM-Solver-2 and singlestep DPM-Solver-3.
         | 
| 1140 | 
            +
             | 
| 1141 | 
            +
                    =====================================================
         | 
| 1142 | 
            +
             | 
| 1143 | 
            +
                    Some advices for choosing the algorithm:
         | 
| 1144 | 
            +
                        - For **unconditional sampling** or **guided sampling with small guidance scale** by DPMs:
         | 
| 1145 | 
            +
                            Use singlestep DPM-Solver or DPM-Solver++ ("DPM-Solver-fast" in the paper) with `order = 3`.
         | 
| 1146 | 
            +
                            e.g., DPM-Solver:
         | 
| 1147 | 
            +
                                >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver")
         | 
| 1148 | 
            +
                                >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
         | 
| 1149 | 
            +
                                        skip_type='time_uniform', method='singlestep')
         | 
| 1150 | 
            +
                            e.g., DPM-Solver++:
         | 
| 1151 | 
            +
                                >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
         | 
| 1152 | 
            +
                                >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=3,
         | 
| 1153 | 
            +
                                        skip_type='time_uniform', method='singlestep')
         | 
| 1154 | 
            +
                        - For **guided sampling with large guidance scale** by DPMs:
         | 
| 1155 | 
            +
                            Use multistep DPM-Solver with `algorithm_type="dpmsolver++"` and `order = 2`.
         | 
| 1156 | 
            +
                            e.g.
         | 
| 1157 | 
            +
                                >>> dpm_solver = DPM_Solver(model_fn, noise_schedule, algorithm_type="dpmsolver++")
         | 
| 1158 | 
            +
                                >>> x_sample = dpm_solver.sample(x, steps=steps, t_start=t_start, t_end=t_end, order=2,
         | 
| 1159 | 
            +
                                        skip_type='time_uniform', method='multistep')
         | 
| 1160 | 
            +
             | 
| 1161 | 
            +
                    We support three types of `skip_type`:
         | 
| 1162 | 
            +
                        - 'logSNR': uniform logSNR for the time steps. **Recommended for low-resolutional images**
         | 
| 1163 | 
            +
                        - 'time_uniform': uniform time for the time steps. **Recommended for high-resolutional images**.
         | 
| 1164 | 
            +
                        - 'time_quadratic': quadratic time for the time steps.
         | 
| 1165 | 
            +
             | 
| 1166 | 
            +
                    =====================================================
         | 
| 1167 | 
            +
                    Args:
         | 
| 1168 | 
            +
                        x: A pytorch tensor. The initial value at time `t_start`
         | 
| 1169 | 
            +
                            e.g. if `t_start` == T, then `x` is a sample from the standard normal distribution.
         | 
| 1170 | 
            +
                        steps: A `int`. The total number of function evaluations (NFE).
         | 
| 1171 | 
            +
                        t_start: A `float`. The starting time of the sampling.
         | 
| 1172 | 
            +
                            If `T` is None, we use self.noise_schedule.T (default is 1.0).
         | 
| 1173 | 
            +
                        t_end: A `float`. The ending time of the sampling.
         | 
| 1174 | 
            +
                            If `t_end` is None, we use 1. / self.noise_schedule.total_N.
         | 
| 1175 | 
            +
                            e.g. if total_N == 1000, we have `t_end` == 1e-3.
         | 
| 1176 | 
            +
                            For discrete-time DPMs:
         | 
| 1177 | 
            +
                                - We recommend `t_end` == 1. / self.noise_schedule.total_N.
         | 
| 1178 | 
            +
                            For continuous-time DPMs:
         | 
| 1179 | 
            +
                                - We recommend `t_end` == 1e-3 when `steps` <= 15; and `t_end` == 1e-4 when `steps` > 15.
         | 
| 1180 | 
            +
                        order: A `int`. The order of DPM-Solver.
         | 
| 1181 | 
            +
                        skip_type: A `str`. The type for the spacing of the time steps. 'time_uniform' or 'logSNR' or 'time_quadratic'.
         | 
| 1182 | 
            +
                        method: A `str`. The method for sampling. 'singlestep' or 'multistep' or 'singlestep_fixed' or 'adaptive'.
         | 
| 1183 | 
            +
                        denoise_to_zero: A `bool`. Whether to denoise to time 0 at the final step.
         | 
| 1184 | 
            +
                            Default is `False`. If `denoise_to_zero` is `True`, the total NFE is (`steps` + 1).
         | 
| 1185 | 
            +
             | 
| 1186 | 
            +
                            This trick is firstly proposed by DDPM (https://arxiv.org/abs/2006.11239) and
         | 
| 1187 | 
            +
                            score_sde (https://arxiv.org/abs/2011.13456). Such trick can improve the FID
         | 
| 1188 | 
            +
                            for diffusion models sampling by diffusion SDEs for low-resolutional images
         | 
| 1189 | 
            +
                            (such as CIFAR-10). However, we observed that such trick does not matter for
         | 
| 1190 | 
            +
                            high-resolutional images. As it needs an additional NFE, we do not recommend
         | 
| 1191 | 
            +
                            it for high-resolutional images.
         | 
| 1192 | 
            +
                        lower_order_final: A `bool`. Whether to use lower order solvers at the final steps.
         | 
| 1193 | 
            +
                            Only valid for `method=multistep` and `steps < 15`. We empirically find that
         | 
| 1194 | 
            +
                            this trick is a key to stabilizing the sampling by DPM-Solver with very few steps
         | 
| 1195 | 
            +
                            (especially for steps <= 10). So we recommend to set it to be `True`.
         | 
| 1196 | 
            +
                        solver_type: A `str`. The taylor expansion type for the solver. `dpmsolver` or `taylor`. We recommend `dpmsolver`.
         | 
| 1197 | 
            +
                        atol: A `float`. The absolute tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
         | 
| 1198 | 
            +
                        rtol: A `float`. The relative tolerance of the adaptive step size solver. Valid when `method` == 'adaptive'.
         | 
| 1199 | 
            +
                        return_intermediate: A `bool`. Whether to save the xt at each step.
         | 
| 1200 | 
            +
                            When set to `True`, method returns a tuple (x0, intermediates); when set to False, method returns only x0.
         | 
| 1201 | 
            +
                    Returns:
         | 
| 1202 | 
            +
                        x_end: A pytorch tensor. The approximated solution at time `t_end`.
         | 
| 1203 | 
            +
             | 
| 1204 | 
            +
                    """
         | 
| 1205 | 
            +
                    t_0 = 1.0 / self.noise_schedule.total_N if t_end is None else t_end
         | 
| 1206 | 
            +
                    t_T = self.noise_schedule.T if t_start is None else t_start
         | 
| 1207 | 
            +
                    assert (
         | 
| 1208 | 
            +
                        t_0 > 0 and t_T > 0
         | 
| 1209 | 
            +
                    ), "Time range needs to be greater than 0. For discrete-time DPMs, it needs to be in [1 / N, 1], where N is the length of betas array"
         | 
| 1210 | 
            +
                    if return_intermediate:
         | 
| 1211 | 
            +
                        assert method in [
         | 
| 1212 | 
            +
                            "multistep",
         | 
| 1213 | 
            +
                            "singlestep",
         | 
| 1214 | 
            +
                            "singlestep_fixed",
         | 
| 1215 | 
            +
                        ], "Cannot use adaptive solver when saving intermediate values"
         | 
| 1216 | 
            +
                    if self.correcting_xt_fn is not None:
         | 
| 1217 | 
            +
                        assert method in [
         | 
| 1218 | 
            +
                            "multistep",
         | 
| 1219 | 
            +
                            "singlestep",
         | 
| 1220 | 
            +
                            "singlestep_fixed",
         | 
| 1221 | 
            +
                        ], "Cannot use adaptive solver when correcting_xt_fn is not None"
         | 
| 1222 | 
            +
                    device = x.device
         | 
| 1223 | 
            +
                    intermediates = []
         | 
| 1224 | 
            +
                    with torch.no_grad():
         | 
| 1225 | 
            +
                        if method == "adaptive":
         | 
| 1226 | 
            +
                            x = self.dpm_solver_adaptive(
         | 
| 1227 | 
            +
                                x, order=order, t_T=t_T, t_0=t_0, atol=atol, rtol=rtol, solver_type=solver_type
         | 
| 1228 | 
            +
                            )
         | 
| 1229 | 
            +
                        elif method == "multistep":
         | 
| 1230 | 
            +
                            assert steps >= order
         | 
| 1231 | 
            +
                            timesteps = self.get_time_steps(
         | 
| 1232 | 
            +
                                skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device, shift=flow_shift
         | 
| 1233 | 
            +
                            )
         | 
| 1234 | 
            +
                            assert timesteps.shape[0] - 1 == steps
         | 
| 1235 | 
            +
                            # Init the initial values.
         | 
| 1236 | 
            +
                            step = 0
         | 
| 1237 | 
            +
                            t = timesteps[step]
         | 
| 1238 | 
            +
                            t_prev_list = [t]
         | 
| 1239 | 
            +
                            model_prev_list = [self.model_fn(x, t)]
         | 
| 1240 | 
            +
                            if self.correcting_xt_fn is not None:
         | 
| 1241 | 
            +
                                x = self.correcting_xt_fn(x, t, step)
         | 
| 1242 | 
            +
                            if return_intermediate:
         | 
| 1243 | 
            +
                                intermediates.append(x)
         | 
| 1244 | 
            +
                            self.update_progress(step + 1, len(timesteps))
         | 
| 1245 | 
            +
                            # Init the first `order` values by lower order multistep DPM-Solver.
         | 
| 1246 | 
            +
                            for step in range(1, order):
         | 
| 1247 | 
            +
                                t = timesteps[step]
         | 
| 1248 | 
            +
                                x = self.multistep_dpm_solver_update(
         | 
| 1249 | 
            +
                                    x, model_prev_list, t_prev_list, t, step, solver_type=solver_type
         | 
| 1250 | 
            +
                                )
         | 
| 1251 | 
            +
                                if self.correcting_xt_fn is not None:
         | 
| 1252 | 
            +
                                    x = self.correcting_xt_fn(x, t, step)
         | 
| 1253 | 
            +
                                if return_intermediate:
         | 
| 1254 | 
            +
                                    intermediates.append(x)
         | 
| 1255 | 
            +
                                t_prev_list.append(t)
         | 
| 1256 | 
            +
                                model_prev_list.append(self.model_fn(x, t))
         | 
| 1257 | 
            +
                                # update progress bar
         | 
| 1258 | 
            +
                                self.update_progress(step + 1, len(timesteps))
         | 
| 1259 | 
            +
                            # Compute the remaining values by `order`-th order multistep DPM-Solver.
         | 
| 1260 | 
            +
                            for step in tqdm(range(order, steps + 1), disable=os.getenv("DPM_TQDM", "False") == "True"):
         | 
| 1261 | 
            +
                                t = timesteps[step]
         | 
| 1262 | 
            +
                                # We only use lower order for steps < 10
         | 
| 1263 | 
            +
                                # if lower_order_final and steps < 10:
         | 
| 1264 | 
            +
                                if lower_order_final:  # recommended by Shuchen Xue
         | 
| 1265 | 
            +
                                    step_order = min(order, steps + 1 - step)
         | 
| 1266 | 
            +
                                else:
         | 
| 1267 | 
            +
                                    step_order = order
         | 
| 1268 | 
            +
                                x = self.multistep_dpm_solver_update(
         | 
| 1269 | 
            +
                                    x, model_prev_list, t_prev_list, t, step_order, solver_type=solver_type
         | 
| 1270 | 
            +
                                )
         | 
| 1271 | 
            +
                                if self.correcting_xt_fn is not None:
         | 
| 1272 | 
            +
                                    x = self.correcting_xt_fn(x, t, step)
         | 
| 1273 | 
            +
                                if return_intermediate:
         | 
| 1274 | 
            +
                                    intermediates.append(x)
         | 
| 1275 | 
            +
                                for i in range(order - 1):
         | 
| 1276 | 
            +
                                    t_prev_list[i] = t_prev_list[i + 1]
         | 
| 1277 | 
            +
                                    model_prev_list[i] = model_prev_list[i + 1]
         | 
| 1278 | 
            +
                                t_prev_list[-1] = t
         | 
| 1279 | 
            +
                                # We do not need to evaluate the final model value.
         | 
| 1280 | 
            +
                                if step < steps:
         | 
| 1281 | 
            +
                                    model_prev_list[-1] = self.model_fn(x, t)
         | 
| 1282 | 
            +
                                # update progress bar
         | 
| 1283 | 
            +
                                self.update_progress(step + 1, len(timesteps))
         | 
| 1284 | 
            +
                        elif method in ["singlestep", "singlestep_fixed"]:
         | 
| 1285 | 
            +
                            if method == "singlestep":
         | 
| 1286 | 
            +
                                timesteps_outer, orders = self.get_orders_and_timesteps_for_singlestep_solver(
         | 
| 1287 | 
            +
                                    steps=steps, order=order, skip_type=skip_type, t_T=t_T, t_0=t_0, device=device
         | 
| 1288 | 
            +
                                )
         | 
| 1289 | 
            +
                            elif method == "singlestep_fixed":
         | 
| 1290 | 
            +
                                K = steps // order
         | 
| 1291 | 
            +
                                orders = [
         | 
| 1292 | 
            +
                                    order,
         | 
| 1293 | 
            +
                                ] * K
         | 
| 1294 | 
            +
                                timesteps_outer = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=K, device=device)
         | 
| 1295 | 
            +
                            for step, order in enumerate(orders):
         | 
| 1296 | 
            +
                                s, t = timesteps_outer[step], timesteps_outer[step + 1]
         | 
| 1297 | 
            +
                                timesteps_inner = self.get_time_steps(
         | 
| 1298 | 
            +
                                    skip_type=skip_type, t_T=s.item(), t_0=t.item(), N=order, device=device
         | 
| 1299 | 
            +
                                )
         | 
| 1300 | 
            +
                                lambda_inner = self.noise_schedule.marginal_lambda(timesteps_inner)
         | 
| 1301 | 
            +
                                h = lambda_inner[-1] - lambda_inner[0]
         | 
| 1302 | 
            +
                                r1 = None if order <= 1 else (lambda_inner[1] - lambda_inner[0]) / h
         | 
| 1303 | 
            +
                                r2 = None if order <= 2 else (lambda_inner[2] - lambda_inner[0]) / h
         | 
| 1304 | 
            +
                                x = self.singlestep_dpm_solver_update(x, s, t, order, solver_type=solver_type, r1=r1, r2=r2)
         | 
| 1305 | 
            +
                                if self.correcting_xt_fn is not None:
         | 
| 1306 | 
            +
                                    x = self.correcting_xt_fn(x, t, step)
         | 
| 1307 | 
            +
                                if return_intermediate:
         | 
| 1308 | 
            +
                                    intermediates.append(x)
         | 
| 1309 | 
            +
                                self.update_progress(step + 1, len(timesteps_outer))
         | 
| 1310 | 
            +
                        else:
         | 
| 1311 | 
            +
                            raise ValueError(f"Got wrong method {method}")
         | 
| 1312 | 
            +
                        if denoise_to_zero:
         | 
| 1313 | 
            +
                            t = torch.ones((1,)).to(device) * t_0
         | 
| 1314 | 
            +
                            x = self.denoise_to_zero_fn(x, t)
         | 
| 1315 | 
            +
                            if self.correcting_xt_fn is not None:
         | 
| 1316 | 
            +
                                x = self.correcting_xt_fn(x, t, step + 1)
         | 
| 1317 | 
            +
                            if return_intermediate:
         | 
| 1318 | 
            +
                                intermediates.append(x)
         | 
| 1319 | 
            +
                    if return_intermediate:
         | 
| 1320 | 
            +
                        return x, intermediates
         | 
| 1321 | 
            +
                    else:
         | 
| 1322 | 
            +
                        return x
         | 
| 1323 | 
            +
             | 
| 1324 | 
            +
                
         | 
| 1325 | 
            +
            #############################################################
         | 
| 1326 | 
            +
            # other utility functions
         | 
| 1327 | 
            +
            #############################################################
         | 
| 1328 | 
            +
             | 
| 1329 | 
            +
             | 
| 1330 | 
            +
            def interpolate_fn(x, xp, yp):
         | 
| 1331 | 
            +
                """
         | 
| 1332 | 
            +
                A piecewise linear function y = f(x), using xp and yp as keypoints.
         | 
| 1333 | 
            +
                We implement f(x) in a differentiable way (i.e. applicable for autograd).
         | 
| 1334 | 
            +
                The function f(x) is well-defined for all x-axis. (For x beyond the bounds of xp, we use the outmost points of xp to define the linear function.)
         | 
| 1335 | 
            +
             | 
| 1336 | 
            +
                Args:
         | 
| 1337 | 
            +
                    x: PyTorch tensor with shape [N, C], where N is the batch size, C is the number of channels (we use C = 1 for DPM-Solver).
         | 
| 1338 | 
            +
                    xp: PyTorch tensor with shape [C, K], where K is the number of keypoints.
         | 
| 1339 | 
            +
                    yp: PyTorch tensor with shape [C, K].
         | 
| 1340 | 
            +
                Returns:
         | 
| 1341 | 
            +
                    The function values f(x), with shape [N, C].
         | 
| 1342 | 
            +
                """
         | 
| 1343 | 
            +
                N, K = x.shape[0], xp.shape[1]
         | 
| 1344 | 
            +
                all_x = torch.cat([x.unsqueeze(2), xp.unsqueeze(0).repeat((N, 1, 1))], dim=2)
         | 
| 1345 | 
            +
                sorted_all_x, x_indices = torch.sort(all_x, dim=2)
         | 
| 1346 | 
            +
                x_idx = torch.argmin(x_indices, dim=2)
         | 
| 1347 | 
            +
                cand_start_idx = x_idx - 1
         | 
| 1348 | 
            +
                start_idx = torch.where(
         | 
| 1349 | 
            +
                    torch.eq(x_idx, 0),
         | 
| 1350 | 
            +
                    torch.tensor(1, device=x.device),
         | 
| 1351 | 
            +
                    torch.where(
         | 
| 1352 | 
            +
                        torch.eq(x_idx, K),
         | 
| 1353 | 
            +
                        torch.tensor(K - 2, device=x.device),
         | 
| 1354 | 
            +
                        cand_start_idx,
         | 
| 1355 | 
            +
                    ),
         | 
| 1356 | 
            +
                )
         | 
| 1357 | 
            +
                end_idx = torch.where(torch.eq(start_idx, cand_start_idx), start_idx + 2, start_idx + 1)
         | 
| 1358 | 
            +
                start_x = torch.gather(sorted_all_x, dim=2, index=start_idx.unsqueeze(2)).squeeze(2)
         | 
| 1359 | 
            +
                end_x = torch.gather(sorted_all_x, dim=2, index=end_idx.unsqueeze(2)).squeeze(2)
         | 
| 1360 | 
            +
                start_idx2 = torch.where(
         | 
| 1361 | 
            +
                    torch.eq(x_idx, 0),
         | 
| 1362 | 
            +
                    torch.tensor(0, device=x.device),
         | 
| 1363 | 
            +
                    torch.where(
         | 
| 1364 | 
            +
                        torch.eq(x_idx, K),
         | 
| 1365 | 
            +
                        torch.tensor(K - 2, device=x.device),
         | 
| 1366 | 
            +
                        cand_start_idx,
         | 
| 1367 | 
            +
                    ),
         | 
| 1368 | 
            +
                )
         | 
| 1369 | 
            +
                y_positions_expanded = yp.unsqueeze(0).expand(N, -1, -1)
         | 
| 1370 | 
            +
                start_y = torch.gather(y_positions_expanded, dim=2, index=start_idx2.unsqueeze(2)).squeeze(2)
         | 
| 1371 | 
            +
                end_y = torch.gather(y_positions_expanded, dim=2, index=(start_idx2 + 1).unsqueeze(2)).squeeze(2)
         | 
| 1372 | 
            +
                cand = start_y + (x - start_x) * (end_y - start_y) / (end_x - start_x)
         | 
| 1373 | 
            +
                return cand
         | 
| 1374 | 
            +
             | 
| 1375 | 
            +
             | 
| 1376 | 
            +
            def expand_dims(v, dims):
         | 
| 1377 | 
            +
                """
         | 
| 1378 | 
            +
                Expand the tensor `v` to the dim `dims`.
         | 
| 1379 | 
            +
             | 
| 1380 | 
            +
                Args:
         | 
| 1381 | 
            +
                    `v`: a PyTorch tensor with shape [N].
         | 
| 1382 | 
            +
                    `dim`: a `int`.
         | 
| 1383 | 
            +
                Returns:
         | 
| 1384 | 
            +
                    a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
         | 
| 1385 | 
            +
                """
         | 
| 1386 | 
            +
                return v[(...,) + (None,) * (dims - 1)]
         | 
    	
        transport/integrators.py
    ADDED
    
    | @@ -0,0 +1,122 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch as th
         | 
| 2 | 
            +
            from torchdiffeq import odeint
         | 
| 3 | 
            +
            from .utils import time_shift, get_lin_function
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            class sde:
         | 
| 6 | 
            +
                """SDE solver class"""
         | 
| 7 | 
            +
             | 
| 8 | 
            +
                def __init__(
         | 
| 9 | 
            +
                    self,
         | 
| 10 | 
            +
                    drift,
         | 
| 11 | 
            +
                    diffusion,
         | 
| 12 | 
            +
                    *,
         | 
| 13 | 
            +
                    t0,
         | 
| 14 | 
            +
                    t1,
         | 
| 15 | 
            +
                    num_steps,
         | 
| 16 | 
            +
                    sampler_type,
         | 
| 17 | 
            +
                ):
         | 
| 18 | 
            +
                    assert t0 < t1, "SDE sampler has to be in forward time"
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                    self.num_timesteps = num_steps
         | 
| 21 | 
            +
                    self.t = th.linspace(t0, t1, num_steps)
         | 
| 22 | 
            +
                    self.dt = self.t[1] - self.t[0]
         | 
| 23 | 
            +
                    self.drift = drift
         | 
| 24 | 
            +
                    self.diffusion = diffusion
         | 
| 25 | 
            +
                    self.sampler_type = sampler_type
         | 
| 26 | 
            +
             | 
| 27 | 
            +
                def __Euler_Maruyama_step(self, x, mean_x, t, model, **model_kwargs):
         | 
| 28 | 
            +
                    w_cur = th.randn(x.size()).to(x)
         | 
| 29 | 
            +
                    t = th.ones(x.size(0)).to(x) * t
         | 
| 30 | 
            +
                    dw = w_cur * th.sqrt(self.dt)
         | 
| 31 | 
            +
                    drift = self.drift(x, t, model, **model_kwargs)
         | 
| 32 | 
            +
                    diffusion = self.diffusion(x, t)
         | 
| 33 | 
            +
                    mean_x = x + drift * self.dt
         | 
| 34 | 
            +
                    x = mean_x + th.sqrt(2 * diffusion) * dw
         | 
| 35 | 
            +
                    return x, mean_x
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def __Heun_step(self, x, _, t, model, **model_kwargs):
         | 
| 38 | 
            +
                    w_cur = th.randn(x.size()).to(x)
         | 
| 39 | 
            +
                    dw = w_cur * th.sqrt(self.dt)
         | 
| 40 | 
            +
                    t_cur = th.ones(x.size(0)).to(x) * t
         | 
| 41 | 
            +
                    diffusion = self.diffusion(x, t_cur)
         | 
| 42 | 
            +
                    xhat = x + th.sqrt(2 * diffusion) * dw
         | 
| 43 | 
            +
                    K1 = self.drift(xhat, t_cur, model, **model_kwargs)
         | 
| 44 | 
            +
                    xp = xhat + self.dt * K1
         | 
| 45 | 
            +
                    K2 = self.drift(xp, t_cur + self.dt, model, **model_kwargs)
         | 
| 46 | 
            +
                    return (
         | 
| 47 | 
            +
                        xhat + 0.5 * self.dt * (K1 + K2),
         | 
| 48 | 
            +
                        xhat,
         | 
| 49 | 
            +
                    )  # at last time point we do not perform the heun step
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                def __forward_fn(self):
         | 
| 52 | 
            +
                    """TODO: generalize here by adding all private functions ending with steps to it"""
         | 
| 53 | 
            +
                    sampler_dict = {
         | 
| 54 | 
            +
                        "Euler": self.__Euler_Maruyama_step,
         | 
| 55 | 
            +
                        "Heun": self.__Heun_step,
         | 
| 56 | 
            +
                    }
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    try:
         | 
| 59 | 
            +
                        sampler = sampler_dict[self.sampler_type]
         | 
| 60 | 
            +
                    except:
         | 
| 61 | 
            +
                        raise NotImplementedError("Smapler type not implemented.")
         | 
| 62 | 
            +
             | 
| 63 | 
            +
                    return sampler
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                def sample(self, init, model, **model_kwargs):
         | 
| 66 | 
            +
                    """forward loop of sde"""
         | 
| 67 | 
            +
                    x = init
         | 
| 68 | 
            +
                    mean_x = init
         | 
| 69 | 
            +
                    samples = []
         | 
| 70 | 
            +
                    sampler = self.__forward_fn()
         | 
| 71 | 
            +
                    for ti in self.t[:-1]:
         | 
| 72 | 
            +
                        with th.no_grad():
         | 
| 73 | 
            +
                            x, mean_x = sampler(x, mean_x, ti, model, **model_kwargs)
         | 
| 74 | 
            +
                            samples.append(x)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                    return samples
         | 
| 77 | 
            +
             | 
| 78 | 
            +
             | 
| 79 | 
            +
            class ode:
         | 
| 80 | 
            +
                """ODE solver class"""
         | 
| 81 | 
            +
             | 
| 82 | 
            +
                def __init__(
         | 
| 83 | 
            +
                    self,
         | 
| 84 | 
            +
                    drift,
         | 
| 85 | 
            +
                    *,
         | 
| 86 | 
            +
                    t0,
         | 
| 87 | 
            +
                    t1,
         | 
| 88 | 
            +
                    sampler_type,
         | 
| 89 | 
            +
                    num_steps,
         | 
| 90 | 
            +
                    atol,
         | 
| 91 | 
            +
                    rtol,
         | 
| 92 | 
            +
                    do_shift=False,
         | 
| 93 | 
            +
                    time_shifting_factor=None,
         | 
| 94 | 
            +
                ):
         | 
| 95 | 
            +
                    assert t0 < t1, "ODE sampler has to be in forward time"
         | 
| 96 | 
            +
             | 
| 97 | 
            +
                    self.drift = drift
         | 
| 98 | 
            +
                    self.do_shift = do_shift
         | 
| 99 | 
            +
                    self.t = th.linspace(t0, t1, num_steps)
         | 
| 100 | 
            +
                    if time_shifting_factor:
         | 
| 101 | 
            +
                        self.t = self.t / (self.t + time_shifting_factor - time_shifting_factor * self.t)
         | 
| 102 | 
            +
                    self.atol = atol
         | 
| 103 | 
            +
                    self.rtol = rtol
         | 
| 104 | 
            +
                    self.sampler_type = sampler_type
         | 
| 105 | 
            +
             | 
| 106 | 
            +
                def sample(self, x, model, **model_kwargs):
         | 
| 107 | 
            +
                    x = x.float()
         | 
| 108 | 
            +
                    device = x[0].device if isinstance(x, tuple) else x.device
         | 
| 109 | 
            +
             | 
| 110 | 
            +
                    def _fn(t, x):
         | 
| 111 | 
            +
                        t = th.ones(x[0].size(0)).to(device) * t if isinstance(x, tuple) else th.ones(x.size(0)).to(device) * t
         | 
| 112 | 
            +
                        model_output = self.drift(x, t, model, **model_kwargs).float()
         | 
| 113 | 
            +
                        return model_output
         | 
| 114 | 
            +
                    
         | 
| 115 | 
            +
                    t = self.t.to(device)
         | 
| 116 | 
            +
                    if self.do_shift:
         | 
| 117 | 
            +
                        mu = get_lin_function(y1=0.5, y2=1.15)(x.shape[1])
         | 
| 118 | 
            +
                        t = time_shift(mu, 1.0, t)
         | 
| 119 | 
            +
                    atol = [self.atol] * len(x) if isinstance(x, tuple) else [self.atol]
         | 
| 120 | 
            +
                    rtol = [self.rtol] * len(x) if isinstance(x, tuple) else [self.rtol]
         | 
| 121 | 
            +
                    samples = odeint(_fn, x, t, method=self.sampler_type, atol=atol, rtol=rtol)
         | 
| 122 | 
            +
                    return samples
         | 
    	
        transport/path.py
    ADDED
    
    | @@ -0,0 +1,201 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import numpy as np
         | 
| 2 | 
            +
            import torch as th
         | 
| 3 | 
            +
             | 
| 4 | 
            +
             | 
| 5 | 
            +
            def expand_t_like_x(t, x):
         | 
| 6 | 
            +
                """Function to reshape time t to broadcastable dimension of x
         | 
| 7 | 
            +
                Args:
         | 
| 8 | 
            +
                  t: [batch_dim,], time vector
         | 
| 9 | 
            +
                  x: [batch_dim,...], data point
         | 
| 10 | 
            +
                """
         | 
| 11 | 
            +
                dims = [1] * len(x[0].size())
         | 
| 12 | 
            +
                t = t.view(t.size(0), *dims)
         | 
| 13 | 
            +
                return t
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            #################### Coupling Plans ####################
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            class ICPlan:
         | 
| 20 | 
            +
                """Linear Coupling Plan"""
         | 
| 21 | 
            +
             | 
| 22 | 
            +
                def __init__(self, sigma=0.0):
         | 
| 23 | 
            +
                    self.sigma = sigma
         | 
| 24 | 
            +
             | 
| 25 | 
            +
                def compute_alpha_t(self, t):
         | 
| 26 | 
            +
                    """Compute the data coefficient along the path"""
         | 
| 27 | 
            +
                    return t, 1
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                def compute_sigma_t(self, t):
         | 
| 30 | 
            +
                    """Compute the noise coefficient along the path"""
         | 
| 31 | 
            +
                    return 1 - t, -1
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                def compute_d_alpha_alpha_ratio_t(self, t):
         | 
| 34 | 
            +
                    """Compute the ratio between d_alpha and alpha"""
         | 
| 35 | 
            +
                    return 1 / t
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                def compute_drift(self, x, t):
         | 
| 38 | 
            +
                    """We always output sde according to score parametrization;"""
         | 
| 39 | 
            +
                    t = expand_t_like_x(t, x)
         | 
| 40 | 
            +
                    alpha_ratio = self.compute_d_alpha_alpha_ratio_t(t)
         | 
| 41 | 
            +
                    sigma_t, d_sigma_t = self.compute_sigma_t(t)
         | 
| 42 | 
            +
                    drift = alpha_ratio * x
         | 
| 43 | 
            +
                    diffusion = alpha_ratio * (sigma_t**2) - sigma_t * d_sigma_t
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                    return -drift, diffusion
         | 
| 46 | 
            +
             | 
| 47 | 
            +
                def compute_diffusion(self, x, t, form="constant", norm=1.0):
         | 
| 48 | 
            +
                    """Compute the diffusion term of the SDE
         | 
| 49 | 
            +
                    Args:
         | 
| 50 | 
            +
                      x: [batch_dim, ...], data point
         | 
| 51 | 
            +
                      t: [batch_dim,], time vector
         | 
| 52 | 
            +
                      form: str, form of the diffusion term
         | 
| 53 | 
            +
                      norm: float, norm of the diffusion term
         | 
| 54 | 
            +
                    """
         | 
| 55 | 
            +
                    t = expand_t_like_x(t, x)
         | 
| 56 | 
            +
                    choices = {
         | 
| 57 | 
            +
                        "constant": norm,
         | 
| 58 | 
            +
                        "SBDM": norm * self.compute_drift(x, t)[1],
         | 
| 59 | 
            +
                        "sigma": norm * self.compute_sigma_t(t)[0],
         | 
| 60 | 
            +
                        "linear": norm * (1 - t),
         | 
| 61 | 
            +
                        "decreasing": 0.25 * (norm * th.cos(np.pi * t) + 1) ** 2,
         | 
| 62 | 
            +
                        "inccreasing-decreasing": norm * th.sin(np.pi * t) ** 2,
         | 
| 63 | 
            +
                    }
         | 
| 64 | 
            +
             | 
| 65 | 
            +
                    try:
         | 
| 66 | 
            +
                        diffusion = choices[form]
         | 
| 67 | 
            +
                    except KeyError:
         | 
| 68 | 
            +
                        raise NotImplementedError(f"Diffusion form {form} not implemented")
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    return diffusion
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def get_score_from_velocity(self, velocity, x, t):
         | 
| 73 | 
            +
                    """Wrapper function: transfrom velocity prediction model to score
         | 
| 74 | 
            +
                    Args:
         | 
| 75 | 
            +
                        velocity: [batch_dim, ...] shaped tensor; velocity model output
         | 
| 76 | 
            +
                        x: [batch_dim, ...] shaped tensor; x_t data point
         | 
| 77 | 
            +
                        t: [batch_dim,] time tensor
         | 
| 78 | 
            +
                    """
         | 
| 79 | 
            +
                    t = expand_t_like_x(t, x)
         | 
| 80 | 
            +
                    alpha_t, d_alpha_t = self.compute_alpha_t(t)
         | 
| 81 | 
            +
                    sigma_t, d_sigma_t = self.compute_sigma_t(t)
         | 
| 82 | 
            +
                    mean = x
         | 
| 83 | 
            +
                    reverse_alpha_ratio = alpha_t / d_alpha_t
         | 
| 84 | 
            +
                    var = sigma_t**2 - reverse_alpha_ratio * d_sigma_t * sigma_t
         | 
| 85 | 
            +
                    score = (reverse_alpha_ratio * velocity - mean) / var
         | 
| 86 | 
            +
                    return score
         | 
| 87 | 
            +
             | 
| 88 | 
            +
                def get_noise_from_velocity(self, velocity, x, t):
         | 
| 89 | 
            +
                    """Wrapper function: transfrom velocity prediction model to denoiser
         | 
| 90 | 
            +
                    Args:
         | 
| 91 | 
            +
                        velocity: [batch_dim, ...] shaped tensor; velocity model output
         | 
| 92 | 
            +
                        x: [batch_dim, ...] shaped tensor; x_t data point
         | 
| 93 | 
            +
                        t: [batch_dim,] time tensor
         | 
| 94 | 
            +
                    """
         | 
| 95 | 
            +
                    t = expand_t_like_x(t, x)
         | 
| 96 | 
            +
                    alpha_t, d_alpha_t = self.compute_alpha_t(t)
         | 
| 97 | 
            +
                    sigma_t, d_sigma_t = self.compute_sigma_t(t)
         | 
| 98 | 
            +
                    mean = x
         | 
| 99 | 
            +
                    reverse_alpha_ratio = alpha_t / d_alpha_t
         | 
| 100 | 
            +
                    var = reverse_alpha_ratio * d_sigma_t - sigma_t
         | 
| 101 | 
            +
                    noise = (reverse_alpha_ratio * velocity - mean) / var
         | 
| 102 | 
            +
                    return noise
         | 
| 103 | 
            +
             | 
| 104 | 
            +
                def get_velocity_from_score(self, score, x, t):
         | 
| 105 | 
            +
                    """Wrapper function: transfrom score prediction model to velocity
         | 
| 106 | 
            +
                    Args:
         | 
| 107 | 
            +
                        score: [batch_dim, ...] shaped tensor; score model output
         | 
| 108 | 
            +
                        x: [batch_dim, ...] shaped tensor; x_t data point
         | 
| 109 | 
            +
                        t: [batch_dim,] time tensor
         | 
| 110 | 
            +
                    """
         | 
| 111 | 
            +
                    t = expand_t_like_x(t, x)
         | 
| 112 | 
            +
                    drift, var = self.compute_drift(x, t)
         | 
| 113 | 
            +
                    velocity = var * score - drift
         | 
| 114 | 
            +
                    return velocity
         | 
| 115 | 
            +
             | 
| 116 | 
            +
                def compute_mu_t(self, t, x0, x1):
         | 
| 117 | 
            +
                    """Compute the mean of time-dependent density p_t"""
         | 
| 118 | 
            +
                    t = expand_t_like_x(t, x1)
         | 
| 119 | 
            +
                    alpha_t, _ = self.compute_alpha_t(t)
         | 
| 120 | 
            +
                    sigma_t, _ = self.compute_sigma_t(t)
         | 
| 121 | 
            +
                    if isinstance(x1, (list, tuple)):
         | 
| 122 | 
            +
                        return [alpha_t[i] * x1[i] + sigma_t[i] * x0[i] for i in range(len(x1))]
         | 
| 123 | 
            +
                    else:
         | 
| 124 | 
            +
                        return alpha_t * x1 + sigma_t * x0
         | 
| 125 | 
            +
             | 
| 126 | 
            +
                def compute_xt(self, t, x0, x1):
         | 
| 127 | 
            +
                    """Sample xt from time-dependent density p_t; rng is required"""
         | 
| 128 | 
            +
                    xt = self.compute_mu_t(t, x0, x1)
         | 
| 129 | 
            +
                    return xt
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def compute_ut(self, t, x0, x1, xt):
         | 
| 132 | 
            +
                    """Compute the vector field corresponding to p_t"""
         | 
| 133 | 
            +
                    t = expand_t_like_x(t, x1)
         | 
| 134 | 
            +
                    _, d_alpha_t = self.compute_alpha_t(t)
         | 
| 135 | 
            +
                    _, d_sigma_t = self.compute_sigma_t(t)
         | 
| 136 | 
            +
                    if isinstance(x1, (list, tuple)):
         | 
| 137 | 
            +
                        return [d_alpha_t * x1[i] + d_sigma_t * x0[i] for i in range(len(x1))]
         | 
| 138 | 
            +
                    else:
         | 
| 139 | 
            +
                        return d_alpha_t * x1 + d_sigma_t * x0
         | 
| 140 | 
            +
             | 
| 141 | 
            +
                def plan(self, t, x0, x1):
         | 
| 142 | 
            +
                    xt = self.compute_xt(t, x0, x1)
         | 
| 143 | 
            +
                    ut = self.compute_ut(t, x0, x1, xt)
         | 
| 144 | 
            +
                    return t, xt, ut
         | 
| 145 | 
            +
             | 
| 146 | 
            +
             | 
| 147 | 
            +
            class VPCPlan(ICPlan):
         | 
| 148 | 
            +
                """class for VP path flow matching"""
         | 
| 149 | 
            +
             | 
| 150 | 
            +
                def __init__(self, sigma_min=0.1, sigma_max=20.0):
         | 
| 151 | 
            +
                    self.sigma_min = sigma_min
         | 
| 152 | 
            +
                    self.sigma_max = sigma_max
         | 
| 153 | 
            +
                    self.log_mean_coeff = (
         | 
| 154 | 
            +
                        lambda t: -0.25 * ((1 - t) ** 2) * (self.sigma_max - self.sigma_min) - 0.5 * (1 - t) * self.sigma_min
         | 
| 155 | 
            +
                    )
         | 
| 156 | 
            +
                    self.d_log_mean_coeff = lambda t: 0.5 * (1 - t) * (self.sigma_max - self.sigma_min) + 0.5 * self.sigma_min
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def compute_alpha_t(self, t):
         | 
| 159 | 
            +
                    """Compute coefficient of x1"""
         | 
| 160 | 
            +
                    alpha_t = self.log_mean_coeff(t)
         | 
| 161 | 
            +
                    alpha_t = th.exp(alpha_t)
         | 
| 162 | 
            +
                    d_alpha_t = alpha_t * self.d_log_mean_coeff(t)
         | 
| 163 | 
            +
                    return alpha_t, d_alpha_t
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def compute_sigma_t(self, t):
         | 
| 166 | 
            +
                    """Compute coefficient of x0"""
         | 
| 167 | 
            +
                    p_sigma_t = 2 * self.log_mean_coeff(t)
         | 
| 168 | 
            +
                    sigma_t = th.sqrt(1 - th.exp(p_sigma_t))
         | 
| 169 | 
            +
                    d_sigma_t = th.exp(p_sigma_t) * (2 * self.d_log_mean_coeff(t)) / (-2 * sigma_t)
         | 
| 170 | 
            +
                    return sigma_t, d_sigma_t
         | 
| 171 | 
            +
             | 
| 172 | 
            +
                def compute_d_alpha_alpha_ratio_t(self, t):
         | 
| 173 | 
            +
                    """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
         | 
| 174 | 
            +
                    return self.d_log_mean_coeff(t)
         | 
| 175 | 
            +
             | 
| 176 | 
            +
                def compute_drift(self, x, t):
         | 
| 177 | 
            +
                    """Compute the drift term of the SDE"""
         | 
| 178 | 
            +
                    t = expand_t_like_x(t, x)
         | 
| 179 | 
            +
                    beta_t = self.sigma_min + (1 - t) * (self.sigma_max - self.sigma_min)
         | 
| 180 | 
            +
                    return -0.5 * beta_t * x, beta_t / 2
         | 
| 181 | 
            +
             | 
| 182 | 
            +
             | 
| 183 | 
            +
            class GVPCPlan(ICPlan):
         | 
| 184 | 
            +
                def __init__(self, sigma=0.0):
         | 
| 185 | 
            +
                    super().__init__(sigma)
         | 
| 186 | 
            +
             | 
| 187 | 
            +
                def compute_alpha_t(self, t):
         | 
| 188 | 
            +
                    """Compute coefficient of x1"""
         | 
| 189 | 
            +
                    alpha_t = th.sin(t * np.pi / 2)
         | 
| 190 | 
            +
                    d_alpha_t = np.pi / 2 * th.cos(t * np.pi / 2)
         | 
| 191 | 
            +
                    return alpha_t, d_alpha_t
         | 
| 192 | 
            +
             | 
| 193 | 
            +
                def compute_sigma_t(self, t):
         | 
| 194 | 
            +
                    """Compute coefficient of x0"""
         | 
| 195 | 
            +
                    sigma_t = th.cos(t * np.pi / 2)
         | 
| 196 | 
            +
                    d_sigma_t = -np.pi / 2 * th.sin(t * np.pi / 2)
         | 
| 197 | 
            +
                    return sigma_t, d_sigma_t
         | 
| 198 | 
            +
             | 
| 199 | 
            +
                def compute_d_alpha_alpha_ratio_t(self, t):
         | 
| 200 | 
            +
                    """Special purposed function for computing numerical stabled d_alpha_t / alpha_t"""
         | 
| 201 | 
            +
                    return np.pi / (2 * th.tan(t * np.pi / 2))
         | 
    	
        transport/transport.py
    ADDED
    
    | @@ -0,0 +1,490 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import enum
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            from typing import Callable
         | 
| 4 | 
            +
             | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
            import torch as th
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from . import path
         | 
| 9 | 
            +
            from .integrators import ode, sde
         | 
| 10 | 
            +
            from .utils import mean_flat, expand_dims
         | 
| 11 | 
            +
            from .dpm_solver import NoiseScheduleFlow, model_wrapper, DPM_Solver
         | 
| 12 | 
            +
             | 
| 13 | 
            +
             | 
| 14 | 
            +
            class ModelType(enum.Enum):
         | 
| 15 | 
            +
                """
         | 
| 16 | 
            +
                Which type of output the model predicts.
         | 
| 17 | 
            +
                """
         | 
| 18 | 
            +
             | 
| 19 | 
            +
                NOISE = enum.auto()  # the model predicts epsilon
         | 
| 20 | 
            +
                SCORE = enum.auto()  # the model predicts \nabla \log p(x)
         | 
| 21 | 
            +
                VELOCITY = enum.auto()  # the model predicts v(x)
         | 
| 22 | 
            +
             | 
| 23 | 
            +
             | 
| 24 | 
            +
            class PathType(enum.Enum):
         | 
| 25 | 
            +
                """
         | 
| 26 | 
            +
                Which type of path to use.
         | 
| 27 | 
            +
                """
         | 
| 28 | 
            +
             | 
| 29 | 
            +
                LINEAR = enum.auto()
         | 
| 30 | 
            +
                GVP = enum.auto()
         | 
| 31 | 
            +
                VP = enum.auto()
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            class WeightType(enum.Enum):
         | 
| 35 | 
            +
                """
         | 
| 36 | 
            +
                Which type of weighting to use.
         | 
| 37 | 
            +
                """
         | 
| 38 | 
            +
             | 
| 39 | 
            +
                NONE = enum.auto()
         | 
| 40 | 
            +
                VELOCITY = enum.auto()
         | 
| 41 | 
            +
                LIKELIHOOD = enum.auto()
         | 
| 42 | 
            +
             | 
| 43 | 
            +
             | 
| 44 | 
            +
            class Transport:
         | 
| 45 | 
            +
                def __init__(self, *, model_type, path_type, loss_type, train_eps, sample_eps, snr_type, do_shift, seq_len):
         | 
| 46 | 
            +
                    path_options = {
         | 
| 47 | 
            +
                        PathType.LINEAR: path.ICPlan,
         | 
| 48 | 
            +
                        PathType.GVP: path.GVPCPlan,
         | 
| 49 | 
            +
                        PathType.VP: path.VPCPlan,
         | 
| 50 | 
            +
                    }
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                    self.loss_type = loss_type
         | 
| 53 | 
            +
                    self.model_type = model_type
         | 
| 54 | 
            +
                    self.path_sampler = path_options[path_type]()
         | 
| 55 | 
            +
                    self.train_eps = train_eps
         | 
| 56 | 
            +
                    self.sample_eps = sample_eps
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.snr_type = snr_type
         | 
| 59 | 
            +
                    self.do_shift = do_shift
         | 
| 60 | 
            +
                    self.seq_len = seq_len
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def prior_logp(self, z):
         | 
| 63 | 
            +
                    """
         | 
| 64 | 
            +
                    Standard multivariate normal prior
         | 
| 65 | 
            +
                    Assume z is batched
         | 
| 66 | 
            +
                    """
         | 
| 67 | 
            +
                    shape = th.tensor(z.size())
         | 
| 68 | 
            +
                    N = th.prod(shape[1:])
         | 
| 69 | 
            +
                    _fn = lambda x: -N / 2.0 * np.log(2 * np.pi) - th.sum(x**2) / 2.0
         | 
| 70 | 
            +
                    return th.vmap(_fn)(z)
         | 
| 71 | 
            +
             | 
| 72 | 
            +
                def check_interval(
         | 
| 73 | 
            +
                    self,
         | 
| 74 | 
            +
                    train_eps,
         | 
| 75 | 
            +
                    sample_eps,
         | 
| 76 | 
            +
                    *,
         | 
| 77 | 
            +
                    diffusion_form="SBDM",
         | 
| 78 | 
            +
                    sde=False,
         | 
| 79 | 
            +
                    reverse=False,
         | 
| 80 | 
            +
                    eval=False,
         | 
| 81 | 
            +
                    last_step_size=0.0,
         | 
| 82 | 
            +
                ):
         | 
| 83 | 
            +
                    t0 = 0
         | 
| 84 | 
            +
                    t1 = 1
         | 
| 85 | 
            +
                    eps = train_eps if not eval else sample_eps
         | 
| 86 | 
            +
                    if type(self.path_sampler) in [path.VPCPlan]:
         | 
| 87 | 
            +
                        t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                    elif (type(self.path_sampler) in [path.ICPlan, path.GVPCPlan]) and (
         | 
| 90 | 
            +
                        self.model_type != ModelType.VELOCITY or sde
         | 
| 91 | 
            +
                    ):  # avoid numerical issue by taking a first semi-implicit step
         | 
| 92 | 
            +
                        t0 = eps if (diffusion_form == "SBDM" and sde) or self.model_type != ModelType.VELOCITY else 0
         | 
| 93 | 
            +
                        t1 = 1 - eps if (not sde or last_step_size == 0) else 1 - last_step_size
         | 
| 94 | 
            +
             | 
| 95 | 
            +
                    if reverse:
         | 
| 96 | 
            +
                        t0, t1 = 1 - t0, 1 - t1
         | 
| 97 | 
            +
             | 
| 98 | 
            +
                    return t0, t1
         | 
| 99 | 
            +
             | 
| 100 | 
            +
                def sample(self, x1):
         | 
| 101 | 
            +
                    """Sampling x0 & t based on shape of x1 (if needed)
         | 
| 102 | 
            +
                    Args:
         | 
| 103 | 
            +
                      x1 - data point; [batch, *dim]
         | 
| 104 | 
            +
                    """
         | 
| 105 | 
            +
                    if isinstance(x1, (list, tuple)):
         | 
| 106 | 
            +
                        x0 = [th.randn_like(img_start) for img_start in x1]
         | 
| 107 | 
            +
                    else:
         | 
| 108 | 
            +
                        x0 = th.randn_like(x1)
         | 
| 109 | 
            +
                    t0, t1 = self.check_interval(self.train_eps, self.sample_eps)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                    if self.snr_type.startswith("uniform"):
         | 
| 112 | 
            +
                        assert t0 == 0.0 and t1 == 1.0, "not implemented."
         | 
| 113 | 
            +
                        if "_" in self.snr_type:
         | 
| 114 | 
            +
                            _, t0, t1 = self.snr_type.split("_")
         | 
| 115 | 
            +
                            t0, t1 = float(t0), float(t1)
         | 
| 116 | 
            +
                        t = th.rand((len(x1),)) * (t1 - t0) + t0
         | 
| 117 | 
            +
                    elif self.snr_type == "lognorm":
         | 
| 118 | 
            +
                        u = th.normal(mean=0.0, std=1.0, size=(len(x1),))
         | 
| 119 | 
            +
                        t = 1 / (1 + th.exp(-u)) * (t1 - t0) + t0
         | 
| 120 | 
            +
                    else:
         | 
| 121 | 
            +
                        raise NotImplementedError("Not implemented snr_type %s" % self.snr_type)
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    if self.do_shift:
         | 
| 124 | 
            +
                        base_shift: float = 0.5
         | 
| 125 | 
            +
                        max_shift: float = 1.15
         | 
| 126 | 
            +
                        mu = self.get_lin_function(y1=base_shift, y2=max_shift)(self.seq_len)
         | 
| 127 | 
            +
                        t = self.time_shift(mu, 1.0, t)
         | 
| 128 | 
            +
                    t = t.to(x1[0])
         | 
| 129 | 
            +
                    return t, x0, x1
         | 
| 130 | 
            +
             | 
| 131 | 
            +
                def time_shift(self, mu: float, sigma: float, t: th.Tensor):
         | 
| 132 | 
            +
                    # the following implementation was original for t=0: clean / t=1: noise
         | 
| 133 | 
            +
                    # Since we adopt the reverse, the 1-t operations are needed
         | 
| 134 | 
            +
                    t = 1 - t
         | 
| 135 | 
            +
                    t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
         | 
| 136 | 
            +
                    t = 1 - t
         | 
| 137 | 
            +
                    return t
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                def get_lin_function(
         | 
| 140 | 
            +
                    self, x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15
         | 
| 141 | 
            +
                ) -> Callable[[float], float]:
         | 
| 142 | 
            +
                    m = (y2 - y1) / (x2 - x1)
         | 
| 143 | 
            +
                    b = y1 - m * x1
         | 
| 144 | 
            +
                    return lambda x: m * x + b
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                def training_losses(self, model, x1, model_kwargs=None):
         | 
| 147 | 
            +
                    """Loss for training the score model
         | 
| 148 | 
            +
                    Args:
         | 
| 149 | 
            +
                    - model: backbone model; could be score, noise, or velocity
         | 
| 150 | 
            +
                    - x1: datapoint
         | 
| 151 | 
            +
                    - model_kwargs: additional arguments for the model
         | 
| 152 | 
            +
                    """
         | 
| 153 | 
            +
                    if model_kwargs == None:
         | 
| 154 | 
            +
                        model_kwargs = {}
         | 
| 155 | 
            +
                    t, x0, x1 = self.sample(x1)
         | 
| 156 | 
            +
                    t, xt, ut = self.path_sampler.plan(t, x0, x1)
         | 
| 157 | 
            +
                    if "cond" in model_kwargs:
         | 
| 158 | 
            +
                        conds = model_kwargs.pop("cond")
         | 
| 159 | 
            +
                        xt = [th.cat([x, cond], dim=0) if cond is not None else x for x, cond in zip(xt, conds)]
         | 
| 160 | 
            +
                    model_output = model(xt, t, **model_kwargs)
         | 
| 161 | 
            +
                    B = len(x0)
         | 
| 162 | 
            +
             | 
| 163 | 
            +
                    terms = {}
         | 
| 164 | 
            +
                    # terms['pred'] = model_output
         | 
| 165 | 
            +
                    if self.model_type == ModelType.VELOCITY:
         | 
| 166 | 
            +
                        if isinstance(x1, (list, tuple)):
         | 
| 167 | 
            +
                            assert len(model_output) == len(ut) == len(x1)
         | 
| 168 | 
            +
                            for i in range(B):
         | 
| 169 | 
            +
                                assert (
         | 
| 170 | 
            +
                                    model_output[i].shape == ut[i].shape == x1[i].shape
         | 
| 171 | 
            +
                                ), f"{model_output[i].shape} {ut[i].shape} {x1[i].shape}"
         | 
| 172 | 
            +
                            terms["task_loss"] = th.stack(
         | 
| 173 | 
            +
                                [((ut[i] - model_output[i]) ** 2).mean() for i in range(B)],
         | 
| 174 | 
            +
                                dim=0,
         | 
| 175 | 
            +
                            )
         | 
| 176 | 
            +
                        else:
         | 
| 177 | 
            +
                            terms["task_loss"] = mean_flat(((model_output - ut) ** 2))
         | 
| 178 | 
            +
                    else:
         | 
| 179 | 
            +
                        raise NotImplementedError
         | 
| 180 | 
            +
             | 
| 181 | 
            +
                    terms["loss"] = terms["task_loss"]
         | 
| 182 | 
            +
                    terms["task_loss"] = terms["task_loss"].clone().detach()
         | 
| 183 | 
            +
                    terms["t"] = t
         | 
| 184 | 
            +
                    return terms
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                def get_drift(self):
         | 
| 187 | 
            +
                    """member function for obtaining the drift of the probability flow ODE"""
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    def score_ode(x, t, model, **model_kwargs):
         | 
| 190 | 
            +
                        drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
         | 
| 191 | 
            +
                        model_output = model(x, t, **model_kwargs)
         | 
| 192 | 
            +
                        return -drift_mean + drift_var * model_output  # by change of variable
         | 
| 193 | 
            +
             | 
| 194 | 
            +
                    def noise_ode(x, t, model, **model_kwargs):
         | 
| 195 | 
            +
                        drift_mean, drift_var = self.path_sampler.compute_drift(x, t)
         | 
| 196 | 
            +
                        sigma_t, _ = self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))
         | 
| 197 | 
            +
                        model_output = model(x, t, **model_kwargs)
         | 
| 198 | 
            +
                        score = model_output / -sigma_t
         | 
| 199 | 
            +
                        return -drift_mean + drift_var * score
         | 
| 200 | 
            +
             | 
| 201 | 
            +
                    def velocity_ode(x, t, model, **model_kwargs):
         | 
| 202 | 
            +
                        model_output = model(x, t, **model_kwargs)
         | 
| 203 | 
            +
                        return model_output
         | 
| 204 | 
            +
             | 
| 205 | 
            +
                    if self.model_type == ModelType.NOISE:
         | 
| 206 | 
            +
                        drift_fn = noise_ode
         | 
| 207 | 
            +
                    elif self.model_type == ModelType.SCORE:
         | 
| 208 | 
            +
                        drift_fn = score_ode
         | 
| 209 | 
            +
                    else:
         | 
| 210 | 
            +
                        drift_fn = velocity_ode
         | 
| 211 | 
            +
             | 
| 212 | 
            +
                    def body_fn(x, t, model, **model_kwargs):
         | 
| 213 | 
            +
                        model_output = drift_fn(x, t, model, **model_kwargs)
         | 
| 214 | 
            +
                        assert model_output.shape == x.shape, "Output shape from ODE solver must match input shape"
         | 
| 215 | 
            +
                        return model_output
         | 
| 216 | 
            +
             | 
| 217 | 
            +
                    return body_fn
         | 
| 218 | 
            +
             | 
| 219 | 
            +
                def get_score(
         | 
| 220 | 
            +
                    self,
         | 
| 221 | 
            +
                ):
         | 
| 222 | 
            +
                    """member function for obtaining score of
         | 
| 223 | 
            +
                    x_t = alpha_t * x + sigma_t * eps"""
         | 
| 224 | 
            +
                    if self.model_type == ModelType.NOISE:
         | 
| 225 | 
            +
                        score_fn = (
         | 
| 226 | 
            +
                            lambda x, t, model, **kwargs: model(x, t, **kwargs)
         | 
| 227 | 
            +
                            / -self.path_sampler.compute_sigma_t(path.expand_t_like_x(t, x))[0]
         | 
| 228 | 
            +
                        )
         | 
| 229 | 
            +
                    elif self.model_type == ModelType.SCORE:
         | 
| 230 | 
            +
                        score_fn = lambda x, t, model, **kwagrs: model(x, t, **kwagrs)
         | 
| 231 | 
            +
                    elif self.model_type == ModelType.VELOCITY:
         | 
| 232 | 
            +
                        score_fn = lambda x, t, model, **kwargs: self.path_sampler.get_score_from_velocity(
         | 
| 233 | 
            +
                            model(x, t, **kwargs), x, t
         | 
| 234 | 
            +
                        )
         | 
| 235 | 
            +
                    else:
         | 
| 236 | 
            +
                        raise NotImplementedError()
         | 
| 237 | 
            +
             | 
| 238 | 
            +
                    return score_fn
         | 
| 239 | 
            +
             | 
| 240 | 
            +
             | 
| 241 | 
            +
            class Sampler:
         | 
| 242 | 
            +
                """Sampler class for the transport model"""
         | 
| 243 | 
            +
             | 
| 244 | 
            +
                def __init__(
         | 
| 245 | 
            +
                    self,
         | 
| 246 | 
            +
                    transport,
         | 
| 247 | 
            +
                ):
         | 
| 248 | 
            +
                    """Constructor for a general sampler; supporting different sampling methods
         | 
| 249 | 
            +
                    Args:
         | 
| 250 | 
            +
                    - transport: an tranport object specify model prediction & interpolant type
         | 
| 251 | 
            +
                    """
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    self.transport = transport
         | 
| 254 | 
            +
                    self.drift = self.transport.get_drift()
         | 
| 255 | 
            +
                    self.score = self.transport.get_score()
         | 
| 256 | 
            +
             | 
| 257 | 
            +
                def __get_sde_diffusion_and_drift(
         | 
| 258 | 
            +
                    self,
         | 
| 259 | 
            +
                    *,
         | 
| 260 | 
            +
                    diffusion_form="SBDM",
         | 
| 261 | 
            +
                    diffusion_norm=1.0,
         | 
| 262 | 
            +
                ):
         | 
| 263 | 
            +
                    def diffusion_fn(x, t):
         | 
| 264 | 
            +
                        diffusion = self.transport.path_sampler.compute_diffusion(x, t, form=diffusion_form, norm=diffusion_norm)
         | 
| 265 | 
            +
                        return diffusion
         | 
| 266 | 
            +
             | 
| 267 | 
            +
                    sde_drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs) + diffusion_fn(x, t) * self.score(
         | 
| 268 | 
            +
                        x, t, model, **kwargs
         | 
| 269 | 
            +
                    )
         | 
| 270 | 
            +
             | 
| 271 | 
            +
                    sde_diffusion = diffusion_fn
         | 
| 272 | 
            +
             | 
| 273 | 
            +
                    return sde_drift, sde_diffusion
         | 
| 274 | 
            +
             | 
| 275 | 
            +
                def __get_last_step(
         | 
| 276 | 
            +
                    self,
         | 
| 277 | 
            +
                    sde_drift,
         | 
| 278 | 
            +
                    *,
         | 
| 279 | 
            +
                    last_step,
         | 
| 280 | 
            +
                    last_step_size,
         | 
| 281 | 
            +
                ):
         | 
| 282 | 
            +
                    """Get the last step function of the SDE solver"""
         | 
| 283 | 
            +
             | 
| 284 | 
            +
                    if last_step is None:
         | 
| 285 | 
            +
                        last_step_fn = lambda x, t, model, **model_kwargs: x
         | 
| 286 | 
            +
                    elif last_step == "Mean":
         | 
| 287 | 
            +
                        last_step_fn = (
         | 
| 288 | 
            +
                            lambda x, t, model, **model_kwargs: x + sde_drift(x, t, model, **model_kwargs) * last_step_size
         | 
| 289 | 
            +
                        )
         | 
| 290 | 
            +
                    elif last_step == "Tweedie":
         | 
| 291 | 
            +
                        alpha = self.transport.path_sampler.compute_alpha_t  # simple aliasing; the original name was too long
         | 
| 292 | 
            +
                        sigma = self.transport.path_sampler.compute_sigma_t
         | 
| 293 | 
            +
                        last_step_fn = lambda x, t, model, **model_kwargs: x / alpha(t)[0][0] + (sigma(t)[0][0] ** 2) / alpha(t)[0][
         | 
| 294 | 
            +
                            0
         | 
| 295 | 
            +
                        ] * self.score(x, t, model, **model_kwargs)
         | 
| 296 | 
            +
                    elif last_step == "Euler":
         | 
| 297 | 
            +
                        last_step_fn = (
         | 
| 298 | 
            +
                            lambda x, t, model, **model_kwargs: x + self.drift(x, t, model, **model_kwargs) * last_step_size
         | 
| 299 | 
            +
                        )
         | 
| 300 | 
            +
                    else:
         | 
| 301 | 
            +
                        raise NotImplementedError()
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    return last_step_fn
         | 
| 304 | 
            +
             | 
| 305 | 
            +
                def sample_sde(
         | 
| 306 | 
            +
                    self,
         | 
| 307 | 
            +
                    *,
         | 
| 308 | 
            +
                    sampling_method="Euler",
         | 
| 309 | 
            +
                    diffusion_form="SBDM",
         | 
| 310 | 
            +
                    diffusion_norm=1.0,
         | 
| 311 | 
            +
                    last_step="Mean",
         | 
| 312 | 
            +
                    last_step_size=0.04,
         | 
| 313 | 
            +
                    num_steps=250,
         | 
| 314 | 
            +
                ):
         | 
| 315 | 
            +
                    """returns a sampling function with given SDE settings
         | 
| 316 | 
            +
                    Args:
         | 
| 317 | 
            +
                    - sampling_method: type of sampler used in solving the SDE; default to be Euler-Maruyama
         | 
| 318 | 
            +
                    - diffusion_form: function form of diffusion coefficient; default to be matching SBDM
         | 
| 319 | 
            +
                    - diffusion_norm: function magnitude of diffusion coefficient; default to 1
         | 
| 320 | 
            +
                    - last_step: type of the last step; default to identity
         | 
| 321 | 
            +
                    - last_step_size: size of the last step; default to match the stride of 250 steps over [0,1]
         | 
| 322 | 
            +
                    - num_steps: total integration step of SDE
         | 
| 323 | 
            +
                    """
         | 
| 324 | 
            +
             | 
| 325 | 
            +
                    if last_step is None:
         | 
| 326 | 
            +
                        last_step_size = 0.0
         | 
| 327 | 
            +
             | 
| 328 | 
            +
                    sde_drift, sde_diffusion = self.__get_sde_diffusion_and_drift(
         | 
| 329 | 
            +
                        diffusion_form=diffusion_form,
         | 
| 330 | 
            +
                        diffusion_norm=diffusion_norm,
         | 
| 331 | 
            +
                    )
         | 
| 332 | 
            +
             | 
| 333 | 
            +
                    t0, t1 = self.transport.check_interval(
         | 
| 334 | 
            +
                        self.transport.train_eps,
         | 
| 335 | 
            +
                        self.transport.sample_eps,
         | 
| 336 | 
            +
                        diffusion_form=diffusion_form,
         | 
| 337 | 
            +
                        sde=True,
         | 
| 338 | 
            +
                        eval=True,
         | 
| 339 | 
            +
                        reverse=False,
         | 
| 340 | 
            +
                        last_step_size=last_step_size,
         | 
| 341 | 
            +
                    )
         | 
| 342 | 
            +
             | 
| 343 | 
            +
                    _sde = sde(
         | 
| 344 | 
            +
                        sde_drift,
         | 
| 345 | 
            +
                        sde_diffusion,
         | 
| 346 | 
            +
                        t0=t0,
         | 
| 347 | 
            +
                        t1=t1,
         | 
| 348 | 
            +
                        num_steps=num_steps,
         | 
| 349 | 
            +
                        sampler_type=sampling_method,
         | 
| 350 | 
            +
                    )
         | 
| 351 | 
            +
             | 
| 352 | 
            +
                    last_step_fn = self.__get_last_step(sde_drift, last_step=last_step, last_step_size=last_step_size)
         | 
| 353 | 
            +
             | 
| 354 | 
            +
                    def _sample(init, model, **model_kwargs):
         | 
| 355 | 
            +
                        xs = _sde.sample(init, model, **model_kwargs)
         | 
| 356 | 
            +
                        ts = th.ones(init.size(0), device=init.device) * t1
         | 
| 357 | 
            +
                        x = last_step_fn(xs[-1], ts, model, **model_kwargs)
         | 
| 358 | 
            +
                        xs.append(x)
         | 
| 359 | 
            +
             | 
| 360 | 
            +
                        assert len(xs) == num_steps, "Samples does not match the number of steps"
         | 
| 361 | 
            +
             | 
| 362 | 
            +
                        return xs
         | 
| 363 | 
            +
             | 
| 364 | 
            +
                    return _sample
         | 
| 365 | 
            +
                
         | 
| 366 | 
            +
                def sample_dpm(
         | 
| 367 | 
            +
                    self,
         | 
| 368 | 
            +
                    model,
         | 
| 369 | 
            +
                    model_kwargs=None,
         | 
| 370 | 
            +
                ):
         | 
| 371 | 
            +
             | 
| 372 | 
            +
                    noise_schedule = NoiseScheduleFlow(schedule="discrete_flow")
         | 
| 373 | 
            +
             | 
| 374 | 
            +
                    def noise_pred_fn(x, t_continuous):
         | 
| 375 | 
            +
                        output = model(x, 1 - t_continuous, **model_kwargs)
         | 
| 376 | 
            +
                        _, sigma_t = noise_schedule.marginal_alpha(t_continuous), noise_schedule.marginal_std(t_continuous)
         | 
| 377 | 
            +
                        try:
         | 
| 378 | 
            +
                            noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output
         | 
| 379 | 
            +
                        except:
         | 
| 380 | 
            +
                            noise = x - (1 - expand_dims(sigma_t, x.dim()).to(x)) * output[0]
         | 
| 381 | 
            +
                        return noise
         | 
| 382 | 
            +
             | 
| 383 | 
            +
                    return DPM_Solver(noise_pred_fn, noise_schedule, algorithm_type="dpmsolver++").sample
         | 
| 384 | 
            +
             | 
| 385 | 
            +
             | 
| 386 | 
            +
                def sample_ode(
         | 
| 387 | 
            +
                    self,
         | 
| 388 | 
            +
                    *,
         | 
| 389 | 
            +
                    sampling_method="dopri5",
         | 
| 390 | 
            +
                    num_steps=50,
         | 
| 391 | 
            +
                    atol=1e-6,
         | 
| 392 | 
            +
                    rtol=1e-3,
         | 
| 393 | 
            +
                    reverse=False,
         | 
| 394 | 
            +
                    do_shift=False,
         | 
| 395 | 
            +
                    time_shifting_factor=None, 
         | 
| 396 | 
            +
                ):
         | 
| 397 | 
            +
                    """returns a sampling function with given ODE settings
         | 
| 398 | 
            +
                    Args:
         | 
| 399 | 
            +
                    - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
         | 
| 400 | 
            +
                    - num_steps:
         | 
| 401 | 
            +
                        - fixed solver (Euler, Heun): the actual number of integration steps performed
         | 
| 402 | 
            +
                        - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
         | 
| 403 | 
            +
                    - atol: absolute error tolerance for the solver
         | 
| 404 | 
            +
                    - rtol: relative error tolerance for the solver
         | 
| 405 | 
            +
                    """
         | 
| 406 | 
            +
             | 
| 407 | 
            +
                    # for flux
         | 
| 408 | 
            +
                    drift = lambda x, t, model, **kwargs: self.drift(x, t, model, **kwargs)
         | 
| 409 | 
            +
             | 
| 410 | 
            +
                    t0, t1 = self.transport.check_interval(
         | 
| 411 | 
            +
                        self.transport.train_eps,
         | 
| 412 | 
            +
                        self.transport.sample_eps,
         | 
| 413 | 
            +
                        sde=False,
         | 
| 414 | 
            +
                        eval=True,
         | 
| 415 | 
            +
                        reverse=reverse,
         | 
| 416 | 
            +
                        last_step_size=0.0,
         | 
| 417 | 
            +
                    )
         | 
| 418 | 
            +
             | 
| 419 | 
            +
                    _ode = ode(
         | 
| 420 | 
            +
                        drift=drift,
         | 
| 421 | 
            +
                        t0=t0,
         | 
| 422 | 
            +
                        t1=t1,
         | 
| 423 | 
            +
                        sampler_type=sampling_method,
         | 
| 424 | 
            +
                        num_steps=num_steps,
         | 
| 425 | 
            +
                        atol=atol,
         | 
| 426 | 
            +
                        rtol=rtol,
         | 
| 427 | 
            +
                        do_shift=do_shift,
         | 
| 428 | 
            +
                        time_shifting_factor=time_shifting_factor,
         | 
| 429 | 
            +
                    )
         | 
| 430 | 
            +
             | 
| 431 | 
            +
                    return _ode.sample
         | 
| 432 | 
            +
             | 
| 433 | 
            +
                def sample_ode_likelihood(
         | 
| 434 | 
            +
                    self,
         | 
| 435 | 
            +
                    *,
         | 
| 436 | 
            +
                    sampling_method="dopri5",
         | 
| 437 | 
            +
                    num_steps=50,
         | 
| 438 | 
            +
                    atol=1e-6,
         | 
| 439 | 
            +
                    rtol=1e-3,
         | 
| 440 | 
            +
                ):
         | 
| 441 | 
            +
                    """returns a sampling function for calculating likelihood with given ODE settings
         | 
| 442 | 
            +
                    Args:
         | 
| 443 | 
            +
                    - sampling_method: type of sampler used in solving the ODE; default to be Dopri5
         | 
| 444 | 
            +
                    - num_steps:
         | 
| 445 | 
            +
                        - fixed solver (Euler, Heun): the actual number of integration steps performed
         | 
| 446 | 
            +
                        - adaptive solver (Dopri5): the number of datapoints saved during integration; produced by interpolation
         | 
| 447 | 
            +
                    - atol: absolute error tolerance for the solver
         | 
| 448 | 
            +
                    - rtol: relative error tolerance for the solver
         | 
| 449 | 
            +
                    """
         | 
| 450 | 
            +
             | 
| 451 | 
            +
                    def _likelihood_drift(x, t, model, **model_kwargs):
         | 
| 452 | 
            +
                        x, _ = x
         | 
| 453 | 
            +
                        eps = th.randint(2, x.size(), dtype=th.float, device=x.device) * 2 - 1
         | 
| 454 | 
            +
                        t = th.ones_like(t) * (1 - t)
         | 
| 455 | 
            +
                        with th.enable_grad():
         | 
| 456 | 
            +
                            x.requires_grad = True
         | 
| 457 | 
            +
                            grad = th.autograd.grad(th.sum(self.drift(x, t, model, **model_kwargs) * eps), x)[0]
         | 
| 458 | 
            +
                            logp_grad = th.sum(grad * eps, dim=tuple(range(1, len(x.size()))))
         | 
| 459 | 
            +
                            drift = self.drift(x, t, model, **model_kwargs)
         | 
| 460 | 
            +
                        return (-drift, logp_grad)
         | 
| 461 | 
            +
             | 
| 462 | 
            +
                    t0, t1 = self.transport.check_interval(
         | 
| 463 | 
            +
                        self.transport.train_eps,
         | 
| 464 | 
            +
                        self.transport.sample_eps,
         | 
| 465 | 
            +
                        sde=False,
         | 
| 466 | 
            +
                        eval=True,
         | 
| 467 | 
            +
                        reverse=False,
         | 
| 468 | 
            +
                        last_step_size=0.0,
         | 
| 469 | 
            +
                    )
         | 
| 470 | 
            +
             | 
| 471 | 
            +
                    _ode = ode(
         | 
| 472 | 
            +
                        drift=_likelihood_drift,
         | 
| 473 | 
            +
                        t0=t0,
         | 
| 474 | 
            +
                        t1=t1,
         | 
| 475 | 
            +
                        sampler_type=sampling_method,
         | 
| 476 | 
            +
                        num_steps=num_steps,
         | 
| 477 | 
            +
                        atol=atol,
         | 
| 478 | 
            +
                        rtol=rtol,
         | 
| 479 | 
            +
                    )
         | 
| 480 | 
            +
             | 
| 481 | 
            +
                    def _sample_fn(x, model, **model_kwargs):
         | 
| 482 | 
            +
                        init_logp = th.zeros(x.size(0)).to(x)
         | 
| 483 | 
            +
                        input = (x, init_logp)
         | 
| 484 | 
            +
                        drift, delta_logp = _ode.sample(input, model, **model_kwargs)
         | 
| 485 | 
            +
                        drift, delta_logp = drift[-1], delta_logp[-1]
         | 
| 486 | 
            +
                        prior_logp = self.transport.prior_logp(drift)
         | 
| 487 | 
            +
                        logp = prior_logp - delta_logp
         | 
| 488 | 
            +
                        return logp, drift
         | 
| 489 | 
            +
             | 
| 490 | 
            +
                    return _sample_fn
         | 
    	
        transport/utils.py
    ADDED
    
    | @@ -0,0 +1,56 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch as th
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            class EasyDict:
         | 
| 5 | 
            +
                def __init__(self, sub_dict):
         | 
| 6 | 
            +
                    for k, v in sub_dict.items():
         | 
| 7 | 
            +
                        setattr(self, k, v)
         | 
| 8 | 
            +
             | 
| 9 | 
            +
                def __getitem__(self, key):
         | 
| 10 | 
            +
                    return getattr(self, key)
         | 
| 11 | 
            +
             | 
| 12 | 
            +
             | 
| 13 | 
            +
            def mean_flat(x):
         | 
| 14 | 
            +
                """
         | 
| 15 | 
            +
                Take the mean over all non-batch dimensions.
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                return th.mean(x, dim=list(range(1, len(x.size()))))
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            def log_state(state):
         | 
| 21 | 
            +
                result = []
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                sorted_state = dict(sorted(state.items()))
         | 
| 24 | 
            +
                for key, value in sorted_state.items():
         | 
| 25 | 
            +
                    # Check if the value is an instance of a class
         | 
| 26 | 
            +
                    if "<object" in str(value) or "object at" in str(value):
         | 
| 27 | 
            +
                        result.append(f"{key}: [{value.__class__.__name__}]")
         | 
| 28 | 
            +
                    else:
         | 
| 29 | 
            +
                        result.append(f"{key}: {value}")
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                return "\n".join(result)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
            def time_shift(mu: float, sigma: float, t: th.Tensor):
         | 
| 34 | 
            +
                # the following implementation was original for t=0: clean / t=1: noise
         | 
| 35 | 
            +
                # Since we adopt the reverse, the 1-t operations are needed
         | 
| 36 | 
            +
                t = 1 - t
         | 
| 37 | 
            +
                t = math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma)
         | 
| 38 | 
            +
                t = 1 - t
         | 
| 39 | 
            +
                return t
         | 
| 40 | 
            +
             | 
| 41 | 
            +
            def get_lin_function(x1: float = 256, y1: float = 0.5, x2: float = 4096, y2: float = 1.15):
         | 
| 42 | 
            +
                m = (y2 - y1) / (x2 - x1)
         | 
| 43 | 
            +
                b = y1 - m * x1
         | 
| 44 | 
            +
                return lambda x: m * x + b
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def expand_dims(v, dims):
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                Expand the tensor `v` to the dim `dims`.
         | 
| 49 | 
            +
             | 
| 50 | 
            +
                Args:
         | 
| 51 | 
            +
                    `v`: a PyTorch tensor with shape [N].
         | 
| 52 | 
            +
                    `dim`: a `int`.
         | 
| 53 | 
            +
                Returns:
         | 
| 54 | 
            +
                    a PyTorch tensor with shape [N, 1, 1, ..., 1] and the total dimension is `dims`.
         | 
| 55 | 
            +
                """
         | 
| 56 | 
            +
                return v[(...,) + (None,) * (dims - 1)]
         | 
