Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	
		hungchiayu1
		
	commited on
		
		
					Commit 
							
							·
						
						ffead1e
	
1
								Parent(s):
							
							e7af757
								
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.  
							See raw diff
- README.md +5 -5
- app.py +140 -0
- audioldm/__init__.py +8 -0
- audioldm/__main__.py +183 -0
- audioldm/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/__pycache__/ldm.cpython-310.pyc +0 -0
- audioldm/__pycache__/ldm.cpython-39.pyc +0 -0
- audioldm/__pycache__/pipeline.cpython-310.pyc +0 -0
- audioldm/__pycache__/pipeline.cpython-39.pyc +0 -0
- audioldm/__pycache__/utils.cpython-310.pyc +0 -0
- audioldm/__pycache__/utils.cpython-39.pyc +0 -0
- audioldm/audio/__init__.py +2 -0
- audioldm/audio/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/audio_processing.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/audio_processing.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/mix.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/stft.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/stft.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/tools.cpython-310.pyc +0 -0
- audioldm/audio/__pycache__/tools.cpython-39.pyc +0 -0
- audioldm/audio/__pycache__/torch_tools.cpython-39.pyc +0 -0
- audioldm/audio/audio_processing.py +100 -0
- audioldm/audio/stft.py +186 -0
- audioldm/audio/tools.py +85 -0
- audioldm/hifigan/__init__.py +7 -0
- audioldm/hifigan/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/hifigan/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/hifigan/__pycache__/models.cpython-310.pyc +0 -0
- audioldm/hifigan/__pycache__/models.cpython-39.pyc +0 -0
- audioldm/hifigan/__pycache__/utilities.cpython-310.pyc +0 -0
- audioldm/hifigan/__pycache__/utilities.cpython-39.pyc +0 -0
- audioldm/hifigan/models.py +174 -0
- audioldm/hifigan/utilities.py +86 -0
- audioldm/latent_diffusion/__init__.py +0 -0
- audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc +0 -0
- audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc +0 -0
- audioldm/latent_diffusion/attention.py +469 -0
    	
        README.md
    CHANGED
    
    | @@ -1,10 +1,10 @@ | |
| 1 | 
             
            ---
         | 
| 2 | 
            -
            title:  | 
| 3 | 
            -
            emoji:  | 
| 4 | 
            -
            colorFrom:  | 
| 5 | 
            -
            colorTo:  | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            -
            sdk_version:  | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
|  | |
| 1 | 
             
            ---
         | 
| 2 | 
            +
            title: Tango
         | 
| 3 | 
            +
            emoji: 🐠
         | 
| 4 | 
            +
            colorFrom: indigo
         | 
| 5 | 
            +
            colorTo: pink
         | 
| 6 | 
             
            sdk: gradio
         | 
| 7 | 
            +
            sdk_version: 3.28.0
         | 
| 8 | 
             
            app_file: app.py
         | 
| 9 | 
             
            pinned: false
         | 
| 10 | 
             
            ---
         | 
    	
        app.py
    ADDED
    
    | @@ -0,0 +1,140 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import gradio as gr
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import wavio
         | 
| 5 | 
            +
            from tqdm import tqdm
         | 
| 6 | 
            +
            from huggingface_hub import snapshot_download
         | 
| 7 | 
            +
            from models import AudioDiffusion, DDPMScheduler
         | 
| 8 | 
            +
            from audioldm.audio.stft import TacotronSTFT
         | 
| 9 | 
            +
            from audioldm.variational_autoencoder import AutoencoderKL
         | 
| 10 | 
            +
            from gradio import Markdown
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            class Tango:
         | 
| 13 | 
            +
                def __init__(self, name="declare-lab/tango2", device="cuda:0"):
         | 
| 14 | 
            +
                    
         | 
| 15 | 
            +
                    path = snapshot_download(repo_id=name)
         | 
| 16 | 
            +
                    
         | 
| 17 | 
            +
                    vae_config = json.load(open("{}/vae_config.json".format(path)))
         | 
| 18 | 
            +
                    stft_config = json.load(open("{}/stft_config.json".format(path)))
         | 
| 19 | 
            +
                    main_config = json.load(open("{}/main_config.json".format(path)))
         | 
| 20 | 
            +
                    
         | 
| 21 | 
            +
                    self.vae = AutoencoderKL(**vae_config).to(device)
         | 
| 22 | 
            +
                    self.stft = TacotronSTFT(**stft_config).to(device)
         | 
| 23 | 
            +
                    self.model = AudioDiffusion(**main_config).to(device)
         | 
| 24 | 
            +
                    
         | 
| 25 | 
            +
                    vae_weights = torch.load("{}/pytorch_model_vae.bin".format(path), map_location=device)
         | 
| 26 | 
            +
                    stft_weights = torch.load("{}/pytorch_model_stft.bin".format(path), map_location=device)
         | 
| 27 | 
            +
                    main_weights = torch.load("{}/pytorch_model_main.bin".format(path), map_location=device)
         | 
| 28 | 
            +
                    
         | 
| 29 | 
            +
                    self.vae.load_state_dict(vae_weights)
         | 
| 30 | 
            +
                    self.stft.load_state_dict(stft_weights)
         | 
| 31 | 
            +
                    self.model.load_state_dict(main_weights)
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    print ("Successfully loaded checkpoint from:", name)
         | 
| 34 | 
            +
                    
         | 
| 35 | 
            +
                    self.vae.eval()
         | 
| 36 | 
            +
                    self.stft.eval()
         | 
| 37 | 
            +
                    self.model.eval()
         | 
| 38 | 
            +
                    
         | 
| 39 | 
            +
                    self.scheduler = DDPMScheduler.from_pretrained(main_config["scheduler_name"], subfolder="scheduler")
         | 
| 40 | 
            +
                    
         | 
| 41 | 
            +
                def chunks(self, lst, n):
         | 
| 42 | 
            +
                    """ Yield successive n-sized chunks from a list. """
         | 
| 43 | 
            +
                    for i in range(0, len(lst), n):
         | 
| 44 | 
            +
                        yield lst[i:i + n]
         | 
| 45 | 
            +
                    
         | 
| 46 | 
            +
                def generate(self, prompt, steps=100, guidance=3, samples=1, disable_progress=True):
         | 
| 47 | 
            +
                    """ Genrate audio for a single prompt string. """
         | 
| 48 | 
            +
                    with torch.no_grad():
         | 
| 49 | 
            +
                        latents = self.model.inference([prompt], self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
         | 
| 50 | 
            +
                        mel = self.vae.decode_first_stage(latents)
         | 
| 51 | 
            +
                        wave = self.vae.decode_to_waveform(mel)
         | 
| 52 | 
            +
                    return wave[0]
         | 
| 53 | 
            +
                
         | 
| 54 | 
            +
                def generate_for_batch(self, prompts, steps=200, guidance=3, samples=1, batch_size=8, disable_progress=True):
         | 
| 55 | 
            +
                    """ Genrate audio for a list of prompt strings. """
         | 
| 56 | 
            +
                    outputs = []
         | 
| 57 | 
            +
                    for k in tqdm(range(0, len(prompts), batch_size)):
         | 
| 58 | 
            +
                        batch = prompts[k: k+batch_size]
         | 
| 59 | 
            +
                        with torch.no_grad():
         | 
| 60 | 
            +
                            latents = self.model.inference(batch, self.scheduler, steps, guidance, samples, disable_progress=disable_progress)
         | 
| 61 | 
            +
                            mel = self.vae.decode_first_stage(latents)
         | 
| 62 | 
            +
                            wave = self.vae.decode_to_waveform(mel)
         | 
| 63 | 
            +
                            outputs += [item for item in wave]
         | 
| 64 | 
            +
                    if samples == 1:
         | 
| 65 | 
            +
                        return outputs
         | 
| 66 | 
            +
                    else:
         | 
| 67 | 
            +
                        return list(self.chunks(outputs, samples))
         | 
| 68 | 
            +
             | 
| 69 | 
            +
            # Initialize TANGO
         | 
| 70 | 
            +
            if torch.cuda.is_available():
         | 
| 71 | 
            +
                tango = Tango()
         | 
| 72 | 
            +
            else:
         | 
| 73 | 
            +
                tango = Tango(device="cpu")
         | 
| 74 | 
            +
             | 
| 75 | 
            +
            def gradio_generate(prompt, steps, guidance):
         | 
| 76 | 
            +
                output_wave = tango.generate(prompt, steps, guidance)
         | 
| 77 | 
            +
                # output_filename = f"{prompt.replace(' ', '_')}_{steps}_{guidance}"[:250] + ".wav"
         | 
| 78 | 
            +
                output_filename = "temp.wav"
         | 
| 79 | 
            +
                wavio.write(output_filename, output_wave, rate=16000, sampwidth=2)
         | 
| 80 | 
            +
                
         | 
| 81 | 
            +
                return output_filename
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            # description_text = """
         | 
| 84 | 
            +
            # <p><a href="https://huggingface.co/spaces/declare-lab/tango/blob/main/app.py?duplicate=true"> <img style="margin-top: 0em; margin-bottom: 0em" src="https://bit.ly/3gLdBN6" alt="Duplicate Space"></a> For faster inference without waiting in queue, you may duplicate the space and upgrade to a GPU in the settings. <br/><br/>
         | 
| 85 | 
            +
            # Generate audio using TANGO by providing a text prompt.
         | 
| 86 | 
            +
            # <br/><br/>Limitations: TANGO is trained on the small AudioCaps dataset so it may not generate good audio \
         | 
| 87 | 
            +
            # samples related to concepts that it has not seen in training (e.g. singing). For the same reason, TANGO \
         | 
| 88 | 
            +
            # is not always able to finely control its generations over textual control prompts. For example, \
         | 
| 89 | 
            +
            # the generations from TANGO for prompts Chopping tomatoes on a wooden table and Chopping potatoes \
         | 
| 90 | 
            +
            # on a metal table are very similar. \
         | 
| 91 | 
            +
            # <br/><br/>We are currently training another version of TANGO on larger datasets to enhance its generalization, \
         | 
| 92 | 
            +
            # compositional and controllable generation ability.
         | 
| 93 | 
            +
            # <br/><br/>We recommend using a guidance scale of 3. The default number of steps is set to 100. More steps generally lead to better quality of generated audios but will take longer.
         | 
| 94 | 
            +
            # <br/><br/>
         | 
| 95 | 
            +
            # <h1> ChatGPT-enhanced audio generation</h1>
         | 
| 96 | 
            +
            # <br/>
         | 
| 97 | 
            +
            # As TANGO consists of an instruction-tuned LLM, it is able to process complex sound descriptions allowing us to provide more detailed instructions to improve the generation quality.
         | 
| 98 | 
            +
            # For example, ``A boat is moving on the sea'' vs ``The sound of the water lapping against the hull of the boat or splashing as you move through the waves''. The latter is obtained by prompting ChatGPT to explain the sound generated when a boat moves on the sea.
         | 
| 99 | 
            +
            # Using this ChatGPT-generated description of the sound, TANGO provides superior results.
         | 
| 100 | 
            +
            # <p/>
         | 
| 101 | 
            +
            # """
         | 
| 102 | 
            +
            description_text = ""
         | 
| 103 | 
            +
            # Gradio input and output components
         | 
| 104 | 
            +
            input_text = gr.Textbox(lines=2, label="Prompt")
         | 
| 105 | 
            +
            output_audio = gr.Audio(label="Generated Audio", type="filepath")
         | 
| 106 | 
            +
            denoising_steps = gr.Slider(minimum=100, maximum=200, value=100, step=1, label="Steps", interactive=True)
         | 
| 107 | 
            +
            guidance_scale = gr.Slider(minimum=1, maximum=10, value=3, step=0.1, label="Guidance Scale", interactive=True)
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            # Gradio interface
         | 
| 110 | 
            +
            gr_interface = gr.Interface(
         | 
| 111 | 
            +
                fn=gradio_generate,
         | 
| 112 | 
            +
                inputs=[input_text, denoising_steps, guidance_scale],
         | 
| 113 | 
            +
                outputs=[output_audio],
         | 
| 114 | 
            +
                title="TANGO2: Aligning Diffusion-based Text-to-Audio Generative Models through Direct Preference Optimization",
         | 
| 115 | 
            +
                description=description_text,
         | 
| 116 | 
            +
                allow_flagging=False,
         | 
| 117 | 
            +
                examples=[
         | 
| 118 | 
            +
                    ["A lady is singing a song with a kid"],
         | 
| 119 | 
            +
                    ["The sound of the water lapping against the hull of the boat or splashing as you move through the waves"],
         | 
| 120 | 
            +
                    ["An audience cheering and clapping"],
         | 
| 121 | 
            +
                    ["Rolling thunder with lightning strikes"],
         | 
| 122 | 
            +
                    ["Gentle water stream, birds chirping and sudden gun shot"],
         | 
| 123 | 
            +
                    ["A car engine revving"],
         | 
| 124 | 
            +
                    ["A dog barking"],
         | 
| 125 | 
            +
                    ["A cat meowing"],
         | 
| 126 | 
            +
                    ["Wooden table tapping sound while water pouring"],
         | 
| 127 | 
            +
                    ["Emergency sirens wailing"],
         | 
| 128 | 
            +
                    ["two gunshots followed by birds flying away while chirping"],
         | 
| 129 | 
            +
                    ["Whistling with birds chirping"],
         | 
| 130 | 
            +
                    ["A person snoring"],
         | 
| 131 | 
            +
                    ["Motor vehicles are driving with loud engines and a person whistles"],
         | 
| 132 | 
            +
                    ["People cheering in a stadium while thunder and lightning strikes"],
         | 
| 133 | 
            +
                    ["A helicopter is in flight"],
         | 
| 134 | 
            +
                    ["A dog barking and a man talking and a racing car passes by"],
         | 
| 135 | 
            +
                ],
         | 
| 136 | 
            +
                cache_examples=False, # Turn on to cache.
         | 
| 137 | 
            +
            )
         | 
| 138 | 
            +
             | 
| 139 | 
            +
            # Launch Gradio app
         | 
| 140 | 
            +
            gr_interface.launch()
         | 
    	
        audioldm/__init__.py
    ADDED
    
    | @@ -0,0 +1,8 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .ldm import LatentDiffusion
         | 
| 2 | 
            +
            from .utils import seed_everything, save_wave, get_time, get_duration
         | 
| 3 | 
            +
            from .pipeline import *
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
             | 
| 8 | 
            +
             | 
    	
        audioldm/__main__.py
    ADDED
    
    | @@ -0,0 +1,183 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            #!/usr/bin/python3
         | 
| 2 | 
            +
            import os
         | 
| 3 | 
            +
            from audioldm import text_to_audio, style_transfer, build_model, save_wave, get_time, round_up_duration, get_duration
         | 
| 4 | 
            +
            import argparse
         | 
| 5 | 
            +
             | 
| 6 | 
            +
            CACHE_DIR = os.getenv(
         | 
| 7 | 
            +
                "AUDIOLDM_CACHE_DIR",
         | 
| 8 | 
            +
                os.path.join(os.path.expanduser("~"), ".cache/audioldm"))
         | 
| 9 | 
            +
             | 
| 10 | 
            +
            parser = argparse.ArgumentParser()
         | 
| 11 | 
            +
             | 
| 12 | 
            +
            parser.add_argument(
         | 
| 13 | 
            +
                "--mode",
         | 
| 14 | 
            +
                type=str,
         | 
| 15 | 
            +
                required=False,
         | 
| 16 | 
            +
                default="generation",
         | 
| 17 | 
            +
                help="generation: text-to-audio generation; transfer: style transfer",
         | 
| 18 | 
            +
                choices=["generation", "transfer"]
         | 
| 19 | 
            +
            )
         | 
| 20 | 
            +
             | 
| 21 | 
            +
            parser.add_argument(
         | 
| 22 | 
            +
                "-t",
         | 
| 23 | 
            +
                "--text",
         | 
| 24 | 
            +
                type=str,
         | 
| 25 | 
            +
                required=False,
         | 
| 26 | 
            +
                default="",
         | 
| 27 | 
            +
                help="Text prompt to the model for audio generation",
         | 
| 28 | 
            +
            )
         | 
| 29 | 
            +
             | 
| 30 | 
            +
            parser.add_argument(
         | 
| 31 | 
            +
                "-f",
         | 
| 32 | 
            +
                "--file_path",
         | 
| 33 | 
            +
                type=str,
         | 
| 34 | 
            +
                required=False,
         | 
| 35 | 
            +
                default=None,
         | 
| 36 | 
            +
                help="(--mode transfer): Original audio file for style transfer; Or (--mode generation): the guidance audio file for generating simialr audio",
         | 
| 37 | 
            +
            )
         | 
| 38 | 
            +
             | 
| 39 | 
            +
            parser.add_argument(
         | 
| 40 | 
            +
                "--transfer_strength",
         | 
| 41 | 
            +
                type=float,
         | 
| 42 | 
            +
                required=False,
         | 
| 43 | 
            +
                default=0.5,
         | 
| 44 | 
            +
                help="A value between 0 and 1. 0 means original audio without transfer, 1 means completely transfer to the audio indicated by text",
         | 
| 45 | 
            +
            )
         | 
| 46 | 
            +
             | 
| 47 | 
            +
            parser.add_argument(
         | 
| 48 | 
            +
                "-s",
         | 
| 49 | 
            +
                "--save_path",
         | 
| 50 | 
            +
                type=str,
         | 
| 51 | 
            +
                required=False,
         | 
| 52 | 
            +
                help="The path to save model output",
         | 
| 53 | 
            +
                default="./output",
         | 
| 54 | 
            +
            )
         | 
| 55 | 
            +
             | 
| 56 | 
            +
            parser.add_argument(
         | 
| 57 | 
            +
                "--model_name",
         | 
| 58 | 
            +
                type=str,
         | 
| 59 | 
            +
                required=False,
         | 
| 60 | 
            +
                help="The checkpoint you gonna use",
         | 
| 61 | 
            +
                default="audioldm-s-full",
         | 
| 62 | 
            +
                choices=["audioldm-s-full", "audioldm-l-full", "audioldm-s-full-v2"]
         | 
| 63 | 
            +
            )
         | 
| 64 | 
            +
             | 
| 65 | 
            +
            parser.add_argument(
         | 
| 66 | 
            +
                "-ckpt",
         | 
| 67 | 
            +
                "--ckpt_path",
         | 
| 68 | 
            +
                type=str,
         | 
| 69 | 
            +
                required=False,
         | 
| 70 | 
            +
                help="The path to the pretrained .ckpt model",
         | 
| 71 | 
            +
                default=None,
         | 
| 72 | 
            +
            )
         | 
| 73 | 
            +
             | 
| 74 | 
            +
            parser.add_argument(
         | 
| 75 | 
            +
                "-b",
         | 
| 76 | 
            +
                "--batchsize",
         | 
| 77 | 
            +
                type=int,
         | 
| 78 | 
            +
                required=False,
         | 
| 79 | 
            +
                default=1,
         | 
| 80 | 
            +
                help="Generate how many samples at the same time",
         | 
| 81 | 
            +
            )
         | 
| 82 | 
            +
             | 
| 83 | 
            +
            parser.add_argument(
         | 
| 84 | 
            +
                "--ddim_steps",
         | 
| 85 | 
            +
                type=int,
         | 
| 86 | 
            +
                required=False,
         | 
| 87 | 
            +
                default=200,
         | 
| 88 | 
            +
                help="The sampling step for DDIM",
         | 
| 89 | 
            +
            )
         | 
| 90 | 
            +
             | 
| 91 | 
            +
            parser.add_argument(
         | 
| 92 | 
            +
                "-gs",
         | 
| 93 | 
            +
                "--guidance_scale",
         | 
| 94 | 
            +
                type=float,
         | 
| 95 | 
            +
                required=False,
         | 
| 96 | 
            +
                default=2.5,
         | 
| 97 | 
            +
                help="Guidance scale (Large => better quality and relavancy to text; Small => better diversity)",
         | 
| 98 | 
            +
            )
         | 
| 99 | 
            +
             | 
| 100 | 
            +
            parser.add_argument(
         | 
| 101 | 
            +
                "-dur",
         | 
| 102 | 
            +
                "--duration",
         | 
| 103 | 
            +
                type=float,
         | 
| 104 | 
            +
                required=False,
         | 
| 105 | 
            +
                default=10.0,
         | 
| 106 | 
            +
                help="The duration of the samples",
         | 
| 107 | 
            +
            )
         | 
| 108 | 
            +
             | 
| 109 | 
            +
            parser.add_argument(
         | 
| 110 | 
            +
                "-n",
         | 
| 111 | 
            +
                "--n_candidate_gen_per_text",
         | 
| 112 | 
            +
                type=int,
         | 
| 113 | 
            +
                required=False,
         | 
| 114 | 
            +
                default=3,
         | 
| 115 | 
            +
                help="Automatic quality control. This number control the number of candidates (e.g., generate three audios and choose the best to show you). A Larger value usually lead to better quality with heavier computation",
         | 
| 116 | 
            +
            )
         | 
| 117 | 
            +
             | 
| 118 | 
            +
            parser.add_argument(
         | 
| 119 | 
            +
                "--seed",
         | 
| 120 | 
            +
                type=int,
         | 
| 121 | 
            +
                required=False,
         | 
| 122 | 
            +
                default=42,
         | 
| 123 | 
            +
                help="Change this value (any integer number) will lead to a different generation result.",
         | 
| 124 | 
            +
            )
         | 
| 125 | 
            +
             | 
| 126 | 
            +
            args = parser.parse_args()
         | 
| 127 | 
            +
             | 
| 128 | 
            +
            if(args.ckpt_path is not None):
         | 
| 129 | 
            +
                print("Warning: ckpt_path has no effect after version 0.0.20.")
         | 
| 130 | 
            +
                
         | 
| 131 | 
            +
            assert args.duration % 2.5 == 0, "Duration must be a multiple of 2.5"
         | 
| 132 | 
            +
             | 
| 133 | 
            +
            mode = args.mode
         | 
| 134 | 
            +
            if(mode == "generation" and args.file_path is not None):
         | 
| 135 | 
            +
                mode = "generation_audio_to_audio"
         | 
| 136 | 
            +
                if(len(args.text) > 0):
         | 
| 137 | 
            +
                    print("Warning: You have specified the --file_path. --text will be ignored")
         | 
| 138 | 
            +
                    args.text = ""
         | 
| 139 | 
            +
                    
         | 
| 140 | 
            +
            save_path = os.path.join(args.save_path, mode)
         | 
| 141 | 
            +
             | 
| 142 | 
            +
            if(args.file_path is not None):
         | 
| 143 | 
            +
                save_path = os.path.join(save_path, os.path.basename(args.file_path.split(".")[0]))
         | 
| 144 | 
            +
             | 
| 145 | 
            +
            text = args.text
         | 
| 146 | 
            +
            random_seed = args.seed
         | 
| 147 | 
            +
            duration = args.duration
         | 
| 148 | 
            +
            guidance_scale = args.guidance_scale
         | 
| 149 | 
            +
            n_candidate_gen_per_text = args.n_candidate_gen_per_text
         | 
| 150 | 
            +
             | 
| 151 | 
            +
            os.makedirs(save_path, exist_ok=True)
         | 
| 152 | 
            +
            audioldm = build_model(model_name=args.model_name)
         | 
| 153 | 
            +
             | 
| 154 | 
            +
            if(args.mode == "generation"):
         | 
| 155 | 
            +
                waveform = text_to_audio(
         | 
| 156 | 
            +
                    audioldm,
         | 
| 157 | 
            +
                    text,
         | 
| 158 | 
            +
                    args.file_path,
         | 
| 159 | 
            +
                    random_seed,
         | 
| 160 | 
            +
                    duration=duration,
         | 
| 161 | 
            +
                    guidance_scale=guidance_scale,
         | 
| 162 | 
            +
                    ddim_steps=args.ddim_steps,
         | 
| 163 | 
            +
                    n_candidate_gen_per_text=n_candidate_gen_per_text,
         | 
| 164 | 
            +
                    batchsize=args.batchsize,
         | 
| 165 | 
            +
                )
         | 
| 166 | 
            +
                
         | 
| 167 | 
            +
            elif(args.mode == "transfer"):
         | 
| 168 | 
            +
                assert args.file_path is not None
         | 
| 169 | 
            +
                assert os.path.exists(args.file_path), "The original audio file \'%s\' for style transfer does not exist." % args.file_path
         | 
| 170 | 
            +
                waveform = style_transfer(
         | 
| 171 | 
            +
                    audioldm,
         | 
| 172 | 
            +
                    text,
         | 
| 173 | 
            +
                    args.file_path,
         | 
| 174 | 
            +
                    args.transfer_strength,
         | 
| 175 | 
            +
                    random_seed,
         | 
| 176 | 
            +
                    duration=duration,
         | 
| 177 | 
            +
                    guidance_scale=guidance_scale,
         | 
| 178 | 
            +
                    ddim_steps=args.ddim_steps,
         | 
| 179 | 
            +
                    batchsize=args.batchsize,
         | 
| 180 | 
            +
                )
         | 
| 181 | 
            +
                waveform = waveform[:,None,:]
         | 
| 182 | 
            +
             | 
| 183 | 
            +
            save_wave(waveform, save_path, name="%s_%s" % (get_time(), text))
         | 
    	
        audioldm/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (315 Bytes). View file | 
|  | 
    	
        audioldm/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (322 Bytes). View file | 
|  | 
    	
        audioldm/__pycache__/ldm.cpython-310.pyc
    ADDED
    
    | Binary file (16.1 kB). View file | 
|  | 
    	
        audioldm/__pycache__/ldm.cpython-39.pyc
    ADDED
    
    | Binary file (16 kB). View file | 
|  | 
    	
        audioldm/__pycache__/pipeline.cpython-310.pyc
    ADDED
    
    | Binary file (6.63 kB). View file | 
|  | 
    	
        audioldm/__pycache__/pipeline.cpython-39.pyc
    ADDED
    
    | Binary file (6.54 kB). View file | 
|  | 
    	
        audioldm/__pycache__/utils.cpython-310.pyc
    ADDED
    
    | Binary file (8.01 kB). View file | 
|  | 
    	
        audioldm/__pycache__/utils.cpython-39.pyc
    ADDED
    
    | Binary file (7.35 kB). View file | 
|  | 
    	
        audioldm/audio/__init__.py
    ADDED
    
    | @@ -0,0 +1,2 @@ | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .tools import wav_to_fbank, read_wav_file
         | 
| 2 | 
            +
            from .stft import TacotronSTFT
         | 
    	
        audioldm/audio/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (253 Bytes). View file | 
|  | 
    	
        audioldm/audio/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (260 Bytes). View file | 
|  | 
    	
        audioldm/audio/__pycache__/audio_processing.cpython-310.pyc
    ADDED
    
    | Binary file (2.78 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/audio_processing.cpython-39.pyc
    ADDED
    
    | Binary file (2.78 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/mix.cpython-39.pyc
    ADDED
    
    | Binary file (1.7 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/stft.cpython-310.pyc
    ADDED
    
    | Binary file (4.98 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/stft.cpython-39.pyc
    ADDED
    
    | Binary file (4.99 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/tools.cpython-310.pyc
    ADDED
    
    | Binary file (2.18 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/tools.cpython-39.pyc
    ADDED
    
    | Binary file (2.19 kB). View file | 
|  | 
    	
        audioldm/audio/__pycache__/torch_tools.cpython-39.pyc
    ADDED
    
    | Binary file (3.79 kB). View file | 
|  | 
    	
        audioldm/audio/audio_processing.py
    ADDED
    
    | @@ -0,0 +1,100 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import librosa.util as librosa_util
         | 
| 4 | 
            +
            from scipy.signal import get_window
         | 
| 5 | 
            +
             | 
| 6 | 
            +
             | 
| 7 | 
            +
            def window_sumsquare(
         | 
| 8 | 
            +
                window,
         | 
| 9 | 
            +
                n_frames,
         | 
| 10 | 
            +
                hop_length,
         | 
| 11 | 
            +
                win_length,
         | 
| 12 | 
            +
                n_fft,
         | 
| 13 | 
            +
                dtype=np.float32,
         | 
| 14 | 
            +
                norm=None,
         | 
| 15 | 
            +
            ):
         | 
| 16 | 
            +
                """
         | 
| 17 | 
            +
                # from librosa 0.6
         | 
| 18 | 
            +
                Compute the sum-square envelope of a window function at a given hop length.
         | 
| 19 | 
            +
             | 
| 20 | 
            +
                This is used to estimate modulation effects induced by windowing
         | 
| 21 | 
            +
                observations in short-time fourier transforms.
         | 
| 22 | 
            +
             | 
| 23 | 
            +
                Parameters
         | 
| 24 | 
            +
                ----------
         | 
| 25 | 
            +
                window : string, tuple, number, callable, or list-like
         | 
| 26 | 
            +
                    Window specification, as in `get_window`
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                n_frames : int > 0
         | 
| 29 | 
            +
                    The number of analysis frames
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                hop_length : int > 0
         | 
| 32 | 
            +
                    The number of samples to advance between frames
         | 
| 33 | 
            +
             | 
| 34 | 
            +
                win_length : [optional]
         | 
| 35 | 
            +
                    The length of the window function.  By default, this matches `n_fft`.
         | 
| 36 | 
            +
             | 
| 37 | 
            +
                n_fft : int > 0
         | 
| 38 | 
            +
                    The length of each analysis frame.
         | 
| 39 | 
            +
             | 
| 40 | 
            +
                dtype : np.dtype
         | 
| 41 | 
            +
                    The data type of the output
         | 
| 42 | 
            +
             | 
| 43 | 
            +
                Returns
         | 
| 44 | 
            +
                -------
         | 
| 45 | 
            +
                wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
         | 
| 46 | 
            +
                    The sum-squared envelope of the window function
         | 
| 47 | 
            +
                """
         | 
| 48 | 
            +
                if win_length is None:
         | 
| 49 | 
            +
                    win_length = n_fft
         | 
| 50 | 
            +
             | 
| 51 | 
            +
                n = n_fft + hop_length * (n_frames - 1)
         | 
| 52 | 
            +
                x = np.zeros(n, dtype=dtype)
         | 
| 53 | 
            +
             | 
| 54 | 
            +
                # Compute the squared window at the desired length
         | 
| 55 | 
            +
                win_sq = get_window(window, win_length, fftbins=True)
         | 
| 56 | 
            +
                win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
         | 
| 57 | 
            +
                win_sq = librosa_util.pad_center(win_sq, n_fft)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                # Fill the envelope
         | 
| 60 | 
            +
                for i in range(n_frames):
         | 
| 61 | 
            +
                    sample = i * hop_length
         | 
| 62 | 
            +
                    x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
         | 
| 63 | 
            +
                return x
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def griffin_lim(magnitudes, stft_fn, n_iters=30):
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                PARAMS
         | 
| 69 | 
            +
                ------
         | 
| 70 | 
            +
                magnitudes: spectrogram magnitudes
         | 
| 71 | 
            +
                stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
             | 
| 74 | 
            +
                angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
         | 
| 75 | 
            +
                angles = angles.astype(np.float32)
         | 
| 76 | 
            +
                angles = torch.autograd.Variable(torch.from_numpy(angles))
         | 
| 77 | 
            +
                signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
         | 
| 78 | 
            +
             | 
| 79 | 
            +
                for i in range(n_iters):
         | 
| 80 | 
            +
                    _, angles = stft_fn.transform(signal)
         | 
| 81 | 
            +
                    signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
         | 
| 82 | 
            +
                return signal
         | 
| 83 | 
            +
             | 
| 84 | 
            +
             | 
| 85 | 
            +
            def dynamic_range_compression(x, normalize_fun=torch.log, C=1, clip_val=1e-5):
         | 
| 86 | 
            +
                """
         | 
| 87 | 
            +
                PARAMS
         | 
| 88 | 
            +
                ------
         | 
| 89 | 
            +
                C: compression factor
         | 
| 90 | 
            +
                """
         | 
| 91 | 
            +
                return normalize_fun(torch.clamp(x, min=clip_val) * C)
         | 
| 92 | 
            +
             | 
| 93 | 
            +
             | 
| 94 | 
            +
            def dynamic_range_decompression(x, C=1):
         | 
| 95 | 
            +
                """
         | 
| 96 | 
            +
                PARAMS
         | 
| 97 | 
            +
                ------
         | 
| 98 | 
            +
                C: compression factor used to compress
         | 
| 99 | 
            +
                """
         | 
| 100 | 
            +
                return torch.exp(x) / C
         | 
    	
        audioldm/audio/stft.py
    ADDED
    
    | @@ -0,0 +1,186 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn.functional as F
         | 
| 3 | 
            +
            import numpy as np
         | 
| 4 | 
            +
            from scipy.signal import get_window
         | 
| 5 | 
            +
            from librosa.util import pad_center, tiny
         | 
| 6 | 
            +
            from librosa.filters import mel as librosa_mel_fn
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from audioldm.audio.audio_processing import (
         | 
| 9 | 
            +
                dynamic_range_compression,
         | 
| 10 | 
            +
                dynamic_range_decompression,
         | 
| 11 | 
            +
                window_sumsquare,
         | 
| 12 | 
            +
            )
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            class STFT(torch.nn.Module):
         | 
| 16 | 
            +
                """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
         | 
| 17 | 
            +
             | 
| 18 | 
            +
                def __init__(self, filter_length, hop_length, win_length, window="hann"):
         | 
| 19 | 
            +
                    super(STFT, self).__init__()
         | 
| 20 | 
            +
                    self.filter_length = filter_length
         | 
| 21 | 
            +
                    self.hop_length = hop_length
         | 
| 22 | 
            +
                    self.win_length = win_length
         | 
| 23 | 
            +
                    self.window = window
         | 
| 24 | 
            +
                    self.forward_transform = None
         | 
| 25 | 
            +
                    scale = self.filter_length / self.hop_length
         | 
| 26 | 
            +
                    fourier_basis = np.fft.fft(np.eye(self.filter_length))
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                    cutoff = int((self.filter_length / 2 + 1))
         | 
| 29 | 
            +
                    fourier_basis = np.vstack(
         | 
| 30 | 
            +
                        [np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
         | 
| 31 | 
            +
                    )
         | 
| 32 | 
            +
             | 
| 33 | 
            +
                    forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
         | 
| 34 | 
            +
                    inverse_basis = torch.FloatTensor(
         | 
| 35 | 
            +
                        np.linalg.pinv(scale * fourier_basis).T[:, None, :]
         | 
| 36 | 
            +
                    )
         | 
| 37 | 
            +
             | 
| 38 | 
            +
                    if window is not None:
         | 
| 39 | 
            +
                        assert filter_length >= win_length
         | 
| 40 | 
            +
                        # get window and zero center pad it to filter_length
         | 
| 41 | 
            +
                        fft_window = get_window(window, win_length, fftbins=True)
         | 
| 42 | 
            +
                        fft_window = pad_center(fft_window, filter_length)
         | 
| 43 | 
            +
                        fft_window = torch.from_numpy(fft_window).float()
         | 
| 44 | 
            +
             | 
| 45 | 
            +
                        # window the bases
         | 
| 46 | 
            +
                        forward_basis *= fft_window
         | 
| 47 | 
            +
                        inverse_basis *= fft_window
         | 
| 48 | 
            +
             | 
| 49 | 
            +
                    self.register_buffer("forward_basis", forward_basis.float())
         | 
| 50 | 
            +
                    self.register_buffer("inverse_basis", inverse_basis.float())
         | 
| 51 | 
            +
             | 
| 52 | 
            +
                def transform(self, input_data):
         | 
| 53 | 
            +
                    device = self.forward_basis.device
         | 
| 54 | 
            +
                    input_data = input_data.to(device)
         | 
| 55 | 
            +
                    
         | 
| 56 | 
            +
                    num_batches = input_data.size(0)
         | 
| 57 | 
            +
                    num_samples = input_data.size(1)
         | 
| 58 | 
            +
             | 
| 59 | 
            +
                    self.num_samples = num_samples
         | 
| 60 | 
            +
             | 
| 61 | 
            +
                    # similar to librosa, reflect-pad the input
         | 
| 62 | 
            +
                    input_data = input_data.view(num_batches, 1, num_samples)
         | 
| 63 | 
            +
                    input_data = F.pad(
         | 
| 64 | 
            +
                        input_data.unsqueeze(1),
         | 
| 65 | 
            +
                        (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
         | 
| 66 | 
            +
                        mode="reflect",
         | 
| 67 | 
            +
                    )
         | 
| 68 | 
            +
                    input_data = input_data.squeeze(1)
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                    forward_transform = F.conv1d(
         | 
| 71 | 
            +
                        input_data,
         | 
| 72 | 
            +
                        torch.autograd.Variable(self.forward_basis, requires_grad=False),
         | 
| 73 | 
            +
                        stride=self.hop_length,
         | 
| 74 | 
            +
                        padding=0,
         | 
| 75 | 
            +
                    )#.cpu()
         | 
| 76 | 
            +
             | 
| 77 | 
            +
                    cutoff = int((self.filter_length / 2) + 1)
         | 
| 78 | 
            +
                    real_part = forward_transform[:, :cutoff, :]
         | 
| 79 | 
            +
                    imag_part = forward_transform[:, cutoff:, :]
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                    magnitude = torch.sqrt(real_part**2 + imag_part**2)
         | 
| 82 | 
            +
                    phase = torch.autograd.Variable(torch.atan2(imag_part.data, real_part.data))
         | 
| 83 | 
            +
             | 
| 84 | 
            +
                    return magnitude, phase
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                def inverse(self, magnitude, phase):
         | 
| 87 | 
            +
                    device = self.forward_basis.device
         | 
| 88 | 
            +
                    magnitude, phase = magnitude.to(device), phase.to(device)
         | 
| 89 | 
            +
                    
         | 
| 90 | 
            +
                    recombine_magnitude_phase = torch.cat(
         | 
| 91 | 
            +
                        [magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
         | 
| 92 | 
            +
                    )
         | 
| 93 | 
            +
             | 
| 94 | 
            +
                    inverse_transform = F.conv_transpose1d(
         | 
| 95 | 
            +
                        recombine_magnitude_phase,
         | 
| 96 | 
            +
                        torch.autograd.Variable(self.inverse_basis, requires_grad=False),
         | 
| 97 | 
            +
                        stride=self.hop_length,
         | 
| 98 | 
            +
                        padding=0,
         | 
| 99 | 
            +
                    )
         | 
| 100 | 
            +
             | 
| 101 | 
            +
                    if self.window is not None:
         | 
| 102 | 
            +
                        window_sum = window_sumsquare(
         | 
| 103 | 
            +
                            self.window,
         | 
| 104 | 
            +
                            magnitude.size(-1),
         | 
| 105 | 
            +
                            hop_length=self.hop_length,
         | 
| 106 | 
            +
                            win_length=self.win_length,
         | 
| 107 | 
            +
                            n_fft=self.filter_length,
         | 
| 108 | 
            +
                            dtype=np.float32,
         | 
| 109 | 
            +
                        )
         | 
| 110 | 
            +
                        # remove modulation effects
         | 
| 111 | 
            +
                        approx_nonzero_indices = torch.from_numpy(
         | 
| 112 | 
            +
                            np.where(window_sum > tiny(window_sum))[0]
         | 
| 113 | 
            +
                        )
         | 
| 114 | 
            +
                        window_sum = torch.autograd.Variable(
         | 
| 115 | 
            +
                            torch.from_numpy(window_sum), requires_grad=False
         | 
| 116 | 
            +
                        )
         | 
| 117 | 
            +
                        window_sum = window_sum
         | 
| 118 | 
            +
                        inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
         | 
| 119 | 
            +
                            approx_nonzero_indices
         | 
| 120 | 
            +
                        ]
         | 
| 121 | 
            +
             | 
| 122 | 
            +
                        # scale by hop ratio
         | 
| 123 | 
            +
                        inverse_transform *= float(self.filter_length) / self.hop_length
         | 
| 124 | 
            +
             | 
| 125 | 
            +
                    inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
         | 
| 126 | 
            +
                    inverse_transform = inverse_transform[:, :, : -int(self.filter_length / 2) :]
         | 
| 127 | 
            +
             | 
| 128 | 
            +
                    return inverse_transform
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                def forward(self, input_data):
         | 
| 131 | 
            +
                    self.magnitude, self.phase = self.transform(input_data)
         | 
| 132 | 
            +
                    reconstruction = self.inverse(self.magnitude, self.phase)
         | 
| 133 | 
            +
                    return reconstruction
         | 
| 134 | 
            +
             | 
| 135 | 
            +
             | 
| 136 | 
            +
            class TacotronSTFT(torch.nn.Module):
         | 
| 137 | 
            +
                def __init__(
         | 
| 138 | 
            +
                    self,
         | 
| 139 | 
            +
                    filter_length,
         | 
| 140 | 
            +
                    hop_length,
         | 
| 141 | 
            +
                    win_length,
         | 
| 142 | 
            +
                    n_mel_channels,
         | 
| 143 | 
            +
                    sampling_rate,
         | 
| 144 | 
            +
                    mel_fmin,
         | 
| 145 | 
            +
                    mel_fmax,
         | 
| 146 | 
            +
                ):
         | 
| 147 | 
            +
                    super(TacotronSTFT, self).__init__()
         | 
| 148 | 
            +
                    self.n_mel_channels = n_mel_channels
         | 
| 149 | 
            +
                    self.sampling_rate = sampling_rate
         | 
| 150 | 
            +
                    self.stft_fn = STFT(filter_length, hop_length, win_length)
         | 
| 151 | 
            +
                    mel_basis = librosa_mel_fn(
         | 
| 152 | 
            +
                        sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
         | 
| 153 | 
            +
                    )
         | 
| 154 | 
            +
                    mel_basis = torch.from_numpy(mel_basis).float()
         | 
| 155 | 
            +
                    self.register_buffer("mel_basis", mel_basis)
         | 
| 156 | 
            +
             | 
| 157 | 
            +
                def spectral_normalize(self, magnitudes, normalize_fun):
         | 
| 158 | 
            +
                    output = dynamic_range_compression(magnitudes, normalize_fun)
         | 
| 159 | 
            +
                    return output
         | 
| 160 | 
            +
             | 
| 161 | 
            +
                def spectral_de_normalize(self, magnitudes):
         | 
| 162 | 
            +
                    output = dynamic_range_decompression(magnitudes)
         | 
| 163 | 
            +
                    return output
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                def mel_spectrogram(self, y, normalize_fun=torch.log):
         | 
| 166 | 
            +
                    """Computes mel-spectrograms from a batch of waves
         | 
| 167 | 
            +
                    PARAMS
         | 
| 168 | 
            +
                    ------
         | 
| 169 | 
            +
                    y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
         | 
| 170 | 
            +
             | 
| 171 | 
            +
                    RETURNS
         | 
| 172 | 
            +
                    -------
         | 
| 173 | 
            +
                    mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
         | 
| 174 | 
            +
                    """
         | 
| 175 | 
            +
                    assert torch.min(y.data) >= -1, torch.min(y.data)
         | 
| 176 | 
            +
                    assert torch.max(y.data) <= 1, torch.max(y.data)
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    magnitudes, phases = self.stft_fn.transform(y)
         | 
| 179 | 
            +
                    magnitudes = magnitudes.data
         | 
| 180 | 
            +
                    mel_output = torch.matmul(self.mel_basis, magnitudes)
         | 
| 181 | 
            +
                    mel_output = self.spectral_normalize(mel_output, normalize_fun)
         | 
| 182 | 
            +
                    energy = torch.norm(magnitudes, dim=1)
         | 
| 183 | 
            +
             | 
| 184 | 
            +
                    log_magnitudes = self.spectral_normalize(magnitudes, normalize_fun)
         | 
| 185 | 
            +
             | 
| 186 | 
            +
                    return mel_output, log_magnitudes, energy
         | 
    	
        audioldm/audio/tools.py
    ADDED
    
    | @@ -0,0 +1,85 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import numpy as np
         | 
| 3 | 
            +
            import torchaudio
         | 
| 4 | 
            +
             | 
| 5 | 
            +
             | 
| 6 | 
            +
            def get_mel_from_wav(audio, _stft):
         | 
| 7 | 
            +
                audio = torch.clip(torch.FloatTensor(audio).unsqueeze(0), -1, 1)
         | 
| 8 | 
            +
                audio = torch.autograd.Variable(audio, requires_grad=False)
         | 
| 9 | 
            +
                melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio)
         | 
| 10 | 
            +
                melspec = torch.squeeze(melspec, 0).numpy().astype(np.float32)
         | 
| 11 | 
            +
                log_magnitudes_stft = (
         | 
| 12 | 
            +
                    torch.squeeze(log_magnitudes_stft, 0).numpy().astype(np.float32)
         | 
| 13 | 
            +
                )
         | 
| 14 | 
            +
                energy = torch.squeeze(energy, 0).numpy().astype(np.float32)
         | 
| 15 | 
            +
                return melspec, log_magnitudes_stft, energy
         | 
| 16 | 
            +
             | 
| 17 | 
            +
             | 
| 18 | 
            +
            def _pad_spec(fbank, target_length=1024):
         | 
| 19 | 
            +
                n_frames = fbank.shape[0]
         | 
| 20 | 
            +
                p = target_length - n_frames
         | 
| 21 | 
            +
                # cut and pad
         | 
| 22 | 
            +
                if p > 0:
         | 
| 23 | 
            +
                    m = torch.nn.ZeroPad2d((0, 0, 0, p))
         | 
| 24 | 
            +
                    fbank = m(fbank)
         | 
| 25 | 
            +
                elif p < 0:
         | 
| 26 | 
            +
                    fbank = fbank[0:target_length, :]
         | 
| 27 | 
            +
             | 
| 28 | 
            +
                if fbank.size(-1) % 2 != 0:
         | 
| 29 | 
            +
                    fbank = fbank[..., :-1]
         | 
| 30 | 
            +
             | 
| 31 | 
            +
                return fbank
         | 
| 32 | 
            +
             | 
| 33 | 
            +
             | 
| 34 | 
            +
            def pad_wav(waveform, segment_length):
         | 
| 35 | 
            +
                waveform_length = waveform.shape[-1]
         | 
| 36 | 
            +
                assert waveform_length > 100, "Waveform is too short, %s" % waveform_length
         | 
| 37 | 
            +
                if segment_length is None or waveform_length == segment_length:
         | 
| 38 | 
            +
                    return waveform
         | 
| 39 | 
            +
                elif waveform_length > segment_length:
         | 
| 40 | 
            +
                    return waveform[:segment_length]
         | 
| 41 | 
            +
                elif waveform_length < segment_length:
         | 
| 42 | 
            +
                    temp_wav = np.zeros((1, segment_length))
         | 
| 43 | 
            +
                    temp_wav[:, :waveform_length] = waveform
         | 
| 44 | 
            +
                return temp_wav
         | 
| 45 | 
            +
             | 
| 46 | 
            +
            def normalize_wav(waveform):
         | 
| 47 | 
            +
                waveform = waveform - np.mean(waveform)
         | 
| 48 | 
            +
                waveform = waveform / (np.max(np.abs(waveform)) + 1e-8)
         | 
| 49 | 
            +
                return waveform * 0.5
         | 
| 50 | 
            +
             | 
| 51 | 
            +
             | 
| 52 | 
            +
            def read_wav_file(filename, segment_length):
         | 
| 53 | 
            +
                # waveform, sr = librosa.load(filename, sr=None, mono=True) # 4 times slower
         | 
| 54 | 
            +
                waveform, sr = torchaudio.load(filename)  # Faster!!!
         | 
| 55 | 
            +
                waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)
         | 
| 56 | 
            +
                waveform = waveform.numpy()[0, ...]
         | 
| 57 | 
            +
                waveform = normalize_wav(waveform)
         | 
| 58 | 
            +
                waveform = waveform[None, ...]
         | 
| 59 | 
            +
                waveform = pad_wav(waveform, segment_length)
         | 
| 60 | 
            +
                
         | 
| 61 | 
            +
                waveform = waveform / np.max(np.abs(waveform))
         | 
| 62 | 
            +
                waveform = 0.5 * waveform
         | 
| 63 | 
            +
                
         | 
| 64 | 
            +
                return waveform
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def wav_to_fbank(filename, target_length=1024, fn_STFT=None):
         | 
| 68 | 
            +
                assert fn_STFT is not None
         | 
| 69 | 
            +
             | 
| 70 | 
            +
                # mixup
         | 
| 71 | 
            +
                waveform = read_wav_file(filename, target_length * 160)  # hop size is 160
         | 
| 72 | 
            +
             | 
| 73 | 
            +
                waveform = waveform[0, ...]
         | 
| 74 | 
            +
                waveform = torch.FloatTensor(waveform)
         | 
| 75 | 
            +
             | 
| 76 | 
            +
                fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT)
         | 
| 77 | 
            +
             | 
| 78 | 
            +
                fbank = torch.FloatTensor(fbank.T)
         | 
| 79 | 
            +
                log_magnitudes_stft = torch.FloatTensor(log_magnitudes_stft.T)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec(
         | 
| 82 | 
            +
                    log_magnitudes_stft, target_length
         | 
| 83 | 
            +
                )
         | 
| 84 | 
            +
             | 
| 85 | 
            +
                return fbank, log_magnitudes_stft, waveform
         | 
    	
        audioldm/hifigan/__init__.py
    ADDED
    
    | @@ -0,0 +1,7 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from .models import Generator
         | 
| 2 | 
            +
             | 
| 3 | 
            +
             | 
| 4 | 
            +
            class AttrDict(dict):
         | 
| 5 | 
            +
                def __init__(self, *args, **kwargs):
         | 
| 6 | 
            +
                    super(AttrDict, self).__init__(*args, **kwargs)
         | 
| 7 | 
            +
                    self.__dict__ = self
         | 
    	
        audioldm/hifigan/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (569 Bytes). View file | 
|  | 
    	
        audioldm/hifigan/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (574 Bytes). View file | 
|  | 
    	
        audioldm/hifigan/__pycache__/models.cpython-310.pyc
    ADDED
    
    | Binary file (3.73 kB). View file | 
|  | 
    	
        audioldm/hifigan/__pycache__/models.cpython-39.pyc
    ADDED
    
    | Binary file (3.73 kB). View file | 
|  | 
    	
        audioldm/hifigan/__pycache__/utilities.cpython-310.pyc
    ADDED
    
    | Binary file (2.48 kB). View file | 
|  | 
    	
        audioldm/hifigan/__pycache__/utilities.cpython-39.pyc
    ADDED
    
    | Binary file (2.37 kB). View file | 
|  | 
    	
        audioldm/hifigan/models.py
    ADDED
    
    | @@ -0,0 +1,174 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import torch
         | 
| 2 | 
            +
            import torch.nn as nn
         | 
| 3 | 
            +
            import torch.nn.functional as F
         | 
| 4 | 
            +
            from torch.nn import Conv1d, ConvTranspose1d
         | 
| 5 | 
            +
            from torch.nn.utils import weight_norm, remove_weight_norm
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            LRELU_SLOPE = 0.1
         | 
| 8 | 
            +
             | 
| 9 | 
            +
             | 
| 10 | 
            +
            def init_weights(m, mean=0.0, std=0.01):
         | 
| 11 | 
            +
                classname = m.__class__.__name__
         | 
| 12 | 
            +
                if classname.find("Conv") != -1:
         | 
| 13 | 
            +
                    m.weight.data.normal_(mean, std)
         | 
| 14 | 
            +
             | 
| 15 | 
            +
             | 
| 16 | 
            +
            def get_padding(kernel_size, dilation=1):
         | 
| 17 | 
            +
                return int((kernel_size * dilation - dilation) / 2)
         | 
| 18 | 
            +
             | 
| 19 | 
            +
             | 
| 20 | 
            +
            class ResBlock(torch.nn.Module):
         | 
| 21 | 
            +
                def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
         | 
| 22 | 
            +
                    super(ResBlock, self).__init__()
         | 
| 23 | 
            +
                    self.h = h
         | 
| 24 | 
            +
                    self.convs1 = nn.ModuleList(
         | 
| 25 | 
            +
                        [
         | 
| 26 | 
            +
                            weight_norm(
         | 
| 27 | 
            +
                                Conv1d(
         | 
| 28 | 
            +
                                    channels,
         | 
| 29 | 
            +
                                    channels,
         | 
| 30 | 
            +
                                    kernel_size,
         | 
| 31 | 
            +
                                    1,
         | 
| 32 | 
            +
                                    dilation=dilation[0],
         | 
| 33 | 
            +
                                    padding=get_padding(kernel_size, dilation[0]),
         | 
| 34 | 
            +
                                )
         | 
| 35 | 
            +
                            ),
         | 
| 36 | 
            +
                            weight_norm(
         | 
| 37 | 
            +
                                Conv1d(
         | 
| 38 | 
            +
                                    channels,
         | 
| 39 | 
            +
                                    channels,
         | 
| 40 | 
            +
                                    kernel_size,
         | 
| 41 | 
            +
                                    1,
         | 
| 42 | 
            +
                                    dilation=dilation[1],
         | 
| 43 | 
            +
                                    padding=get_padding(kernel_size, dilation[1]),
         | 
| 44 | 
            +
                                )
         | 
| 45 | 
            +
                            ),
         | 
| 46 | 
            +
                            weight_norm(
         | 
| 47 | 
            +
                                Conv1d(
         | 
| 48 | 
            +
                                    channels,
         | 
| 49 | 
            +
                                    channels,
         | 
| 50 | 
            +
                                    kernel_size,
         | 
| 51 | 
            +
                                    1,
         | 
| 52 | 
            +
                                    dilation=dilation[2],
         | 
| 53 | 
            +
                                    padding=get_padding(kernel_size, dilation[2]),
         | 
| 54 | 
            +
                                )
         | 
| 55 | 
            +
                            ),
         | 
| 56 | 
            +
                        ]
         | 
| 57 | 
            +
                    )
         | 
| 58 | 
            +
                    self.convs1.apply(init_weights)
         | 
| 59 | 
            +
             | 
| 60 | 
            +
                    self.convs2 = nn.ModuleList(
         | 
| 61 | 
            +
                        [
         | 
| 62 | 
            +
                            weight_norm(
         | 
| 63 | 
            +
                                Conv1d(
         | 
| 64 | 
            +
                                    channels,
         | 
| 65 | 
            +
                                    channels,
         | 
| 66 | 
            +
                                    kernel_size,
         | 
| 67 | 
            +
                                    1,
         | 
| 68 | 
            +
                                    dilation=1,
         | 
| 69 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 70 | 
            +
                                )
         | 
| 71 | 
            +
                            ),
         | 
| 72 | 
            +
                            weight_norm(
         | 
| 73 | 
            +
                                Conv1d(
         | 
| 74 | 
            +
                                    channels,
         | 
| 75 | 
            +
                                    channels,
         | 
| 76 | 
            +
                                    kernel_size,
         | 
| 77 | 
            +
                                    1,
         | 
| 78 | 
            +
                                    dilation=1,
         | 
| 79 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 80 | 
            +
                                )
         | 
| 81 | 
            +
                            ),
         | 
| 82 | 
            +
                            weight_norm(
         | 
| 83 | 
            +
                                Conv1d(
         | 
| 84 | 
            +
                                    channels,
         | 
| 85 | 
            +
                                    channels,
         | 
| 86 | 
            +
                                    kernel_size,
         | 
| 87 | 
            +
                                    1,
         | 
| 88 | 
            +
                                    dilation=1,
         | 
| 89 | 
            +
                                    padding=get_padding(kernel_size, 1),
         | 
| 90 | 
            +
                                )
         | 
| 91 | 
            +
                            ),
         | 
| 92 | 
            +
                        ]
         | 
| 93 | 
            +
                    )
         | 
| 94 | 
            +
                    self.convs2.apply(init_weights)
         | 
| 95 | 
            +
             | 
| 96 | 
            +
                def forward(self, x):
         | 
| 97 | 
            +
                    for c1, c2 in zip(self.convs1, self.convs2):
         | 
| 98 | 
            +
                        xt = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 99 | 
            +
                        xt = c1(xt)
         | 
| 100 | 
            +
                        xt = F.leaky_relu(xt, LRELU_SLOPE)
         | 
| 101 | 
            +
                        xt = c2(xt)
         | 
| 102 | 
            +
                        x = xt + x
         | 
| 103 | 
            +
                    return x
         | 
| 104 | 
            +
             | 
| 105 | 
            +
                def remove_weight_norm(self):
         | 
| 106 | 
            +
                    for l in self.convs1:
         | 
| 107 | 
            +
                        remove_weight_norm(l)
         | 
| 108 | 
            +
                    for l in self.convs2:
         | 
| 109 | 
            +
                        remove_weight_norm(l)
         | 
| 110 | 
            +
             | 
| 111 | 
            +
             | 
| 112 | 
            +
            class Generator(torch.nn.Module):
         | 
| 113 | 
            +
                def __init__(self, h):
         | 
| 114 | 
            +
                    super(Generator, self).__init__()
         | 
| 115 | 
            +
                    self.h = h
         | 
| 116 | 
            +
                    self.num_kernels = len(h.resblock_kernel_sizes)
         | 
| 117 | 
            +
                    self.num_upsamples = len(h.upsample_rates)
         | 
| 118 | 
            +
                    self.conv_pre = weight_norm(
         | 
| 119 | 
            +
                        Conv1d(h.num_mels, h.upsample_initial_channel, 7, 1, padding=3)
         | 
| 120 | 
            +
                    )
         | 
| 121 | 
            +
                    resblock = ResBlock
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                    self.ups = nn.ModuleList()
         | 
| 124 | 
            +
                    for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
         | 
| 125 | 
            +
                        self.ups.append(
         | 
| 126 | 
            +
                            weight_norm(
         | 
| 127 | 
            +
                                ConvTranspose1d(
         | 
| 128 | 
            +
                                    h.upsample_initial_channel // (2**i),
         | 
| 129 | 
            +
                                    h.upsample_initial_channel // (2 ** (i + 1)),
         | 
| 130 | 
            +
                                    k,
         | 
| 131 | 
            +
                                    u,
         | 
| 132 | 
            +
                                    padding=(k - u) // 2,
         | 
| 133 | 
            +
                                )
         | 
| 134 | 
            +
                            )
         | 
| 135 | 
            +
                        )
         | 
| 136 | 
            +
             | 
| 137 | 
            +
                    self.resblocks = nn.ModuleList()
         | 
| 138 | 
            +
                    for i in range(len(self.ups)):
         | 
| 139 | 
            +
                        ch = h.upsample_initial_channel // (2 ** (i + 1))
         | 
| 140 | 
            +
                        for j, (k, d) in enumerate(
         | 
| 141 | 
            +
                            zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
         | 
| 142 | 
            +
                        ):
         | 
| 143 | 
            +
                            self.resblocks.append(resblock(h, ch, k, d))
         | 
| 144 | 
            +
             | 
| 145 | 
            +
                    self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
         | 
| 146 | 
            +
                    self.ups.apply(init_weights)
         | 
| 147 | 
            +
                    self.conv_post.apply(init_weights)
         | 
| 148 | 
            +
             | 
| 149 | 
            +
                def forward(self, x):
         | 
| 150 | 
            +
                    x = self.conv_pre(x)
         | 
| 151 | 
            +
                    for i in range(self.num_upsamples):
         | 
| 152 | 
            +
                        x = F.leaky_relu(x, LRELU_SLOPE)
         | 
| 153 | 
            +
                        x = self.ups[i](x)
         | 
| 154 | 
            +
                        xs = None
         | 
| 155 | 
            +
                        for j in range(self.num_kernels):
         | 
| 156 | 
            +
                            if xs is None:
         | 
| 157 | 
            +
                                xs = self.resblocks[i * self.num_kernels + j](x)
         | 
| 158 | 
            +
                            else:
         | 
| 159 | 
            +
                                xs += self.resblocks[i * self.num_kernels + j](x)
         | 
| 160 | 
            +
                        x = xs / self.num_kernels
         | 
| 161 | 
            +
                    x = F.leaky_relu(x)
         | 
| 162 | 
            +
                    x = self.conv_post(x)
         | 
| 163 | 
            +
                    x = torch.tanh(x)
         | 
| 164 | 
            +
             | 
| 165 | 
            +
                    return x
         | 
| 166 | 
            +
             | 
| 167 | 
            +
                def remove_weight_norm(self):
         | 
| 168 | 
            +
                    # print("Removing weight norm...")
         | 
| 169 | 
            +
                    for l in self.ups:
         | 
| 170 | 
            +
                        remove_weight_norm(l)
         | 
| 171 | 
            +
                    for l in self.resblocks:
         | 
| 172 | 
            +
                        l.remove_weight_norm()
         | 
| 173 | 
            +
                    remove_weight_norm(self.conv_pre)
         | 
| 174 | 
            +
                    remove_weight_norm(self.conv_post)
         | 
    	
        audioldm/hifigan/utilities.py
    ADDED
    
    | @@ -0,0 +1,86 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            import os
         | 
| 2 | 
            +
            import json
         | 
| 3 | 
            +
             | 
| 4 | 
            +
            import torch
         | 
| 5 | 
            +
            import numpy as np
         | 
| 6 | 
            +
             | 
| 7 | 
            +
            import audioldm.hifigan as hifigan
         | 
| 8 | 
            +
             | 
| 9 | 
            +
            HIFIGAN_16K_64 = {
         | 
| 10 | 
            +
                "resblock": "1",
         | 
| 11 | 
            +
                "num_gpus": 6,
         | 
| 12 | 
            +
                "batch_size": 16,
         | 
| 13 | 
            +
                "learning_rate": 0.0002,
         | 
| 14 | 
            +
                "adam_b1": 0.8,
         | 
| 15 | 
            +
                "adam_b2": 0.99,
         | 
| 16 | 
            +
                "lr_decay": 0.999,
         | 
| 17 | 
            +
                "seed": 1234,
         | 
| 18 | 
            +
                "upsample_rates": [5, 4, 2, 2, 2],
         | 
| 19 | 
            +
                "upsample_kernel_sizes": [16, 16, 8, 4, 4],
         | 
| 20 | 
            +
                "upsample_initial_channel": 1024,
         | 
| 21 | 
            +
                "resblock_kernel_sizes": [3, 7, 11],
         | 
| 22 | 
            +
                "resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
         | 
| 23 | 
            +
                "segment_size": 8192,
         | 
| 24 | 
            +
                "num_mels": 64,
         | 
| 25 | 
            +
                "num_freq": 1025,
         | 
| 26 | 
            +
                "n_fft": 1024,
         | 
| 27 | 
            +
                "hop_size": 160,
         | 
| 28 | 
            +
                "win_size": 1024,
         | 
| 29 | 
            +
                "sampling_rate": 16000,
         | 
| 30 | 
            +
                "fmin": 0,
         | 
| 31 | 
            +
                "fmax": 8000,
         | 
| 32 | 
            +
                "fmax_for_loss": None,
         | 
| 33 | 
            +
                "num_workers": 4,
         | 
| 34 | 
            +
                "dist_config": {
         | 
| 35 | 
            +
                    "dist_backend": "nccl",
         | 
| 36 | 
            +
                    "dist_url": "tcp://localhost:54321",
         | 
| 37 | 
            +
                    "world_size": 1,
         | 
| 38 | 
            +
                },
         | 
| 39 | 
            +
            }
         | 
| 40 | 
            +
             | 
| 41 | 
            +
             | 
| 42 | 
            +
            def get_available_checkpoint_keys(model, ckpt):
         | 
| 43 | 
            +
                print("==> Attemp to reload from %s" % ckpt)
         | 
| 44 | 
            +
                state_dict = torch.load(ckpt)["state_dict"]
         | 
| 45 | 
            +
                current_state_dict = model.state_dict()
         | 
| 46 | 
            +
                new_state_dict = {}
         | 
| 47 | 
            +
                for k in state_dict.keys():
         | 
| 48 | 
            +
                    if (
         | 
| 49 | 
            +
                        k in current_state_dict.keys()
         | 
| 50 | 
            +
                        and current_state_dict[k].size() == state_dict[k].size()
         | 
| 51 | 
            +
                    ):
         | 
| 52 | 
            +
                        new_state_dict[k] = state_dict[k]
         | 
| 53 | 
            +
                    else:
         | 
| 54 | 
            +
                        print("==> WARNING: Skipping %s" % k)
         | 
| 55 | 
            +
                print(
         | 
| 56 | 
            +
                    "%s out of %s keys are matched"
         | 
| 57 | 
            +
                    % (len(new_state_dict.keys()), len(state_dict.keys()))
         | 
| 58 | 
            +
                )
         | 
| 59 | 
            +
                return new_state_dict
         | 
| 60 | 
            +
             | 
| 61 | 
            +
             | 
| 62 | 
            +
            def get_param_num(model):
         | 
| 63 | 
            +
                num_param = sum(param.numel() for param in model.parameters())
         | 
| 64 | 
            +
                return num_param
         | 
| 65 | 
            +
             | 
| 66 | 
            +
             | 
| 67 | 
            +
            def get_vocoder(config, device):
         | 
| 68 | 
            +
                config = hifigan.AttrDict(HIFIGAN_16K_64)
         | 
| 69 | 
            +
                vocoder = hifigan.Generator(config)
         | 
| 70 | 
            +
                vocoder.eval()
         | 
| 71 | 
            +
                vocoder.remove_weight_norm()
         | 
| 72 | 
            +
                vocoder.to(device)
         | 
| 73 | 
            +
                return vocoder
         | 
| 74 | 
            +
             | 
| 75 | 
            +
             | 
| 76 | 
            +
            def vocoder_infer(mels, vocoder, lengths=None):
         | 
| 77 | 
            +
                vocoder.eval()
         | 
| 78 | 
            +
                with torch.no_grad():
         | 
| 79 | 
            +
                    wavs = vocoder(mels).squeeze(1)
         | 
| 80 | 
            +
             | 
| 81 | 
            +
                wavs = (wavs.cpu().numpy() * 32768).astype("int16")
         | 
| 82 | 
            +
             | 
| 83 | 
            +
                if lengths is not None:
         | 
| 84 | 
            +
                    wavs = wavs[:, :lengths]
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                return wavs
         | 
    	
        audioldm/latent_diffusion/__init__.py
    ADDED
    
    | 
            File without changes
         | 
    	
        audioldm/latent_diffusion/__pycache__/__init__.cpython-310.pyc
    ADDED
    
    | Binary file (157 Bytes). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/__init__.cpython-39.pyc
    ADDED
    
    | Binary file (164 Bytes). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/attention.cpython-310.pyc
    ADDED
    
    | Binary file (11.4 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/attention.cpython-39.pyc
    ADDED
    
    | Binary file (11.4 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/ddim.cpython-310.pyc
    ADDED
    
    | Binary file (7.2 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/ddim.cpython-39.pyc
    ADDED
    
    | Binary file (7.11 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/ddpm.cpython-310.pyc
    ADDED
    
    | Binary file (11.1 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/ddpm.cpython-39.pyc
    ADDED
    
    | Binary file (11 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/ema.cpython-310.pyc
    ADDED
    
    | Binary file (3.01 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/ema.cpython-39.pyc
    ADDED
    
    | Binary file (3 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/openaimodel.cpython-39.pyc
    ADDED
    
    | Binary file (23.7 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/util.cpython-310.pyc
    ADDED
    
    | Binary file (9.53 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/__pycache__/util.cpython-39.pyc
    ADDED
    
    | Binary file (9.6 kB). View file | 
|  | 
    	
        audioldm/latent_diffusion/attention.py
    ADDED
    
    | @@ -0,0 +1,469 @@ | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | 
|  | |
| 1 | 
            +
            from inspect import isfunction
         | 
| 2 | 
            +
            import math
         | 
| 3 | 
            +
            import torch
         | 
| 4 | 
            +
            import torch.nn.functional as F
         | 
| 5 | 
            +
            from torch import nn
         | 
| 6 | 
            +
            from einops import rearrange
         | 
| 7 | 
            +
             | 
| 8 | 
            +
            from audioldm.latent_diffusion.util import checkpoint
         | 
| 9 | 
            +
             | 
| 10 | 
            +
             | 
| 11 | 
            +
            def exists(val):
         | 
| 12 | 
            +
                return val is not None
         | 
| 13 | 
            +
             | 
| 14 | 
            +
             | 
| 15 | 
            +
            def uniq(arr):
         | 
| 16 | 
            +
                return {el: True for el in arr}.keys()
         | 
| 17 | 
            +
             | 
| 18 | 
            +
             | 
| 19 | 
            +
            def default(val, d):
         | 
| 20 | 
            +
                if exists(val):
         | 
| 21 | 
            +
                    return val
         | 
| 22 | 
            +
                return d() if isfunction(d) else d
         | 
| 23 | 
            +
             | 
| 24 | 
            +
             | 
| 25 | 
            +
            def max_neg_value(t):
         | 
| 26 | 
            +
                return -torch.finfo(t.dtype).max
         | 
| 27 | 
            +
             | 
| 28 | 
            +
             | 
| 29 | 
            +
            def init_(tensor):
         | 
| 30 | 
            +
                dim = tensor.shape[-1]
         | 
| 31 | 
            +
                std = 1 / math.sqrt(dim)
         | 
| 32 | 
            +
                tensor.uniform_(-std, std)
         | 
| 33 | 
            +
                return tensor
         | 
| 34 | 
            +
             | 
| 35 | 
            +
             | 
| 36 | 
            +
            # feedforward
         | 
| 37 | 
            +
            class GEGLU(nn.Module):
         | 
| 38 | 
            +
                def __init__(self, dim_in, dim_out):
         | 
| 39 | 
            +
                    super().__init__()
         | 
| 40 | 
            +
                    self.proj = nn.Linear(dim_in, dim_out * 2)
         | 
| 41 | 
            +
             | 
| 42 | 
            +
                def forward(self, x):
         | 
| 43 | 
            +
                    x, gate = self.proj(x).chunk(2, dim=-1)
         | 
| 44 | 
            +
                    return x * F.gelu(gate)
         | 
| 45 | 
            +
             | 
| 46 | 
            +
             | 
| 47 | 
            +
            class FeedForward(nn.Module):
         | 
| 48 | 
            +
                def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
         | 
| 49 | 
            +
                    super().__init__()
         | 
| 50 | 
            +
                    inner_dim = int(dim * mult)
         | 
| 51 | 
            +
                    dim_out = default(dim_out, dim)
         | 
| 52 | 
            +
                    project_in = (
         | 
| 53 | 
            +
                        nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
         | 
| 54 | 
            +
                        if not glu
         | 
| 55 | 
            +
                        else GEGLU(dim, inner_dim)
         | 
| 56 | 
            +
                    )
         | 
| 57 | 
            +
             | 
| 58 | 
            +
                    self.net = nn.Sequential(
         | 
| 59 | 
            +
                        project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
         | 
| 60 | 
            +
                    )
         | 
| 61 | 
            +
             | 
| 62 | 
            +
                def forward(self, x):
         | 
| 63 | 
            +
                    return self.net(x)
         | 
| 64 | 
            +
             | 
| 65 | 
            +
             | 
| 66 | 
            +
            def zero_module(module):
         | 
| 67 | 
            +
                """
         | 
| 68 | 
            +
                Zero out the parameters of a module and return it.
         | 
| 69 | 
            +
                """
         | 
| 70 | 
            +
                for p in module.parameters():
         | 
| 71 | 
            +
                    p.detach().zero_()
         | 
| 72 | 
            +
                return module
         | 
| 73 | 
            +
             | 
| 74 | 
            +
             | 
| 75 | 
            +
            def Normalize(in_channels):
         | 
| 76 | 
            +
                return torch.nn.GroupNorm(
         | 
| 77 | 
            +
                    num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
         | 
| 78 | 
            +
                )
         | 
| 79 | 
            +
             | 
| 80 | 
            +
             | 
| 81 | 
            +
            class LinearAttention(nn.Module):
         | 
| 82 | 
            +
                def __init__(self, dim, heads=4, dim_head=32):
         | 
| 83 | 
            +
                    super().__init__()
         | 
| 84 | 
            +
                    self.heads = heads
         | 
| 85 | 
            +
                    hidden_dim = dim_head * heads
         | 
| 86 | 
            +
                    self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
         | 
| 87 | 
            +
                    self.to_out = nn.Conv2d(hidden_dim, dim, 1)
         | 
| 88 | 
            +
             | 
| 89 | 
            +
                def forward(self, x):
         | 
| 90 | 
            +
                    b, c, h, w = x.shape
         | 
| 91 | 
            +
                    qkv = self.to_qkv(x)
         | 
| 92 | 
            +
                    q, k, v = rearrange(
         | 
| 93 | 
            +
                        qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3
         | 
| 94 | 
            +
                    )
         | 
| 95 | 
            +
                    k = k.softmax(dim=-1)
         | 
| 96 | 
            +
                    context = torch.einsum("bhdn,bhen->bhde", k, v)
         | 
| 97 | 
            +
                    out = torch.einsum("bhde,bhdn->bhen", context, q)
         | 
| 98 | 
            +
                    out = rearrange(
         | 
| 99 | 
            +
                        out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w
         | 
| 100 | 
            +
                    )
         | 
| 101 | 
            +
                    return self.to_out(out)
         | 
| 102 | 
            +
             | 
| 103 | 
            +
             | 
| 104 | 
            +
            class SpatialSelfAttention(nn.Module):
         | 
| 105 | 
            +
                def __init__(self, in_channels):
         | 
| 106 | 
            +
                    super().__init__()
         | 
| 107 | 
            +
                    self.in_channels = in_channels
         | 
| 108 | 
            +
             | 
| 109 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 110 | 
            +
                    self.q = torch.nn.Conv2d(
         | 
| 111 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 112 | 
            +
                    )
         | 
| 113 | 
            +
                    self.k = torch.nn.Conv2d(
         | 
| 114 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 115 | 
            +
                    )
         | 
| 116 | 
            +
                    self.v = torch.nn.Conv2d(
         | 
| 117 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 118 | 
            +
                    )
         | 
| 119 | 
            +
                    self.proj_out = torch.nn.Conv2d(
         | 
| 120 | 
            +
                        in_channels, in_channels, kernel_size=1, stride=1, padding=0
         | 
| 121 | 
            +
                    )
         | 
| 122 | 
            +
             | 
| 123 | 
            +
                def forward(self, x):
         | 
| 124 | 
            +
                    h_ = x
         | 
| 125 | 
            +
                    h_ = self.norm(h_)
         | 
| 126 | 
            +
                    q = self.q(h_)
         | 
| 127 | 
            +
                    k = self.k(h_)
         | 
| 128 | 
            +
                    v = self.v(h_)
         | 
| 129 | 
            +
             | 
| 130 | 
            +
                    # compute attention
         | 
| 131 | 
            +
                    b, c, h, w = q.shape
         | 
| 132 | 
            +
                    q = rearrange(q, "b c h w -> b (h w) c")
         | 
| 133 | 
            +
                    k = rearrange(k, "b c h w -> b c (h w)")
         | 
| 134 | 
            +
                    w_ = torch.einsum("bij,bjk->bik", q, k)
         | 
| 135 | 
            +
             | 
| 136 | 
            +
                    w_ = w_ * (int(c) ** (-0.5))
         | 
| 137 | 
            +
                    w_ = torch.nn.functional.softmax(w_, dim=2)
         | 
| 138 | 
            +
             | 
| 139 | 
            +
                    # attend to values
         | 
| 140 | 
            +
                    v = rearrange(v, "b c h w -> b c (h w)")
         | 
| 141 | 
            +
                    w_ = rearrange(w_, "b i j -> b j i")
         | 
| 142 | 
            +
                    h_ = torch.einsum("bij,bjk->bik", v, w_)
         | 
| 143 | 
            +
                    h_ = rearrange(h_, "b c (h w) -> b c h w", h=h)
         | 
| 144 | 
            +
                    h_ = self.proj_out(h_)
         | 
| 145 | 
            +
             | 
| 146 | 
            +
                    return x + h_
         | 
| 147 | 
            +
             | 
| 148 | 
            +
             | 
| 149 | 
            +
            class CrossAttention(nn.Module):
         | 
| 150 | 
            +
                """
         | 
| 151 | 
            +
                ### Cross Attention Layer
         | 
| 152 | 
            +
                This falls-back to self-attention when conditional embeddings are not specified.
         | 
| 153 | 
            +
                """
         | 
| 154 | 
            +
             | 
| 155 | 
            +
                # use_flash_attention: bool = True
         | 
| 156 | 
            +
                use_flash_attention: bool = False
         | 
| 157 | 
            +
             | 
| 158 | 
            +
                def __init__(
         | 
| 159 | 
            +
                    self,
         | 
| 160 | 
            +
                    query_dim,
         | 
| 161 | 
            +
                    context_dim=None,
         | 
| 162 | 
            +
                    heads=8,
         | 
| 163 | 
            +
                    dim_head=64,
         | 
| 164 | 
            +
                    dropout=0.0,
         | 
| 165 | 
            +
                    is_inplace: bool = True,
         | 
| 166 | 
            +
                ):
         | 
| 167 | 
            +
                    # def __init__(self, d_model: int, d_cond: int, n_heads: int, d_head: int, is_inplace: bool = True):
         | 
| 168 | 
            +
                    """
         | 
| 169 | 
            +
                    :param d_model: is the input embedding size
         | 
| 170 | 
            +
                    :param n_heads: is the number of attention heads
         | 
| 171 | 
            +
                    :param d_head: is the size of a attention head
         | 
| 172 | 
            +
                    :param d_cond: is the size of the conditional embeddings
         | 
| 173 | 
            +
                    :param is_inplace: specifies whether to perform the attention softmax computation inplace to
         | 
| 174 | 
            +
                        save memory
         | 
| 175 | 
            +
                    """
         | 
| 176 | 
            +
                    super().__init__()
         | 
| 177 | 
            +
             | 
| 178 | 
            +
                    self.is_inplace = is_inplace
         | 
| 179 | 
            +
                    self.n_heads = heads
         | 
| 180 | 
            +
                    self.d_head = dim_head
         | 
| 181 | 
            +
             | 
| 182 | 
            +
                    # Attention scaling factor
         | 
| 183 | 
            +
                    self.scale = dim_head**-0.5
         | 
| 184 | 
            +
             | 
| 185 | 
            +
                    # The normal self-attention layer
         | 
| 186 | 
            +
                    if context_dim is None:
         | 
| 187 | 
            +
                        context_dim = query_dim
         | 
| 188 | 
            +
             | 
| 189 | 
            +
                    # Query, key and value mappings
         | 
| 190 | 
            +
                    d_attn = dim_head * heads
         | 
| 191 | 
            +
                    self.to_q = nn.Linear(query_dim, d_attn, bias=False)
         | 
| 192 | 
            +
                    self.to_k = nn.Linear(context_dim, d_attn, bias=False)
         | 
| 193 | 
            +
                    self.to_v = nn.Linear(context_dim, d_attn, bias=False)
         | 
| 194 | 
            +
             | 
| 195 | 
            +
                    # Final linear layer
         | 
| 196 | 
            +
                    self.to_out = nn.Sequential(nn.Linear(d_attn, query_dim), nn.Dropout(dropout))
         | 
| 197 | 
            +
             | 
| 198 | 
            +
                    # Setup [flash attention](https://github.com/HazyResearch/flash-attention).
         | 
| 199 | 
            +
                    # Flash attention is only used if it's installed
         | 
| 200 | 
            +
                    # and `CrossAttention.use_flash_attention` is set to `True`.
         | 
| 201 | 
            +
                    try:
         | 
| 202 | 
            +
                        # You can install flash attention by cloning their Github repo,
         | 
| 203 | 
            +
                        # [https://github.com/HazyResearch/flash-attention](https://github.com/HazyResearch/flash-attention)
         | 
| 204 | 
            +
                        # and then running `python setup.py install`
         | 
| 205 | 
            +
                        from flash_attn.flash_attention import FlashAttention
         | 
| 206 | 
            +
             | 
| 207 | 
            +
                        self.flash = FlashAttention()
         | 
| 208 | 
            +
                        # Set the scale for scaled dot-product attention.
         | 
| 209 | 
            +
                        self.flash.softmax_scale = self.scale
         | 
| 210 | 
            +
                    # Set to `None` if it's not installed
         | 
| 211 | 
            +
                    except ImportError:
         | 
| 212 | 
            +
                        self.flash = None
         | 
| 213 | 
            +
             | 
| 214 | 
            +
                def forward(self, x, context=None, mask=None):
         | 
| 215 | 
            +
                    """
         | 
| 216 | 
            +
                    :param x: are the input embeddings of shape `[batch_size, height * width, d_model]`
         | 
| 217 | 
            +
                    :param cond: is the conditional embeddings of shape `[batch_size, n_cond, d_cond]`
         | 
| 218 | 
            +
                    """
         | 
| 219 | 
            +
             | 
| 220 | 
            +
                    # If `cond` is `None` we perform self attention
         | 
| 221 | 
            +
                    has_cond = context is not None
         | 
| 222 | 
            +
                    if not has_cond:
         | 
| 223 | 
            +
                        context = x
         | 
| 224 | 
            +
             | 
| 225 | 
            +
                    # Get query, key and value vectors
         | 
| 226 | 
            +
                    q = self.to_q(x)
         | 
| 227 | 
            +
                    k = self.to_k(context)
         | 
| 228 | 
            +
                    v = self.to_v(context)
         | 
| 229 | 
            +
             | 
| 230 | 
            +
                    # Use flash attention if it's available and the head size is less than or equal to `128`
         | 
| 231 | 
            +
                    if (
         | 
| 232 | 
            +
                        CrossAttention.use_flash_attention
         | 
| 233 | 
            +
                        and self.flash is not None
         | 
| 234 | 
            +
                        and not has_cond
         | 
| 235 | 
            +
                        and self.d_head <= 128
         | 
| 236 | 
            +
                    ):
         | 
| 237 | 
            +
                        return self.flash_attention(q, k, v)
         | 
| 238 | 
            +
                    # Otherwise, fallback to normal attention
         | 
| 239 | 
            +
                    else:
         | 
| 240 | 
            +
                        return self.normal_attention(q, k, v)
         | 
| 241 | 
            +
             | 
| 242 | 
            +
                def flash_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
         | 
| 243 | 
            +
                    """
         | 
| 244 | 
            +
                    #### Flash Attention
         | 
| 245 | 
            +
                    :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 246 | 
            +
                    :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 247 | 
            +
                    :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 248 | 
            +
                    """
         | 
| 249 | 
            +
             | 
| 250 | 
            +
                    # Get batch size and number of elements along sequence axis (`width * height`)
         | 
| 251 | 
            +
                    batch_size, seq_len, _ = q.shape
         | 
| 252 | 
            +
             | 
| 253 | 
            +
                    # Stack `q`, `k`, `v` vectors for flash attention, to get a single tensor of
         | 
| 254 | 
            +
                    # shape `[batch_size, seq_len, 3, n_heads * d_head]`
         | 
| 255 | 
            +
                    qkv = torch.stack((q, k, v), dim=2)
         | 
| 256 | 
            +
                    # Split the heads
         | 
| 257 | 
            +
                    qkv = qkv.view(batch_size, seq_len, 3, self.n_heads, self.d_head)
         | 
| 258 | 
            +
             | 
| 259 | 
            +
                    # Flash attention works for head sizes `32`, `64` and `128`, so we have to pad the heads to
         | 
| 260 | 
            +
                    # fit this size.
         | 
| 261 | 
            +
                    if self.d_head <= 32:
         | 
| 262 | 
            +
                        pad = 32 - self.d_head
         | 
| 263 | 
            +
                    elif self.d_head <= 64:
         | 
| 264 | 
            +
                        pad = 64 - self.d_head
         | 
| 265 | 
            +
                    elif self.d_head <= 128:
         | 
| 266 | 
            +
                        pad = 128 - self.d_head
         | 
| 267 | 
            +
                    else:
         | 
| 268 | 
            +
                        raise ValueError(f"Head size ${self.d_head} too large for Flash Attention")
         | 
| 269 | 
            +
             | 
| 270 | 
            +
                    # Pad the heads
         | 
| 271 | 
            +
                    if pad:
         | 
| 272 | 
            +
                        qkv = torch.cat(
         | 
| 273 | 
            +
                            (qkv, qkv.new_zeros(batch_size, seq_len, 3, self.n_heads, pad)), dim=-1
         | 
| 274 | 
            +
                        )
         | 
| 275 | 
            +
             | 
| 276 | 
            +
                    # Compute attention
         | 
| 277 | 
            +
                    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
         | 
| 278 | 
            +
                    # This gives a tensor of shape `[batch_size, seq_len, n_heads, d_padded]`
         | 
| 279 | 
            +
                    # TODO here I add the dtype changing
         | 
| 280 | 
            +
                    out, _ = self.flash(qkv.type(torch.float16))
         | 
| 281 | 
            +
                    # Truncate the extra head size
         | 
| 282 | 
            +
                    out = out[:, :, :, : self.d_head].float()
         | 
| 283 | 
            +
                    # Reshape to `[batch_size, seq_len, n_heads * d_head]`
         | 
| 284 | 
            +
                    out = out.reshape(batch_size, seq_len, self.n_heads * self.d_head)
         | 
| 285 | 
            +
             | 
| 286 | 
            +
                    # Map to `[batch_size, height * width, d_model]` with a linear layer
         | 
| 287 | 
            +
                    return self.to_out(out)
         | 
| 288 | 
            +
             | 
| 289 | 
            +
                def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
         | 
| 290 | 
            +
                    """
         | 
| 291 | 
            +
                    #### Normal Attention
         | 
| 292 | 
            +
             | 
| 293 | 
            +
                    :param q: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 294 | 
            +
                    :param k: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 295 | 
            +
                    :param v: are the query vectors before splitting heads, of shape `[batch_size, seq, d_attn]`
         | 
| 296 | 
            +
                    """
         | 
| 297 | 
            +
             | 
| 298 | 
            +
                    # Split them to heads of shape `[batch_size, seq_len, n_heads, d_head]`
         | 
| 299 | 
            +
                    q = q.view(*q.shape[:2], self.n_heads, -1)  # [bs, 64, 20, 32]
         | 
| 300 | 
            +
                    k = k.view(*k.shape[:2], self.n_heads, -1)  # [bs, 1, 20, 32]
         | 
| 301 | 
            +
                    v = v.view(*v.shape[:2], self.n_heads, -1)
         | 
| 302 | 
            +
             | 
| 303 | 
            +
                    # Calculate attention $\frac{Q K^\top}{\sqrt{d_{key}}}$
         | 
| 304 | 
            +
                    attn = torch.einsum("bihd,bjhd->bhij", q, k) * self.scale
         | 
| 305 | 
            +
             | 
| 306 | 
            +
                    # Compute softmax
         | 
| 307 | 
            +
                    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)$$
         | 
| 308 | 
            +
                    if self.is_inplace:
         | 
| 309 | 
            +
                        half = attn.shape[0] // 2
         | 
| 310 | 
            +
                        attn[half:] = attn[half:].softmax(dim=-1)
         | 
| 311 | 
            +
                        attn[:half] = attn[:half].softmax(dim=-1)
         | 
| 312 | 
            +
                    else:
         | 
| 313 | 
            +
                        attn = attn.softmax(dim=-1)
         | 
| 314 | 
            +
             | 
| 315 | 
            +
                    # Compute attention output
         | 
| 316 | 
            +
                    # $$\underset{seq}{softmax}\Bigg(\frac{Q K^\top}{\sqrt{d_{key}}}\Bigg)V$$
         | 
| 317 | 
            +
                    # attn: [bs, 20, 64, 1]
         | 
| 318 | 
            +
                    # v: [bs, 1, 20, 32]
         | 
| 319 | 
            +
                    out = torch.einsum("bhij,bjhd->bihd", attn, v)
         | 
| 320 | 
            +
                    # Reshape to `[batch_size, height * width, n_heads * d_head]`
         | 
| 321 | 
            +
                    out = out.reshape(*out.shape[:2], -1)
         | 
| 322 | 
            +
                    # Map to `[batch_size, height * width, d_model]` with a linear layer
         | 
| 323 | 
            +
                    return self.to_out(out)
         | 
| 324 | 
            +
             | 
| 325 | 
            +
             | 
| 326 | 
            +
            # class CrossAttention(nn.Module):
         | 
| 327 | 
            +
            # def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
         | 
| 328 | 
            +
            #     super().__init__()
         | 
| 329 | 
            +
            #     inner_dim = dim_head * heads
         | 
| 330 | 
            +
            #     context_dim = default(context_dim, query_dim)
         | 
| 331 | 
            +
             | 
| 332 | 
            +
            #     self.scale = dim_head ** -0.5
         | 
| 333 | 
            +
            #     self.heads = heads
         | 
| 334 | 
            +
             | 
| 335 | 
            +
            #     self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
         | 
| 336 | 
            +
            #     self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 337 | 
            +
            #     self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
         | 
| 338 | 
            +
             | 
| 339 | 
            +
            #     self.to_out = nn.Sequential(
         | 
| 340 | 
            +
            #         nn.Linear(inner_dim, query_dim),
         | 
| 341 | 
            +
            #         nn.Dropout(dropout)
         | 
| 342 | 
            +
            #     )
         | 
| 343 | 
            +
             | 
| 344 | 
            +
            # def forward(self, x, context=None, mask=None):
         | 
| 345 | 
            +
            #     h = self.heads
         | 
| 346 | 
            +
             | 
| 347 | 
            +
            #     q = self.to_q(x)
         | 
| 348 | 
            +
            #     context = default(context, x)
         | 
| 349 | 
            +
            #     k = self.to_k(context)
         | 
| 350 | 
            +
            #     v = self.to_v(context)
         | 
| 351 | 
            +
             | 
| 352 | 
            +
            #     q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
         | 
| 353 | 
            +
             | 
| 354 | 
            +
            #     sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
         | 
| 355 | 
            +
             | 
| 356 | 
            +
            #     if exists(mask):
         | 
| 357 | 
            +
            #         mask = rearrange(mask, 'b ... -> b (...)')
         | 
| 358 | 
            +
            #         max_neg_value = -torch.finfo(sim.dtype).max
         | 
| 359 | 
            +
            #         mask = repeat(mask, 'b j -> (b h) () j', h=h)
         | 
| 360 | 
            +
            #         sim.masked_fill_(~mask, max_neg_value)
         | 
| 361 | 
            +
             | 
| 362 | 
            +
            #     # attention, what we cannot get enough of
         | 
| 363 | 
            +
            #     attn = sim.softmax(dim=-1)
         | 
| 364 | 
            +
             | 
| 365 | 
            +
            #     out = einsum('b i j, b j d -> b i d', attn, v)
         | 
| 366 | 
            +
            #     out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
         | 
| 367 | 
            +
            #     return self.to_out(out)
         | 
| 368 | 
            +
             | 
| 369 | 
            +
             | 
| 370 | 
            +
            class BasicTransformerBlock(nn.Module):
         | 
| 371 | 
            +
                def __init__(
         | 
| 372 | 
            +
                    self,
         | 
| 373 | 
            +
                    dim,
         | 
| 374 | 
            +
                    n_heads,
         | 
| 375 | 
            +
                    d_head,
         | 
| 376 | 
            +
                    dropout=0.0,
         | 
| 377 | 
            +
                    context_dim=None,
         | 
| 378 | 
            +
                    gated_ff=True,
         | 
| 379 | 
            +
                    checkpoint=True,
         | 
| 380 | 
            +
                ):
         | 
| 381 | 
            +
                    super().__init__()
         | 
| 382 | 
            +
                    self.attn1 = CrossAttention(
         | 
| 383 | 
            +
                        query_dim=dim, heads=n_heads, dim_head=d_head, dropout=dropout
         | 
| 384 | 
            +
                    )  # is a self-attention
         | 
| 385 | 
            +
                    self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
         | 
| 386 | 
            +
                    self.attn2 = CrossAttention(
         | 
| 387 | 
            +
                        query_dim=dim,
         | 
| 388 | 
            +
                        context_dim=context_dim,
         | 
| 389 | 
            +
                        heads=n_heads,
         | 
| 390 | 
            +
                        dim_head=d_head,
         | 
| 391 | 
            +
                        dropout=dropout,
         | 
| 392 | 
            +
                    )  # is self-attn if context is none
         | 
| 393 | 
            +
                    self.norm1 = nn.LayerNorm(dim)
         | 
| 394 | 
            +
                    self.norm2 = nn.LayerNorm(dim)
         | 
| 395 | 
            +
                    self.norm3 = nn.LayerNorm(dim)
         | 
| 396 | 
            +
                    self.checkpoint = checkpoint
         | 
| 397 | 
            +
             | 
| 398 | 
            +
                def forward(self, x, context=None):
         | 
| 399 | 
            +
                    if context is None:
         | 
| 400 | 
            +
                        return checkpoint(self._forward, (x,), self.parameters(), self.checkpoint)
         | 
| 401 | 
            +
                    else:
         | 
| 402 | 
            +
                        return checkpoint(
         | 
| 403 | 
            +
                            self._forward, (x, context), self.parameters(), self.checkpoint
         | 
| 404 | 
            +
                        )
         | 
| 405 | 
            +
             | 
| 406 | 
            +
                def _forward(self, x, context=None):
         | 
| 407 | 
            +
                    x = self.attn1(self.norm1(x)) + x
         | 
| 408 | 
            +
                    x = self.attn2(self.norm2(x), context=context) + x
         | 
| 409 | 
            +
                    x = self.ff(self.norm3(x)) + x
         | 
| 410 | 
            +
                    return x
         | 
| 411 | 
            +
             | 
| 412 | 
            +
             | 
| 413 | 
            +
            class SpatialTransformer(nn.Module):
         | 
| 414 | 
            +
                """
         | 
| 415 | 
            +
                Transformer block for image-like data.
         | 
| 416 | 
            +
                First, project the input (aka embedding)
         | 
| 417 | 
            +
                and reshape to b, t, d.
         | 
| 418 | 
            +
                Then apply standard transformer action.
         | 
| 419 | 
            +
                Finally, reshape to image
         | 
| 420 | 
            +
                """
         | 
| 421 | 
            +
             | 
| 422 | 
            +
                def __init__(
         | 
| 423 | 
            +
                    self,
         | 
| 424 | 
            +
                    in_channels,
         | 
| 425 | 
            +
                    n_heads,
         | 
| 426 | 
            +
                    d_head,
         | 
| 427 | 
            +
                    depth=1,
         | 
| 428 | 
            +
                    dropout=0.0,
         | 
| 429 | 
            +
                    context_dim=None,
         | 
| 430 | 
            +
                    no_context=False,
         | 
| 431 | 
            +
                ):
         | 
| 432 | 
            +
                    super().__init__()
         | 
| 433 | 
            +
             | 
| 434 | 
            +
                    if no_context:
         | 
| 435 | 
            +
                        context_dim = None
         | 
| 436 | 
            +
             | 
| 437 | 
            +
                    self.in_channels = in_channels
         | 
| 438 | 
            +
                    inner_dim = n_heads * d_head
         | 
| 439 | 
            +
                    self.norm = Normalize(in_channels)
         | 
| 440 | 
            +
             | 
| 441 | 
            +
                    self.proj_in = nn.Conv2d(
         | 
| 442 | 
            +
                        in_channels, inner_dim, kernel_size=1, stride=1, padding=0
         | 
| 443 | 
            +
                    )
         | 
| 444 | 
            +
             | 
| 445 | 
            +
                    self.transformer_blocks = nn.ModuleList(
         | 
| 446 | 
            +
                        [
         | 
| 447 | 
            +
                            BasicTransformerBlock(
         | 
| 448 | 
            +
                                inner_dim, n_heads, d_head, dropout=dropout, context_dim=context_dim
         | 
| 449 | 
            +
                            )
         | 
| 450 | 
            +
                            for d in range(depth)
         | 
| 451 | 
            +
                        ]
         | 
| 452 | 
            +
                    )
         | 
| 453 | 
            +
             | 
| 454 | 
            +
                    self.proj_out = zero_module(
         | 
| 455 | 
            +
                        nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
         | 
| 456 | 
            +
                    )
         | 
| 457 | 
            +
             | 
| 458 | 
            +
                def forward(self, x, context=None):
         | 
| 459 | 
            +
                    # note: if no context is given, cross-attention defaults to self-attention
         | 
| 460 | 
            +
                    b, c, h, w = x.shape
         | 
| 461 | 
            +
                    x_in = x
         | 
| 462 | 
            +
                    x = self.norm(x)
         | 
| 463 | 
            +
                    x = self.proj_in(x)
         | 
| 464 | 
            +
                    x = rearrange(x, "b c h w -> b (h w) c")
         | 
| 465 | 
            +
                    for block in self.transformer_blocks:
         | 
| 466 | 
            +
                        x = block(x, context=context)
         | 
| 467 | 
            +
                    x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w)
         | 
| 468 | 
            +
                    x = self.proj_out(x)
         | 
| 469 | 
            +
                    return x + x_in
         | 
 
			
