Spaces:
				
			
			
	
			
			
					
		Running
		
	
	
	
			
			
	
	
	
	
		
		
					
		Running
		
	
		bubbliiiing
		
	commited on
		
		
					Commit 
							
							·
						
						e262715
	
1
								Parent(s):
							
							788d423
								
update v3
Browse files- .gitignore +160 -0
 - app.py +3 -3
 - easyanimate/api/api.py +38 -4
 - easyanimate/api/post_infer.py +9 -7
 - easyanimate/data/dataset_image_video.py +64 -3
 - easyanimate/models/attention.py +196 -139
 - easyanimate/models/autoencoder_magvit.py +9 -3
 - easyanimate/models/motion_module.py +146 -277
 - easyanimate/models/norm.py +97 -0
 - easyanimate/models/patch.py +1 -1
 - easyanimate/models/transformer3d.py +81 -75
 - easyanimate/pipeline/pipeline_easyanimate.py +1 -1
 - easyanimate/pipeline/pipeline_easyanimate_inpaint.py +257 -91
 - easyanimate/ui/ui.py +810 -173
 - easyanimate/utils/utils.py +107 -0
 
    	
        .gitignore
    ADDED
    
    | 
         @@ -0,0 +1,160 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            # Byte-compiled / optimized / DLL files
         
     | 
| 2 | 
         
            +
            __pycache__/
         
     | 
| 3 | 
         
            +
            *.py[cod]
         
     | 
| 4 | 
         
            +
            *$py.class
         
     | 
| 5 | 
         
            +
             
     | 
| 6 | 
         
            +
            # C extensions
         
     | 
| 7 | 
         
            +
            *.so
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            # Distribution / packaging
         
     | 
| 10 | 
         
            +
            .Python
         
     | 
| 11 | 
         
            +
            build/
         
     | 
| 12 | 
         
            +
            develop-eggs/
         
     | 
| 13 | 
         
            +
            dist/
         
     | 
| 14 | 
         
            +
            downloads/
         
     | 
| 15 | 
         
            +
            eggs/
         
     | 
| 16 | 
         
            +
            .eggs/
         
     | 
| 17 | 
         
            +
            lib/
         
     | 
| 18 | 
         
            +
            lib64/
         
     | 
| 19 | 
         
            +
            parts/
         
     | 
| 20 | 
         
            +
            sdist/
         
     | 
| 21 | 
         
            +
            var/
         
     | 
| 22 | 
         
            +
            wheels/
         
     | 
| 23 | 
         
            +
            share/python-wheels/
         
     | 
| 24 | 
         
            +
            *.egg-info/
         
     | 
| 25 | 
         
            +
            .installed.cfg
         
     | 
| 26 | 
         
            +
            *.egg
         
     | 
| 27 | 
         
            +
            MANIFEST
         
     | 
| 28 | 
         
            +
             
     | 
| 29 | 
         
            +
            # PyInstaller
         
     | 
| 30 | 
         
            +
            #  Usually these files are written by a python script from a template
         
     | 
| 31 | 
         
            +
            #  before PyInstaller builds the exe, so as to inject date/other infos into it.
         
     | 
| 32 | 
         
            +
            *.manifest
         
     | 
| 33 | 
         
            +
            *.spec
         
     | 
| 34 | 
         
            +
             
     | 
| 35 | 
         
            +
            # Installer logs
         
     | 
| 36 | 
         
            +
            pip-log.txt
         
     | 
| 37 | 
         
            +
            pip-delete-this-directory.txt
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
            # Unit test / coverage reports
         
     | 
| 40 | 
         
            +
            htmlcov/
         
     | 
| 41 | 
         
            +
            .tox/
         
     | 
| 42 | 
         
            +
            .nox/
         
     | 
| 43 | 
         
            +
            .coverage
         
     | 
| 44 | 
         
            +
            .coverage.*
         
     | 
| 45 | 
         
            +
            .cache
         
     | 
| 46 | 
         
            +
            nosetests.xml
         
     | 
| 47 | 
         
            +
            coverage.xml
         
     | 
| 48 | 
         
            +
            *.cover
         
     | 
| 49 | 
         
            +
            *.py,cover
         
     | 
| 50 | 
         
            +
            .hypothesis/
         
     | 
| 51 | 
         
            +
            .pytest_cache/
         
     | 
| 52 | 
         
            +
            cover/
         
     | 
| 53 | 
         
            +
             
     | 
| 54 | 
         
            +
            # Translations
         
     | 
| 55 | 
         
            +
            *.mo
         
     | 
| 56 | 
         
            +
            *.pot
         
     | 
| 57 | 
         
            +
             
     | 
| 58 | 
         
            +
            # Django stuff:
         
     | 
| 59 | 
         
            +
            *.log
         
     | 
| 60 | 
         
            +
            local_settings.py
         
     | 
| 61 | 
         
            +
            db.sqlite3
         
     | 
| 62 | 
         
            +
            db.sqlite3-journal
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
            # Flask stuff:
         
     | 
| 65 | 
         
            +
            instance/
         
     | 
| 66 | 
         
            +
            .webassets-cache
         
     | 
| 67 | 
         
            +
             
     | 
| 68 | 
         
            +
            # Scrapy stuff:
         
     | 
| 69 | 
         
            +
            .scrapy
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
            # Sphinx documentation
         
     | 
| 72 | 
         
            +
            docs/_build/
         
     | 
| 73 | 
         
            +
             
     | 
| 74 | 
         
            +
            # PyBuilder
         
     | 
| 75 | 
         
            +
            .pybuilder/
         
     | 
| 76 | 
         
            +
            target/
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
            # Jupyter Notebook
         
     | 
| 79 | 
         
            +
            .ipynb_checkpoints
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
            # IPython
         
     | 
| 82 | 
         
            +
            profile_default/
         
     | 
| 83 | 
         
            +
            ipython_config.py
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
            # pyenv
         
     | 
| 86 | 
         
            +
            #   For a library or package, you might want to ignore these files since the code is
         
     | 
| 87 | 
         
            +
            #   intended to run in multiple environments; otherwise, check them in:
         
     | 
| 88 | 
         
            +
            # .python-version
         
     | 
| 89 | 
         
            +
             
     | 
| 90 | 
         
            +
            # pipenv
         
     | 
| 91 | 
         
            +
            #   According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
         
     | 
| 92 | 
         
            +
            #   However, in case of collaboration, if having platform-specific dependencies or dependencies
         
     | 
| 93 | 
         
            +
            #   having no cross-platform support, pipenv may install dependencies that don't work, or not
         
     | 
| 94 | 
         
            +
            #   install all needed dependencies.
         
     | 
| 95 | 
         
            +
            #Pipfile.lock
         
     | 
| 96 | 
         
            +
             
     | 
| 97 | 
         
            +
            # poetry
         
     | 
| 98 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
         
     | 
| 99 | 
         
            +
            #   This is especially recommended for binary packages to ensure reproducibility, and is more
         
     | 
| 100 | 
         
            +
            #   commonly ignored for libraries.
         
     | 
| 101 | 
         
            +
            #   https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
         
     | 
| 102 | 
         
            +
            #poetry.lock
         
     | 
| 103 | 
         
            +
             
     | 
| 104 | 
         
            +
            # pdm
         
     | 
| 105 | 
         
            +
            #   Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
         
     | 
| 106 | 
         
            +
            #pdm.lock
         
     | 
| 107 | 
         
            +
            #   pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
         
     | 
| 108 | 
         
            +
            #   in version control.
         
     | 
| 109 | 
         
            +
            #   https://pdm.fming.dev/#use-with-ide
         
     | 
| 110 | 
         
            +
            .pdm.toml
         
     | 
| 111 | 
         
            +
             
     | 
| 112 | 
         
            +
            # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
         
     | 
| 113 | 
         
            +
            __pypackages__/
         
     | 
| 114 | 
         
            +
             
     | 
| 115 | 
         
            +
            # Celery stuff
         
     | 
| 116 | 
         
            +
            celerybeat-schedule
         
     | 
| 117 | 
         
            +
            celerybeat.pid
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
            # SageMath parsed files
         
     | 
| 120 | 
         
            +
            *.sage.py
         
     | 
| 121 | 
         
            +
             
     | 
| 122 | 
         
            +
            # Environments
         
     | 
| 123 | 
         
            +
            .env
         
     | 
| 124 | 
         
            +
            .venv
         
     | 
| 125 | 
         
            +
            env/
         
     | 
| 126 | 
         
            +
            venv/
         
     | 
| 127 | 
         
            +
            ENV/
         
     | 
| 128 | 
         
            +
            env.bak/
         
     | 
| 129 | 
         
            +
            venv.bak/
         
     | 
| 130 | 
         
            +
             
     | 
| 131 | 
         
            +
            # Spyder project settings
         
     | 
| 132 | 
         
            +
            .spyderproject
         
     | 
| 133 | 
         
            +
            .spyproject
         
     | 
| 134 | 
         
            +
             
     | 
| 135 | 
         
            +
            # Rope project settings
         
     | 
| 136 | 
         
            +
            .ropeproject
         
     | 
| 137 | 
         
            +
             
     | 
| 138 | 
         
            +
            # mkdocs documentation
         
     | 
| 139 | 
         
            +
            /site
         
     | 
| 140 | 
         
            +
             
     | 
| 141 | 
         
            +
            # mypy
         
     | 
| 142 | 
         
            +
            .mypy_cache/
         
     | 
| 143 | 
         
            +
            .dmypy.json
         
     | 
| 144 | 
         
            +
            dmypy.json
         
     | 
| 145 | 
         
            +
             
     | 
| 146 | 
         
            +
            # Pyre type checker
         
     | 
| 147 | 
         
            +
            .pyre/
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
            # pytype static type analyzer
         
     | 
| 150 | 
         
            +
            .pytype/
         
     | 
| 151 | 
         
            +
             
     | 
| 152 | 
         
            +
            # Cython debug symbols
         
     | 
| 153 | 
         
            +
            cython_debug/
         
     | 
| 154 | 
         
            +
             
     | 
| 155 | 
         
            +
            # PyCharm
         
     | 
| 156 | 
         
            +
            #  JetBrains specific template is maintained in a separate JetBrains.gitignore that can
         
     | 
| 157 | 
         
            +
            #  be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
         
     | 
| 158 | 
         
            +
            #  and can be added to the global gitignore or merged into this file.  For a more nuclear
         
     | 
| 159 | 
         
            +
            #  option (not recommended) you can uncomment the following to ignore the entire idea folder.
         
     | 
| 160 | 
         
            +
            #.idea/
         
     | 
    	
        app.py
    CHANGED
    
    | 
         @@ -11,9 +11,9 @@ if __name__ == "__main__": 
     | 
|
| 11 | 
         
             
                server_port = 7860
         
     | 
| 12 | 
         | 
| 13 | 
         
             
                # Params below is used when ui_mode = "modelscope"
         
     | 
| 14 | 
         
            -
                edition = " 
     | 
| 15 | 
         
            -
                config_path = "config/ 
     | 
| 16 | 
         
            -
                model_name = "models/Diffusion_Transformer/ 
     | 
| 17 | 
         
             
                savedir_sample = "samples"
         
     | 
| 18 | 
         | 
| 19 | 
         
             
                if ui_mode == "modelscope":
         
     | 
| 
         | 
|
| 11 | 
         
             
                server_port = 7860
         
     | 
| 12 | 
         | 
| 13 | 
         
             
                # Params below is used when ui_mode = "modelscope"
         
     | 
| 14 | 
         
            +
                edition = "v3"
         
     | 
| 15 | 
         
            +
                config_path = "config/easyanimate_video_slicevae_motion_module_v3.yaml"
         
     | 
| 16 | 
         
            +
                model_name = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-InP-512x512"
         
     | 
| 17 | 
         
             
                savedir_sample = "samples"
         
     | 
| 18 | 
         | 
| 19 | 
         
             
                if ui_mode == "modelscope":
         
     | 
    	
        easyanimate/api/api.py
    CHANGED
    
    | 
         @@ -1,10 +1,14 @@ 
     | 
|
| 1 | 
         
             
            import io
         
     | 
| 
         | 
|
| 2 | 
         
             
            import base64
         
     | 
| 3 | 
         
             
            import torch
         
     | 
| 4 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 
         | 
|
| 5 | 
         | 
| 6 | 
         
             
            from fastapi import FastAPI
         
     | 
| 7 | 
         
             
            from io import BytesIO
         
     | 
| 
         | 
|
| 8 | 
         | 
| 9 | 
         
             
            # Function to encode a file to Base64
         
     | 
| 10 | 
         
             
            def encode_file_to_base64(file_path):
         
     | 
| 
         @@ -59,16 +63,34 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): 
     | 
|
| 59 | 
         
             
                    lora_model_path = datas.get('lora_model_path', 'none')
         
     | 
| 60 | 
         
             
                    lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
         
     | 
| 61 | 
         
             
                    prompt_textbox = datas.get('prompt_textbox', None)
         
     | 
| 62 | 
         
            -
                    negative_prompt_textbox = datas.get('negative_prompt_textbox', '')
         
     | 
| 63 | 
         
             
                    sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
         
     | 
| 64 | 
         
             
                    sample_step_slider = datas.get('sample_step_slider', 30)
         
     | 
| 
         | 
|
| 65 | 
         
             
                    width_slider = datas.get('width_slider', 672)
         
     | 
| 66 | 
         
             
                    height_slider = datas.get('height_slider', 384)
         
     | 
| 
         | 
|
| 67 | 
         
             
                    is_image = datas.get('is_image', False)
         
     | 
| 
         | 
|
| 68 | 
         
             
                    length_slider = datas.get('length_slider', 144)
         
     | 
| 
         | 
|
| 
         | 
|
| 69 | 
         
             
                    cfg_scale_slider = datas.get('cfg_scale_slider', 6)
         
     | 
| 
         | 
|
| 
         | 
|
| 70 | 
         
             
                    seed_textbox = datas.get("seed_textbox", 43)
         
     | 
| 71 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 72 | 
         
             
                    try:
         
     | 
| 73 | 
         
             
                        save_sample_path, comment = controller.generate(
         
     | 
| 74 | 
         
             
                            "",
         
     | 
| 
         @@ -80,17 +102,29 @@ def infer_forward_api(_: gr.Blocks, app: FastAPI, controller): 
     | 
|
| 80 | 
         
             
                            negative_prompt_textbox, 
         
     | 
| 81 | 
         
             
                            sampler_dropdown, 
         
     | 
| 82 | 
         
             
                            sample_step_slider, 
         
     | 
| 
         | 
|
| 83 | 
         
             
                            width_slider, 
         
     | 
| 84 | 
         
             
                            height_slider, 
         
     | 
| 85 | 
         
            -
                             
     | 
| 
         | 
|
| 86 | 
         
             
                            length_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 87 | 
         
             
                            cfg_scale_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 88 | 
         
             
                            seed_textbox,
         
     | 
| 89 | 
         
             
                            is_api = True,
         
     | 
| 90 | 
         
             
                        )
         
     | 
| 91 | 
         
             
                    except Exception as e:
         
     | 
| 
         | 
|
| 92 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 
         | 
|
| 93 | 
         
             
                        save_sample_path = ""
         
     | 
| 94 | 
         
             
                        comment = f"Error. error information is {str(e)}"
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
             
            import io
         
     | 
| 2 | 
         
            +
            import gc
         
     | 
| 3 | 
         
             
            import base64
         
     | 
| 4 | 
         
             
            import torch
         
     | 
| 5 | 
         
             
            import gradio as gr
         
     | 
| 6 | 
         
            +
            import tempfile
         
     | 
| 7 | 
         
            +
            import hashlib
         
     | 
| 8 | 
         | 
| 9 | 
         
             
            from fastapi import FastAPI
         
     | 
| 10 | 
         
             
            from io import BytesIO
         
     | 
| 11 | 
         
            +
            from PIL import Image
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            # Function to encode a file to Base64
         
     | 
| 14 | 
         
             
            def encode_file_to_base64(file_path):
         
     | 
| 
         | 
|
| 63 | 
         
             
                    lora_model_path = datas.get('lora_model_path', 'none')
         
     | 
| 64 | 
         
             
                    lora_alpha_slider = datas.get('lora_alpha_slider', 0.55)
         
     | 
| 65 | 
         
             
                    prompt_textbox = datas.get('prompt_textbox', None)
         
     | 
| 66 | 
         
            +
                    negative_prompt_textbox = datas.get('negative_prompt_textbox', 'The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion.')
         
     | 
| 67 | 
         
             
                    sampler_dropdown = datas.get('sampler_dropdown', 'Euler')
         
     | 
| 68 | 
         
             
                    sample_step_slider = datas.get('sample_step_slider', 30)
         
     | 
| 69 | 
         
            +
                    resize_method = datas.get('resize_method', "Generate by")
         
     | 
| 70 | 
         
             
                    width_slider = datas.get('width_slider', 672)
         
     | 
| 71 | 
         
             
                    height_slider = datas.get('height_slider', 384)
         
     | 
| 72 | 
         
            +
                    base_resolution = datas.get('base_resolution', 512)
         
     | 
| 73 | 
         
             
                    is_image = datas.get('is_image', False)
         
     | 
| 74 | 
         
            +
                    generation_method = datas.get('generation_method', False)
         
     | 
| 75 | 
         
             
                    length_slider = datas.get('length_slider', 144)
         
     | 
| 76 | 
         
            +
                    overlap_video_length = datas.get('overlap_video_length', 4)
         
     | 
| 77 | 
         
            +
                    partial_video_length = datas.get('partial_video_length', 72)
         
     | 
| 78 | 
         
             
                    cfg_scale_slider = datas.get('cfg_scale_slider', 6)
         
     | 
| 79 | 
         
            +
                    start_image = datas.get('start_image', None)
         
     | 
| 80 | 
         
            +
                    end_image = datas.get('end_image', None)
         
     | 
| 81 | 
         
             
                    seed_textbox = datas.get("seed_textbox", 43)
         
     | 
| 82 | 
         | 
| 83 | 
         
            +
                    generation_method = "Image Generation" if is_image else generation_method
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    temp_directory = tempfile.gettempdir()
         
     | 
| 86 | 
         
            +
                    if start_image is not None:
         
     | 
| 87 | 
         
            +
                        start_image = base64.b64decode(start_image)
         
     | 
| 88 | 
         
            +
                        start_image = [Image.open(BytesIO(start_image))]
         
     | 
| 89 | 
         
            +
                    
         
     | 
| 90 | 
         
            +
                    if end_image is not None:
         
     | 
| 91 | 
         
            +
                        end_image = base64.b64decode(end_image)
         
     | 
| 92 | 
         
            +
                        end_image = [Image.open(BytesIO(end_image))]
         
     | 
| 93 | 
         
            +
             
     | 
| 94 | 
         
             
                    try:
         
     | 
| 95 | 
         
             
                        save_sample_path, comment = controller.generate(
         
     | 
| 96 | 
         
             
                            "",
         
     | 
| 
         | 
|
| 102 | 
         
             
                            negative_prompt_textbox, 
         
     | 
| 103 | 
         
             
                            sampler_dropdown, 
         
     | 
| 104 | 
         
             
                            sample_step_slider, 
         
     | 
| 105 | 
         
            +
                            resize_method,
         
     | 
| 106 | 
         
             
                            width_slider, 
         
     | 
| 107 | 
         
             
                            height_slider, 
         
     | 
| 108 | 
         
            +
                            base_resolution,
         
     | 
| 109 | 
         
            +
                            generation_method,
         
     | 
| 110 | 
         
             
                            length_slider, 
         
     | 
| 111 | 
         
            +
                            overlap_video_length, 
         
     | 
| 112 | 
         
            +
                            partial_video_length, 
         
     | 
| 113 | 
         
             
                            cfg_scale_slider, 
         
     | 
| 114 | 
         
            +
                            start_image,
         
     | 
| 115 | 
         
            +
                            end_image,
         
     | 
| 116 | 
         
             
                            seed_textbox,
         
     | 
| 117 | 
         
             
                            is_api = True,
         
     | 
| 118 | 
         
             
                        )
         
     | 
| 119 | 
         
             
                    except Exception as e:
         
     | 
| 120 | 
         
            +
                        gc.collect()
         
     | 
| 121 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 122 | 
         
            +
                        torch.cuda.ipc_collect()
         
     | 
| 123 | 
         
             
                        save_sample_path = ""
         
     | 
| 124 | 
         
             
                        comment = f"Error. error information is {str(e)}"
         
     | 
| 125 | 
         
            +
                        return {"message": comment}
         
     | 
| 126 | 
         
            +
                    
         
     | 
| 127 | 
         
            +
                    if save_sample_path != "":
         
     | 
| 128 | 
         
            +
                        return {"message": comment, "save_sample_path": save_sample_path, "base64_encoding": encode_file_to_base64(save_sample_path)}
         
     | 
| 129 | 
         
            +
                    else:
         
     | 
| 130 | 
         
            +
                        return {"message": comment, "save_sample_path": save_sample_path}
         
     | 
    	
        easyanimate/api/post_infer.py
    CHANGED
    
    | 
         @@ -26,7 +26,7 @@ def post_update_edition(edition, url='http://0.0.0.0:7860'): 
     | 
|
| 26 | 
         
             
                data = r.content.decode('utf-8')
         
     | 
| 27 | 
         
             
                return data
         
     | 
| 28 | 
         | 
| 29 | 
         
            -
            def post_infer( 
     | 
| 30 | 
         
             
                datas = json.dumps({
         
     | 
| 31 | 
         
             
                    "base_model_path": "none",
         
     | 
| 32 | 
         
             
                    "motion_module_path": "none",
         
     | 
| 
         @@ -38,7 +38,7 @@ def post_infer(is_image, length_slider, url='http://127.0.0.1:7860'): 
     | 
|
| 38 | 
         
             
                    "sample_step_slider": 30, 
         
     | 
| 39 | 
         
             
                    "width_slider": 672, 
         
     | 
| 40 | 
         
             
                    "height_slider": 384, 
         
     | 
| 41 | 
         
            -
                    " 
     | 
| 42 | 
         
             
                    "length_slider": length_slider,
         
     | 
| 43 | 
         
             
                    "cfg_scale_slider": 6,
         
     | 
| 44 | 
         
             
                    "seed_textbox": 43,
         
     | 
| 
         @@ -55,29 +55,31 @@ if __name__ == '__main__': 
     | 
|
| 55 | 
         
             
                # -------------------------- #
         
     | 
| 56 | 
         
             
                #  Step 1: update edition
         
     | 
| 57 | 
         
             
                # -------------------------- #
         
     | 
| 58 | 
         
            -
                edition = " 
     | 
| 59 | 
         
             
                outputs = post_update_edition(edition)
         
     | 
| 60 | 
         
             
                print('Output update edition: ', outputs)
         
     | 
| 61 | 
         | 
| 62 | 
         
             
                # -------------------------- #
         
     | 
| 63 | 
         
             
                #  Step 2: update edition
         
     | 
| 64 | 
         
             
                # -------------------------- #
         
     | 
| 65 | 
         
            -
                diffusion_transformer_path = " 
     | 
| 66 | 
         
             
                outputs = post_diffusion_transformer(diffusion_transformer_path)
         
     | 
| 67 | 
         
             
                print('Output update edition: ', outputs)
         
     | 
| 68 | 
         | 
| 69 | 
         
             
                # -------------------------- #
         
     | 
| 70 | 
         
             
                #  Step 3: infer
         
     | 
| 71 | 
         
             
                # -------------------------- #
         
     | 
| 72 | 
         
            -
                 
     | 
| 73 | 
         
            -
                 
     | 
| 74 | 
         
            -
                 
     | 
| 
         | 
|
| 75 | 
         | 
| 76 | 
         
             
                # Get decoded data
         
     | 
| 77 | 
         
             
                outputs = json.loads(outputs)
         
     | 
| 78 | 
         
             
                base64_encoding = outputs["base64_encoding"]
         
     | 
| 79 | 
         
             
                decoded_data = base64.b64decode(base64_encoding)
         
     | 
| 80 | 
         | 
| 
         | 
|
| 81 | 
         
             
                if is_image or length_slider == 1:
         
     | 
| 82 | 
         
             
                    file_path = "1.png"
         
     | 
| 83 | 
         
             
                else:
         
     | 
| 
         | 
|
| 26 | 
         
             
                data = r.content.decode('utf-8')
         
     | 
| 27 | 
         
             
                return data
         
     | 
| 28 | 
         | 
| 29 | 
         
            +
            def post_infer(generation_method, length_slider, url='http://127.0.0.1:7860'):
         
     | 
| 30 | 
         
             
                datas = json.dumps({
         
     | 
| 31 | 
         
             
                    "base_model_path": "none",
         
     | 
| 32 | 
         
             
                    "motion_module_path": "none",
         
     | 
| 
         | 
|
| 38 | 
         
             
                    "sample_step_slider": 30, 
         
     | 
| 39 | 
         
             
                    "width_slider": 672, 
         
     | 
| 40 | 
         
             
                    "height_slider": 384, 
         
     | 
| 41 | 
         
            +
                    "generation_method": "Video Generation",
         
     | 
| 42 | 
         
             
                    "length_slider": length_slider,
         
     | 
| 43 | 
         
             
                    "cfg_scale_slider": 6,
         
     | 
| 44 | 
         
             
                    "seed_textbox": 43,
         
     | 
| 
         | 
|
| 55 | 
         
             
                # -------------------------- #
         
     | 
| 56 | 
         
             
                #  Step 1: update edition
         
     | 
| 57 | 
         
             
                # -------------------------- #
         
     | 
| 58 | 
         
            +
                edition = "v3"
         
     | 
| 59 | 
         
             
                outputs = post_update_edition(edition)
         
     | 
| 60 | 
         
             
                print('Output update edition: ', outputs)
         
     | 
| 61 | 
         | 
| 62 | 
         
             
                # -------------------------- #
         
     | 
| 63 | 
         
             
                #  Step 2: update edition
         
     | 
| 64 | 
         
             
                # -------------------------- #
         
     | 
| 65 | 
         
            +
                diffusion_transformer_path = "models/Diffusion_Transformer/EasyAnimateV3-XL-2-512x512"
         
     | 
| 66 | 
         
             
                outputs = post_diffusion_transformer(diffusion_transformer_path)
         
     | 
| 67 | 
         
             
                print('Output update edition: ', outputs)
         
     | 
| 68 | 
         | 
| 69 | 
         
             
                # -------------------------- #
         
     | 
| 70 | 
         
             
                #  Step 3: infer
         
     | 
| 71 | 
         
             
                # -------------------------- #
         
     | 
| 72 | 
         
            +
                # "Video Generation" and "Image Generation"
         
     | 
| 73 | 
         
            +
                generation_method = "Video Generation"
         
     | 
| 74 | 
         
            +
                length_slider = 72
         
     | 
| 75 | 
         
            +
                outputs = post_infer(generation_method, length_slider)
         
     | 
| 76 | 
         | 
| 77 | 
         
             
                # Get decoded data
         
     | 
| 78 | 
         
             
                outputs = json.loads(outputs)
         
     | 
| 79 | 
         
             
                base64_encoding = outputs["base64_encoding"]
         
     | 
| 80 | 
         
             
                decoded_data = base64.b64decode(base64_encoding)
         
     | 
| 81 | 
         | 
| 82 | 
         
            +
                is_image = True if generation_method == "Image Generation" else False
         
     | 
| 83 | 
         
             
                if is_image or length_slider == 1:
         
     | 
| 84 | 
         
             
                    file_path = "1.png"
         
     | 
| 85 | 
         
             
                else:
         
     | 
    	
        easyanimate/data/dataset_image_video.py
    CHANGED
    
    | 
         @@ -12,6 +12,7 @@ import gc 
     | 
|
| 12 | 
         
             
            import numpy as np
         
     | 
| 13 | 
         
             
            import torch
         
     | 
| 14 | 
         
             
            import torchvision.transforms as transforms
         
     | 
| 
         | 
|
| 15 | 
         
             
            from func_timeout import func_timeout, FunctionTimedOut
         
     | 
| 16 | 
         
             
            from decord import VideoReader
         
     | 
| 17 | 
         
             
            from PIL import Image
         
     | 
| 
         @@ -21,6 +22,52 @@ from contextlib import contextmanager 
     | 
|
| 21 | 
         | 
| 22 | 
         
             
            VIDEO_READER_TIMEOUT = 20
         
     | 
| 23 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 24 | 
         
             
            class ImageVideoSampler(BatchSampler):
         
     | 
| 25 | 
         
             
                """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
         
     | 
| 26 | 
         | 
| 
         @@ -88,10 +135,11 @@ class ImageVideoDataset(Dataset): 
     | 
|
| 88 | 
         
             
                        video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
         
     | 
| 89 | 
         
             
                        image_sample_size=512,
         
     | 
| 90 | 
         
             
                        video_repeat=0,
         
     | 
| 91 | 
         
            -
                        text_drop_ratio 
     | 
| 92 | 
         
             
                        enable_bucket=False,
         
     | 
| 93 | 
         
             
                        video_length_drop_start=0.1, 
         
     | 
| 94 | 
         
             
                        video_length_drop_end=0.9,
         
     | 
| 
         | 
|
| 95 | 
         
             
                    ):
         
     | 
| 96 | 
         
             
                    # Loading annotations from files
         
     | 
| 97 | 
         
             
                    print(f"loading annotations from {ann_path} ...")
         
     | 
| 
         @@ -120,6 +168,8 @@ class ImageVideoDataset(Dataset): 
     | 
|
| 120 | 
         
             
                    # TODO: enable bucket training
         
     | 
| 121 | 
         
             
                    self.enable_bucket = enable_bucket
         
     | 
| 122 | 
         
             
                    self.text_drop_ratio = text_drop_ratio
         
     | 
| 
         | 
|
| 
         | 
|
| 123 | 
         
             
                    self.video_length_drop_start = video_length_drop_start
         
     | 
| 124 | 
         
             
                    self.video_length_drop_end = video_length_drop_end
         
     | 
| 125 | 
         | 
| 
         @@ -165,7 +215,7 @@ class ImageVideoDataset(Dataset): 
     | 
|
| 165 | 
         | 
| 166 | 
         
             
                            video_length = int(self.video_length_drop_end * len(video_reader))
         
     | 
| 167 | 
         
             
                            clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
         
     | 
| 168 | 
         
            -
                            start_idx   = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length)
         
     | 
| 169 | 
         
             
                            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
         
     | 
| 170 | 
         | 
| 171 | 
         
             
                            try:
         
     | 
| 
         @@ -230,6 +280,17 @@ class ImageVideoDataset(Dataset): 
     | 
|
| 230 | 
         
             
                        except Exception as e:
         
     | 
| 231 | 
         
             
                            print(e, self.dataset[idx % len(self.dataset)])
         
     | 
| 232 | 
         
             
                            idx = random.randint(0, self.length-1)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 233 | 
         
             
                    return sample
         
     | 
| 234 | 
         | 
| 235 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         @@ -238,4 +299,4 @@ if __name__ == "__main__": 
     | 
|
| 238 | 
         
             
                )
         
     | 
| 239 | 
         
             
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
         
     | 
| 240 | 
         
             
                for idx, batch in enumerate(dataloader):
         
     | 
| 241 | 
         
            -
                    print(batch["pixel_values"].shape, len(batch["text"]))
         
     | 
| 
         | 
|
| 12 | 
         
             
            import numpy as np
         
     | 
| 13 | 
         
             
            import torch
         
     | 
| 14 | 
         
             
            import torchvision.transforms as transforms
         
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
             
            from func_timeout import func_timeout, FunctionTimedOut
         
     | 
| 17 | 
         
             
            from decord import VideoReader
         
     | 
| 18 | 
         
             
            from PIL import Image
         
     | 
| 
         | 
|
| 22 | 
         | 
| 23 | 
         
             
            VIDEO_READER_TIMEOUT = 20
         
     | 
| 24 | 
         | 
| 25 | 
         
            +
            def get_random_mask(shape):
         
     | 
| 26 | 
         
            +
                f, c, h, w = shape
         
     | 
| 27 | 
         
            +
                
         
     | 
| 28 | 
         
            +
                if f != 1:
         
     | 
| 29 | 
         
            +
                    mask_index = np.random.randint(1, 4)
         
     | 
| 30 | 
         
            +
                else:
         
     | 
| 31 | 
         
            +
                    mask_index = np.random.randint(1, 2)
         
     | 
| 32 | 
         
            +
                mask = torch.zeros((f, 1, h, w), dtype=torch.uint8)
         
     | 
| 33 | 
         
            +
             
     | 
| 34 | 
         
            +
                if mask_index == 0:
         
     | 
| 35 | 
         
            +
                    center_x = torch.randint(0, w, (1,)).item()
         
     | 
| 36 | 
         
            +
                    center_y = torch.randint(0, h, (1,)).item()
         
     | 
| 37 | 
         
            +
                    block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()  # 方块的宽度范围
         
     | 
| 38 | 
         
            +
                    block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()  # 方块的高度范围
         
     | 
| 39 | 
         
            +
             
     | 
| 40 | 
         
            +
                    start_x = max(center_x - block_size_x // 2, 0)
         
     | 
| 41 | 
         
            +
                    end_x = min(center_x + block_size_x // 2, w)
         
     | 
| 42 | 
         
            +
                    start_y = max(center_y - block_size_y // 2, 0)
         
     | 
| 43 | 
         
            +
                    end_y = min(center_y + block_size_y // 2, h)
         
     | 
| 44 | 
         
            +
                    mask[:, :, start_y:end_y, start_x:end_x] = 1
         
     | 
| 45 | 
         
            +
                elif mask_index == 1:
         
     | 
| 46 | 
         
            +
                    mask[:, :, :, :] = 1
         
     | 
| 47 | 
         
            +
                elif mask_index == 2:
         
     | 
| 48 | 
         
            +
                    mask_frame_index = np.random.randint(1, 5)
         
     | 
| 49 | 
         
            +
                    mask[mask_frame_index:, :, :, :] = 1
         
     | 
| 50 | 
         
            +
                elif mask_index == 3:
         
     | 
| 51 | 
         
            +
                    mask_frame_index = np.random.randint(1, 5)
         
     | 
| 52 | 
         
            +
                    mask[mask_frame_index:-mask_frame_index, :, :, :] = 1
         
     | 
| 53 | 
         
            +
                elif mask_index == 4:
         
     | 
| 54 | 
         
            +
                    center_x = torch.randint(0, w, (1,)).item()
         
     | 
| 55 | 
         
            +
                    center_y = torch.randint(0, h, (1,)).item()
         
     | 
| 56 | 
         
            +
                    block_size_x = torch.randint(w // 4, w // 4 * 3, (1,)).item()  # 方块的宽度范围
         
     | 
| 57 | 
         
            +
                    block_size_y = torch.randint(h // 4, h // 4 * 3, (1,)).item()  # 方块的高度范围
         
     | 
| 58 | 
         
            +
             
     | 
| 59 | 
         
            +
                    start_x = max(center_x - block_size_x // 2, 0)
         
     | 
| 60 | 
         
            +
                    end_x = min(center_x + block_size_x // 2, w)
         
     | 
| 61 | 
         
            +
                    start_y = max(center_y - block_size_y // 2, 0)
         
     | 
| 62 | 
         
            +
                    end_y = min(center_y + block_size_y // 2, h)
         
     | 
| 63 | 
         
            +
             
     | 
| 64 | 
         
            +
                    mask_frame_before = np.random.randint(0, f // 2)
         
     | 
| 65 | 
         
            +
                    mask_frame_after = np.random.randint(f // 2, f)
         
     | 
| 66 | 
         
            +
                    mask[mask_frame_before:mask_frame_after, :, start_y:end_y, start_x:end_x] = 1
         
     | 
| 67 | 
         
            +
                else:
         
     | 
| 68 | 
         
            +
                    raise ValueError(f"The mask_index {mask_index} is not define")
         
     | 
| 69 | 
         
            +
                return mask
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
             
            class ImageVideoSampler(BatchSampler):
         
     | 
| 72 | 
         
             
                """A sampler wrapper for grouping images with similar aspect ratio into a same batch.
         
     | 
| 73 | 
         | 
| 
         | 
|
| 135 | 
         
             
                        video_sample_size=512, video_sample_stride=4, video_sample_n_frames=16,
         
     | 
| 136 | 
         
             
                        image_sample_size=512,
         
     | 
| 137 | 
         
             
                        video_repeat=0,
         
     | 
| 138 | 
         
            +
                        text_drop_ratio=-1,
         
     | 
| 139 | 
         
             
                        enable_bucket=False,
         
     | 
| 140 | 
         
             
                        video_length_drop_start=0.1, 
         
     | 
| 141 | 
         
             
                        video_length_drop_end=0.9,
         
     | 
| 142 | 
         
            +
                        enable_inpaint=False,
         
     | 
| 143 | 
         
             
                    ):
         
     | 
| 144 | 
         
             
                    # Loading annotations from files
         
     | 
| 145 | 
         
             
                    print(f"loading annotations from {ann_path} ...")
         
     | 
| 
         | 
|
| 168 | 
         
             
                    # TODO: enable bucket training
         
     | 
| 169 | 
         
             
                    self.enable_bucket = enable_bucket
         
     | 
| 170 | 
         
             
                    self.text_drop_ratio = text_drop_ratio
         
     | 
| 171 | 
         
            +
                    self.enable_inpaint  = enable_inpaint
         
     | 
| 172 | 
         
            +
             
     | 
| 173 | 
         
             
                    self.video_length_drop_start = video_length_drop_start
         
     | 
| 174 | 
         
             
                    self.video_length_drop_end = video_length_drop_end
         
     | 
| 175 | 
         | 
| 
         | 
|
| 215 | 
         | 
| 216 | 
         
             
                            video_length = int(self.video_length_drop_end * len(video_reader))
         
     | 
| 217 | 
         
             
                            clip_length = min(video_length, (min_sample_n_frames - 1) * self.video_sample_stride + 1)
         
     | 
| 218 | 
         
            +
                            start_idx   = random.randint(int(self.video_length_drop_start * video_length), video_length - clip_length) if video_length != clip_length else 0
         
     | 
| 219 | 
         
             
                            batch_index = np.linspace(start_idx, start_idx + clip_length - 1, min_sample_n_frames, dtype=int)
         
     | 
| 220 | 
         | 
| 221 | 
         
             
                            try:
         
     | 
| 
         | 
|
| 280 | 
         
             
                        except Exception as e:
         
     | 
| 281 | 
         
             
                            print(e, self.dataset[idx % len(self.dataset)])
         
     | 
| 282 | 
         
             
                            idx = random.randint(0, self.length-1)
         
     | 
| 283 | 
         
            +
             
     | 
| 284 | 
         
            +
                    if self.enable_inpaint and not self.enable_bucket:
         
     | 
| 285 | 
         
            +
                        mask = get_random_mask(pixel_values.size())
         
     | 
| 286 | 
         
            +
                        mask_pixel_values = pixel_values * (1 - mask) + torch.ones_like(pixel_values) * -1 * mask
         
     | 
| 287 | 
         
            +
                        sample["mask_pixel_values"] = mask_pixel_values
         
     | 
| 288 | 
         
            +
                        sample["mask"] = mask
         
     | 
| 289 | 
         
            +
             
     | 
| 290 | 
         
            +
                        clip_pixel_values = sample["pixel_values"][0].permute(1, 2, 0).contiguous()
         
     | 
| 291 | 
         
            +
                        clip_pixel_values = (clip_pixel_values * 0.5 + 0.5) * 255
         
     | 
| 292 | 
         
            +
                        sample["clip_pixel_values"] = clip_pixel_values
         
     | 
| 293 | 
         
            +
             
     | 
| 294 | 
         
             
                    return sample
         
     | 
| 295 | 
         | 
| 296 | 
         
             
            if __name__ == "__main__":
         
     | 
| 
         | 
|
| 299 | 
         
             
                )
         
     | 
| 300 | 
         
             
                dataloader = torch.utils.data.DataLoader(dataset, batch_size=4, num_workers=16)
         
     | 
| 301 | 
         
             
                for idx, batch in enumerate(dataloader):
         
     | 
| 302 | 
         
            +
                    print(batch["pixel_values"].shape, len(batch["text"]))
         
     | 
    	
        easyanimate/models/attention.py
    CHANGED
    
    | 
         @@ -11,17 +11,25 @@ 
     | 
|
| 11 | 
         
             
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 14 | 
         
            -
            import math
         
     | 
| 15 | 
         
             
            from typing import Any, Dict, Optional
         
     | 
| 16 | 
         | 
| 
         | 
|
| 
         | 
|
| 17 | 
         
             
            import torch
         
     | 
| 18 | 
         
             
            import torch.nn.functional as F
         
     | 
| 19 | 
         
             
            import torch.nn.init as init
         
     | 
| 20 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 21 | 
         
             
            from diffusers.models.attention import AdaLayerNorm, FeedForward
         
     | 
| 22 | 
         
            -
            from diffusers.models.attention_processor import Attention
         
     | 
| 23 | 
         
             
            from diffusers.models.embeddings import SinusoidalPositionalEmbedding
         
     | 
| 24 | 
         
            -
            from diffusers.models.lora import LoRACompatibleLinear
         
     | 
| 25 | 
         
             
            from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
         
     | 
| 26 | 
         
             
            from diffusers.utils import USE_PEFT_BACKEND
         
     | 
| 27 | 
         
             
            from diffusers.utils.import_utils import is_xformers_available
         
     | 
| 
         @@ -29,7 +37,8 @@ from diffusers.utils.torch_utils import maybe_allow_in_graph 
     | 
|
| 29 | 
         
             
            from einops import rearrange, repeat
         
     | 
| 30 | 
         
             
            from torch import nn
         
     | 
| 31 | 
         | 
| 32 | 
         
            -
            from .motion_module import get_motion_module
         
     | 
| 
         | 
|
| 33 | 
         | 
| 34 | 
         
             
            if is_xformers_available():
         
     | 
| 35 | 
         
             
                import xformers
         
     | 
| 
         @@ -38,6 +47,13 @@ else: 
     | 
|
| 38 | 
         
             
                xformers = None
         
     | 
| 39 | 
         | 
| 40 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 41 | 
         
             
            @maybe_allow_in_graph
         
     | 
| 42 | 
         
             
            class GatedSelfAttentionDense(nn.Module):
         
     | 
| 43 | 
         
             
                r"""
         
     | 
| 
         @@ -59,8 +75,8 @@ class GatedSelfAttentionDense(nn.Module): 
     | 
|
| 59 | 
         
             
                    self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
         
     | 
| 60 | 
         
             
                    self.ff = FeedForward(query_dim, activation_fn="geglu")
         
     | 
| 61 | 
         | 
| 62 | 
         
            -
                    self.norm1 =  
     | 
| 63 | 
         
            -
                    self.norm2 =  
     | 
| 64 | 
         | 
| 65 | 
         
             
                    self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
         
     | 
| 66 | 
         
             
                    self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
         
     | 
| 
         @@ -80,14 +96,6 @@ class GatedSelfAttentionDense(nn.Module): 
     | 
|
| 80 | 
         
             
                    return x
         
     | 
| 81 | 
         | 
| 82 | 
         | 
| 83 | 
         
            -
            def zero_module(module):
         
     | 
| 84 | 
         
            -
                # Zero out the parameters of a module and return it.
         
     | 
| 85 | 
         
            -
                for p in module.parameters():
         
     | 
| 86 | 
         
            -
                    p.detach().zero_()
         
     | 
| 87 | 
         
            -
                return module
         
     | 
| 88 | 
         
            -
             
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
             
     | 
| 91 | 
         
             
            class KVCompressionCrossAttention(nn.Module):
         
     | 
| 92 | 
         
             
                r"""
         
     | 
| 93 | 
         
             
                A cross attention layer.
         
     | 
| 
         @@ -154,7 +162,7 @@ class KVCompressionCrossAttention(nn.Module): 
     | 
|
| 154 | 
         
             
                        stride=2,
         
     | 
| 155 | 
         
             
                        bias=True
         
     | 
| 156 | 
         
             
                    )
         
     | 
| 157 | 
         
            -
                    self.kv_compression_norm =  
     | 
| 158 | 
         
             
                    init.constant_(self.kv_compression.weight, 1 / 4)
         
     | 
| 159 | 
         
             
                    if self.kv_compression.bias is not None:
         
     | 
| 160 | 
         
             
                        init.constant_(self.kv_compression.bias, 0)
         
     | 
| 
         @@ -410,6 +418,8 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 410 | 
         
             
                    # motion module kwargs
         
     | 
| 411 | 
         
             
                    motion_module_type = "VanillaGrid",
         
     | 
| 412 | 
         
             
                    motion_module_kwargs = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 413 | 
         
             
                ):
         
     | 
| 414 | 
         
             
                    super().__init__()
         
     | 
| 415 | 
         
             
                    self.only_cross_attention = only_cross_attention
         
     | 
| 
         @@ -442,7 +452,7 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 442 | 
         
             
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 443 | 
         
             
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 444 | 
         
             
                    else:
         
     | 
| 445 | 
         
            -
                        self.norm1 =  
     | 
| 446 | 
         | 
| 447 | 
         
             
                    self.kvcompression = kvcompression
         
     | 
| 448 | 
         
             
                    if kvcompression:
         
     | 
| 
         @@ -456,16 +466,28 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 456 | 
         
             
                            upcast_attention=upcast_attention,
         
     | 
| 457 | 
         
             
                        )
         
     | 
| 458 | 
         
             
                    else:
         
     | 
| 459 | 
         
            -
                         
     | 
| 460 | 
         
            -
                             
     | 
| 461 | 
         
            -
             
     | 
| 462 | 
         
            -
             
     | 
| 463 | 
         
            -
             
     | 
| 464 | 
         
            -
             
     | 
| 465 | 
         
            -
             
     | 
| 466 | 
         
            -
             
     | 
| 467 | 
         
            -
             
     | 
| 468 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 469 | 
         | 
| 470 | 
         
             
                    self.attn_temporal = get_motion_module(
         
     | 
| 471 | 
         
             
                        in_channels = dim,
         
     | 
| 
         @@ -481,27 +503,45 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 481 | 
         
             
                        self.norm2 = (
         
     | 
| 482 | 
         
             
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 483 | 
         
             
                            if self.use_ada_layer_norm
         
     | 
| 484 | 
         
            -
                            else  
     | 
| 485 | 
         
             
                        )
         
     | 
| 486 | 
         
            -
                         
     | 
| 487 | 
         
            -
                             
     | 
| 488 | 
         
            -
             
     | 
| 489 | 
         
            -
             
     | 
| 490 | 
         
            -
             
     | 
| 491 | 
         
            -
             
     | 
| 492 | 
         
            -
             
     | 
| 493 | 
         
            -
             
     | 
| 494 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 495 | 
         
             
                    else:
         
     | 
| 496 | 
         
             
                        self.norm2 = None
         
     | 
| 497 | 
         
             
                        self.attn2 = None
         
     | 
| 498 | 
         | 
| 499 | 
         
             
                    # 3. Feed-forward
         
     | 
| 500 | 
         
             
                    if not self.use_ada_layer_norm_single:
         
     | 
| 501 | 
         
            -
                        self.norm3 =  
     | 
| 502 | 
         | 
| 503 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 504 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 505 | 
         
             
                    # 4. Fuser
         
     | 
| 506 | 
         
             
                    if attention_type == "gated" or attention_type == "gated-text-image":
         
     | 
| 507 | 
         
             
                        self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
         
     | 
| 
         @@ -654,6 +694,9 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 654 | 
         
             
                        )
         
     | 
| 655 | 
         
             
                    else:
         
     | 
| 656 | 
         
             
                        ff_output = self.ff(norm_hidden_states, scale=lora_scale)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 657 | 
         | 
| 658 | 
         
             
                    if self.use_ada_layer_norm_zero:
         
     | 
| 659 | 
         
             
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 
         @@ -723,6 +766,8 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): 
     | 
|
| 723 | 
         
             
                    attention_type: str = "default",
         
     | 
| 724 | 
         
             
                    positional_embeddings: Optional[str] = None,
         
     | 
| 725 | 
         
             
                    num_positional_embeddings: Optional[int] = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 726 | 
         
             
                ):
         
     | 
| 727 | 
         
             
                    super().__init__()
         
     | 
| 728 | 
         
             
                    self.only_cross_attention = only_cross_attention
         
     | 
| 
         @@ -755,17 +800,30 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): 
     | 
|
| 755 | 
         
             
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 756 | 
         
             
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 757 | 
         
             
                    else:
         
     | 
| 758 | 
         
            -
                        self.norm1 =  
     | 
| 759 | 
         | 
| 760 | 
         
            -
                     
     | 
| 761 | 
         
            -
                         
     | 
| 762 | 
         
            -
             
     | 
| 763 | 
         
            -
             
     | 
| 764 | 
         
            -
             
     | 
| 765 | 
         
            -
             
     | 
| 766 | 
         
            -
             
     | 
| 767 | 
         
            -
             
     | 
| 768 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 769 | 
         | 
| 770 | 
         
             
                    # 2. Cross-Attn
         
     | 
| 771 | 
         
             
                    if cross_attention_dim is not None or double_self_attention:
         
     | 
| 
         @@ -775,27 +833,45 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): 
     | 
|
| 775 | 
         
             
                        self.norm2 = (
         
     | 
| 776 | 
         
             
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 777 | 
         
             
                            if self.use_ada_layer_norm
         
     | 
| 778 | 
         
            -
                            else  
     | 
| 779 | 
         
             
                        )
         
     | 
| 780 | 
         
            -
                         
     | 
| 781 | 
         
            -
                             
     | 
| 782 | 
         
            -
             
     | 
| 783 | 
         
            -
             
     | 
| 784 | 
         
            -
             
     | 
| 785 | 
         
            -
             
     | 
| 786 | 
         
            -
             
     | 
| 787 | 
         
            -
             
     | 
| 788 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 789 | 
         
             
                    else:
         
     | 
| 790 | 
         
             
                        self.norm2 = None
         
     | 
| 791 | 
         
             
                        self.attn2 = None
         
     | 
| 792 | 
         | 
| 793 | 
         
             
                    # 3. Feed-forward
         
     | 
| 794 | 
         
             
                    if not self.use_ada_layer_norm_single:
         
     | 
| 795 | 
         
            -
                        self.norm3 =  
     | 
| 796 | 
         | 
| 797 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 798 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 799 | 
         
             
                    # 4. Fuser
         
     | 
| 800 | 
         
             
                    if attention_type == "gated" or attention_type == "gated-text-image":
         
     | 
| 801 | 
         
             
                        self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
         
     | 
| 
         @@ -927,6 +1003,9 @@ class SelfAttentionTemporalTransformerBlock(nn.Module): 
     | 
|
| 927 | 
         
             
                        )
         
     | 
| 928 | 
         
             
                    else:
         
     | 
| 929 | 
         
             
                        ff_output = self.ff(norm_hidden_states, scale=lora_scale)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 930 | 
         | 
| 931 | 
         
             
                    if self.use_ada_layer_norm_zero:
         
     | 
| 932 | 
         
             
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 
         @@ -997,6 +1076,8 @@ class KVCompressionTransformerBlock(nn.Module): 
     | 
|
| 997 | 
         
             
                    positional_embeddings: Optional[str] = None,
         
     | 
| 998 | 
         
             
                    num_positional_embeddings: Optional[int] = None,
         
     | 
| 999 | 
         
             
                    kvcompression: Optional[bool] = False,
         
     | 
| 
         | 
|
| 
         | 
|
| 1000 | 
         
             
                ):
         
     | 
| 1001 | 
         
             
                    super().__init__()
         
     | 
| 1002 | 
         
             
                    self.only_cross_attention = only_cross_attention
         
     | 
| 
         @@ -1029,7 +1110,7 @@ class KVCompressionTransformerBlock(nn.Module): 
     | 
|
| 1029 | 
         
             
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 1030 | 
         
             
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 1031 | 
         
             
                    else:
         
     | 
| 1032 | 
         
            -
                        self.norm1 =  
     | 
| 1033 | 
         | 
| 1034 | 
         
             
                    self.kvcompression = kvcompression
         
     | 
| 1035 | 
         
             
                    if kvcompression:
         
     | 
| 
         @@ -1043,16 +1124,28 @@ class KVCompressionTransformerBlock(nn.Module): 
     | 
|
| 1043 | 
         
             
                            upcast_attention=upcast_attention,
         
     | 
| 1044 | 
         
             
                        )
         
     | 
| 1045 | 
         
             
                    else:
         
     | 
| 1046 | 
         
            -
                         
     | 
| 1047 | 
         
            -
                             
     | 
| 1048 | 
         
            -
             
     | 
| 1049 | 
         
            -
             
     | 
| 1050 | 
         
            -
             
     | 
| 1051 | 
         
            -
             
     | 
| 1052 | 
         
            -
             
     | 
| 1053 | 
         
            -
             
     | 
| 1054 | 
         
            -
             
     | 
| 1055 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1056 | 
         | 
| 1057 | 
         
             
                    # 2. Cross-Attn
         
     | 
| 1058 | 
         
             
                    if cross_attention_dim is not None or double_self_attention:
         
     | 
| 
         @@ -1062,27 +1155,45 @@ class KVCompressionTransformerBlock(nn.Module): 
     | 
|
| 1062 | 
         
             
                        self.norm2 = (
         
     | 
| 1063 | 
         
             
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 1064 | 
         
             
                            if self.use_ada_layer_norm
         
     | 
| 1065 | 
         
            -
                            else  
     | 
| 1066 | 
         
             
                        )
         
     | 
| 1067 | 
         
            -
                         
     | 
| 1068 | 
         
            -
                             
     | 
| 1069 | 
         
            -
             
     | 
| 1070 | 
         
            -
             
     | 
| 1071 | 
         
            -
             
     | 
| 1072 | 
         
            -
             
     | 
| 1073 | 
         
            -
             
     | 
| 1074 | 
         
            -
             
     | 
| 1075 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1076 | 
         
             
                    else:
         
     | 
| 1077 | 
         
             
                        self.norm2 = None
         
     | 
| 1078 | 
         
             
                        self.attn2 = None
         
     | 
| 1079 | 
         | 
| 1080 | 
         
             
                    # 3. Feed-forward
         
     | 
| 1081 | 
         
             
                    if not self.use_ada_layer_norm_single:
         
     | 
| 1082 | 
         
            -
                        self.norm3 =  
     | 
| 1083 | 
         | 
| 1084 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 1085 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1086 | 
         
             
                    # 4. Fuser
         
     | 
| 1087 | 
         
             
                    if attention_type == "gated" or attention_type == "gated-text-image":
         
     | 
| 1088 | 
         
             
                        self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
         
     | 
| 
         @@ -1229,6 +1340,9 @@ class KVCompressionTransformerBlock(nn.Module): 
     | 
|
| 1229 | 
         
             
                        )
         
     | 
| 1230 | 
         
             
                    else:
         
     | 
| 1231 | 
         
             
                        ff_output = self.ff(norm_hidden_states, scale=lora_scale)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 1232 | 
         | 
| 1233 | 
         
             
                    if self.use_ada_layer_norm_zero:
         
     | 
| 1234 | 
         
             
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 
         @@ -1239,61 +1353,4 @@ class KVCompressionTransformerBlock(nn.Module): 
     | 
|
| 1239 | 
         
             
                    if hidden_states.ndim == 4:
         
     | 
| 1240 | 
         
             
                        hidden_states = hidden_states.squeeze(1)
         
     | 
| 1241 | 
         | 
| 1242 | 
         
            -
                    return hidden_states
         
     | 
| 1243 | 
         
            -
             
     | 
| 1244 | 
         
            -
             
     | 
| 1245 | 
         
            -
            class FeedForward(nn.Module):
         
     | 
| 1246 | 
         
            -
                r"""
         
     | 
| 1247 | 
         
            -
                A feed-forward layer.
         
     | 
| 1248 | 
         
            -
             
     | 
| 1249 | 
         
            -
                Parameters:
         
     | 
| 1250 | 
         
            -
                    dim (`int`): The number of channels in the input.
         
     | 
| 1251 | 
         
            -
                    dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
         
     | 
| 1252 | 
         
            -
                    mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
         
     | 
| 1253 | 
         
            -
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 1254 | 
         
            -
                    activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
         
     | 
| 1255 | 
         
            -
                    final_dropout (`bool` *optional*, defaults to False): Apply a final dropout.
         
     | 
| 1256 | 
         
            -
                """
         
     | 
| 1257 | 
         
            -
             
     | 
| 1258 | 
         
            -
                def __init__(
         
     | 
| 1259 | 
         
            -
                    self,
         
     | 
| 1260 | 
         
            -
                    dim: int,
         
     | 
| 1261 | 
         
            -
                    dim_out: Optional[int] = None,
         
     | 
| 1262 | 
         
            -
                    mult: int = 4,
         
     | 
| 1263 | 
         
            -
                    dropout: float = 0.0,
         
     | 
| 1264 | 
         
            -
                    activation_fn: str = "geglu",
         
     | 
| 1265 | 
         
            -
                    final_dropout: bool = False,
         
     | 
| 1266 | 
         
            -
                ):
         
     | 
| 1267 | 
         
            -
                    super().__init__()
         
     | 
| 1268 | 
         
            -
                    inner_dim = int(dim * mult)
         
     | 
| 1269 | 
         
            -
                    dim_out = dim_out if dim_out is not None else dim
         
     | 
| 1270 | 
         
            -
                    linear_cls = LoRACompatibleLinear if not USE_PEFT_BACKEND else nn.Linear
         
     | 
| 1271 | 
         
            -
             
     | 
| 1272 | 
         
            -
                    if activation_fn == "gelu":
         
     | 
| 1273 | 
         
            -
                        act_fn = GELU(dim, inner_dim)
         
     | 
| 1274 | 
         
            -
                    if activation_fn == "gelu-approximate":
         
     | 
| 1275 | 
         
            -
                        act_fn = GELU(dim, inner_dim, approximate="tanh")
         
     | 
| 1276 | 
         
            -
                    elif activation_fn == "geglu":
         
     | 
| 1277 | 
         
            -
                        act_fn = GEGLU(dim, inner_dim)
         
     | 
| 1278 | 
         
            -
                    elif activation_fn == "geglu-approximate":
         
     | 
| 1279 | 
         
            -
                        act_fn = ApproximateGELU(dim, inner_dim)
         
     | 
| 1280 | 
         
            -
             
     | 
| 1281 | 
         
            -
                    self.net = nn.ModuleList([])
         
     | 
| 1282 | 
         
            -
                    # project in
         
     | 
| 1283 | 
         
            -
                    self.net.append(act_fn)
         
     | 
| 1284 | 
         
            -
                    # project dropout
         
     | 
| 1285 | 
         
            -
                    self.net.append(nn.Dropout(dropout))
         
     | 
| 1286 | 
         
            -
                    # project out
         
     | 
| 1287 | 
         
            -
                    self.net.append(linear_cls(inner_dim, dim_out))
         
     | 
| 1288 | 
         
            -
                    # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout
         
     | 
| 1289 | 
         
            -
                    if final_dropout:
         
     | 
| 1290 | 
         
            -
                        self.net.append(nn.Dropout(dropout))
         
     | 
| 1291 | 
         
            -
             
     | 
| 1292 | 
         
            -
                def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor:
         
     | 
| 1293 | 
         
            -
                    compatible_cls = (GEGLU,) if USE_PEFT_BACKEND else (GEGLU, LoRACompatibleLinear)
         
     | 
| 1294 | 
         
            -
                    for module in self.net:
         
     | 
| 1295 | 
         
            -
                        if isinstance(module, compatible_cls):
         
     | 
| 1296 | 
         
            -
                            hidden_states = module(hidden_states, scale)
         
     | 
| 1297 | 
         
            -
                        else:
         
     | 
| 1298 | 
         
            -
                            hidden_states = module(hidden_states)
         
     | 
| 1299 | 
         
            -
                    return hidden_states
         
     | 
| 
         | 
|
| 11 | 
         
             
            # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
         
     | 
| 12 | 
         
             
            # See the License for the specific language governing permissions and
         
     | 
| 13 | 
         
             
            # limitations under the License.
         
     | 
| 
         | 
|
| 14 | 
         
             
            from typing import Any, Dict, Optional
         
     | 
| 15 | 
         | 
| 16 | 
         
            +
            import diffusers
         
     | 
| 17 | 
         
            +
            import pkg_resources
         
     | 
| 18 | 
         
             
            import torch
         
     | 
| 19 | 
         
             
            import torch.nn.functional as F
         
     | 
| 20 | 
         
             
            import torch.nn.init as init
         
     | 
| 21 | 
         
            +
             
     | 
| 22 | 
         
            +
            installed_version = diffusers.__version__
         
     | 
| 23 | 
         
            +
             
     | 
| 24 | 
         
            +
            if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 25 | 
         
            +
                from diffusers.models.attention_processor import (Attention,
         
     | 
| 26 | 
         
            +
                                                                  AttnProcessor2_0,
         
     | 
| 27 | 
         
            +
                                                                  HunyuanAttnProcessor2_0)
         
     | 
| 28 | 
         
            +
            else:
         
     | 
| 29 | 
         
            +
                from diffusers.models.attention_processor import Attention, AttnProcessor2_0
         
     | 
| 30 | 
         
            +
             
     | 
| 31 | 
         
             
            from diffusers.models.attention import AdaLayerNorm, FeedForward
         
     | 
| 
         | 
|
| 32 | 
         
             
            from diffusers.models.embeddings import SinusoidalPositionalEmbedding
         
     | 
| 
         | 
|
| 33 | 
         
             
            from diffusers.models.normalization import AdaLayerNorm, AdaLayerNormZero
         
     | 
| 34 | 
         
             
            from diffusers.utils import USE_PEFT_BACKEND
         
     | 
| 35 | 
         
             
            from diffusers.utils.import_utils import is_xformers_available
         
     | 
| 
         | 
|
| 37 | 
         
             
            from einops import rearrange, repeat
         
     | 
| 38 | 
         
             
            from torch import nn
         
     | 
| 39 | 
         | 
| 40 | 
         
            +
            from .motion_module import PositionalEncoding, get_motion_module
         
     | 
| 41 | 
         
            +
            from .norm import FP32LayerNorm
         
     | 
| 42 | 
         | 
| 43 | 
         
             
            if is_xformers_available():
         
     | 
| 44 | 
         
             
                import xformers
         
     | 
| 
         | 
|
| 47 | 
         
             
                xformers = None
         
     | 
| 48 | 
         | 
| 49 | 
         | 
| 50 | 
         
            +
            def zero_module(module):
         
     | 
| 51 | 
         
            +
                # Zero out the parameters of a module and return it.
         
     | 
| 52 | 
         
            +
                for p in module.parameters():
         
     | 
| 53 | 
         
            +
                    p.detach().zero_()
         
     | 
| 54 | 
         
            +
                return module
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
             
     | 
| 57 | 
         
             
            @maybe_allow_in_graph
         
     | 
| 58 | 
         
             
            class GatedSelfAttentionDense(nn.Module):
         
     | 
| 59 | 
         
             
                r"""
         
     | 
| 
         | 
|
| 75 | 
         
             
                    self.attn = Attention(query_dim=query_dim, heads=n_heads, dim_head=d_head)
         
     | 
| 76 | 
         
             
                    self.ff = FeedForward(query_dim, activation_fn="geglu")
         
     | 
| 77 | 
         | 
| 78 | 
         
            +
                    self.norm1 = FP32LayerNorm(query_dim)
         
     | 
| 79 | 
         
            +
                    self.norm2 = FP32LayerNorm(query_dim)
         
     | 
| 80 | 
         | 
| 81 | 
         
             
                    self.register_parameter("alpha_attn", nn.Parameter(torch.tensor(0.0)))
         
     | 
| 82 | 
         
             
                    self.register_parameter("alpha_dense", nn.Parameter(torch.tensor(0.0)))
         
     | 
| 
         | 
|
| 96 | 
         
             
                    return x
         
     | 
| 97 | 
         | 
| 98 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 99 | 
         
             
            class KVCompressionCrossAttention(nn.Module):
         
     | 
| 100 | 
         
             
                r"""
         
     | 
| 101 | 
         
             
                A cross attention layer.
         
     | 
| 
         | 
|
| 162 | 
         
             
                        stride=2,
         
     | 
| 163 | 
         
             
                        bias=True
         
     | 
| 164 | 
         
             
                    )
         
     | 
| 165 | 
         
            +
                    self.kv_compression_norm = FP32LayerNorm(query_dim)
         
     | 
| 166 | 
         
             
                    init.constant_(self.kv_compression.weight, 1 / 4)
         
     | 
| 167 | 
         
             
                    if self.kv_compression.bias is not None:
         
     | 
| 168 | 
         
             
                        init.constant_(self.kv_compression.bias, 0)
         
     | 
| 
         | 
|
| 418 | 
         
             
                    # motion module kwargs
         
     | 
| 419 | 
         
             
                    motion_module_type = "VanillaGrid",
         
     | 
| 420 | 
         
             
                    motion_module_kwargs = None,
         
     | 
| 421 | 
         
            +
                    qk_norm = False,
         
     | 
| 422 | 
         
            +
                    after_norm = False,
         
     | 
| 423 | 
         
             
                ):
         
     | 
| 424 | 
         
             
                    super().__init__()
         
     | 
| 425 | 
         
             
                    self.only_cross_attention = only_cross_attention
         
     | 
| 
         | 
|
| 452 | 
         
             
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 453 | 
         
             
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 454 | 
         
             
                    else:
         
     | 
| 455 | 
         
            +
                        self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 456 | 
         | 
| 457 | 
         
             
                    self.kvcompression = kvcompression
         
     | 
| 458 | 
         
             
                    if kvcompression:
         
     | 
| 
         | 
|
| 466 | 
         
             
                            upcast_attention=upcast_attention,
         
     | 
| 467 | 
         
             
                        )
         
     | 
| 468 | 
         
             
                    else:
         
     | 
| 469 | 
         
            +
                        if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 470 | 
         
            +
                            self.attn1 = Attention(
         
     | 
| 471 | 
         
            +
                                query_dim=dim,
         
     | 
| 472 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 473 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 474 | 
         
            +
                                dropout=dropout,
         
     | 
| 475 | 
         
            +
                                bias=attention_bias,
         
     | 
| 476 | 
         
            +
                                cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 477 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 478 | 
         
            +
                                qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 479 | 
         
            +
                                processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 480 | 
         
            +
                            )
         
     | 
| 481 | 
         
            +
                        else:
         
     | 
| 482 | 
         
            +
                            self.attn1 = Attention(
         
     | 
| 483 | 
         
            +
                                query_dim=dim,
         
     | 
| 484 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 485 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 486 | 
         
            +
                                dropout=dropout,
         
     | 
| 487 | 
         
            +
                                bias=attention_bias,
         
     | 
| 488 | 
         
            +
                                cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 489 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 490 | 
         
            +
                            )
         
     | 
| 491 | 
         | 
| 492 | 
         
             
                    self.attn_temporal = get_motion_module(
         
     | 
| 493 | 
         
             
                        in_channels = dim,
         
     | 
| 
         | 
|
| 503 | 
         
             
                        self.norm2 = (
         
     | 
| 504 | 
         
             
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 505 | 
         
             
                            if self.use_ada_layer_norm
         
     | 
| 506 | 
         
            +
                            else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 507 | 
         
             
                        )
         
     | 
| 508 | 
         
            +
                        if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 509 | 
         
            +
                            self.attn2 = Attention(
         
     | 
| 510 | 
         
            +
                                query_dim=dim,
         
     | 
| 511 | 
         
            +
                                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 512 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 513 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 514 | 
         
            +
                                dropout=dropout,
         
     | 
| 515 | 
         
            +
                                bias=attention_bias,
         
     | 
| 516 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 517 | 
         
            +
                                qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 518 | 
         
            +
                                processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 519 | 
         
            +
                            )  # is self-attn if encoder_hidden_states is none
         
     | 
| 520 | 
         
            +
                        else:
         
     | 
| 521 | 
         
            +
                            self.attn2 = Attention(
         
     | 
| 522 | 
         
            +
                                query_dim=dim,
         
     | 
| 523 | 
         
            +
                                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 524 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 525 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 526 | 
         
            +
                                dropout=dropout,
         
     | 
| 527 | 
         
            +
                                bias=attention_bias,
         
     | 
| 528 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 529 | 
         
            +
                            )  # is self-attn if encoder_hidden_states is none
         
     | 
| 530 | 
         
             
                    else:
         
     | 
| 531 | 
         
             
                        self.norm2 = None
         
     | 
| 532 | 
         
             
                        self.attn2 = None
         
     | 
| 533 | 
         | 
| 534 | 
         
             
                    # 3. Feed-forward
         
     | 
| 535 | 
         
             
                    if not self.use_ada_layer_norm_single:
         
     | 
| 536 | 
         
            +
                        self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 537 | 
         | 
| 538 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 539 | 
         | 
| 540 | 
         
            +
                    if after_norm:
         
     | 
| 541 | 
         
            +
                        self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 542 | 
         
            +
                    else:
         
     | 
| 543 | 
         
            +
                        self.norm4 = None
         
     | 
| 544 | 
         
            +
             
     | 
| 545 | 
         
             
                    # 4. Fuser
         
     | 
| 546 | 
         
             
                    if attention_type == "gated" or attention_type == "gated-text-image":
         
     | 
| 547 | 
         
             
                        self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
         
     | 
| 
         | 
|
| 694 | 
         
             
                        )
         
     | 
| 695 | 
         
             
                    else:
         
     | 
| 696 | 
         
             
                        ff_output = self.ff(norm_hidden_states, scale=lora_scale)
         
     | 
| 697 | 
         
            +
                    
         
     | 
| 698 | 
         
            +
                    if self.norm4 is not None:
         
     | 
| 699 | 
         
            +
                        ff_output = self.norm4(ff_output)
         
     | 
| 700 | 
         | 
| 701 | 
         
             
                    if self.use_ada_layer_norm_zero:
         
     | 
| 702 | 
         
             
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 
         | 
|
| 766 | 
         
             
                    attention_type: str = "default",
         
     | 
| 767 | 
         
             
                    positional_embeddings: Optional[str] = None,
         
     | 
| 768 | 
         
             
                    num_positional_embeddings: Optional[int] = None,
         
     | 
| 769 | 
         
            +
                    qk_norm = False,
         
     | 
| 770 | 
         
            +
                    after_norm = False,
         
     | 
| 771 | 
         
             
                ):
         
     | 
| 772 | 
         
             
                    super().__init__()
         
     | 
| 773 | 
         
             
                    self.only_cross_attention = only_cross_attention
         
     | 
| 
         | 
|
| 800 | 
         
             
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 801 | 
         
             
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 802 | 
         
             
                    else:
         
     | 
| 803 | 
         
            +
                        self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 804 | 
         | 
| 805 | 
         
            +
                    if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 806 | 
         
            +
                        self.attn1 = Attention(
         
     | 
| 807 | 
         
            +
                            query_dim=dim,
         
     | 
| 808 | 
         
            +
                            heads=num_attention_heads,
         
     | 
| 809 | 
         
            +
                            dim_head=attention_head_dim,
         
     | 
| 810 | 
         
            +
                            dropout=dropout,
         
     | 
| 811 | 
         
            +
                            bias=attention_bias,
         
     | 
| 812 | 
         
            +
                            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 813 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 814 | 
         
            +
                            qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 815 | 
         
            +
                            processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 816 | 
         
            +
                        )
         
     | 
| 817 | 
         
            +
                    else:
         
     | 
| 818 | 
         
            +
                        self.attn1 = Attention(
         
     | 
| 819 | 
         
            +
                            query_dim=dim,
         
     | 
| 820 | 
         
            +
                            heads=num_attention_heads,
         
     | 
| 821 | 
         
            +
                            dim_head=attention_head_dim,
         
     | 
| 822 | 
         
            +
                            dropout=dropout,
         
     | 
| 823 | 
         
            +
                            bias=attention_bias,
         
     | 
| 824 | 
         
            +
                            cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 825 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 826 | 
         
            +
                        )
         
     | 
| 827 | 
         | 
| 828 | 
         
             
                    # 2. Cross-Attn
         
     | 
| 829 | 
         
             
                    if cross_attention_dim is not None or double_self_attention:
         
     | 
| 
         | 
|
| 833 | 
         
             
                        self.norm2 = (
         
     | 
| 834 | 
         
             
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 835 | 
         
             
                            if self.use_ada_layer_norm
         
     | 
| 836 | 
         
            +
                            else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 837 | 
         
             
                        )
         
     | 
| 838 | 
         
            +
                        if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 839 | 
         
            +
                            self.attn2 = Attention(
         
     | 
| 840 | 
         
            +
                                query_dim=dim,
         
     | 
| 841 | 
         
            +
                                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 842 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 843 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 844 | 
         
            +
                                dropout=dropout,
         
     | 
| 845 | 
         
            +
                                bias=attention_bias,
         
     | 
| 846 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 847 | 
         
            +
                                qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 848 | 
         
            +
                                processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 849 | 
         
            +
                            )  # is self-attn if encoder_hidden_states is none
         
     | 
| 850 | 
         
            +
                        else:
         
     | 
| 851 | 
         
            +
                            self.attn2 = Attention(
         
     | 
| 852 | 
         
            +
                                query_dim=dim,
         
     | 
| 853 | 
         
            +
                                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 854 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 855 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 856 | 
         
            +
                                dropout=dropout,
         
     | 
| 857 | 
         
            +
                                bias=attention_bias,
         
     | 
| 858 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 859 | 
         
            +
                            )  # is self-attn if encoder_hidden_states is none
         
     | 
| 860 | 
         
             
                    else:
         
     | 
| 861 | 
         
             
                        self.norm2 = None
         
     | 
| 862 | 
         
             
                        self.attn2 = None
         
     | 
| 863 | 
         | 
| 864 | 
         
             
                    # 3. Feed-forward
         
     | 
| 865 | 
         
             
                    if not self.use_ada_layer_norm_single:
         
     | 
| 866 | 
         
            +
                        self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 867 | 
         | 
| 868 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 869 | 
         | 
| 870 | 
         
            +
                    if after_norm:
         
     | 
| 871 | 
         
            +
                        self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 872 | 
         
            +
                    else:
         
     | 
| 873 | 
         
            +
                        self.norm4 = None
         
     | 
| 874 | 
         
            +
             
     | 
| 875 | 
         
             
                    # 4. Fuser
         
     | 
| 876 | 
         
             
                    if attention_type == "gated" or attention_type == "gated-text-image":
         
     | 
| 877 | 
         
             
                        self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
         
     | 
| 
         | 
|
| 1003 | 
         
             
                        )
         
     | 
| 1004 | 
         
             
                    else:
         
     | 
| 1005 | 
         
             
                        ff_output = self.ff(norm_hidden_states, scale=lora_scale)
         
     | 
| 1006 | 
         
            +
                    
         
     | 
| 1007 | 
         
            +
                    if self.norm4 is not None:
         
     | 
| 1008 | 
         
            +
                        ff_output = self.norm4(ff_output)
         
     | 
| 1009 | 
         | 
| 1010 | 
         
             
                    if self.use_ada_layer_norm_zero:
         
     | 
| 1011 | 
         
             
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 
         | 
|
| 1076 | 
         
             
                    positional_embeddings: Optional[str] = None,
         
     | 
| 1077 | 
         
             
                    num_positional_embeddings: Optional[int] = None,
         
     | 
| 1078 | 
         
             
                    kvcompression: Optional[bool] = False,
         
     | 
| 1079 | 
         
            +
                    qk_norm = False,
         
     | 
| 1080 | 
         
            +
                    after_norm = False,
         
     | 
| 1081 | 
         
             
                ):
         
     | 
| 1082 | 
         
             
                    super().__init__()
         
     | 
| 1083 | 
         
             
                    self.only_cross_attention = only_cross_attention
         
     | 
| 
         | 
|
| 1110 | 
         
             
                    elif self.use_ada_layer_norm_zero:
         
     | 
| 1111 | 
         
             
                        self.norm1 = AdaLayerNormZero(dim, num_embeds_ada_norm)
         
     | 
| 1112 | 
         
             
                    else:
         
     | 
| 1113 | 
         
            +
                        self.norm1 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 1114 | 
         | 
| 1115 | 
         
             
                    self.kvcompression = kvcompression
         
     | 
| 1116 | 
         
             
                    if kvcompression:
         
     | 
| 
         | 
|
| 1124 | 
         
             
                            upcast_attention=upcast_attention,
         
     | 
| 1125 | 
         
             
                        )
         
     | 
| 1126 | 
         
             
                    else:
         
     | 
| 1127 | 
         
            +
                        if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 1128 | 
         
            +
                            self.attn1 = Attention(
         
     | 
| 1129 | 
         
            +
                                query_dim=dim,
         
     | 
| 1130 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 1131 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 1132 | 
         
            +
                                dropout=dropout,
         
     | 
| 1133 | 
         
            +
                                bias=attention_bias,
         
     | 
| 1134 | 
         
            +
                                cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 1135 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 1136 | 
         
            +
                                qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 1137 | 
         
            +
                                processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 1138 | 
         
            +
                            )
         
     | 
| 1139 | 
         
            +
                        else:
         
     | 
| 1140 | 
         
            +
                            self.attn1 = Attention(
         
     | 
| 1141 | 
         
            +
                                query_dim=dim,
         
     | 
| 1142 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 1143 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 1144 | 
         
            +
                                dropout=dropout,
         
     | 
| 1145 | 
         
            +
                                bias=attention_bias,
         
     | 
| 1146 | 
         
            +
                                cross_attention_dim=cross_attention_dim if only_cross_attention else None,
         
     | 
| 1147 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 1148 | 
         
            +
                            )
         
     | 
| 1149 | 
         | 
| 1150 | 
         
             
                    # 2. Cross-Attn
         
     | 
| 1151 | 
         
             
                    if cross_attention_dim is not None or double_self_attention:
         
     | 
| 
         | 
|
| 1155 | 
         
             
                        self.norm2 = (
         
     | 
| 1156 | 
         
             
                            AdaLayerNorm(dim, num_embeds_ada_norm)
         
     | 
| 1157 | 
         
             
                            if self.use_ada_layer_norm
         
     | 
| 1158 | 
         
            +
                            else FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 1159 | 
         
             
                        )
         
     | 
| 1160 | 
         
            +
                        if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 1161 | 
         
            +
                            self.attn2 = Attention(
         
     | 
| 1162 | 
         
            +
                                query_dim=dim,
         
     | 
| 1163 | 
         
            +
                                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 1164 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 1165 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 1166 | 
         
            +
                                dropout=dropout,
         
     | 
| 1167 | 
         
            +
                                bias=attention_bias,
         
     | 
| 1168 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 1169 | 
         
            +
                                qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 1170 | 
         
            +
                                processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 1171 | 
         
            +
                            )  # is self-attn if encoder_hidden_states is none
         
     | 
| 1172 | 
         
            +
                        else:
         
     | 
| 1173 | 
         
            +
                            self.attn2 = Attention(
         
     | 
| 1174 | 
         
            +
                                query_dim=dim,
         
     | 
| 1175 | 
         
            +
                                cross_attention_dim=cross_attention_dim if not double_self_attention else None,
         
     | 
| 1176 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 1177 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 1178 | 
         
            +
                                dropout=dropout,
         
     | 
| 1179 | 
         
            +
                                bias=attention_bias,
         
     | 
| 1180 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 1181 | 
         
            +
                            )  # is self-attn if encoder_hidden_states is none
         
     | 
| 1182 | 
         
             
                    else:
         
     | 
| 1183 | 
         
             
                        self.norm2 = None
         
     | 
| 1184 | 
         
             
                        self.attn2 = None
         
     | 
| 1185 | 
         | 
| 1186 | 
         
             
                    # 3. Feed-forward
         
     | 
| 1187 | 
         
             
                    if not self.use_ada_layer_norm_single:
         
     | 
| 1188 | 
         
            +
                        self.norm3 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 1189 | 
         | 
| 1190 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
         
     | 
| 1191 | 
         | 
| 1192 | 
         
            +
                    if after_norm:
         
     | 
| 1193 | 
         
            +
                        self.norm4 = FP32LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
         
     | 
| 1194 | 
         
            +
                    else:
         
     | 
| 1195 | 
         
            +
                        self.norm4 = None
         
     | 
| 1196 | 
         
            +
             
     | 
| 1197 | 
         
             
                    # 4. Fuser
         
     | 
| 1198 | 
         
             
                    if attention_type == "gated" or attention_type == "gated-text-image":
         
     | 
| 1199 | 
         
             
                        self.fuser = GatedSelfAttentionDense(dim, cross_attention_dim, num_attention_heads, attention_head_dim)
         
     | 
| 
         | 
|
| 1340 | 
         
             
                        )
         
     | 
| 1341 | 
         
             
                    else:
         
     | 
| 1342 | 
         
             
                        ff_output = self.ff(norm_hidden_states, scale=lora_scale)
         
     | 
| 1343 | 
         
            +
                    
         
     | 
| 1344 | 
         
            +
                    if self.norm4 is not None:
         
     | 
| 1345 | 
         
            +
                        ff_output = self.norm4(ff_output)
         
     | 
| 1346 | 
         | 
| 1347 | 
         
             
                    if self.use_ada_layer_norm_zero:
         
     | 
| 1348 | 
         
             
                        ff_output = gate_mlp.unsqueeze(1) * ff_output
         
     | 
| 
         | 
|
| 1353 | 
         
             
                    if hidden_states.ndim == 4:
         
     | 
| 1354 | 
         
             
                        hidden_states = hidden_states.squeeze(1)
         
     | 
| 1355 | 
         | 
| 1356 | 
         
            +
                    return hidden_states
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
    	
        easyanimate/models/autoencoder_magvit.py
    CHANGED
    
    | 
         @@ -17,7 +17,12 @@ import torch 
     | 
|
| 17 | 
         
             
            import torch.nn as nn
         
     | 
| 18 | 
         
             
            import torch.nn.functional as F
         
     | 
| 19 | 
         
             
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         
     | 
| 20 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 21 | 
         
             
            from diffusers.models.attention_processor import (
         
     | 
| 22 | 
         
             
                ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
         
     | 
| 23 | 
         
             
                AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
         
     | 
| 
         @@ -93,6 +98,7 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): 
     | 
|
| 93 | 
         
             
                    norm_num_groups: int = 32,
         
     | 
| 94 | 
         
             
                    scaling_factor: float = 0.1825,
         
     | 
| 95 | 
         
             
                    slice_compression_vae=False,
         
     | 
| 
         | 
|
| 96 | 
         
             
                    mini_batch_encoder=9,
         
     | 
| 97 | 
         
             
                    mini_batch_decoder=3,
         
     | 
| 98 | 
         
             
                ):
         
     | 
| 
         @@ -145,8 +151,8 @@ class AutoencoderKLMagvit(ModelMixin, ConfigMixin, FromOriginalVAEMixin): 
     | 
|
| 145 | 
         
             
                    self.mini_batch_encoder = mini_batch_encoder
         
     | 
| 146 | 
         
             
                    self.mini_batch_decoder = mini_batch_decoder
         
     | 
| 147 | 
         
             
                    self.use_slicing = False
         
     | 
| 148 | 
         
            -
                    self.use_tiling =  
     | 
| 149 | 
         
            -
                    self.tile_sample_min_size =  
     | 
| 150 | 
         
             
                    self.tile_overlap_factor = 0.25
         
     | 
| 151 | 
         
             
                    self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
         
     | 
| 152 | 
         
             
                    self.scaling_factor = scaling_factor
         
     | 
| 
         | 
|
| 17 | 
         
             
            import torch.nn as nn
         
     | 
| 18 | 
         
             
            import torch.nn.functional as F
         
     | 
| 19 | 
         
             
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         
     | 
| 20 | 
         
            +
             
     | 
| 21 | 
         
            +
            try:
         
     | 
| 22 | 
         
            +
                from diffusers.loaders import FromOriginalVAEMixin
         
     | 
| 23 | 
         
            +
            except:
         
     | 
| 24 | 
         
            +
                from diffusers.loaders import FromOriginalModelMixin as FromOriginalVAEMixin
         
     | 
| 25 | 
         
            +
             
     | 
| 26 | 
         
             
            from diffusers.models.attention_processor import (
         
     | 
| 27 | 
         
             
                ADDED_KV_ATTENTION_PROCESSORS, CROSS_ATTENTION_PROCESSORS, Attention,
         
     | 
| 28 | 
         
             
                AttentionProcessor, AttnAddedKVProcessor, AttnProcessor)
         
     | 
| 
         | 
|
| 98 | 
         
             
                    norm_num_groups: int = 32,
         
     | 
| 99 | 
         
             
                    scaling_factor: float = 0.1825,
         
     | 
| 100 | 
         
             
                    slice_compression_vae=False,
         
     | 
| 101 | 
         
            +
                    use_tiling=False,
         
     | 
| 102 | 
         
             
                    mini_batch_encoder=9,
         
     | 
| 103 | 
         
             
                    mini_batch_decoder=3,
         
     | 
| 104 | 
         
             
                ):
         
     | 
| 
         | 
|
| 151 | 
         
             
                    self.mini_batch_encoder = mini_batch_encoder
         
     | 
| 152 | 
         
             
                    self.mini_batch_decoder = mini_batch_decoder
         
     | 
| 153 | 
         
             
                    self.use_slicing = False
         
     | 
| 154 | 
         
            +
                    self.use_tiling = use_tiling
         
     | 
| 155 | 
         
            +
                    self.tile_sample_min_size = 384
         
     | 
| 156 | 
         
             
                    self.tile_overlap_factor = 0.25
         
     | 
| 157 | 
         
             
                    self.tile_latent_min_size = int(self.tile_sample_min_size / (2 ** (len(ch_mult) - 1)))
         
     | 
| 158 | 
         
             
                    self.scaling_factor = scaling_factor
         
     | 
    	
        easyanimate/models/motion_module.py
    CHANGED
    
    | 
         @@ -1,248 +1,33 @@ 
     | 
|
| 1 | 
         
             
            """Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
         
     | 
| 2 | 
         
             
            """
         
     | 
| 3 | 
         
             
            import math
         
     | 
| 4 | 
         
            -
            from typing import Any, Callable, List, Optional, Tuple, Union
         
     | 
| 5 | 
         | 
| 
         | 
|
| 
         | 
|
| 6 | 
         
             
            import torch
         
     | 
| 7 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            from diffusers.models.attention import FeedForward
         
     | 
| 9 | 
         
             
            from diffusers.utils.import_utils import is_xformers_available
         
     | 
| 10 | 
         
             
            from einops import rearrange, repeat
         
     | 
| 11 | 
         
             
            from torch import nn
         
     | 
| 12 | 
         | 
| 
         | 
|
| 
         | 
|
| 13 | 
         
             
            if is_xformers_available():
         
     | 
| 14 | 
         
             
                import xformers
         
     | 
| 15 | 
         
             
                import xformers.ops
         
     | 
| 16 | 
         
             
            else:
         
     | 
| 17 | 
         
             
                xformers = None
         
     | 
| 18 | 
         | 
| 19 | 
         
            -
            class CrossAttention(nn.Module):
         
     | 
| 20 | 
         
            -
                r"""
         
     | 
| 21 | 
         
            -
                A cross attention layer.
         
     | 
| 22 | 
         
            -
             
     | 
| 23 | 
         
            -
                Parameters:
         
     | 
| 24 | 
         
            -
                    query_dim (`int`): The number of channels in the query.
         
     | 
| 25 | 
         
            -
                    cross_attention_dim (`int`, *optional*):
         
     | 
| 26 | 
         
            -
                        The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
         
     | 
| 27 | 
         
            -
                    heads (`int`,  *optional*, defaults to 8): The number of heads to use for multi-head attention.
         
     | 
| 28 | 
         
            -
                    dim_head (`int`,  *optional*, defaults to 64): The number of channels in each head.
         
     | 
| 29 | 
         
            -
                    dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
         
     | 
| 30 | 
         
            -
                    bias (`bool`, *optional*, defaults to False):
         
     | 
| 31 | 
         
            -
                        Set to `True` for the query, key, and value linear layers to contain a bias parameter.
         
     | 
| 32 | 
         
            -
                """
         
     | 
| 33 | 
         
            -
             
     | 
| 34 | 
         
            -
                def __init__(
         
     | 
| 35 | 
         
            -
                    self,
         
     | 
| 36 | 
         
            -
                    query_dim: int,
         
     | 
| 37 | 
         
            -
                    cross_attention_dim: Optional[int] = None,
         
     | 
| 38 | 
         
            -
                    heads: int = 8,
         
     | 
| 39 | 
         
            -
                    dim_head: int = 64,
         
     | 
| 40 | 
         
            -
                    dropout: float = 0.0,
         
     | 
| 41 | 
         
            -
                    bias=False,
         
     | 
| 42 | 
         
            -
                    upcast_attention: bool = False,
         
     | 
| 43 | 
         
            -
                    upcast_softmax: bool = False,
         
     | 
| 44 | 
         
            -
                    added_kv_proj_dim: Optional[int] = None,
         
     | 
| 45 | 
         
            -
                    norm_num_groups: Optional[int] = None,
         
     | 
| 46 | 
         
            -
                ):
         
     | 
| 47 | 
         
            -
                    super().__init__()
         
     | 
| 48 | 
         
            -
                    inner_dim = dim_head * heads
         
     | 
| 49 | 
         
            -
                    cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
         
     | 
| 50 | 
         
            -
                    self.upcast_attention = upcast_attention
         
     | 
| 51 | 
         
            -
                    self.upcast_softmax = upcast_softmax
         
     | 
| 52 | 
         
            -
             
     | 
| 53 | 
         
            -
                    self.scale = dim_head**-0.5
         
     | 
| 54 | 
         
            -
             
     | 
| 55 | 
         
            -
                    self.heads = heads
         
     | 
| 56 | 
         
            -
                    # for slice_size > 0 the attention score computation
         
     | 
| 57 | 
         
            -
                    # is split across the batch axis to save memory
         
     | 
| 58 | 
         
            -
                    # You can set slice_size with `set_attention_slice`
         
     | 
| 59 | 
         
            -
                    self.sliceable_head_dim = heads
         
     | 
| 60 | 
         
            -
                    self._slice_size = None
         
     | 
| 61 | 
         
            -
                    self._use_memory_efficient_attention_xformers = False
         
     | 
| 62 | 
         
            -
                    self.added_kv_proj_dim = added_kv_proj_dim
         
     | 
| 63 | 
         
            -
             
     | 
| 64 | 
         
            -
                    if norm_num_groups is not None:
         
     | 
| 65 | 
         
            -
                        self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
         
     | 
| 66 | 
         
            -
                    else:
         
     | 
| 67 | 
         
            -
                        self.group_norm = None
         
     | 
| 68 | 
         
            -
             
     | 
| 69 | 
         
            -
                    self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
         
     | 
| 70 | 
         
            -
                    self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
         
     | 
| 71 | 
         
            -
                    self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
         
     | 
| 72 | 
         
            -
             
     | 
| 73 | 
         
            -
                    if self.added_kv_proj_dim is not None:
         
     | 
| 74 | 
         
            -
                        self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
         
     | 
| 75 | 
         
            -
                        self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
         
     | 
| 76 | 
         
            -
             
     | 
| 77 | 
         
            -
                    self.to_out = nn.ModuleList([])
         
     | 
| 78 | 
         
            -
                    self.to_out.append(nn.Linear(inner_dim, query_dim))
         
     | 
| 79 | 
         
            -
                    self.to_out.append(nn.Dropout(dropout))
         
     | 
| 80 | 
         
            -
             
     | 
| 81 | 
         
            -
                def set_use_memory_efficient_attention_xformers(
         
     | 
| 82 | 
         
            -
                    self, valid: bool, attention_op: Optional[Callable] = None
         
     | 
| 83 | 
         
            -
                ) -> None:
         
     | 
| 84 | 
         
            -
                    self._use_memory_efficient_attention_xformers = valid
         
     | 
| 85 | 
         
            -
                    
         
     | 
| 86 | 
         
            -
                def reshape_heads_to_batch_dim(self, tensor):
         
     | 
| 87 | 
         
            -
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 88 | 
         
            -
                    head_size = self.heads
         
     | 
| 89 | 
         
            -
                    tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
         
     | 
| 90 | 
         
            -
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
         
     | 
| 91 | 
         
            -
                    return tensor
         
     | 
| 92 | 
         
            -
             
     | 
| 93 | 
         
            -
                def reshape_batch_dim_to_heads(self, tensor):
         
     | 
| 94 | 
         
            -
                    batch_size, seq_len, dim = tensor.shape
         
     | 
| 95 | 
         
            -
                    head_size = self.heads
         
     | 
| 96 | 
         
            -
                    tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
         
     | 
| 97 | 
         
            -
                    tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
         
     | 
| 98 | 
         
            -
                    return tensor
         
     | 
| 99 | 
         
            -
             
     | 
| 100 | 
         
            -
                def set_attention_slice(self, slice_size):
         
     | 
| 101 | 
         
            -
                    if slice_size is not None and slice_size > self.sliceable_head_dim:
         
     | 
| 102 | 
         
            -
                        raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
         
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
                    self._slice_size = slice_size
         
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
         
     | 
| 107 | 
         
            -
                    batch_size, sequence_length, _ = hidden_states.shape
         
     | 
| 108 | 
         
            -
             
     | 
| 109 | 
         
            -
                    encoder_hidden_states = encoder_hidden_states
         
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
                    if self.group_norm is not None:
         
     | 
| 112 | 
         
            -
                        hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 113 | 
         
            -
             
     | 
| 114 | 
         
            -
                    query = self.to_q(hidden_states)
         
     | 
| 115 | 
         
            -
                    dim = query.shape[-1]
         
     | 
| 116 | 
         
            -
                    query = self.reshape_heads_to_batch_dim(query)
         
     | 
| 117 | 
         
            -
             
     | 
| 118 | 
         
            -
                    if self.added_kv_proj_dim is not None:
         
     | 
| 119 | 
         
            -
                        key = self.to_k(hidden_states)
         
     | 
| 120 | 
         
            -
                        value = self.to_v(hidden_states)
         
     | 
| 121 | 
         
            -
                        encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
         
     | 
| 122 | 
         
            -
                        encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
         
     | 
| 123 | 
         
            -
             
     | 
| 124 | 
         
            -
                        key = self.reshape_heads_to_batch_dim(key)
         
     | 
| 125 | 
         
            -
                        value = self.reshape_heads_to_batch_dim(value)
         
     | 
| 126 | 
         
            -
                        encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
         
     | 
| 127 | 
         
            -
                        encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
         
     | 
| 128 | 
         
            -
             
     | 
| 129 | 
         
            -
                        key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
         
     | 
| 130 | 
         
            -
                        value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
         
     | 
| 131 | 
         
            -
                    else:
         
     | 
| 132 | 
         
            -
                        encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
         
     | 
| 133 | 
         
            -
                        key = self.to_k(encoder_hidden_states)
         
     | 
| 134 | 
         
            -
                        value = self.to_v(encoder_hidden_states)
         
     | 
| 135 | 
         
            -
             
     | 
| 136 | 
         
            -
                        key = self.reshape_heads_to_batch_dim(key)
         
     | 
| 137 | 
         
            -
                        value = self.reshape_heads_to_batch_dim(value)
         
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
                    if attention_mask is not None:
         
     | 
| 140 | 
         
            -
                        if attention_mask.shape[-1] != query.shape[1]:
         
     | 
| 141 | 
         
            -
                            target_length = query.shape[1]
         
     | 
| 142 | 
         
            -
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         
     | 
| 143 | 
         
            -
                            attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
         
     | 
| 144 | 
         
            -
             
     | 
| 145 | 
         
            -
                    # attention, what we cannot get enough of
         
     | 
| 146 | 
         
            -
                    if self._use_memory_efficient_attention_xformers:
         
     | 
| 147 | 
         
            -
                        hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
         
     | 
| 148 | 
         
            -
                        # Some versions of xformers return output in fp32, cast it back to the dtype of the input
         
     | 
| 149 | 
         
            -
                        hidden_states = hidden_states.to(query.dtype)
         
     | 
| 150 | 
         
            -
                    else:
         
     | 
| 151 | 
         
            -
                        if self._slice_size is None or query.shape[0] // self._slice_size == 1:
         
     | 
| 152 | 
         
            -
                            hidden_states = self._attention(query, key, value, attention_mask)
         
     | 
| 153 | 
         
            -
                        else:
         
     | 
| 154 | 
         
            -
                            hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
         
     | 
| 155 | 
         
            -
             
     | 
| 156 | 
         
            -
                    # linear proj
         
     | 
| 157 | 
         
            -
                    hidden_states = self.to_out[0](hidden_states)
         
     | 
| 158 | 
         
            -
             
     | 
| 159 | 
         
            -
                    # dropout
         
     | 
| 160 | 
         
            -
                    hidden_states = self.to_out[1](hidden_states)
         
     | 
| 161 | 
         
            -
                    return hidden_states
         
     | 
| 162 | 
         
            -
             
     | 
| 163 | 
         
            -
                def _attention(self, query, key, value, attention_mask=None):
         
     | 
| 164 | 
         
            -
                    if self.upcast_attention:
         
     | 
| 165 | 
         
            -
                        query = query.float()
         
     | 
| 166 | 
         
            -
                        key = key.float()
         
     | 
| 167 | 
         
            -
             
     | 
| 168 | 
         
            -
                    attention_scores = torch.baddbmm(
         
     | 
| 169 | 
         
            -
                        torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
         
     | 
| 170 | 
         
            -
                        query,
         
     | 
| 171 | 
         
            -
                        key.transpose(-1, -2),
         
     | 
| 172 | 
         
            -
                        beta=0,
         
     | 
| 173 | 
         
            -
                        alpha=self.scale,
         
     | 
| 174 | 
         
            -
                    )
         
     | 
| 175 | 
         
            -
             
     | 
| 176 | 
         
            -
                    if attention_mask is not None:
         
     | 
| 177 | 
         
            -
                        attention_scores = attention_scores + attention_mask
         
     | 
| 178 | 
         
            -
             
     | 
| 179 | 
         
            -
                    if self.upcast_softmax:
         
     | 
| 180 | 
         
            -
                        attention_scores = attention_scores.float()
         
     | 
| 181 | 
         
            -
             
     | 
| 182 | 
         
            -
                    attention_probs = attention_scores.softmax(dim=-1)
         
     | 
| 183 | 
         
            -
             
     | 
| 184 | 
         
            -
                    # cast back to the original dtype
         
     | 
| 185 | 
         
            -
                    attention_probs = attention_probs.to(value.dtype)
         
     | 
| 186 | 
         
            -
             
     | 
| 187 | 
         
            -
                    # compute attention output
         
     | 
| 188 | 
         
            -
                    hidden_states = torch.bmm(attention_probs, value)
         
     | 
| 189 | 
         
            -
             
     | 
| 190 | 
         
            -
                    # reshape hidden_states
         
     | 
| 191 | 
         
            -
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         
     | 
| 192 | 
         
            -
                    return hidden_states
         
     | 
| 193 | 
         
            -
             
     | 
| 194 | 
         
            -
                def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
         
     | 
| 195 | 
         
            -
                    batch_size_attention = query.shape[0]
         
     | 
| 196 | 
         
            -
                    hidden_states = torch.zeros(
         
     | 
| 197 | 
         
            -
                        (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
         
     | 
| 198 | 
         
            -
                    )
         
     | 
| 199 | 
         
            -
                    slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
         
     | 
| 200 | 
         
            -
                    for i in range(hidden_states.shape[0] // slice_size):
         
     | 
| 201 | 
         
            -
                        start_idx = i * slice_size
         
     | 
| 202 | 
         
            -
                        end_idx = (i + 1) * slice_size
         
     | 
| 203 | 
         
            -
             
     | 
| 204 | 
         
            -
                        query_slice = query[start_idx:end_idx]
         
     | 
| 205 | 
         
            -
                        key_slice = key[start_idx:end_idx]
         
     | 
| 206 | 
         
            -
             
     | 
| 207 | 
         
            -
                        if self.upcast_attention:
         
     | 
| 208 | 
         
            -
                            query_slice = query_slice.float()
         
     | 
| 209 | 
         
            -
                            key_slice = key_slice.float()
         
     | 
| 210 | 
         
            -
             
     | 
| 211 | 
         
            -
                        attn_slice = torch.baddbmm(
         
     | 
| 212 | 
         
            -
                            torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
         
     | 
| 213 | 
         
            -
                            query_slice,
         
     | 
| 214 | 
         
            -
                            key_slice.transpose(-1, -2),
         
     | 
| 215 | 
         
            -
                            beta=0,
         
     | 
| 216 | 
         
            -
                            alpha=self.scale,
         
     | 
| 217 | 
         
            -
                        )
         
     | 
| 218 | 
         
            -
             
     | 
| 219 | 
         
            -
                        if attention_mask is not None:
         
     | 
| 220 | 
         
            -
                            attn_slice = attn_slice + attention_mask[start_idx:end_idx]
         
     | 
| 221 | 
         
            -
             
     | 
| 222 | 
         
            -
                        if self.upcast_softmax:
         
     | 
| 223 | 
         
            -
                            attn_slice = attn_slice.float()
         
     | 
| 224 | 
         
            -
             
     | 
| 225 | 
         
            -
                        attn_slice = attn_slice.softmax(dim=-1)
         
     | 
| 226 | 
         
            -
             
     | 
| 227 | 
         
            -
                        # cast back to the original dtype
         
     | 
| 228 | 
         
            -
                        attn_slice = attn_slice.to(value.dtype)
         
     | 
| 229 | 
         
            -
                        attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
         
     | 
| 230 | 
         
            -
             
     | 
| 231 | 
         
            -
                        hidden_states[start_idx:end_idx] = attn_slice
         
     | 
| 232 | 
         
            -
             
     | 
| 233 | 
         
            -
                    # reshape hidden_states
         
     | 
| 234 | 
         
            -
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         
     | 
| 235 | 
         
            -
                    return hidden_states
         
     | 
| 236 | 
         
            -
             
     | 
| 237 | 
         
            -
                def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
         
     | 
| 238 | 
         
            -
                    # TODO attention_mask
         
     | 
| 239 | 
         
            -
                    query = query.contiguous()
         
     | 
| 240 | 
         
            -
                    key = key.contiguous()
         
     | 
| 241 | 
         
            -
                    value = value.contiguous()
         
     | 
| 242 | 
         
            -
                    hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
         
     | 
| 243 | 
         
            -
                    hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
         
     | 
| 244 | 
         
            -
                    return hidden_states
         
     | 
| 245 | 
         
            -
             
     | 
| 246 | 
         
             
            def zero_module(module):
         
     | 
| 247 | 
         
             
                # Zero out the parameters of a module and return it.
         
     | 
| 248 | 
         
             
                for p in module.parameters():
         
     | 
| 
         @@ -275,6 +60,11 @@ class VanillaTemporalModule(nn.Module): 
     | 
|
| 275 | 
         
             
                    zero_initialize                    = True,
         
     | 
| 276 | 
         
             
                    block_size                         = 1,
         
     | 
| 277 | 
         
             
                    grid                               = False,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 278 | 
         
             
                ):
         
     | 
| 279 | 
         
             
                    super().__init__()
         
     | 
| 280 | 
         | 
| 
         @@ -289,17 +79,87 @@ class VanillaTemporalModule(nn.Module): 
     | 
|
| 289 | 
         
             
                        temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 290 | 
         
             
                        grid=grid,
         
     | 
| 291 | 
         
             
                        block_size=block_size,
         
     | 
| 
         | 
|
| 
         | 
|
| 292 | 
         
             
                    )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 293 | 
         
             
                    if zero_initialize:
         
     | 
| 294 | 
         
             
                        self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
         
     | 
| 
         | 
|
| 
         | 
|
| 295 | 
         | 
| 296 | 
         
             
                def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
         
     | 
| 297 | 
         
             
                    hidden_states = input_tensor
         
     | 
| 298 | 
         
             
                    hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
         
     | 
| 
         | 
|
| 
         | 
|
| 299 | 
         | 
| 300 | 
         
             
                    output = hidden_states
         
     | 
| 301 | 
         
             
                    return output
         
     | 
| 302 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 303 | 
         
             
            class TemporalTransformer3DModel(nn.Module):
         
     | 
| 304 | 
         
             
                def __init__(
         
     | 
| 305 | 
         
             
                    self,
         
     | 
| 
         @@ -321,6 +181,8 @@ class TemporalTransformer3DModel(nn.Module): 
     | 
|
| 321 | 
         
             
                    temporal_position_encoding_max_len = 4096,
         
     | 
| 322 | 
         
             
                    grid                               = False,
         
     | 
| 323 | 
         
             
                    block_size                         = 1,
         
     | 
| 
         | 
|
| 
         | 
|
| 324 | 
         
             
                ):
         
     | 
| 325 | 
         
             
                    super().__init__()
         
     | 
| 326 | 
         | 
| 
         @@ -348,6 +210,8 @@ class TemporalTransformer3DModel(nn.Module): 
     | 
|
| 348 | 
         
             
                                temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 349 | 
         
             
                                block_size=block_size,
         
     | 
| 350 | 
         
             
                                grid=grid,
         
     | 
| 
         | 
|
| 
         | 
|
| 351 | 
         
             
                            )
         
     | 
| 352 | 
         
             
                            for d in range(num_layers)
         
     | 
| 353 | 
         
             
                        ]
         
     | 
| 
         @@ -398,6 +262,8 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 398 | 
         
             
                    temporal_position_encoding_max_len = 4096,
         
     | 
| 399 | 
         
             
                    block_size                         = 1,
         
     | 
| 400 | 
         
             
                    grid                               = False,
         
     | 
| 
         | 
|
| 
         | 
|
| 401 | 
         
             
                ):
         
     | 
| 402 | 
         
             
                    super().__init__()
         
     | 
| 403 | 
         | 
| 
         @@ -422,15 +288,36 @@ class TemporalTransformerBlock(nn.Module): 
     | 
|
| 422 | 
         
             
                                temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 423 | 
         
             
                                block_size=block_size,
         
     | 
| 424 | 
         
             
                                grid=grid,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 425 | 
         
             
                            )
         
     | 
| 426 | 
         
             
                        )
         
     | 
| 427 | 
         
            -
                        norms.append( 
     | 
| 428 | 
         | 
| 429 | 
         
             
                    self.attention_blocks = nn.ModuleList(attention_blocks)
         
     | 
| 430 | 
         
             
                    self.norms = nn.ModuleList(norms)
         
     | 
| 431 | 
         | 
| 432 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
         
     | 
| 433 | 
         
            -
                    self.ff_norm =  
     | 
| 434 | 
         | 
| 435 | 
         
             
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
         
     | 
| 436 | 
         
             
                    for attention_block, norm in zip(self.attention_blocks, self.norms):
         
     | 
| 
         @@ -468,7 +355,7 @@ class PositionalEncoding(nn.Module): 
     | 
|
| 468 | 
         
             
                    x = x + self.pe[:, :x.size(1)]
         
     | 
| 469 | 
         
             
                    return self.dropout(x)
         
     | 
| 470 | 
         | 
| 471 | 
         
            -
            class VersatileAttention( 
     | 
| 472 | 
         
             
                def __init__(
         
     | 
| 473 | 
         
             
                        self,
         
     | 
| 474 | 
         
             
                        attention_mode                     = None,
         
     | 
| 
         @@ -477,21 +364,23 @@ class VersatileAttention(CrossAttention): 
     | 
|
| 477 | 
         
             
                        temporal_position_encoding_max_len = 4096,  
         
     | 
| 478 | 
         
             
                        grid                               = False,
         
     | 
| 479 | 
         
             
                        block_size                         = 1,
         
     | 
| 
         | 
|
| 480 | 
         
             
                        *args, **kwargs
         
     | 
| 481 | 
         
             
                    ):
         
     | 
| 482 | 
         
             
                    super().__init__(*args, **kwargs)
         
     | 
| 483 | 
         
            -
                    assert attention_mode == "Temporal"
         
     | 
| 484 | 
         | 
| 485 | 
         
             
                    self.attention_mode = attention_mode
         
     | 
| 486 | 
         
             
                    self.is_cross_attention = kwargs["cross_attention_dim"] is not None
         
     | 
| 487 | 
         | 
| 488 | 
         
             
                    self.block_size = block_size
         
     | 
| 489 | 
         
             
                    self.grid = grid
         
     | 
| 
         | 
|
| 490 | 
         
             
                    self.pos_encoder = PositionalEncoding(
         
     | 
| 491 | 
         
             
                        kwargs["query_dim"],
         
     | 
| 492 | 
         
             
                        dropout=0., 
         
     | 
| 493 | 
         
             
                        max_len=temporal_position_encoding_max_len
         
     | 
| 494 | 
         
            -
                    ) if (temporal_position_encoding and attention_mode == "Temporal") else None
         
     | 
| 495 | 
         | 
| 496 | 
         
             
                def extra_repr(self):
         
     | 
| 497 | 
         
             
                    return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
         
     | 
| 
         @@ -503,8 +392,13 @@ class VersatileAttention(CrossAttention): 
     | 
|
| 503 | 
         
             
                        # for add pos_encoder 
         
     | 
| 504 | 
         
             
                        _, before_d, _c = hidden_states.size()
         
     | 
| 505 | 
         
             
                        hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
         
     | 
| 506 | 
         
            -
                         
     | 
| 507 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 508 | 
         | 
| 509 | 
         
             
                        if self.grid:
         
     | 
| 510 | 
         
             
                            hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
         
     | 
| 
         @@ -515,61 +409,36 @@ class VersatileAttention(CrossAttention): 
     | 
|
| 515 | 
         
             
                        else:
         
     | 
| 516 | 
         
             
                            d = before_d    
         
     | 
| 517 | 
         
             
                        encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 518 | 
         
             
                    else:
         
     | 
| 519 | 
         
             
                        raise NotImplementedError
         
     | 
| 520 | 
         | 
| 521 | 
         
            -
                    encoder_hidden_states = encoder_hidden_states
         
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
                    if self.group_norm is not None:
         
     | 
| 524 | 
         
            -
                        hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
         
     | 
| 525 | 
         
            -
             
     | 
| 526 | 
         
            -
                    query = self.to_q(hidden_states)
         
     | 
| 527 | 
         
            -
                    dim = query.shape[-1]
         
     | 
| 528 | 
         
            -
                    query = self.reshape_heads_to_batch_dim(query)
         
     | 
| 529 | 
         
            -
             
     | 
| 530 | 
         
            -
                    if self.added_kv_proj_dim is not None:
         
     | 
| 531 | 
         
            -
                        raise NotImplementedError
         
     | 
| 532 | 
         
            -
             
     | 
| 533 | 
         
             
                    encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
         
     | 
| 534 | 
         
            -
                    key = self.to_k(encoder_hidden_states)
         
     | 
| 535 | 
         
            -
                    value = self.to_v(encoder_hidden_states)
         
     | 
| 536 | 
         
            -
             
     | 
| 537 | 
         
            -
                    key = self.reshape_heads_to_batch_dim(key)
         
     | 
| 538 | 
         
            -
                    value = self.reshape_heads_to_batch_dim(value)
         
     | 
| 539 | 
         
            -
             
     | 
| 540 | 
         
            -
                    if attention_mask is not None:
         
     | 
| 541 | 
         
            -
                        if attention_mask.shape[-1] != query.shape[1]:
         
     | 
| 542 | 
         
            -
                            target_length = query.shape[1]
         
     | 
| 543 | 
         
            -
                            attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
         
     | 
| 544 | 
         
            -
                            attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
         
     | 
| 545 | 
         | 
| 546 | 
         
             
                    bs = 512
         
     | 
| 547 | 
         
             
                    new_hidden_states = []
         
     | 
| 548 | 
         
            -
                    for i in range(0,  
     | 
| 549 | 
         
            -
                         
     | 
| 550 | 
         
            -
             
     | 
| 551 | 
         
            -
                             
     | 
| 552 | 
         
            -
                             
     | 
| 553 | 
         
            -
             
     | 
| 554 | 
         
            -
                         
     | 
| 555 | 
         
            -
                            if self._slice_size is None or query[i : i + bs].shape[0] // self._slice_size == 1:
         
     | 
| 556 | 
         
            -
                                hidden_states = self._attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
         
     | 
| 557 | 
         
            -
                            else:
         
     | 
| 558 | 
         
            -
                                hidden_states = self._sliced_attention(query[i : i + bs], key[i : i + bs], value[i : i + bs], sequence_length, dim, attention_mask[i : i + bs] if attention_mask is not None else attention_mask)
         
     | 
| 559 | 
         
            -
                        new_hidden_states.append(hidden_states)
         
     | 
| 560 | 
         
             
                    hidden_states = torch.cat(new_hidden_states, dim = 0)
         
     | 
| 561 | 
         | 
| 562 | 
         
            -
                    # linear proj
         
     | 
| 563 | 
         
            -
                    hidden_states = self.to_out[0](hidden_states)
         
     | 
| 564 | 
         
            -
             
     | 
| 565 | 
         
            -
                    # dropout
         
     | 
| 566 | 
         
            -
                    hidden_states = self.to_out[1](hidden_states)
         
     | 
| 567 | 
         
            -
             
     | 
| 568 | 
         
             
                    if self.attention_mode == "Temporal":
         
     | 
| 569 | 
         
             
                        hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
         
     | 
| 570 | 
         
             
                        if self.grid:
         
     | 
| 571 | 
         
             
                            hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
         
     | 
| 572 | 
         
             
                            hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
         
     | 
| 573 | 
         
             
                            hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
         
     | 
| 
         | 
|
| 
         | 
|
| 574 | 
         | 
| 575 | 
         
             
                    return hidden_states
         
     | 
| 
         | 
|
| 1 | 
         
             
            """Modified from https://github.com/guoyww/AnimateDiff/blob/main/animatediff/models/motion_module.py
         
     | 
| 2 | 
         
             
            """
         
     | 
| 3 | 
         
             
            import math
         
     | 
| 
         | 
|
| 4 | 
         | 
| 5 | 
         
            +
            import diffusers
         
     | 
| 6 | 
         
            +
            import pkg_resources
         
     | 
| 7 | 
         
             
            import torch
         
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            installed_version = diffusers.__version__
         
     | 
| 10 | 
         
            +
             
     | 
| 11 | 
         
            +
            if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 12 | 
         
            +
                from diffusers.models.attention_processor import (Attention,
         
     | 
| 13 | 
         
            +
                                                                  AttnProcessor2_0,
         
     | 
| 14 | 
         
            +
                                                                  HunyuanAttnProcessor2_0)
         
     | 
| 15 | 
         
            +
            else:
         
     | 
| 16 | 
         
            +
                from diffusers.models.attention_processor import Attention, AttnProcessor2_0
         
     | 
| 17 | 
         
            +
             
     | 
| 18 | 
         
             
            from diffusers.models.attention import FeedForward
         
     | 
| 19 | 
         
             
            from diffusers.utils.import_utils import is_xformers_available
         
     | 
| 20 | 
         
             
            from einops import rearrange, repeat
         
     | 
| 21 | 
         
             
            from torch import nn
         
     | 
| 22 | 
         | 
| 23 | 
         
            +
            from .norm import FP32LayerNorm
         
     | 
| 24 | 
         
            +
             
     | 
| 25 | 
         
             
            if is_xformers_available():
         
     | 
| 26 | 
         
             
                import xformers
         
     | 
| 27 | 
         
             
                import xformers.ops
         
     | 
| 28 | 
         
             
            else:
         
     | 
| 29 | 
         
             
                xformers = None
         
     | 
| 30 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 31 | 
         
             
            def zero_module(module):
         
     | 
| 32 | 
         
             
                # Zero out the parameters of a module and return it.
         
     | 
| 33 | 
         
             
                for p in module.parameters():
         
     | 
| 
         | 
|
| 60 | 
         
             
                    zero_initialize                    = True,
         
     | 
| 61 | 
         
             
                    block_size                         = 1,
         
     | 
| 62 | 
         
             
                    grid                               = False,
         
     | 
| 63 | 
         
            +
                    remove_time_embedding_in_photo     = False,
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    global_num_attention_heads         = 16,
         
     | 
| 66 | 
         
            +
                    global_attention                   = False,
         
     | 
| 67 | 
         
            +
                    qk_norm                            = False,
         
     | 
| 68 | 
         
             
                ):
         
     | 
| 69 | 
         
             
                    super().__init__()
         
     | 
| 70 | 
         | 
| 
         | 
|
| 79 | 
         
             
                        temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 80 | 
         
             
                        grid=grid,
         
     | 
| 81 | 
         
             
                        block_size=block_size,
         
     | 
| 82 | 
         
            +
                        remove_time_embedding_in_photo=remove_time_embedding_in_photo,
         
     | 
| 83 | 
         
            +
                        qk_norm=qk_norm,
         
     | 
| 84 | 
         
             
                    )
         
     | 
| 85 | 
         
            +
                    self.global_transformer = GlobalTransformer3DModel(
         
     | 
| 86 | 
         
            +
                        in_channels=in_channels,
         
     | 
| 87 | 
         
            +
                        num_attention_heads=global_num_attention_heads,
         
     | 
| 88 | 
         
            +
                        attention_head_dim=in_channels // global_num_attention_heads // temporal_attention_dim_div,
         
     | 
| 89 | 
         
            +
                        qk_norm=qk_norm,
         
     | 
| 90 | 
         
            +
                    ) if global_attention else None
         
     | 
| 91 | 
         
             
                    if zero_initialize:
         
     | 
| 92 | 
         
             
                        self.temporal_transformer.proj_out = zero_module(self.temporal_transformer.proj_out)
         
     | 
| 93 | 
         
            +
                        if global_attention:
         
     | 
| 94 | 
         
            +
                            self.global_transformer.proj_out = zero_module(self.global_transformer.proj_out)
         
     | 
| 95 | 
         | 
| 96 | 
         
             
                def forward(self, input_tensor, encoder_hidden_states=None, attention_mask=None, anchor_frame_idx=None):
         
     | 
| 97 | 
         
             
                    hidden_states = input_tensor
         
     | 
| 98 | 
         
             
                    hidden_states = self.temporal_transformer(hidden_states, encoder_hidden_states, attention_mask)
         
     | 
| 99 | 
         
            +
                    if self.global_transformer is not None:
         
     | 
| 100 | 
         
            +
                        hidden_states = self.global_transformer(hidden_states)
         
     | 
| 101 | 
         | 
| 102 | 
         
             
                    output = hidden_states
         
     | 
| 103 | 
         
             
                    return output
         
     | 
| 104 | 
         | 
| 105 | 
         
            +
            class GlobalTransformer3DModel(nn.Module):
         
     | 
| 106 | 
         
            +
                def __init__(
         
     | 
| 107 | 
         
            +
                    self,
         
     | 
| 108 | 
         
            +
                    in_channels,
         
     | 
| 109 | 
         
            +
                    num_attention_heads,
         
     | 
| 110 | 
         
            +
                    attention_head_dim,
         
     | 
| 111 | 
         
            +
                    dropout                            = 0.0,
         
     | 
| 112 | 
         
            +
                    attention_bias                     = False,
         
     | 
| 113 | 
         
            +
                    upcast_attention                   = False,
         
     | 
| 114 | 
         
            +
                    qk_norm                            = False,
         
     | 
| 115 | 
         
            +
                ):
         
     | 
| 116 | 
         
            +
                    super().__init__()
         
     | 
| 117 | 
         
            +
             
     | 
| 118 | 
         
            +
                    inner_dim = num_attention_heads * attention_head_dim
         
     | 
| 119 | 
         
            +
             
     | 
| 120 | 
         
            +
                    self.norm1 = FP32LayerNorm(inner_dim)        
         
     | 
| 121 | 
         
            +
                    self.proj_in = nn.Linear(in_channels, inner_dim)
         
     | 
| 122 | 
         
            +
                    self.norm2 = FP32LayerNorm(inner_dim)       
         
     | 
| 123 | 
         
            +
                    if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2"):
         
     | 
| 124 | 
         
            +
                        self.attention = Attention(
         
     | 
| 125 | 
         
            +
                            query_dim=inner_dim,
         
     | 
| 126 | 
         
            +
                            heads=num_attention_heads,
         
     | 
| 127 | 
         
            +
                            dim_head=attention_head_dim,
         
     | 
| 128 | 
         
            +
                            dropout=dropout,
         
     | 
| 129 | 
         
            +
                            bias=attention_bias,
         
     | 
| 130 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 131 | 
         
            +
                            qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 132 | 
         
            +
                            processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 133 | 
         
            +
                        )
         
     | 
| 134 | 
         
            +
                    else:
         
     | 
| 135 | 
         
            +
                        self.attention = Attention(
         
     | 
| 136 | 
         
            +
                            query_dim=inner_dim,
         
     | 
| 137 | 
         
            +
                            heads=num_attention_heads,
         
     | 
| 138 | 
         
            +
                            dim_head=attention_head_dim,
         
     | 
| 139 | 
         
            +
                            dropout=dropout,
         
     | 
| 140 | 
         
            +
                            bias=attention_bias,
         
     | 
| 141 | 
         
            +
                            upcast_attention=upcast_attention,
         
     | 
| 142 | 
         
            +
                        )
         
     | 
| 143 | 
         
            +
                    self.proj_out = nn.Linear(inner_dim, in_channels)
         
     | 
| 144 | 
         
            +
                
         
     | 
| 145 | 
         
            +
                def forward(self, hidden_states):
         
     | 
| 146 | 
         
            +
                    assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
         
     | 
| 147 | 
         
            +
                    video_length, height, width = hidden_states.shape[2], hidden_states.shape[3], hidden_states.shape[4]
         
     | 
| 148 | 
         
            +
                    hidden_states = rearrange(hidden_states, "b c f h w -> b (f h w) c")
         
     | 
| 149 | 
         
            +
                    
         
     | 
| 150 | 
         
            +
                    residual = hidden_states
         
     | 
| 151 | 
         
            +
                    hidden_states = self.norm1(hidden_states)
         
     | 
| 152 | 
         
            +
                    hidden_states = self.proj_in(hidden_states)
         
     | 
| 153 | 
         
            +
             
     | 
| 154 | 
         
            +
                    # Attention Blocks
         
     | 
| 155 | 
         
            +
                    hidden_states = self.norm2(hidden_states)
         
     | 
| 156 | 
         
            +
                    hidden_states = self.attention(hidden_states)
         
     | 
| 157 | 
         
            +
                    hidden_states = self.proj_out(hidden_states)
         
     | 
| 158 | 
         
            +
             
     | 
| 159 | 
         
            +
                    output = hidden_states + residual
         
     | 
| 160 | 
         
            +
                    output = rearrange(output, "b (f h w) c -> b c f h w", f=video_length, h=height, w=width)
         
     | 
| 161 | 
         
            +
                    return output
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
             
            class TemporalTransformer3DModel(nn.Module):
         
     | 
| 164 | 
         
             
                def __init__(
         
     | 
| 165 | 
         
             
                    self,
         
     | 
| 
         | 
|
| 181 | 
         
             
                    temporal_position_encoding_max_len = 4096,
         
     | 
| 182 | 
         
             
                    grid                               = False,
         
     | 
| 183 | 
         
             
                    block_size                         = 1,
         
     | 
| 184 | 
         
            +
                    remove_time_embedding_in_photo     = False,
         
     | 
| 185 | 
         
            +
                    qk_norm                            = False,
         
     | 
| 186 | 
         
             
                ):
         
     | 
| 187 | 
         
             
                    super().__init__()
         
     | 
| 188 | 
         | 
| 
         | 
|
| 210 | 
         
             
                                temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 211 | 
         
             
                                block_size=block_size,
         
     | 
| 212 | 
         
             
                                grid=grid,
         
     | 
| 213 | 
         
            +
                                remove_time_embedding_in_photo=remove_time_embedding_in_photo,
         
     | 
| 214 | 
         
            +
                                qk_norm=qk_norm
         
     | 
| 215 | 
         
             
                            )
         
     | 
| 216 | 
         
             
                            for d in range(num_layers)
         
     | 
| 217 | 
         
             
                        ]
         
     | 
| 
         | 
|
| 262 | 
         
             
                    temporal_position_encoding_max_len = 4096,
         
     | 
| 263 | 
         
             
                    block_size                         = 1,
         
     | 
| 264 | 
         
             
                    grid                               = False,
         
     | 
| 265 | 
         
            +
                    remove_time_embedding_in_photo     = False,
         
     | 
| 266 | 
         
            +
                    qk_norm                            = False,
         
     | 
| 267 | 
         
             
                ):
         
     | 
| 268 | 
         
             
                    super().__init__()
         
     | 
| 269 | 
         | 
| 
         | 
|
| 288 | 
         
             
                                temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 289 | 
         
             
                                block_size=block_size,
         
     | 
| 290 | 
         
             
                                grid=grid,
         
     | 
| 291 | 
         
            +
                                remove_time_embedding_in_photo=remove_time_embedding_in_photo,
         
     | 
| 292 | 
         
            +
                                qk_norm="layer_norm" if qk_norm else None,
         
     | 
| 293 | 
         
            +
                                processor=HunyuanAttnProcessor2_0() if qk_norm else AttnProcessor2_0(),
         
     | 
| 294 | 
         
            +
                            ) if pkg_resources.parse_version(installed_version) >= pkg_resources.parse_version("0.28.2") else \
         
     | 
| 295 | 
         
            +
                            VersatileAttention(
         
     | 
| 296 | 
         
            +
                                attention_mode=block_name.split("_")[0],
         
     | 
| 297 | 
         
            +
                                cross_attention_dim=cross_attention_dim if block_name.endswith("_Cross") else None,
         
     | 
| 298 | 
         
            +
                                
         
     | 
| 299 | 
         
            +
                                query_dim=dim,
         
     | 
| 300 | 
         
            +
                                heads=num_attention_heads,
         
     | 
| 301 | 
         
            +
                                dim_head=attention_head_dim,
         
     | 
| 302 | 
         
            +
                                dropout=dropout,
         
     | 
| 303 | 
         
            +
                                bias=attention_bias,
         
     | 
| 304 | 
         
            +
                                upcast_attention=upcast_attention,
         
     | 
| 305 | 
         
            +
                    
         
     | 
| 306 | 
         
            +
                                cross_frame_attention_mode=cross_frame_attention_mode,
         
     | 
| 307 | 
         
            +
                                temporal_position_encoding=temporal_position_encoding,
         
     | 
| 308 | 
         
            +
                                temporal_position_encoding_max_len=temporal_position_encoding_max_len,
         
     | 
| 309 | 
         
            +
                                block_size=block_size,
         
     | 
| 310 | 
         
            +
                                grid=grid,
         
     | 
| 311 | 
         
            +
                                remove_time_embedding_in_photo=remove_time_embedding_in_photo,
         
     | 
| 312 | 
         
             
                            )
         
     | 
| 313 | 
         
             
                        )
         
     | 
| 314 | 
         
            +
                        norms.append(FP32LayerNorm(dim))
         
     | 
| 315 | 
         | 
| 316 | 
         
             
                    self.attention_blocks = nn.ModuleList(attention_blocks)
         
     | 
| 317 | 
         
             
                    self.norms = nn.ModuleList(norms)
         
     | 
| 318 | 
         | 
| 319 | 
         
             
                    self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
         
     | 
| 320 | 
         
            +
                    self.ff_norm = FP32LayerNorm(dim)
         
     | 
| 321 | 
         | 
| 322 | 
         
             
                def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None, height=None, weight=None):
         
     | 
| 323 | 
         
             
                    for attention_block, norm in zip(self.attention_blocks, self.norms):
         
     | 
| 
         | 
|
| 355 | 
         
             
                    x = x + self.pe[:, :x.size(1)]
         
     | 
| 356 | 
         
             
                    return self.dropout(x)
         
     | 
| 357 | 
         | 
| 358 | 
         
            +
            class VersatileAttention(Attention):
         
     | 
| 359 | 
         
             
                def __init__(
         
     | 
| 360 | 
         
             
                        self,
         
     | 
| 361 | 
         
             
                        attention_mode                     = None,
         
     | 
| 
         | 
|
| 364 | 
         
             
                        temporal_position_encoding_max_len = 4096,  
         
     | 
| 365 | 
         
             
                        grid                               = False,
         
     | 
| 366 | 
         
             
                        block_size                         = 1,
         
     | 
| 367 | 
         
            +
                        remove_time_embedding_in_photo     = False,
         
     | 
| 368 | 
         
             
                        *args, **kwargs
         
     | 
| 369 | 
         
             
                    ):
         
     | 
| 370 | 
         
             
                    super().__init__(*args, **kwargs)
         
     | 
| 371 | 
         
            +
                    assert attention_mode == "Temporal" or attention_mode == "Global"
         
     | 
| 372 | 
         | 
| 373 | 
         
             
                    self.attention_mode = attention_mode
         
     | 
| 374 | 
         
             
                    self.is_cross_attention = kwargs["cross_attention_dim"] is not None
         
     | 
| 375 | 
         | 
| 376 | 
         
             
                    self.block_size = block_size
         
     | 
| 377 | 
         
             
                    self.grid = grid
         
     | 
| 378 | 
         
            +
                    self.remove_time_embedding_in_photo = remove_time_embedding_in_photo
         
     | 
| 379 | 
         
             
                    self.pos_encoder = PositionalEncoding(
         
     | 
| 380 | 
         
             
                        kwargs["query_dim"],
         
     | 
| 381 | 
         
             
                        dropout=0., 
         
     | 
| 382 | 
         
             
                        max_len=temporal_position_encoding_max_len
         
     | 
| 383 | 
         
            +
                    ) if (temporal_position_encoding and attention_mode == "Temporal") or (temporal_position_encoding and attention_mode == "Global") else None
         
     | 
| 384 | 
         | 
| 385 | 
         
             
                def extra_repr(self):
         
     | 
| 386 | 
         
             
                    return f"(Module Info) Attention_Mode: {self.attention_mode}, Is_Cross_Attention: {self.is_cross_attention}"
         
     | 
| 
         | 
|
| 392 | 
         
             
                        # for add pos_encoder 
         
     | 
| 393 | 
         
             
                        _, before_d, _c = hidden_states.size()
         
     | 
| 394 | 
         
             
                        hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
         
     | 
| 395 | 
         
            +
                        
         
     | 
| 396 | 
         
            +
                        if self.remove_time_embedding_in_photo:
         
     | 
| 397 | 
         
            +
                            if self.pos_encoder is not None and video_length > 1:
         
     | 
| 398 | 
         
            +
                                hidden_states = self.pos_encoder(hidden_states)
         
     | 
| 399 | 
         
            +
                        else:
         
     | 
| 400 | 
         
            +
                            if self.pos_encoder is not None:
         
     | 
| 401 | 
         
            +
                                hidden_states = self.pos_encoder(hidden_states)
         
     | 
| 402 | 
         | 
| 403 | 
         
             
                        if self.grid:
         
     | 
| 404 | 
         
             
                            hidden_states = rearrange(hidden_states, "(b d) f c -> b f d c", f=video_length, d=before_d)
         
     | 
| 
         | 
|
| 409 | 
         
             
                        else:
         
     | 
| 410 | 
         
             
                            d = before_d    
         
     | 
| 411 | 
         
             
                        encoder_hidden_states = repeat(encoder_hidden_states, "b n c -> (b d) n c", d=d) if encoder_hidden_states is not None else encoder_hidden_states
         
     | 
| 412 | 
         
            +
                    elif self.attention_mode == "Global":
         
     | 
| 413 | 
         
            +
                        # for add pos_encoder 
         
     | 
| 414 | 
         
            +
                        _, d, _c = hidden_states.size()
         
     | 
| 415 | 
         
            +
                        hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
         
     | 
| 416 | 
         
            +
                        if self.pos_encoder is not None:
         
     | 
| 417 | 
         
            +
                            hidden_states = self.pos_encoder(hidden_states)
         
     | 
| 418 | 
         
            +
                        hidden_states = rearrange(hidden_states, "(b d) f c -> b (f d) c", f=video_length, d=d)
         
     | 
| 419 | 
         
             
                    else:
         
     | 
| 420 | 
         
             
                        raise NotImplementedError
         
     | 
| 421 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 422 | 
         
             
                    encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 423 | 
         | 
| 424 | 
         
             
                    bs = 512
         
     | 
| 425 | 
         
             
                    new_hidden_states = []
         
     | 
| 426 | 
         
            +
                    for i in range(0, hidden_states.shape[0], bs):
         
     | 
| 427 | 
         
            +
                        __hidden_states = super().forward(
         
     | 
| 428 | 
         
            +
                            hidden_states[i : i + bs],
         
     | 
| 429 | 
         
            +
                            encoder_hidden_states=encoder_hidden_states[i : i + bs],
         
     | 
| 430 | 
         
            +
                            attention_mask=attention_mask
         
     | 
| 431 | 
         
            +
                        )
         
     | 
| 432 | 
         
            +
                        new_hidden_states.append(__hidden_states)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 433 | 
         
             
                    hidden_states = torch.cat(new_hidden_states, dim = 0)
         
     | 
| 434 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 435 | 
         
             
                    if self.attention_mode == "Temporal":
         
     | 
| 436 | 
         
             
                        hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
         
     | 
| 437 | 
         
             
                        if self.grid:
         
     | 
| 438 | 
         
             
                            hidden_states = rearrange(hidden_states, "(b f n m) (h w) c -> (b f) h n w m c", f=video_length, n=self.block_size, m=self.block_size, h=height // self.block_size, w=weight // self.block_size)
         
     | 
| 439 | 
         
             
                            hidden_states = rearrange(hidden_states, "b h n w m c -> b (h n) (w m) c")
         
     | 
| 440 | 
         
             
                            hidden_states = rearrange(hidden_states, "b h w c -> b (h w) c")
         
     | 
| 441 | 
         
            +
                    elif self.attention_mode == "Global":
         
     | 
| 442 | 
         
            +
                        hidden_states = rearrange(hidden_states, "b (f d) c -> (b f) d c", f=video_length, d=d)
         
     | 
| 443 | 
         | 
| 444 | 
         
             
                    return hidden_states
         
     | 
    	
        easyanimate/models/norm.py
    ADDED
    
    | 
         @@ -0,0 +1,97 @@ 
     | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 1 | 
         
            +
            from typing import Any, Dict, Optional, Tuple
         
     | 
| 2 | 
         
            +
             
     | 
| 3 | 
         
            +
            import torch
         
     | 
| 4 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 5 | 
         
            +
            from diffusers.models.embeddings import TimestepEmbedding, Timesteps
         
     | 
| 6 | 
         
            +
            from torch import nn
         
     | 
| 7 | 
         
            +
             
     | 
| 8 | 
         
            +
             
     | 
| 9 | 
         
            +
            def zero_module(module):
         
     | 
| 10 | 
         
            +
                # Zero out the parameters of a module and return it.
         
     | 
| 11 | 
         
            +
                for p in module.parameters():
         
     | 
| 12 | 
         
            +
                    p.detach().zero_()
         
     | 
| 13 | 
         
            +
                return module
         
     | 
| 14 | 
         
            +
             
     | 
| 15 | 
         
            +
             
     | 
| 16 | 
         
            +
            class FP32LayerNorm(nn.LayerNorm):
         
     | 
| 17 | 
         
            +
                def forward(self, inputs: torch.Tensor) -> torch.Tensor:
         
     | 
| 18 | 
         
            +
                    origin_dtype = inputs.dtype
         
     | 
| 19 | 
         
            +
                    if hasattr(self, 'weight') and self.weight is not None:
         
     | 
| 20 | 
         
            +
                        return F.layer_norm(
         
     | 
| 21 | 
         
            +
                            inputs.float(), self.normalized_shape, self.weight.float(), self.bias.float(), self.eps
         
     | 
| 22 | 
         
            +
                        ).to(origin_dtype)
         
     | 
| 23 | 
         
            +
                    else:
         
     | 
| 24 | 
         
            +
                        return F.layer_norm(
         
     | 
| 25 | 
         
            +
                            inputs.float(), self.normalized_shape, None, None, self.eps
         
     | 
| 26 | 
         
            +
                        ).to(origin_dtype)
         
     | 
| 27 | 
         
            +
             
     | 
| 28 | 
         
            +
            class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
         
     | 
| 29 | 
         
            +
                """
         
     | 
| 30 | 
         
            +
                For PixArt-Alpha.
         
     | 
| 31 | 
         
            +
             
     | 
| 32 | 
         
            +
                Reference:
         
     | 
| 33 | 
         
            +
                https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
         
     | 
| 34 | 
         
            +
                """
         
     | 
| 35 | 
         
            +
             
     | 
| 36 | 
         
            +
                def __init__(self, embedding_dim, size_emb_dim, use_additional_conditions: bool = False):
         
     | 
| 37 | 
         
            +
                    super().__init__()
         
     | 
| 38 | 
         
            +
             
     | 
| 39 | 
         
            +
                    self.outdim = size_emb_dim
         
     | 
| 40 | 
         
            +
                    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
         
     | 
| 41 | 
         
            +
                    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
         
     | 
| 42 | 
         
            +
             
     | 
| 43 | 
         
            +
                    self.use_additional_conditions = use_additional_conditions
         
     | 
| 44 | 
         
            +
                    if use_additional_conditions:
         
     | 
| 45 | 
         
            +
                        self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
         
     | 
| 46 | 
         
            +
                        self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
         
     | 
| 47 | 
         
            +
                        self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
         
     | 
| 48 | 
         
            +
                        
         
     | 
| 49 | 
         
            +
                        self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
         
     | 
| 50 | 
         
            +
                        self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
         
     | 
| 51 | 
         
            +
             
     | 
| 52 | 
         
            +
                def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
         
     | 
| 53 | 
         
            +
                    timesteps_proj = self.time_proj(timestep)
         
     | 
| 54 | 
         
            +
                    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
         
     | 
| 55 | 
         
            +
             
     | 
| 56 | 
         
            +
                    if self.use_additional_conditions:
         
     | 
| 57 | 
         
            +
                        resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
         
     | 
| 58 | 
         
            +
                        resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
         
     | 
| 59 | 
         
            +
                        aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
         
     | 
| 60 | 
         
            +
                        aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
         
     | 
| 61 | 
         
            +
                        conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
         
     | 
| 62 | 
         
            +
                    else:
         
     | 
| 63 | 
         
            +
                        conditioning = timesteps_emb
         
     | 
| 64 | 
         
            +
             
     | 
| 65 | 
         
            +
                    return conditioning
         
     | 
| 66 | 
         
            +
             
     | 
| 67 | 
         
            +
            class AdaLayerNormSingle(nn.Module):
         
     | 
| 68 | 
         
            +
                r"""
         
     | 
| 69 | 
         
            +
                Norm layer adaptive layer norm single (adaLN-single).
         
     | 
| 70 | 
         
            +
             
     | 
| 71 | 
         
            +
                As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
                Parameters:
         
     | 
| 74 | 
         
            +
                    embedding_dim (`int`): The size of each embedding vector.
         
     | 
| 75 | 
         
            +
                    use_additional_conditions (`bool`): To use additional conditions for normalization or not.
         
     | 
| 76 | 
         
            +
                """
         
     | 
| 77 | 
         
            +
             
     | 
| 78 | 
         
            +
                def __init__(self, embedding_dim: int, use_additional_conditions: bool = False):
         
     | 
| 79 | 
         
            +
                    super().__init__()
         
     | 
| 80 | 
         
            +
             
     | 
| 81 | 
         
            +
                    self.emb = PixArtAlphaCombinedTimestepSizeEmbeddings(
         
     | 
| 82 | 
         
            +
                        embedding_dim, size_emb_dim=embedding_dim // 3, use_additional_conditions=use_additional_conditions
         
     | 
| 83 | 
         
            +
                    )
         
     | 
| 84 | 
         
            +
             
     | 
| 85 | 
         
            +
                    self.silu = nn.SiLU()
         
     | 
| 86 | 
         
            +
                    self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
         
     | 
| 87 | 
         
            +
             
     | 
| 88 | 
         
            +
                def forward(
         
     | 
| 89 | 
         
            +
                    self,
         
     | 
| 90 | 
         
            +
                    timestep: torch.Tensor,
         
     | 
| 91 | 
         
            +
                    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
         
     | 
| 92 | 
         
            +
                    batch_size: Optional[int] = None,
         
     | 
| 93 | 
         
            +
                    hidden_dtype: Optional[torch.dtype] = None,
         
     | 
| 94 | 
         
            +
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         
     | 
| 95 | 
         
            +
                    # No modulation happening here.
         
     | 
| 96 | 
         
            +
                    embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
         
     | 
| 97 | 
         
            +
                    return self.linear(self.silu(embedded_timestep)), embedded_timestep
         
     | 
    	
        easyanimate/models/patch.py
    CHANGED
    
    | 
         @@ -1,10 +1,10 @@ 
     | 
|
| 
         | 
|
| 1 | 
         
             
            from typing import Optional
         
     | 
| 2 | 
         | 
| 3 | 
         
             
            import numpy as np
         
     | 
| 4 | 
         
             
            import torch
         
     | 
| 5 | 
         
             
            import torch.nn.functional as F
         
     | 
| 6 | 
         
             
            import torch.nn.init as init
         
     | 
| 7 | 
         
            -
            import math
         
     | 
| 8 | 
         
             
            from einops import rearrange
         
     | 
| 9 | 
         
             
            from torch import nn
         
     | 
| 10 | 
         | 
| 
         | 
|
| 1 | 
         
            +
            import math
         
     | 
| 2 | 
         
             
            from typing import Optional
         
     | 
| 3 | 
         | 
| 4 | 
         
             
            import numpy as np
         
     | 
| 5 | 
         
             
            import torch
         
     | 
| 6 | 
         
             
            import torch.nn.functional as F
         
     | 
| 7 | 
         
             
            import torch.nn.init as init
         
     | 
| 
         | 
|
| 8 | 
         
             
            from einops import rearrange
         
     | 
| 9 | 
         
             
            from torch import nn
         
     | 
| 10 | 
         | 
    	
        easyanimate/models/transformer3d.py
    CHANGED
    
    | 
         @@ -15,26 +15,30 @@ import json 
     | 
|
| 15 | 
         
             
            import math
         
     | 
| 16 | 
         
             
            import os
         
     | 
| 17 | 
         
             
            from dataclasses import dataclass
         
     | 
| 18 | 
         
            -
            from typing import Any, Dict, Optional
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            import numpy as np
         
     | 
| 21 | 
         
             
            import torch
         
     | 
| 22 | 
         
             
            import torch.nn.functional as F
         
     | 
| 23 | 
         
             
            import torch.nn.init as init
         
     | 
| 24 | 
         
             
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         
     | 
| 25 | 
         
            -
            from diffusers.models.attention import BasicTransformerBlock
         
     | 
| 26 | 
         
            -
            from diffusers.models.embeddings import PatchEmbed,  
     | 
| 
         | 
|
| 27 | 
         
             
            from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
         
     | 
| 28 | 
         
             
            from diffusers.models.modeling_utils import ModelMixin
         
     | 
| 29 | 
         
            -
            from diffusers.models.normalization import  
     | 
| 30 | 
         
            -
            from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, is_torch_version
         
     | 
| 
         | 
|
| 
         | 
|
| 31 | 
         
             
            from einops import rearrange
         
     | 
| 32 | 
         
             
            from torch import nn
         
     | 
| 33 | 
         
            -
            from typing import Dict, Optional, Tuple
         
     | 
| 34 | 
         | 
| 35 | 
         
             
            from .attention import (SelfAttentionTemporalTransformerBlock,
         
     | 
| 36 | 
         
             
                                    TemporalTransformerBlock)
         
     | 
| 37 | 
         
            -
            from . 
     | 
| 
         | 
|
| 
         | 
|
| 38 | 
         | 
| 39 | 
         
             
            try:
         
     | 
| 40 | 
         
             
                from diffusers.models.embeddings import PixArtAlphaTextProjection
         
     | 
| 
         @@ -48,77 +52,25 @@ def zero_module(module): 
     | 
|
| 48 | 
         
             
                    p.detach().zero_()
         
     | 
| 49 | 
         
             
                return module
         
     | 
| 50 | 
         | 
| 51 | 
         
            -
            class PixArtAlphaCombinedTimestepSizeEmbeddings(nn.Module):
         
     | 
| 52 | 
         
            -
                """
         
     | 
| 53 | 
         
            -
                For PixArt-Alpha.
         
     | 
| 54 | 
         | 
| 55 | 
         
            -
             
     | 
| 56 | 
         
            -
                https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L164C9-L168C29
         
     | 
| 57 | 
         
             
                """
         
     | 
| 
         | 
|
| 58 | 
         | 
| 59 | 
         
            -
                 
     | 
| 60 | 
         
            -
                    super().__init__()
         
     | 
| 61 | 
         
            -
             
     | 
| 62 | 
         
            -
                    self.outdim = size_emb_dim
         
     | 
| 63 | 
         
            -
                    self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
         
     | 
| 64 | 
         
            -
                    self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
         
     | 
| 65 | 
         
            -
             
     | 
| 66 | 
         
            -
                    self.use_additional_conditions = use_additional_conditions
         
     | 
| 67 | 
         
            -
                    if use_additional_conditions:
         
     | 
| 68 | 
         
            -
                        self.additional_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
         
     | 
| 69 | 
         
            -
                        self.resolution_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
         
     | 
| 70 | 
         
            -
                        self.aspect_ratio_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=size_emb_dim)
         
     | 
| 71 | 
         
            -
                        
         
     | 
| 72 | 
         
            -
                        self.resolution_embedder.linear_2 = zero_module(self.resolution_embedder.linear_2)
         
     | 
| 73 | 
         
            -
                        self.aspect_ratio_embedder.linear_2 = zero_module(self.aspect_ratio_embedder.linear_2)
         
     | 
| 74 | 
         
            -
             
     | 
| 75 | 
         
            -
                def forward(self, timestep, resolution, aspect_ratio, batch_size, hidden_dtype):
         
     | 
| 76 | 
         
            -
                    timesteps_proj = self.time_proj(timestep)
         
     | 
| 77 | 
         
            -
                    timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype))  # (N, D)
         
     | 
| 78 | 
         
            -
             
     | 
| 79 | 
         
            -
                    if self.use_additional_conditions:
         
     | 
| 80 | 
         
            -
                        resolution_emb = self.additional_condition_proj(resolution.flatten()).to(hidden_dtype)
         
     | 
| 81 | 
         
            -
                        resolution_emb = self.resolution_embedder(resolution_emb).reshape(batch_size, -1)
         
     | 
| 82 | 
         
            -
                        aspect_ratio_emb = self.additional_condition_proj(aspect_ratio.flatten()).to(hidden_dtype)
         
     | 
| 83 | 
         
            -
                        aspect_ratio_emb = self.aspect_ratio_embedder(aspect_ratio_emb).reshape(batch_size, -1)
         
     | 
| 84 | 
         
            -
                        conditioning = timesteps_emb + torch.cat([resolution_emb, aspect_ratio_emb], dim=1)
         
     | 
| 85 | 
         
            -
                    else:
         
     | 
| 86 | 
         
            -
                        conditioning = timesteps_emb
         
     | 
| 87 | 
         
            -
             
     | 
| 88 | 
         
            -
                    return conditioning
         
     | 
| 89 | 
         
            -
             
     | 
| 90 | 
         
            -
            class AdaLayerNormSingle(nn.Module):
         
     | 
| 91 | 
         
            -
                r"""
         
     | 
| 92 | 
         
            -
                Norm layer adaptive layer norm single (adaLN-single).
         
     | 
| 93 | 
         
            -
             
     | 
| 94 | 
         
            -
                As proposed in PixArt-Alpha (see: https://arxiv.org/abs/2310.00426; Section 2.3).
         
     | 
| 95 | 
         
            -
             
     | 
| 96 | 
         
            -
                Parameters:
         
     | 
| 97 | 
         
            -
                    embedding_dim (`int`): The size of each embedding vector.
         
     | 
| 98 | 
         
            -
                    use_additional_conditions (`bool`): To use additional conditions for normalization or not.
         
     | 
| 99 | 
         
             
                """
         
     | 
| 100 | 
         | 
| 101 | 
         
            -
                def __init__(self,  
     | 
| 102 | 
         
             
                    super().__init__()
         
     | 
| 103 | 
         
            -
             
     | 
| 104 | 
         
            -
                    self. 
     | 
| 105 | 
         
            -
             
     | 
| 106 | 
         
            -
                    )
         
     | 
| 107 | 
         
            -
             
     | 
| 108 | 
         
            -
                     
     | 
| 109 | 
         
            -
                     
     | 
| 110 | 
         
            -
             
     | 
| 111 | 
         
            -
             
     | 
| 112 | 
         
            -
                    self,
         
     | 
| 113 | 
         
            -
                    timestep: torch.Tensor,
         
     | 
| 114 | 
         
            -
                    added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
         
     | 
| 115 | 
         
            -
                    batch_size: Optional[int] = None,
         
     | 
| 116 | 
         
            -
                    hidden_dtype: Optional[torch.dtype] = None,
         
     | 
| 117 | 
         
            -
                ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
         
     | 
| 118 | 
         
            -
                    # No modulation happening here.
         
     | 
| 119 | 
         
            -
                    embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
         
     | 
| 120 | 
         
            -
                    return self.linear(self.silu(embedded_timestep)), embedded_timestep
         
     | 
| 121 | 
         
            -
             
     | 
| 122 | 
         | 
| 123 | 
         
             
            class TimePositionalEncoding(nn.Module):
         
     | 
| 124 | 
         
             
                def __init__(
         
     | 
| 
         @@ -229,9 +181,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 229 | 
         
             
                    # motion module kwargs
         
     | 
| 230 | 
         
             
                    motion_module_type = "VanillaGrid",
         
     | 
| 231 | 
         
             
                    motion_module_kwargs = None,
         
     | 
| 
         | 
|
| 
         | 
|
| 232 | 
         | 
| 233 | 
         
             
                    # time position encoding
         
     | 
| 234 | 
         
            -
                    time_position_encoding_before_transformer = False
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 235 | 
         
             
                ):
         
     | 
| 236 | 
         
             
                    super().__init__()
         
     | 
| 237 | 
         
             
                    self.use_linear_projection = use_linear_projection
         
     | 
| 
         @@ -320,6 +277,35 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 320 | 
         
             
                                    attention_type=attention_type,
         
     | 
| 321 | 
         
             
                                    motion_module_type=motion_module_type,
         
     | 
| 322 | 
         
             
                                    motion_module_kwargs=motion_module_kwargs,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 323 | 
         
             
                                )
         
     | 
| 324 | 
         
             
                                for d in range(num_layers)
         
     | 
| 325 | 
         
             
                            ]
         
     | 
| 
         @@ -346,6 +332,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 346 | 
         
             
                                    kvcompression=False if d < 14 else True,
         
     | 
| 347 | 
         
             
                                    motion_module_type=motion_module_type,
         
     | 
| 348 | 
         
             
                                    motion_module_kwargs=motion_module_kwargs,
         
     | 
| 
         | 
|
| 
         | 
|
| 349 | 
         
             
                                )
         
     | 
| 350 | 
         
             
                                for d in range(num_layers)
         
     | 
| 351 | 
         
             
                            ]
         
     | 
| 
         @@ -369,6 +357,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 369 | 
         
             
                                    norm_elementwise_affine=norm_elementwise_affine,
         
     | 
| 370 | 
         
             
                                    norm_eps=norm_eps,
         
     | 
| 371 | 
         
             
                                    attention_type=attention_type,
         
     | 
| 
         | 
|
| 
         | 
|
| 372 | 
         
             
                                )
         
     | 
| 373 | 
         
             
                                for d in range(num_layers)
         
     | 
| 374 | 
         
             
                            ]
         
     | 
| 
         @@ -438,8 +428,11 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 438 | 
         
             
                        self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
         
     | 
| 439 | 
         | 
| 440 | 
         
             
                    self.caption_projection = None
         
     | 
| 
         | 
|
| 441 | 
         
             
                    if caption_channels is not None:
         
     | 
| 442 | 
         
             
                        self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
         
     | 
| 
         | 
|
| 
         | 
|
| 443 | 
         | 
| 444 | 
         
             
                    self.gradient_checkpointing = False
         
     | 
| 445 | 
         | 
| 
         @@ -456,12 +449,14 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 456 | 
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 457 | 
         
             
                    inpaint_latents: torch.Tensor = None,
         
     | 
| 458 | 
         
             
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         
     | 
| 
         | 
|
| 459 | 
         
             
                    timestep: Optional[torch.LongTensor] = None,
         
     | 
| 460 | 
         
             
                    added_cond_kwargs: Dict[str, torch.Tensor] = None,
         
     | 
| 461 | 
         
             
                    class_labels: Optional[torch.LongTensor] = None,
         
     | 
| 462 | 
         
             
                    cross_attention_kwargs: Dict[str, Any] = None,
         
     | 
| 463 | 
         
             
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 464 | 
         
             
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 
         | 
|
| 465 | 
         
             
                    return_dict: bool = True,
         
     | 
| 466 | 
         
             
                ):
         
     | 
| 467 | 
         
             
                    """
         
     | 
| 
         @@ -520,6 +515,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 520 | 
         
             
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         
     | 
| 521 | 
         
             
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 522 | 
         | 
| 
         | 
|
| 
         | 
|
| 523 | 
         
             
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         
     | 
| 524 | 
         
             
                    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
         
     | 
| 525 | 
         
             
                        encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
         
     | 
| 
         @@ -560,6 +557,13 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 560 | 
         
             
                        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
         
     | 
| 561 | 
         
             
                        encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
         
     | 
| 562 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 563 | 
         
             
                    skips = []
         
     | 
| 564 | 
         
             
                    skip_index = 0
         
     | 
| 565 | 
         
             
                    for index, block in enumerate(self.transformer_blocks):
         
     | 
| 
         @@ -590,7 +594,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 590 | 
         
             
                            args = {
         
     | 
| 591 | 
         
             
                                "basic": [],
         
     | 
| 592 | 
         
             
                                "motionmodule": [video_length, height, width],
         
     | 
| 593 | 
         
            -
                                " 
     | 
| 
         | 
|
| 594 | 
         
             
                                "kvcompression_motionmodule": [video_length, height, width],
         
     | 
| 595 | 
         
             
                            }[self.basic_block_type]
         
     | 
| 596 | 
         
             
                            hidden_states = torch.utils.checkpoint.checkpoint(
         
     | 
| 
         @@ -609,7 +614,8 @@ class Transformer3DModel(ModelMixin, ConfigMixin): 
     | 
|
| 609 | 
         
             
                            kwargs = {
         
     | 
| 610 | 
         
             
                                "basic": {},
         
     | 
| 611 | 
         
             
                                "motionmodule": {"num_frames":video_length, "height":height, "width":width},
         
     | 
| 612 | 
         
            -
                                " 
     | 
| 
         | 
|
| 613 | 
         
             
                                "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
         
     | 
| 614 | 
         
             
                            }[self.basic_block_type]
         
     | 
| 615 | 
         
             
                            hidden_states = block(
         
     | 
| 
         | 
|
| 15 | 
         
             
            import math
         
     | 
| 16 | 
         
             
            import os
         
     | 
| 17 | 
         
             
            from dataclasses import dataclass
         
     | 
| 18 | 
         
            +
            from typing import Any, Dict, Optional, Tuple
         
     | 
| 19 | 
         | 
| 20 | 
         
             
            import numpy as np
         
     | 
| 21 | 
         
             
            import torch
         
     | 
| 22 | 
         
             
            import torch.nn.functional as F
         
     | 
| 23 | 
         
             
            import torch.nn.init as init
         
     | 
| 24 | 
         
             
            from diffusers.configuration_utils import ConfigMixin, register_to_config
         
     | 
| 25 | 
         
            +
            from diffusers.models.attention import BasicTransformerBlock, FeedForward
         
     | 
| 26 | 
         
            +
            from diffusers.models.embeddings import (PatchEmbed, PixArtAlphaTextProjection,
         
     | 
| 27 | 
         
            +
                                                     TimestepEmbedding, Timesteps)
         
     | 
| 28 | 
         
             
            from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
         
     | 
| 29 | 
         
             
            from diffusers.models.modeling_utils import ModelMixin
         
     | 
| 30 | 
         
            +
            from diffusers.models.normalization import AdaLayerNormContinuous
         
     | 
| 31 | 
         
            +
            from diffusers.utils import (USE_PEFT_BACKEND, BaseOutput, is_torch_version,
         
     | 
| 32 | 
         
            +
                                         logging)
         
     | 
| 33 | 
         
            +
            from diffusers.utils.torch_utils import maybe_allow_in_graph
         
     | 
| 34 | 
         
             
            from einops import rearrange
         
     | 
| 35 | 
         
             
            from torch import nn
         
     | 
| 
         | 
|
| 36 | 
         | 
| 37 | 
         
             
            from .attention import (SelfAttentionTemporalTransformerBlock,
         
     | 
| 38 | 
         
             
                                    TemporalTransformerBlock)
         
     | 
| 39 | 
         
            +
            from .norm import AdaLayerNormSingle
         
     | 
| 40 | 
         
            +
            from .patch import (CasualPatchEmbed3D, Patch1D, PatchEmbed3D, PatchEmbedF3D,
         
     | 
| 41 | 
         
            +
                                TemporalUpsampler3D, UnPatch1D)
         
     | 
| 42 | 
         | 
| 43 | 
         
             
            try:
         
     | 
| 44 | 
         
             
                from diffusers.models.embeddings import PixArtAlphaTextProjection
         
     | 
| 
         | 
|
| 52 | 
         
             
                    p.detach().zero_()
         
     | 
| 53 | 
         
             
                return module
         
     | 
| 54 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 55 | 
         | 
| 56 | 
         
            +
            class CLIPProjection(nn.Module):
         
     | 
| 
         | 
|
| 57 | 
         
             
                """
         
     | 
| 58 | 
         
            +
                Projects caption embeddings. Also handles dropout for classifier-free guidance.
         
     | 
| 59 | 
         | 
| 60 | 
         
            +
                Adapted from https://github.com/PixArt-alpha/PixArt-alpha/blob/master/diffusion/model/nets/PixArt_blocks.py
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 61 | 
         
             
                """
         
     | 
| 62 | 
         | 
| 63 | 
         
            +
                def __init__(self, in_features, hidden_size, num_tokens=120):
         
     | 
| 64 | 
         
             
                    super().__init__()
         
     | 
| 65 | 
         
            +
                    self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
         
     | 
| 66 | 
         
            +
                    self.act_1 = nn.GELU(approximate="tanh")
         
     | 
| 67 | 
         
            +
                    self.linear_2 = nn.Linear(in_features=hidden_size, out_features=hidden_size, bias=True)
         
     | 
| 68 | 
         
            +
                    self.linear_2 = zero_module(self.linear_2)
         
     | 
| 69 | 
         
            +
                def forward(self, caption):
         
     | 
| 70 | 
         
            +
                    hidden_states = self.linear_1(caption)
         
     | 
| 71 | 
         
            +
                    hidden_states = self.act_1(hidden_states)
         
     | 
| 72 | 
         
            +
                    hidden_states = self.linear_2(hidden_states)
         
     | 
| 73 | 
         
            +
                    return hidden_states
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 74 | 
         | 
| 75 | 
         
             
            class TimePositionalEncoding(nn.Module):
         
     | 
| 76 | 
         
             
                def __init__(
         
     | 
| 
         | 
|
| 181 | 
         
             
                    # motion module kwargs
         
     | 
| 182 | 
         
             
                    motion_module_type = "VanillaGrid",
         
     | 
| 183 | 
         
             
                    motion_module_kwargs = None,
         
     | 
| 184 | 
         
            +
                    motion_module_kwargs_odd = None,
         
     | 
| 185 | 
         
            +
                    motion_module_kwargs_even = None,
         
     | 
| 186 | 
         | 
| 187 | 
         
             
                    # time position encoding
         
     | 
| 188 | 
         
            +
                    time_position_encoding_before_transformer = False,
         
     | 
| 189 | 
         
            +
             
     | 
| 190 | 
         
            +
                    qk_norm = False,
         
     | 
| 191 | 
         
            +
                    after_norm = False,
         
     | 
| 192 | 
         
             
                ):
         
     | 
| 193 | 
         
             
                    super().__init__()
         
     | 
| 194 | 
         
             
                    self.use_linear_projection = use_linear_projection
         
     | 
| 
         | 
|
| 277 | 
         
             
                                    attention_type=attention_type,
         
     | 
| 278 | 
         
             
                                    motion_module_type=motion_module_type,
         
     | 
| 279 | 
         
             
                                    motion_module_kwargs=motion_module_kwargs,
         
     | 
| 280 | 
         
            +
                                    qk_norm=qk_norm,
         
     | 
| 281 | 
         
            +
                                    after_norm=after_norm,
         
     | 
| 282 | 
         
            +
                                )
         
     | 
| 283 | 
         
            +
                                for d in range(num_layers)
         
     | 
| 284 | 
         
            +
                            ]
         
     | 
| 285 | 
         
            +
                        )
         
     | 
| 286 | 
         
            +
                    elif self.basic_block_type == "global_motionmodule":
         
     | 
| 287 | 
         
            +
                        self.transformer_blocks = nn.ModuleList(
         
     | 
| 288 | 
         
            +
                            [
         
     | 
| 289 | 
         
            +
                                TemporalTransformerBlock(
         
     | 
| 290 | 
         
            +
                                    inner_dim,
         
     | 
| 291 | 
         
            +
                                    num_attention_heads,
         
     | 
| 292 | 
         
            +
                                    attention_head_dim,
         
     | 
| 293 | 
         
            +
                                    dropout=dropout,
         
     | 
| 294 | 
         
            +
                                    cross_attention_dim=cross_attention_dim,
         
     | 
| 295 | 
         
            +
                                    activation_fn=activation_fn,
         
     | 
| 296 | 
         
            +
                                    num_embeds_ada_norm=num_embeds_ada_norm,
         
     | 
| 297 | 
         
            +
                                    attention_bias=attention_bias,
         
     | 
| 298 | 
         
            +
                                    only_cross_attention=only_cross_attention,
         
     | 
| 299 | 
         
            +
                                    double_self_attention=double_self_attention,
         
     | 
| 300 | 
         
            +
                                    upcast_attention=upcast_attention,
         
     | 
| 301 | 
         
            +
                                    norm_type=norm_type,
         
     | 
| 302 | 
         
            +
                                    norm_elementwise_affine=norm_elementwise_affine,
         
     | 
| 303 | 
         
            +
                                    norm_eps=norm_eps,
         
     | 
| 304 | 
         
            +
                                    attention_type=attention_type,
         
     | 
| 305 | 
         
            +
                                    motion_module_type=motion_module_type,
         
     | 
| 306 | 
         
            +
                                    motion_module_kwargs=motion_module_kwargs_even if d % 2 == 0 else motion_module_kwargs_odd,
         
     | 
| 307 | 
         
            +
                                    qk_norm=qk_norm,
         
     | 
| 308 | 
         
            +
                                    after_norm=after_norm,
         
     | 
| 309 | 
         
             
                                )
         
     | 
| 310 | 
         
             
                                for d in range(num_layers)
         
     | 
| 311 | 
         
             
                            ]
         
     | 
| 
         | 
|
| 332 | 
         
             
                                    kvcompression=False if d < 14 else True,
         
     | 
| 333 | 
         
             
                                    motion_module_type=motion_module_type,
         
     | 
| 334 | 
         
             
                                    motion_module_kwargs=motion_module_kwargs,
         
     | 
| 335 | 
         
            +
                                    qk_norm=qk_norm,
         
     | 
| 336 | 
         
            +
                                    after_norm=after_norm,
         
     | 
| 337 | 
         
             
                                )
         
     | 
| 338 | 
         
             
                                for d in range(num_layers)
         
     | 
| 339 | 
         
             
                            ]
         
     | 
| 
         | 
|
| 357 | 
         
             
                                    norm_elementwise_affine=norm_elementwise_affine,
         
     | 
| 358 | 
         
             
                                    norm_eps=norm_eps,
         
     | 
| 359 | 
         
             
                                    attention_type=attention_type,
         
     | 
| 360 | 
         
            +
                                    qk_norm=qk_norm,
         
     | 
| 361 | 
         
            +
                                    after_norm=after_norm,
         
     | 
| 362 | 
         
             
                                )
         
     | 
| 363 | 
         
             
                                for d in range(num_layers)
         
     | 
| 364 | 
         
             
                            ]
         
     | 
| 
         | 
|
| 428 | 
         
             
                        self.adaln_single = AdaLayerNormSingle(inner_dim, use_additional_conditions=self.use_additional_conditions)
         
     | 
| 429 | 
         | 
| 430 | 
         
             
                    self.caption_projection = None
         
     | 
| 431 | 
         
            +
                    self.clip_projection = None
         
     | 
| 432 | 
         
             
                    if caption_channels is not None:
         
     | 
| 433 | 
         
             
                        self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
         
     | 
| 434 | 
         
            +
                        if in_channels == 12:
         
     | 
| 435 | 
         
            +
                            self.clip_projection = CLIPProjection(in_features=768, hidden_size=inner_dim * 8)
         
     | 
| 436 | 
         | 
| 437 | 
         
             
                    self.gradient_checkpointing = False
         
     | 
| 438 | 
         | 
| 
         | 
|
| 449 | 
         
             
                    hidden_states: torch.Tensor,
         
     | 
| 450 | 
         
             
                    inpaint_latents: torch.Tensor = None,
         
     | 
| 451 | 
         
             
                    encoder_hidden_states: Optional[torch.Tensor] = None,
         
     | 
| 452 | 
         
            +
                    clip_encoder_hidden_states: Optional[torch.Tensor] = None,
         
     | 
| 453 | 
         
             
                    timestep: Optional[torch.LongTensor] = None,
         
     | 
| 454 | 
         
             
                    added_cond_kwargs: Dict[str, torch.Tensor] = None,
         
     | 
| 455 | 
         
             
                    class_labels: Optional[torch.LongTensor] = None,
         
     | 
| 456 | 
         
             
                    cross_attention_kwargs: Dict[str, Any] = None,
         
     | 
| 457 | 
         
             
                    attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 458 | 
         
             
                    encoder_attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 459 | 
         
            +
                    clip_attention_mask: Optional[torch.Tensor] = None,
         
     | 
| 460 | 
         
             
                    return_dict: bool = True,
         
     | 
| 461 | 
         
             
                ):
         
     | 
| 462 | 
         
             
                    """
         
     | 
| 
         | 
|
| 515 | 
         
             
                        attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
         
     | 
| 516 | 
         
             
                        attention_mask = attention_mask.unsqueeze(1)
         
     | 
| 517 | 
         | 
| 518 | 
         
            +
                    if clip_attention_mask is not None:
         
     | 
| 519 | 
         
            +
                        encoder_attention_mask = torch.cat([encoder_attention_mask, clip_attention_mask], dim=1)
         
     | 
| 520 | 
         
             
                    # convert encoder_attention_mask to a bias the same way we do for attention_mask
         
     | 
| 521 | 
         
             
                    if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
         
     | 
| 522 | 
         
             
                        encoder_attention_mask = (1 - encoder_attention_mask.to(encoder_hidden_states.dtype)) * -10000.0
         
     | 
| 
         | 
|
| 557 | 
         
             
                        encoder_hidden_states = self.caption_projection(encoder_hidden_states)
         
     | 
| 558 | 
         
             
                        encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
         
     | 
| 559 | 
         | 
| 560 | 
         
            +
                    if clip_encoder_hidden_states is not None and encoder_hidden_states is not None:
         
     | 
| 561 | 
         
            +
                        batch_size = hidden_states.shape[0]
         
     | 
| 562 | 
         
            +
                        clip_encoder_hidden_states = self.clip_projection(clip_encoder_hidden_states)
         
     | 
| 563 | 
         
            +
                        clip_encoder_hidden_states = clip_encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
         
     | 
| 564 | 
         
            +
             
     | 
| 565 | 
         
            +
                        encoder_hidden_states = torch.cat([encoder_hidden_states, clip_encoder_hidden_states], dim = 1)
         
     | 
| 566 | 
         
            +
             
     | 
| 567 | 
         
             
                    skips = []
         
     | 
| 568 | 
         
             
                    skip_index = 0
         
     | 
| 569 | 
         
             
                    for index, block in enumerate(self.transformer_blocks):
         
     | 
| 
         | 
|
| 594 | 
         
             
                            args = {
         
     | 
| 595 | 
         
             
                                "basic": [],
         
     | 
| 596 | 
         
             
                                "motionmodule": [video_length, height, width],
         
     | 
| 597 | 
         
            +
                                "global_motionmodule": [video_length, height, width],
         
     | 
| 598 | 
         
            +
                                "selfattentiontemporal": [],
         
     | 
| 599 | 
         
             
                                "kvcompression_motionmodule": [video_length, height, width],
         
     | 
| 600 | 
         
             
                            }[self.basic_block_type]
         
     | 
| 601 | 
         
             
                            hidden_states = torch.utils.checkpoint.checkpoint(
         
     | 
| 
         | 
|
| 614 | 
         
             
                            kwargs = {
         
     | 
| 615 | 
         
             
                                "basic": {},
         
     | 
| 616 | 
         
             
                                "motionmodule": {"num_frames":video_length, "height":height, "width":width},
         
     | 
| 617 | 
         
            +
                                "global_motionmodule": {"num_frames":video_length, "height":height, "width":width},
         
     | 
| 618 | 
         
            +
                                "selfattentiontemporal": {},
         
     | 
| 619 | 
         
             
                                "kvcompression_motionmodule": {"num_frames":video_length, "height":height, "width":width},
         
     | 
| 620 | 
         
             
                            }[self.basic_block_type]
         
     | 
| 621 | 
         
             
                            hidden_states = block(
         
     | 
    	
        easyanimate/pipeline/pipeline_easyanimate.py
    CHANGED
    
    | 
         @@ -578,7 +578,7 @@ class EasyAnimatePipeline(DiffusionPipeline): 
     | 
|
| 578 | 
         | 
| 579 | 
         
             
                def decode_latents(self, latents):
         
     | 
| 580 | 
         
             
                    video_length = latents.shape[2]
         
     | 
| 581 | 
         
            -
                    latents = 1 /  
     | 
| 582 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 583 | 
         
             
                        mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 584 | 
         
             
                        mini_batch_decoder = self.vae.mini_batch_decoder
         
     | 
| 
         | 
|
| 578 | 
         | 
| 579 | 
         
             
                def decode_latents(self, latents):
         
     | 
| 580 | 
         
             
                    video_length = latents.shape[2]
         
     | 
| 581 | 
         
            +
                    latents = 1 / self.vae.config.scaling_factor * latents
         
     | 
| 582 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 583 | 
         
             
                        mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 584 | 
         
             
                        mini_batch_decoder = self.vae.mini_batch_decoder
         
     | 
    	
        easyanimate/pipeline/pipeline_easyanimate_inpaint.py
    CHANGED
    
    | 
         @@ -15,13 +15,16 @@ 
     | 
|
| 15 | 
         
             
            import html
         
     | 
| 16 | 
         
             
            import inspect
         
     | 
| 17 | 
         
             
            import re
         
     | 
| 
         | 
|
| 18 | 
         
             
            import copy
         
     | 
| 19 | 
         
             
            import urllib.parse as ul
         
     | 
| 20 | 
         
             
            from dataclasses import dataclass
         
     | 
| 
         | 
|
| 21 | 
         
             
            from typing import Callable, List, Optional, Tuple, Union
         
     | 
| 22 | 
         | 
| 23 | 
         
             
            import numpy as np
         
     | 
| 24 | 
         
             
            import torch
         
     | 
| 
         | 
|
| 25 | 
         
             
            from diffusers import DiffusionPipeline, ImagePipelineOutput
         
     | 
| 26 | 
         
             
            from diffusers.image_processor import VaeImageProcessor
         
     | 
| 27 | 
         
             
            from diffusers.models import AutoencoderKL
         
     | 
| 
         @@ -33,6 +36,7 @@ from diffusers.utils.torch_utils import randn_tensor 
     | 
|
| 33 | 
         
             
            from einops import rearrange
         
     | 
| 34 | 
         
             
            from tqdm import tqdm
         
     | 
| 35 | 
         
             
            from transformers import T5EncoderModel, T5Tokenizer
         
     | 
| 
         | 
|
| 36 | 
         | 
| 37 | 
         
             
            from ..models.transformer3d import Transformer3DModel
         
     | 
| 38 | 
         | 
| 
         @@ -109,11 +113,15 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 109 | 
         
             
                    vae: AutoencoderKL,
         
     | 
| 110 | 
         
             
                    transformer: Transformer3DModel,
         
     | 
| 111 | 
         
             
                    scheduler: DPMSolverMultistepScheduler,
         
     | 
| 
         | 
|
| 
         | 
|
| 112 | 
         
             
                ):
         
     | 
| 113 | 
         
             
                    super().__init__()
         
     | 
| 114 | 
         | 
| 115 | 
         
             
                    self.register_modules(
         
     | 
| 116 | 
         
            -
                        tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer,  
     | 
| 
         | 
|
| 
         | 
|
| 117 | 
         
             
                    )
         
     | 
| 118 | 
         | 
| 119 | 
         
             
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         
     | 
| 
         @@ -503,41 +511,64 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 503 | 
         
             
                    return_video_latents=False,
         
     | 
| 504 | 
         
             
                ):
         
     | 
| 505 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 506 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 507 | 
         
             
                    else:
         
     | 
| 508 | 
         
             
                        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
         
     | 
| 
         | 
|
| 509 | 
         
             
                    if isinstance(generator, list) and len(generator) != batch_size:
         
     | 
| 510 | 
         
             
                        raise ValueError(
         
     | 
| 511 | 
         
             
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         
     | 
| 512 | 
         
             
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         
     | 
| 513 | 
         
             
                        )
         
     | 
| 514 | 
         
            -
             
     | 
| 515 | 
         
             
                    if return_video_latents or (latents is None and not is_strength_max):
         
     | 
| 516 | 
         
            -
                        video = video.to(device=device, dtype=dtype)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 517 | 
         | 
| 518 | 
         
            -
                        if video.shape[1] == 4:
         
     | 
| 519 | 
         
            -
                            video_latents = video
         
     | 
| 520 | 
         
             
                        else:
         
     | 
| 521 | 
         
            -
                             
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
                             
     | 
| 524 | 
         
            -
             
     | 
| 525 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 526 | 
         | 
| 527 | 
         
             
                    if latents is None:
         
     | 
| 528 | 
         
            -
                         
     | 
| 529 | 
         
            -
             
     | 
| 530 | 
         
            -
                        noise = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
         
     | 
| 531 | 
         
             
                        # if strength is 1. then initialise the latents to noise, else initial to image + noise
         
     | 
| 532 | 
         
             
                        latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
         
     | 
| 
         | 
|
| 
         | 
|
| 533 | 
         
             
                    else:
         
     | 
| 534 | 
         
             
                        noise = latents.to(device)
         
     | 
| 535 | 
         
            -
                         
     | 
| 536 | 
         
            -
                            raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
         
     | 
| 537 | 
         
            -
                        latents = latents.to(device)
         
     | 
| 538 | 
         | 
| 539 | 
         
             
                    # scale the initial noise by the standard deviation required by the scheduler
         
     | 
| 540 | 
         
            -
                    latents = latents * self.scheduler.init_noise_sigma
         
     | 
| 541 | 
         
             
                    outputs = (latents,)
         
     | 
| 542 | 
         | 
| 543 | 
         
             
                    if return_noise:
         
     | 
| 
         @@ -548,33 +579,61 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 548 | 
         | 
| 549 | 
         
             
                    return outputs
         
     | 
| 550 | 
         | 
| 551 | 
         
            -
                def  
     | 
| 552 | 
         
            -
                     
     | 
| 553 | 
         
            -
             
     | 
| 554 | 
         
            -
                     
     | 
| 555 | 
         
            -
             
     | 
| 556 | 
         
            -
             
     | 
| 557 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 558 | 
         
             
                        for i in range(0, latents.shape[2], mini_batch_decoder):
         
     | 
| 559 | 
         
             
                            with torch.no_grad():
         
     | 
| 560 | 
         
             
                                start_index = i
         
     | 
| 561 | 
         
             
                                end_index = i + mini_batch_decoder
         
     | 
| 562 | 
         
             
                                latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
         
     | 
| 563 | 
         
            -
                                 
     | 
| 564 | 
         
            -
             
     | 
| 565 | 
         
            -
             
     | 
| 566 | 
         
            -
             
     | 
| 567 | 
         
            -
             
     | 
| 568 | 
         
            -
             
     | 
| 569 | 
         
            -
             
     | 
| 570 | 
         
            -
             
     | 
| 571 | 
         
            -
             
     | 
| 572 | 
         
            -
             
     | 
| 573 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 574 | 
         
             
                        video = video.clamp(-1, 1)
         
     | 
| 
         | 
|
| 575 | 
         
             
                    else:
         
     | 
| 576 | 
         
             
                        latents = rearrange(latents, "b c f h w -> (b f) c h w")
         
     | 
| 577 | 
         
            -
                        # video = self.vae.decode(latents).sample
         
     | 
| 578 | 
         
             
                        video = []
         
     | 
| 579 | 
         
             
                        for frame_idx in tqdm(range(latents.shape[0])):
         
     | 
| 580 | 
         
             
                            video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
         
     | 
| 
         @@ -599,6 +658,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 599 | 
         | 
| 600 | 
         
             
                    return image_latents
         
     | 
| 601 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 602 | 
         
             
                def prepare_mask_latents(
         
     | 
| 603 | 
         
             
                    self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
         
     | 
| 604 | 
         
             
                ):
         
     | 
| 
         @@ -610,19 +679,26 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 610 | 
         
             
                    mask = mask.to(device=device, dtype=self.vae.dtype)
         
     | 
| 611 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 612 | 
         
             
                        bs = 1
         
     | 
| 
         | 
|
| 613 | 
         
             
                        new_mask = []
         
     | 
| 614 | 
         
            -
                         
     | 
| 615 | 
         
            -
                             
     | 
| 616 | 
         
            -
             
     | 
| 617 | 
         
            -
                            for j in range(0, mask.shape[2], mini_batch):
         
     | 
| 618 | 
         
            -
                                mask_bs = mask[i : i + bs, :, j: j + mini_batch, :, :]
         
     | 
| 619 | 
         
             
                                mask_bs = self.vae.encode(mask_bs)[0]
         
     | 
| 620 | 
         
             
                                mask_bs = mask_bs.sample()
         
     | 
| 621 | 
         
            -
                                 
     | 
| 622 | 
         
            -
             
     | 
| 623 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 624 | 
         
             
                        mask = torch.cat(new_mask, dim = 0)
         
     | 
| 625 | 
         
            -
                        mask = mask *  
     | 
| 626 | 
         | 
| 627 | 
         
             
                    else:
         
     | 
| 628 | 
         
             
                        if mask.shape[1] == 4:
         
     | 
| 
         @@ -636,19 +712,26 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 636 | 
         
             
                    masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
         
     | 
| 637 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 638 | 
         
             
                        bs = 1
         
     | 
| 
         | 
|
| 639 | 
         
             
                        new_mask_pixel_values = []
         
     | 
| 640 | 
         
            -
                         
     | 
| 641 | 
         
            -
                             
     | 
| 642 | 
         
            -
             
     | 
| 643 | 
         
            -
                            for j in range(0, masked_image.shape[2], mini_batch):
         
     | 
| 644 | 
         
            -
                                mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch, :, :]
         
     | 
| 645 | 
         
             
                                mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
         
     | 
| 646 | 
         
             
                                mask_pixel_values_bs = mask_pixel_values_bs.sample()
         
     | 
| 647 | 
         
            -
                                 
     | 
| 648 | 
         
            -
             
     | 
| 649 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 650 | 
         
             
                        masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
         
     | 
| 651 | 
         
            -
                        masked_image_latents = masked_image_latents *  
     | 
| 652 | 
         | 
| 653 | 
         
             
                    else:
         
     | 
| 654 | 
         
             
                        if masked_image.shape[1] == 4:
         
     | 
| 
         @@ -693,7 +776,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 693 | 
         
             
                    callback_steps: int = 1,
         
     | 
| 694 | 
         
             
                    clean_caption: bool = True,
         
     | 
| 695 | 
         
             
                    mask_feature: bool = True,
         
     | 
| 696 | 
         
            -
                    max_sequence_length: int = 120
         
     | 
| 
         | 
|
| 
         | 
|
| 697 | 
         
             
                ) -> Union[EasyAnimatePipelineOutput, Tuple]:
         
     | 
| 698 | 
         
             
                    """
         
     | 
| 699 | 
         
             
                    Function invoked when calling the pipeline for generation.
         
     | 
| 
         @@ -767,6 +852,8 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 767 | 
         
             
                    # 1. Check inputs. Raise error if not correct
         
     | 
| 768 | 
         
             
                    height = height or self.transformer.config.sample_size * self.vae_scale_factor
         
     | 
| 769 | 
         
             
                    width = width or self.transformer.config.sample_size * self.vae_scale_factor
         
     | 
| 
         | 
|
| 
         | 
|
| 770 | 
         | 
| 771 | 
         
             
                    # 2. Default height and width to transformer
         
     | 
| 772 | 
         
             
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 
         @@ -806,11 +893,13 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 806 | 
         
             
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         
     | 
| 807 | 
         
             
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
         
     | 
| 808 | 
         | 
| 809 | 
         
            -
                    # 4.  
     | 
| 810 | 
         
             
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         
     | 
| 811 | 
         
            -
                    timesteps = self. 
     | 
| 
         | 
|
| 
         | 
|
| 812 | 
         
             
                    # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
         
     | 
| 813 | 
         
            -
                    latent_timestep = timesteps[:1].repeat(batch_size)
         
     | 
| 814 | 
         
             
                    # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
         
     | 
| 815 | 
         
             
                    is_strength_max = strength == 1.0
         
     | 
| 816 | 
         | 
| 
         @@ -825,7 +914,7 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 825 | 
         
             
                    # Prepare latent variables
         
     | 
| 826 | 
         
             
                    num_channels_latents = self.vae.config.latent_channels
         
     | 
| 827 | 
         
             
                    num_channels_transformer = self.transformer.config.in_channels
         
     | 
| 828 | 
         
            -
                    return_image_latents = num_channels_transformer == 4
         
     | 
| 829 | 
         | 
| 830 | 
         
             
                    # 5. Prepare latents.
         
     | 
| 831 | 
         
             
                    latents_outputs = self.prepare_latents(
         
     | 
| 
         @@ -857,30 +946,83 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 857 | 
         
             
                        mask_condition = mask_condition.to(dtype=torch.float32)
         
     | 
| 858 | 
         
             
                        mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
         
     | 
| 859 | 
         | 
| 860 | 
         
            -
                        if  
     | 
| 861 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 862 | 
         
             
                        else:
         
     | 
| 863 | 
         
            -
                             
     | 
| 864 | 
         
            -
             
     | 
| 865 | 
         
            -
             
     | 
| 866 | 
         
            -
                             
     | 
| 867 | 
         
            -
             
     | 
| 868 | 
         
            -
             
     | 
| 869 | 
         
            -
                             
     | 
| 870 | 
         
            -
                             
     | 
| 871 | 
         
            -
             
     | 
| 872 | 
         
            -
                             
     | 
| 873 | 
         
            -
                             
     | 
| 874 | 
         
            -
             
     | 
| 875 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 876 | 
         
             
                    else:
         
     | 
| 877 | 
         
            -
                         
     | 
| 878 | 
         
            -
                         
     | 
| 879 | 
         | 
| 880 | 
         
             
                    # Check that sizes of mask, masked image and latents match
         
     | 
| 881 | 
         
             
                    if num_channels_transformer == 12:
         
     | 
| 882 | 
         
             
                        # default case for runwayml/stable-diffusion-inpainting
         
     | 
| 883 | 
         
            -
                        num_channels_mask =  
     | 
| 884 | 
         
             
                        num_channels_masked_image = masked_video_latents.shape[1]
         
     | 
| 885 | 
         
             
                        if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
         
     | 
| 886 | 
         
             
                            raise ValueError(
         
     | 
| 
         @@ -890,12 +1032,12 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 890 | 
         
             
                                f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
         
     | 
| 891 | 
         
             
                                " `pipeline.transformer` or your `mask_image` or `image` input."
         
     | 
| 892 | 
         
             
                            )
         
     | 
| 893 | 
         
            -
                    elif num_channels_transformer  
     | 
| 894 | 
         
             
                        raise ValueError(
         
     | 
| 895 | 
         
             
                            f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
         
     | 
| 896 | 
         
             
                        )
         
     | 
| 897 | 
         | 
| 898 | 
         
            -
                    #  
     | 
| 899 | 
         
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         
     | 
| 900 | 
         | 
| 901 | 
         
             
                    # 6.1 Prepare micro-conditions.
         
     | 
| 
         @@ -912,21 +1054,25 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 912 | 
         | 
| 913 | 
         
             
                        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
         
     | 
| 914 | 
         | 
| 915 | 
         
            -
                     
     | 
| 916 | 
         
            -
                     
     | 
| 
         | 
|
| 917 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 918 | 
         
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 919 | 
         
             
                        for i, t in enumerate(timesteps):
         
     | 
| 920 | 
         
             
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         
     | 
| 921 | 
         
             
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 922 | 
         | 
| 923 | 
         
            -
                            if  
     | 
| 924 | 
         
            -
                                 
     | 
| 925 | 
         
            -
                                 
     | 
| 926 | 
         
            -
             
     | 
| 927 | 
         
            -
                                 
     | 
| 928 | 
         
            -
                                 
     | 
| 929 | 
         
            -
             
     | 
| 930 | 
         
             
                            current_timestep = t
         
     | 
| 931 | 
         
             
                            if not torch.is_tensor(current_timestep):
         
     | 
| 932 | 
         
             
                                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         
     | 
| 
         @@ -949,7 +1095,9 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 949 | 
         
             
                                encoder_attention_mask=prompt_attention_mask,
         
     | 
| 950 | 
         
             
                                timestep=current_timestep,
         
     | 
| 951 | 
         
             
                                added_cond_kwargs=added_cond_kwargs,
         
     | 
| 952 | 
         
            -
                                inpaint_latents=inpaint_latents 
     | 
| 
         | 
|
| 
         | 
|
| 953 | 
         
             
                                return_dict=False,
         
     | 
| 954 | 
         
             
                            )[0]
         
     | 
| 955 | 
         | 
| 
         @@ -964,6 +1112,17 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 964 | 
         
             
                            # compute previous image: x_t -> x_t-1
         
     | 
| 965 | 
         
             
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
         
     | 
| 966 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 967 | 
         
             
                            # call the callback, if provided
         
     | 
| 968 | 
         
             
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         
     | 
| 969 | 
         
             
                                progress_bar.update()
         
     | 
| 
         @@ -971,9 +1130,16 @@ class EasyAnimateInpaintPipeline(DiffusionPipeline): 
     | 
|
| 971 | 
         
             
                                    step_idx = i // getattr(self.scheduler, "order", 1)
         
     | 
| 972 | 
         
             
                                    callback(step_idx, t, latents)
         
     | 
| 973 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 974 | 
         
             
                    # Post-processing
         
     | 
| 975 | 
         
             
                    video = self.decode_latents(latents)
         
     | 
| 976 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 977 | 
         
             
                    # Convert to tensor
         
     | 
| 978 | 
         
             
                    if output_type == "latent":
         
     | 
| 979 | 
         
             
                        video = torch.from_numpy(video)
         
     | 
| 
         | 
|
| 15 | 
         
             
            import html
         
     | 
| 16 | 
         
             
            import inspect
         
     | 
| 17 | 
         
             
            import re
         
     | 
| 18 | 
         
            +
            import gc
         
     | 
| 19 | 
         
             
            import copy
         
     | 
| 20 | 
         
             
            import urllib.parse as ul
         
     | 
| 21 | 
         
             
            from dataclasses import dataclass
         
     | 
| 22 | 
         
            +
            from PIL import Image
         
     | 
| 23 | 
         
             
            from typing import Callable, List, Optional, Tuple, Union
         
     | 
| 24 | 
         | 
| 25 | 
         
             
            import numpy as np
         
     | 
| 26 | 
         
             
            import torch
         
     | 
| 27 | 
         
            +
            import torch.nn.functional as F
         
     | 
| 28 | 
         
             
            from diffusers import DiffusionPipeline, ImagePipelineOutput
         
     | 
| 29 | 
         
             
            from diffusers.image_processor import VaeImageProcessor
         
     | 
| 30 | 
         
             
            from diffusers.models import AutoencoderKL
         
     | 
| 
         | 
|
| 36 | 
         
             
            from einops import rearrange
         
     | 
| 37 | 
         
             
            from tqdm import tqdm
         
     | 
| 38 | 
         
             
            from transformers import T5EncoderModel, T5Tokenizer
         
     | 
| 39 | 
         
            +
            from transformers import CLIPVisionModelWithProjection,  CLIPImageProcessor
         
     | 
| 40 | 
         | 
| 41 | 
         
             
            from ..models.transformer3d import Transformer3DModel
         
     | 
| 42 | 
         | 
| 
         | 
|
| 113 | 
         
             
                    vae: AutoencoderKL,
         
     | 
| 114 | 
         
             
                    transformer: Transformer3DModel,
         
     | 
| 115 | 
         
             
                    scheduler: DPMSolverMultistepScheduler,
         
     | 
| 116 | 
         
            +
                    clip_image_processor:CLIPImageProcessor = None,
         
     | 
| 117 | 
         
            +
                    clip_image_encoder:CLIPVisionModelWithProjection = None,
         
     | 
| 118 | 
         
             
                ):
         
     | 
| 119 | 
         
             
                    super().__init__()
         
     | 
| 120 | 
         | 
| 121 | 
         
             
                    self.register_modules(
         
     | 
| 122 | 
         
            +
                        tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, 
         
     | 
| 123 | 
         
            +
                        scheduler=scheduler,
         
     | 
| 124 | 
         
            +
                        clip_image_processor=clip_image_processor, clip_image_encoder=clip_image_encoder,
         
     | 
| 125 | 
         
             
                    )
         
     | 
| 126 | 
         | 
| 127 | 
         
             
                    self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
         
     | 
| 
         | 
|
| 511 | 
         
             
                    return_video_latents=False,
         
     | 
| 512 | 
         
             
                ):
         
     | 
| 513 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 514 | 
         
            +
                        mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 515 | 
         
            +
                        mini_batch_decoder = self.vae.mini_batch_decoder
         
     | 
| 516 | 
         
            +
                        shape = (batch_size, num_channels_latents, int(video_length // mini_batch_encoder * mini_batch_decoder) if video_length != 1 else 1, height // self.vae_scale_factor, width // self.vae_scale_factor)
         
     | 
| 517 | 
         
             
                    else:
         
     | 
| 518 | 
         
             
                        shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
         
     | 
| 519 | 
         
            +
             
     | 
| 520 | 
         
             
                    if isinstance(generator, list) and len(generator) != batch_size:
         
     | 
| 521 | 
         
             
                        raise ValueError(
         
     | 
| 522 | 
         
             
                            f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
         
     | 
| 523 | 
         
             
                            f" size of {batch_size}. Make sure the batch size matches the length of the generators."
         
     | 
| 524 | 
         
             
                        )
         
     | 
| 525 | 
         
            +
             
     | 
| 526 | 
         
             
                    if return_video_latents or (latents is None and not is_strength_max):
         
     | 
| 527 | 
         
            +
                        video = video.to(device=device, dtype=self.vae.dtype)
         
     | 
| 528 | 
         
            +
                        if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 529 | 
         
            +
                            bs = 1
         
     | 
| 530 | 
         
            +
                            mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 531 | 
         
            +
                            new_video = []
         
     | 
| 532 | 
         
            +
                            if self.vae.slice_compression_vae:
         
     | 
| 533 | 
         
            +
                                for i in range(0, video.shape[0], bs):
         
     | 
| 534 | 
         
            +
                                    video_bs = video[i : i + bs]
         
     | 
| 535 | 
         
            +
                                    video_bs = self.vae.encode(video_bs)[0]
         
     | 
| 536 | 
         
            +
                                    video_bs = video_bs.sample()
         
     | 
| 537 | 
         
            +
                                    new_video.append(video_bs)
         
     | 
| 538 | 
         
            +
                            else:
         
     | 
| 539 | 
         
            +
                                for i in range(0, video.shape[0], bs):
         
     | 
| 540 | 
         
            +
                                    new_video_mini_batch = []
         
     | 
| 541 | 
         
            +
                                    for j in range(0, video.shape[2], mini_batch_encoder):
         
     | 
| 542 | 
         
            +
                                        video_bs = video[i : i + bs, :, j: j + mini_batch_encoder, :, :]
         
     | 
| 543 | 
         
            +
                                        video_bs = self.vae.encode(video_bs)[0]
         
     | 
| 544 | 
         
            +
                                        video_bs = video_bs.sample()
         
     | 
| 545 | 
         
            +
                                        new_video_mini_batch.append(video_bs)
         
     | 
| 546 | 
         
            +
                                    new_video_mini_batch = torch.cat(new_video_mini_batch, dim = 2)
         
     | 
| 547 | 
         
            +
                                    new_video.append(new_video_mini_batch)
         
     | 
| 548 | 
         
            +
                            video = torch.cat(new_video, dim = 0)
         
     | 
| 549 | 
         
            +
                            video = video * self.vae.config.scaling_factor
         
     | 
| 550 | 
         | 
| 
         | 
|
| 
         | 
|
| 551 | 
         
             
                        else:
         
     | 
| 552 | 
         
            +
                            if video.shape[1] == 4:
         
     | 
| 553 | 
         
            +
                                video = video
         
     | 
| 554 | 
         
            +
                            else:
         
     | 
| 555 | 
         
            +
                                video_length = video.shape[2]
         
     | 
| 556 | 
         
            +
                                video = rearrange(video, "b c f h w -> (b f) c h w")
         
     | 
| 557 | 
         
            +
                                video = self._encode_vae_image(video, generator=generator)
         
     | 
| 558 | 
         
            +
                                video = rearrange(video, "(b f) c h w -> b c f h w", f=video_length)
         
     | 
| 559 | 
         
            +
                        video_latents = video.repeat(batch_size // video.shape[0], 1, 1, 1, 1)
         
     | 
| 560 | 
         | 
| 561 | 
         
             
                    if latents is None:
         
     | 
| 562 | 
         
            +
                        noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
         
     | 
| 
         | 
|
| 
         | 
|
| 563 | 
         
             
                        # if strength is 1. then initialise the latents to noise, else initial to image + noise
         
     | 
| 564 | 
         
             
                        latents = noise if is_strength_max else self.scheduler.add_noise(video_latents, noise, timestep)
         
     | 
| 565 | 
         
            +
                        # if pure noise then scale the initial latents by the  Scheduler's init sigma
         
     | 
| 566 | 
         
            +
                        latents = latents * self.scheduler.init_noise_sigma if is_strength_max else latents
         
     | 
| 567 | 
         
             
                    else:
         
     | 
| 568 | 
         
             
                        noise = latents.to(device)
         
     | 
| 569 | 
         
            +
                        latents = noise * self.scheduler.init_noise_sigma
         
     | 
| 
         | 
|
| 
         | 
|
| 570 | 
         | 
| 571 | 
         
             
                    # scale the initial noise by the standard deviation required by the scheduler
         
     | 
| 
         | 
|
| 572 | 
         
             
                    outputs = (latents,)
         
     | 
| 573 | 
         | 
| 574 | 
         
             
                    if return_noise:
         
     | 
| 
         | 
|
| 579 | 
         | 
| 580 | 
         
             
                    return outputs
         
     | 
| 581 | 
         | 
| 582 | 
         
            +
                def smooth_output(self, video, mini_batch_encoder, mini_batch_decoder):
         
     | 
| 583 | 
         
            +
                    if video.size()[2] <= mini_batch_encoder:
         
     | 
| 584 | 
         
            +
                        return video
         
     | 
| 585 | 
         
            +
                    prefix_index_before = mini_batch_encoder // 2
         
     | 
| 586 | 
         
            +
                    prefix_index_after = mini_batch_encoder - prefix_index_before
         
     | 
| 587 | 
         
            +
                    pixel_values = video[:, :, prefix_index_before:-prefix_index_after]
         
     | 
| 588 | 
         
            +
                    
         
     | 
| 589 | 
         
            +
                    if self.vae.slice_compression_vae:
         
     | 
| 590 | 
         
            +
                        latents = self.vae.encode(pixel_values)[0]
         
     | 
| 591 | 
         
            +
                        latents = latents.sample()
         
     | 
| 592 | 
         
            +
                    else:
         
     | 
| 593 | 
         
            +
                        new_pixel_values = []
         
     | 
| 594 | 
         
            +
                        for i in range(0, pixel_values.shape[2], mini_batch_encoder):
         
     | 
| 595 | 
         
            +
                            with torch.no_grad():
         
     | 
| 596 | 
         
            +
                                pixel_values_bs = pixel_values[:, :, i: i + mini_batch_encoder, :, :]
         
     | 
| 597 | 
         
            +
                                pixel_values_bs = self.vae.encode(pixel_values_bs)[0]
         
     | 
| 598 | 
         
            +
                                pixel_values_bs = pixel_values_bs.sample()
         
     | 
| 599 | 
         
            +
                                new_pixel_values.append(pixel_values_bs)
         
     | 
| 600 | 
         
            +
                        latents = torch.cat(new_pixel_values, dim = 2)
         
     | 
| 601 | 
         
            +
                            
         
     | 
| 602 | 
         
            +
                    if self.vae.slice_compression_vae:
         
     | 
| 603 | 
         
            +
                        middle_video = self.vae.decode(latents)[0]
         
     | 
| 604 | 
         
            +
                    else:
         
     | 
| 605 | 
         
            +
                        middle_video = []
         
     | 
| 606 | 
         
             
                        for i in range(0, latents.shape[2], mini_batch_decoder):
         
     | 
| 607 | 
         
             
                            with torch.no_grad():
         
     | 
| 608 | 
         
             
                                start_index = i
         
     | 
| 609 | 
         
             
                                end_index = i + mini_batch_decoder
         
     | 
| 610 | 
         
             
                                latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
         
     | 
| 611 | 
         
            +
                                middle_video.append(latents_bs)
         
     | 
| 612 | 
         
            +
                        middle_video = torch.cat(middle_video, 2)
         
     | 
| 613 | 
         
            +
                    video[:, :, prefix_index_before:-prefix_index_after] = (video[:, :, prefix_index_before:-prefix_index_after] + middle_video) / 2
         
     | 
| 614 | 
         
            +
                    return video
         
     | 
| 615 | 
         
            +
                
         
     | 
| 616 | 
         
            +
                def decode_latents(self, latents):
         
     | 
| 617 | 
         
            +
                    video_length = latents.shape[2]
         
     | 
| 618 | 
         
            +
                    latents = 1 / self.vae.config.scaling_factor * latents
         
     | 
| 619 | 
         
            +
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 620 | 
         
            +
                        mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 621 | 
         
            +
                        mini_batch_decoder = self.vae.mini_batch_decoder
         
     | 
| 622 | 
         
            +
                        if self.vae.slice_compression_vae:
         
     | 
| 623 | 
         
            +
                            video = self.vae.decode(latents)[0]
         
     | 
| 624 | 
         
            +
                        else:
         
     | 
| 625 | 
         
            +
                            video = []
         
     | 
| 626 | 
         
            +
                            for i in range(0, latents.shape[2], mini_batch_decoder):
         
     | 
| 627 | 
         
            +
                                with torch.no_grad():
         
     | 
| 628 | 
         
            +
                                    start_index = i
         
     | 
| 629 | 
         
            +
                                    end_index = i + mini_batch_decoder
         
     | 
| 630 | 
         
            +
                                    latents_bs = self.vae.decode(latents[:, :, start_index:end_index, :, :])[0]
         
     | 
| 631 | 
         
            +
                                    video.append(latents_bs)
         
     | 
| 632 | 
         
            +
                            video = torch.cat(video, 2)
         
     | 
| 633 | 
         
             
                        video = video.clamp(-1, 1)
         
     | 
| 634 | 
         
            +
                        video = self.smooth_output(video, mini_batch_encoder, mini_batch_decoder).cpu().clamp(-1, 1)
         
     | 
| 635 | 
         
             
                    else:
         
     | 
| 636 | 
         
             
                        latents = rearrange(latents, "b c f h w -> (b f) c h w")
         
     | 
| 
         | 
|
| 637 | 
         
             
                        video = []
         
     | 
| 638 | 
         
             
                        for frame_idx in tqdm(range(latents.shape[0])):
         
     | 
| 639 | 
         
             
                            video.append(self.vae.decode(latents[frame_idx:frame_idx+1]).sample)
         
     | 
| 
         | 
|
| 658 | 
         | 
| 659 | 
         
             
                    return image_latents
         
     | 
| 660 | 
         | 
| 661 | 
         
            +
                # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.StableDiffusionImg2ImgPipeline.get_timesteps
         
     | 
| 662 | 
         
            +
                def get_timesteps(self, num_inference_steps, strength, device):
         
     | 
| 663 | 
         
            +
                    # get the original timestep using init_timestep
         
     | 
| 664 | 
         
            +
                    init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
         
     | 
| 665 | 
         
            +
             
     | 
| 666 | 
         
            +
                    t_start = max(num_inference_steps - init_timestep, 0)
         
     | 
| 667 | 
         
            +
                    timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
         
     | 
| 668 | 
         
            +
             
     | 
| 669 | 
         
            +
                    return timesteps, num_inference_steps - t_start
         
     | 
| 670 | 
         
            +
             
     | 
| 671 | 
         
             
                def prepare_mask_latents(
         
     | 
| 672 | 
         
             
                    self, mask, masked_image, batch_size, height, width, dtype, device, generator, do_classifier_free_guidance
         
     | 
| 673 | 
         
             
                ):
         
     | 
| 
         | 
|
| 679 | 
         
             
                    mask = mask.to(device=device, dtype=self.vae.dtype)
         
     | 
| 680 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 681 | 
         
             
                        bs = 1
         
     | 
| 682 | 
         
            +
                        mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 683 | 
         
             
                        new_mask = []
         
     | 
| 684 | 
         
            +
                        if self.vae.slice_compression_vae:
         
     | 
| 685 | 
         
            +
                            for i in range(0, mask.shape[0], bs):
         
     | 
| 686 | 
         
            +
                                mask_bs = mask[i : i + bs]
         
     | 
| 
         | 
|
| 
         | 
|
| 687 | 
         
             
                                mask_bs = self.vae.encode(mask_bs)[0]
         
     | 
| 688 | 
         
             
                                mask_bs = mask_bs.sample()
         
     | 
| 689 | 
         
            +
                                new_mask.append(mask_bs)
         
     | 
| 690 | 
         
            +
                        else:
         
     | 
| 691 | 
         
            +
                            for i in range(0, mask.shape[0], bs):
         
     | 
| 692 | 
         
            +
                                new_mask_mini_batch = []
         
     | 
| 693 | 
         
            +
                                for j in range(0, mask.shape[2], mini_batch_encoder):
         
     | 
| 694 | 
         
            +
                                    mask_bs = mask[i : i + bs, :, j: j + mini_batch_encoder, :, :]
         
     | 
| 695 | 
         
            +
                                    mask_bs = self.vae.encode(mask_bs)[0]
         
     | 
| 696 | 
         
            +
                                    mask_bs = mask_bs.sample()
         
     | 
| 697 | 
         
            +
                                    new_mask_mini_batch.append(mask_bs)
         
     | 
| 698 | 
         
            +
                                new_mask_mini_batch = torch.cat(new_mask_mini_batch, dim = 2)
         
     | 
| 699 | 
         
            +
                                new_mask.append(new_mask_mini_batch)
         
     | 
| 700 | 
         
             
                        mask = torch.cat(new_mask, dim = 0)
         
     | 
| 701 | 
         
            +
                        mask = mask * self.vae.config.scaling_factor
         
     | 
| 702 | 
         | 
| 703 | 
         
             
                    else:
         
     | 
| 704 | 
         
             
                        if mask.shape[1] == 4:
         
     | 
| 
         | 
|
| 712 | 
         
             
                    masked_image = masked_image.to(device=device, dtype=self.vae.dtype)
         
     | 
| 713 | 
         
             
                    if self.vae.quant_conv.weight.ndim==5:
         
     | 
| 714 | 
         
             
                        bs = 1
         
     | 
| 715 | 
         
            +
                        mini_batch_encoder = self.vae.mini_batch_encoder
         
     | 
| 716 | 
         
             
                        new_mask_pixel_values = []
         
     | 
| 717 | 
         
            +
                        if self.vae.slice_compression_vae:
         
     | 
| 718 | 
         
            +
                            for i in range(0, masked_image.shape[0], bs):
         
     | 
| 719 | 
         
            +
                                mask_pixel_values_bs = masked_image[i : i + bs]
         
     | 
| 
         | 
|
| 
         | 
|
| 720 | 
         
             
                                mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
         
     | 
| 721 | 
         
             
                                mask_pixel_values_bs = mask_pixel_values_bs.sample()
         
     | 
| 722 | 
         
            +
                                new_mask_pixel_values.append(mask_pixel_values_bs)
         
     | 
| 723 | 
         
            +
                        else:
         
     | 
| 724 | 
         
            +
                            for i in range(0, masked_image.shape[0], bs):
         
     | 
| 725 | 
         
            +
                                new_mask_pixel_values_mini_batch = []
         
     | 
| 726 | 
         
            +
                                for j in range(0, masked_image.shape[2], mini_batch_encoder):
         
     | 
| 727 | 
         
            +
                                    mask_pixel_values_bs = masked_image[i : i + bs, :, j: j + mini_batch_encoder, :, :]
         
     | 
| 728 | 
         
            +
                                    mask_pixel_values_bs = self.vae.encode(mask_pixel_values_bs)[0]
         
     | 
| 729 | 
         
            +
                                    mask_pixel_values_bs = mask_pixel_values_bs.sample()
         
     | 
| 730 | 
         
            +
                                    new_mask_pixel_values_mini_batch.append(mask_pixel_values_bs)
         
     | 
| 731 | 
         
            +
                                new_mask_pixel_values_mini_batch = torch.cat(new_mask_pixel_values_mini_batch, dim = 2)
         
     | 
| 732 | 
         
            +
                                new_mask_pixel_values.append(new_mask_pixel_values_mini_batch)
         
     | 
| 733 | 
         
             
                        masked_image_latents = torch.cat(new_mask_pixel_values, dim = 0)
         
     | 
| 734 | 
         
            +
                        masked_image_latents = masked_image_latents * self.vae.config.scaling_factor
         
     | 
| 735 | 
         | 
| 736 | 
         
             
                    else:
         
     | 
| 737 | 
         
             
                        if masked_image.shape[1] == 4:
         
     | 
| 
         | 
|
| 776 | 
         
             
                    callback_steps: int = 1,
         
     | 
| 777 | 
         
             
                    clean_caption: bool = True,
         
     | 
| 778 | 
         
             
                    mask_feature: bool = True,
         
     | 
| 779 | 
         
            +
                    max_sequence_length: int = 120,
         
     | 
| 780 | 
         
            +
                    clip_image: Image = None,
         
     | 
| 781 | 
         
            +
                    clip_apply_ratio: float = 0.50,
         
     | 
| 782 | 
         
             
                ) -> Union[EasyAnimatePipelineOutput, Tuple]:
         
     | 
| 783 | 
         
             
                    """
         
     | 
| 784 | 
         
             
                    Function invoked when calling the pipeline for generation.
         
     | 
| 
         | 
|
| 852 | 
         
             
                    # 1. Check inputs. Raise error if not correct
         
     | 
| 853 | 
         
             
                    height = height or self.transformer.config.sample_size * self.vae_scale_factor
         
     | 
| 854 | 
         
             
                    width = width or self.transformer.config.sample_size * self.vae_scale_factor
         
     | 
| 855 | 
         
            +
                    height = int(height // 16 * 16)
         
     | 
| 856 | 
         
            +
                    width = int(width // 16 * 16)
         
     | 
| 857 | 
         | 
| 858 | 
         
             
                    # 2. Default height and width to transformer
         
     | 
| 859 | 
         
             
                    if prompt is not None and isinstance(prompt, str):
         
     | 
| 
         | 
|
| 893 | 
         
             
                        prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds], dim=0)
         
     | 
| 894 | 
         
             
                        prompt_attention_mask = torch.cat([negative_prompt_attention_mask, prompt_attention_mask], dim=0)
         
     | 
| 895 | 
         | 
| 896 | 
         
            +
                    # 4. set timesteps
         
     | 
| 897 | 
         
             
                    self.scheduler.set_timesteps(num_inference_steps, device=device)
         
     | 
| 898 | 
         
            +
                    timesteps, num_inference_steps = self.get_timesteps(
         
     | 
| 899 | 
         
            +
                        num_inference_steps=num_inference_steps, strength=strength, device=device
         
     | 
| 900 | 
         
            +
                    )
         
     | 
| 901 | 
         
             
                    # at which timestep to set the initial noise (n.b. 50% if strength is 0.5)
         
     | 
| 902 | 
         
            +
                    latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
         
     | 
| 903 | 
         
             
                    # create a boolean to check if the strength is set to 1. if so then initialise the latents with pure noise
         
     | 
| 904 | 
         
             
                    is_strength_max = strength == 1.0
         
     | 
| 905 | 
         | 
| 
         | 
|
| 914 | 
         
             
                    # Prepare latent variables
         
     | 
| 915 | 
         
             
                    num_channels_latents = self.vae.config.latent_channels
         
     | 
| 916 | 
         
             
                    num_channels_transformer = self.transformer.config.in_channels
         
     | 
| 917 | 
         
            +
                    return_image_latents = True # num_channels_transformer == 4
         
     | 
| 918 | 
         | 
| 919 | 
         
             
                    # 5. Prepare latents.
         
     | 
| 920 | 
         
             
                    latents_outputs = self.prepare_latents(
         
     | 
| 
         | 
|
| 946 | 
         
             
                        mask_condition = mask_condition.to(dtype=torch.float32)
         
     | 
| 947 | 
         
             
                        mask_condition = rearrange(mask_condition, "(b f) c h w -> b c f h w", f=video_length)
         
     | 
| 948 | 
         | 
| 949 | 
         
            +
                        if num_channels_transformer == 12:
         
     | 
| 950 | 
         
            +
                            mask_condition_tile = torch.tile(mask_condition, [1, 3, 1, 1, 1])
         
     | 
| 951 | 
         
            +
                            if masked_video_latents is None:
         
     | 
| 952 | 
         
            +
                                masked_video = init_video * (mask_condition_tile < 0.5) + torch.ones_like(init_video) * (mask_condition_tile > 0.5) * -1
         
     | 
| 953 | 
         
            +
                            else:
         
     | 
| 954 | 
         
            +
                                masked_video = masked_video_latents
         
     | 
| 955 | 
         
            +
             
     | 
| 956 | 
         
            +
                            mask_latents, masked_video_latents = self.prepare_mask_latents(
         
     | 
| 957 | 
         
            +
                                mask_condition_tile,
         
     | 
| 958 | 
         
            +
                                masked_video,
         
     | 
| 959 | 
         
            +
                                batch_size,
         
     | 
| 960 | 
         
            +
                                height,
         
     | 
| 961 | 
         
            +
                                width,
         
     | 
| 962 | 
         
            +
                                prompt_embeds.dtype,
         
     | 
| 963 | 
         
            +
                                device,
         
     | 
| 964 | 
         
            +
                                generator,
         
     | 
| 965 | 
         
            +
                                do_classifier_free_guidance,
         
     | 
| 966 | 
         
            +
                            )
         
     | 
| 967 | 
         
            +
                            mask = torch.tile(mask_condition, [1, num_channels_transformer // 3, 1, 1, 1])
         
     | 
| 968 | 
         
            +
                            mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
         
     | 
| 969 | 
         
            +
                            
         
     | 
| 970 | 
         
            +
                            mask_input = torch.cat([mask_latents] * 2) if do_classifier_free_guidance else mask_latents
         
     | 
| 971 | 
         
            +
                            masked_video_latents_input = (
         
     | 
| 972 | 
         
            +
                                torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
         
     | 
| 973 | 
         
            +
                            )
         
     | 
| 974 | 
         
            +
                            inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
         
     | 
| 975 | 
         
             
                        else:
         
     | 
| 976 | 
         
            +
                            mask = torch.tile(mask_condition, [1, num_channels_transformer, 1, 1, 1])
         
     | 
| 977 | 
         
            +
                            mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
         
     | 
| 978 | 
         
            +
                            
         
     | 
| 979 | 
         
            +
                            inpaint_latents = None
         
     | 
| 980 | 
         
            +
                    else:
         
     | 
| 981 | 
         
            +
                        if num_channels_transformer == 12:
         
     | 
| 982 | 
         
            +
                            mask = torch.zeros_like(latents).to(latents.device, latents.dtype)
         
     | 
| 983 | 
         
            +
                            masked_video_latents = torch.zeros_like(latents).to(latents.device, latents.dtype)
         
     | 
| 984 | 
         
            +
             
     | 
| 985 | 
         
            +
                            mask_input = torch.cat([mask] * 2) if do_classifier_free_guidance else mask
         
     | 
| 986 | 
         
            +
                            masked_video_latents_input = (
         
     | 
| 987 | 
         
            +
                                torch.cat([masked_video_latents] * 2) if do_classifier_free_guidance else masked_video_latents
         
     | 
| 988 | 
         
            +
                            )
         
     | 
| 989 | 
         
            +
                            inpaint_latents = torch.cat([mask_input, masked_video_latents_input], dim=1).to(latents.dtype)
         
     | 
| 990 | 
         
            +
                        else:
         
     | 
| 991 | 
         
            +
                            mask = torch.zeros_like(init_video[:, :1])
         
     | 
| 992 | 
         
            +
                            mask = torch.tile(mask, [1, num_channels_transformer, 1, 1, 1])
         
     | 
| 993 | 
         
            +
                            mask = F.interpolate(mask, size=latents.size()[-3:], mode='trilinear', align_corners=True).to(latents.device, latents.dtype)
         
     | 
| 994 | 
         
            +
             
     | 
| 995 | 
         
            +
                            inpaint_latents = None
         
     | 
| 996 | 
         
            +
                
         
     | 
| 997 | 
         
            +
                    if clip_image is not None:
         
     | 
| 998 | 
         
            +
                        inputs = self.clip_image_processor(images=clip_image, return_tensors="pt")
         
     | 
| 999 | 
         
            +
                        inputs["pixel_values"] = inputs["pixel_values"].to(latents.device, dtype=latents.dtype)
         
     | 
| 1000 | 
         
            +
                        clip_encoder_hidden_states = self.clip_image_encoder(**inputs).image_embeds
         
     | 
| 1001 | 
         
            +
                        clip_encoder_hidden_states_neg = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
         
     | 
| 1002 | 
         
            +
             
     | 
| 1003 | 
         
            +
                        clip_attention_mask = torch.ones([batch_size, 8]).to(latents.device, dtype=latents.dtype)
         
     | 
| 1004 | 
         
            +
                        clip_attention_mask_neg = torch.zeros([batch_size, 8]).to(latents.device, dtype=latents.dtype)
         
     | 
| 1005 | 
         
            +
             
     | 
| 1006 | 
         
            +
                        clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states_neg, clip_encoder_hidden_states]) if do_classifier_free_guidance else clip_encoder_hidden_states
         
     | 
| 1007 | 
         
            +
                        clip_attention_mask_input = torch.cat([clip_attention_mask_neg, clip_attention_mask]) if do_classifier_free_guidance else clip_attention_mask
         
     | 
| 1008 | 
         
            +
             
     | 
| 1009 | 
         
            +
                    elif clip_image is None and num_channels_transformer == 12:
         
     | 
| 1010 | 
         
            +
                        clip_encoder_hidden_states = torch.zeros([batch_size, 768]).to(latents.device, dtype=latents.dtype)
         
     | 
| 1011 | 
         
            +
             
     | 
| 1012 | 
         
            +
                        clip_attention_mask = torch.zeros([batch_size, 8])
         
     | 
| 1013 | 
         
            +
                        clip_attention_mask = clip_attention_mask.to(latents.device, dtype=latents.dtype)
         
     | 
| 1014 | 
         
            +
             
     | 
| 1015 | 
         
            +
                        clip_encoder_hidden_states_input = torch.cat([clip_encoder_hidden_states] * 2) if do_classifier_free_guidance else clip_encoder_hidden_states
         
     | 
| 1016 | 
         
            +
                        clip_attention_mask_input = torch.cat([clip_attention_mask] * 2) if do_classifier_free_guidance else clip_attention_mask
         
     | 
| 1017 | 
         
            +
             
     | 
| 1018 | 
         
             
                    else:
         
     | 
| 1019 | 
         
            +
                        clip_encoder_hidden_states_input = None
         
     | 
| 1020 | 
         
            +
                        clip_attention_mask_input = None
         
     | 
| 1021 | 
         | 
| 1022 | 
         
             
                    # Check that sizes of mask, masked image and latents match
         
     | 
| 1023 | 
         
             
                    if num_channels_transformer == 12:
         
     | 
| 1024 | 
         
             
                        # default case for runwayml/stable-diffusion-inpainting
         
     | 
| 1025 | 
         
            +
                        num_channels_mask = mask_latents.shape[1]
         
     | 
| 1026 | 
         
             
                        num_channels_masked_image = masked_video_latents.shape[1]
         
     | 
| 1027 | 
         
             
                        if num_channels_latents + num_channels_mask + num_channels_masked_image != self.transformer.config.in_channels:
         
     | 
| 1028 | 
         
             
                            raise ValueError(
         
     | 
| 
         | 
|
| 1032 | 
         
             
                                f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
         
     | 
| 1033 | 
         
             
                                " `pipeline.transformer` or your `mask_image` or `image` input."
         
     | 
| 1034 | 
         
             
                            )
         
     | 
| 1035 | 
         
            +
                    elif num_channels_transformer != 4:
         
     | 
| 1036 | 
         
             
                        raise ValueError(
         
     | 
| 1037 | 
         
             
                            f"The transformer {self.transformer.__class__} should have 9 input channels, not {self.transformer.config.in_channels}."
         
     | 
| 1038 | 
         
             
                        )
         
     | 
| 1039 | 
         | 
| 1040 | 
         
            +
                    # 9. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
         
     | 
| 1041 | 
         
             
                    extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
         
     | 
| 1042 | 
         | 
| 1043 | 
         
             
                    # 6.1 Prepare micro-conditions.
         
     | 
| 
         | 
|
| 1054 | 
         | 
| 1055 | 
         
             
                        added_cond_kwargs = {"resolution": resolution, "aspect_ratio": aspect_ratio}
         
     | 
| 1056 | 
         | 
| 1057 | 
         
            +
                    gc.collect()
         
     | 
| 1058 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 1059 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 1060 | 
         | 
| 1061 | 
         
            +
                    # 10. Denoising loop
         
     | 
| 1062 | 
         
            +
                    num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
         
     | 
| 1063 | 
         
            +
                    self._num_timesteps = len(timesteps)
         
     | 
| 1064 | 
         
             
                    with self.progress_bar(total=num_inference_steps) as progress_bar:
         
     | 
| 1065 | 
         
             
                        for i, t in enumerate(timesteps):
         
     | 
| 1066 | 
         
             
                            latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
         
     | 
| 1067 | 
         
             
                            latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
         
     | 
| 1068 | 
         | 
| 1069 | 
         
            +
                            if i < len(timesteps) * (1 - clip_apply_ratio) and clip_encoder_hidden_states_input is not None:
         
     | 
| 1070 | 
         
            +
                                clip_encoder_hidden_states_actual_input = torch.zeros_like(clip_encoder_hidden_states_input)
         
     | 
| 1071 | 
         
            +
                                clip_attention_mask_actual_input = torch.zeros_like(clip_attention_mask_input)
         
     | 
| 1072 | 
         
            +
                            else:
         
     | 
| 1073 | 
         
            +
                                clip_encoder_hidden_states_actual_input = clip_encoder_hidden_states_input
         
     | 
| 1074 | 
         
            +
                                clip_attention_mask_actual_input = clip_attention_mask_input
         
     | 
| 1075 | 
         
            +
             
     | 
| 1076 | 
         
             
                            current_timestep = t
         
     | 
| 1077 | 
         
             
                            if not torch.is_tensor(current_timestep):
         
     | 
| 1078 | 
         
             
                                # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
         
     | 
| 
         | 
|
| 1095 | 
         
             
                                encoder_attention_mask=prompt_attention_mask,
         
     | 
| 1096 | 
         
             
                                timestep=current_timestep,
         
     | 
| 1097 | 
         
             
                                added_cond_kwargs=added_cond_kwargs,
         
     | 
| 1098 | 
         
            +
                                inpaint_latents=inpaint_latents,
         
     | 
| 1099 | 
         
            +
                                clip_encoder_hidden_states=clip_encoder_hidden_states_actual_input,
         
     | 
| 1100 | 
         
            +
                                clip_attention_mask=clip_attention_mask_actual_input,
         
     | 
| 1101 | 
         
             
                                return_dict=False,
         
     | 
| 1102 | 
         
             
                            )[0]
         
     | 
| 1103 | 
         | 
| 
         | 
|
| 1112 | 
         
             
                            # compute previous image: x_t -> x_t-1
         
     | 
| 1113 | 
         
             
                            latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
         
     | 
| 1114 | 
         | 
| 1115 | 
         
            +
                            if num_channels_transformer == 4:
         
     | 
| 1116 | 
         
            +
                                init_latents_proper = image_latents
         
     | 
| 1117 | 
         
            +
                                init_mask = mask
         
     | 
| 1118 | 
         
            +
                                if i < len(timesteps) - 1:
         
     | 
| 1119 | 
         
            +
                                    noise_timestep = timesteps[i + 1]
         
     | 
| 1120 | 
         
            +
                                    init_latents_proper = self.scheduler.add_noise(
         
     | 
| 1121 | 
         
            +
                                        init_latents_proper, noise, torch.tensor([noise_timestep])
         
     | 
| 1122 | 
         
            +
                                    )
         
     | 
| 1123 | 
         
            +
                                
         
     | 
| 1124 | 
         
            +
                                latents = (1 - init_mask) * init_latents_proper + init_mask * latents
         
     | 
| 1125 | 
         
            +
             
     | 
| 1126 | 
         
             
                            # call the callback, if provided
         
     | 
| 1127 | 
         
             
                            if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
         
     | 
| 1128 | 
         
             
                                progress_bar.update()
         
     | 
| 
         | 
|
| 1130 | 
         
             
                                    step_idx = i // getattr(self.scheduler, "order", 1)
         
     | 
| 1131 | 
         
             
                                    callback(step_idx, t, latents)
         
     | 
| 1132 | 
         | 
| 1133 | 
         
            +
                    gc.collect()
         
     | 
| 1134 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 1135 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 1136 | 
         
            +
             
     | 
| 1137 | 
         
             
                    # Post-processing
         
     | 
| 1138 | 
         
             
                    video = self.decode_latents(latents)
         
     | 
| 1139 | 
         
            +
                    
         
     | 
| 1140 | 
         
            +
                    gc.collect()
         
     | 
| 1141 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 1142 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 1143 | 
         
             
                    # Convert to tensor
         
     | 
| 1144 | 
         
             
                    if output_type == "latent":
         
     | 
| 1145 | 
         
             
                        video = torch.from_numpy(video)
         
     | 
    	
        easyanimate/ui/ui.py
    CHANGED
    
    | 
         @@ -1,35 +1,40 @@ 
     | 
|
| 1 | 
         
             
            """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
         
     | 
| 2 | 
         
             
            """
         
     | 
| 
         | 
|
| 3 | 
         
             
            import gc
         
     | 
| 4 | 
         
             
            import json
         
     | 
| 5 | 
         
             
            import os
         
     | 
| 6 | 
         
             
            import random
         
     | 
| 7 | 
         
            -
            import base64
         
     | 
| 8 | 
         
            -
            import requests
         
     | 
| 9 | 
         
            -
            import pkg_resources
         
     | 
| 10 | 
         
             
            from datetime import datetime
         
     | 
| 11 | 
         
             
            from glob import glob
         
     | 
| 12 | 
         | 
| 13 | 
         
             
            import gradio as gr
         
     | 
| 14 | 
         
            -
            import torch
         
     | 
| 15 | 
         
             
            import numpy as np
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 16 | 
         
             
            from diffusers import (AutoencoderKL, DDIMScheduler,
         
     | 
| 17 | 
         
             
                                   DPMSolverMultistepScheduler,
         
     | 
| 18 | 
         
             
                                   EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
         
     | 
| 19 | 
         
             
                                   PNDMScheduler)
         
     | 
| 20 | 
         
            -
            from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
         
     | 
| 21 | 
         
             
            from diffusers.utils.import_utils import is_xformers_available
         
     | 
| 22 | 
         
             
            from omegaconf import OmegaConf
         
     | 
| 
         | 
|
| 23 | 
         
             
            from safetensors import safe_open
         
     | 
| 24 | 
         
            -
            from transformers import  
     | 
| 
         | 
|
| 25 | 
         | 
| 
         | 
|
| 
         | 
|
| 26 | 
         
             
            from easyanimate.models.transformer3d import Transformer3DModel
         
     | 
| 27 | 
         
             
            from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
         
     | 
| 
         | 
|
| 
         | 
|
| 28 | 
         
             
            from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
         
     | 
| 29 | 
         
            -
            from easyanimate.utils.utils import  
     | 
| 30 | 
         
            -
             
     | 
| 
         | 
|
| 31 | 
         | 
| 32 | 
         
            -
            sample_idx = 0
         
     | 
| 33 | 
         
             
            scheduler_dict = {
         
     | 
| 34 | 
         
             
                "Euler": EulerDiscreteScheduler,
         
     | 
| 35 | 
         
             
                "Euler A": EulerAncestralDiscreteScheduler,
         
     | 
| 
         @@ -60,8 +65,8 @@ class EasyAnimateController: 
     | 
|
| 60 | 
         
             
                    self.personalized_model_dir     = os.path.join(self.basedir, "models", "Personalized_Model")
         
     | 
| 61 | 
         
             
                    self.savedir                    = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
         
     | 
| 62 | 
         
             
                    self.savedir_sample             = os.path.join(self.savedir, "sample")
         
     | 
| 63 | 
         
            -
                    self.edition                    = " 
     | 
| 64 | 
         
            -
                    self.inference_config           = OmegaConf.load(os.path.join(self.config_dir, " 
     | 
| 65 | 
         
             
                    os.makedirs(self.savedir, exist_ok=True)
         
     | 
| 66 | 
         | 
| 67 | 
         
             
                    self.diffusion_transformer_list = []
         
     | 
| 
         @@ -85,14 +90,14 @@ class EasyAnimateController: 
     | 
|
| 85 | 
         
             
                    self.weight_dtype = torch.bfloat16
         
     | 
| 86 | 
         | 
| 87 | 
         
             
                def refresh_diffusion_transformer(self):
         
     | 
| 88 | 
         
            -
                    self.diffusion_transformer_list = glob(os.path.join(self.diffusion_transformer_dir, "*/"))
         
     | 
| 89 | 
         | 
| 90 | 
         
             
                def refresh_motion_module(self):
         
     | 
| 91 | 
         
            -
                    motion_module_list = glob(os.path.join(self.motion_module_dir, "*.safetensors"))
         
     | 
| 92 | 
         
             
                    self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
         
     | 
| 93 | 
         | 
| 94 | 
         
             
                def refresh_personalized_model(self):
         
     | 
| 95 | 
         
            -
                    personalized_model_list = glob(os.path.join(self.personalized_model_dir, "*.safetensors"))
         
     | 
| 96 | 
         
             
                    self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
         
     | 
| 97 | 
         | 
| 98 | 
         
             
                def update_edition(self, edition):
         
     | 
| 
         @@ -100,19 +105,24 @@ class EasyAnimateController: 
     | 
|
| 100 | 
         
             
                    self.edition = edition
         
     | 
| 101 | 
         
             
                    if edition == "v1":
         
     | 
| 102 | 
         
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
         
     | 
| 103 | 
         
            -
                        return gr. 
     | 
| 104 | 
         
            -
                            gr.update( 
     | 
| 105 | 
         
             
                            gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
         
     | 
| 106 | 
         
            -
                     
     | 
| 107 | 
         
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
         
     | 
| 108 | 
         
            -
                        return gr. 
     | 
| 109 | 
         
            -
                            gr.update( 
     | 
| 110 | 
         
             
                            gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 111 | 
         | 
| 112 | 
         
             
                def update_diffusion_transformer(self, diffusion_transformer_dropdown):
         
     | 
| 113 | 
         
             
                    print("Update diffusion transformer")
         
     | 
| 114 | 
         
             
                    if diffusion_transformer_dropdown == "none":
         
     | 
| 115 | 
         
            -
                        return gr. 
     | 
| 116 | 
         
             
                    if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
         
     | 
| 117 | 
         
             
                        Choosen_AutoencoderKL = AutoencoderKLMagvit
         
     | 
| 118 | 
         
             
                    else:
         
     | 
| 
         @@ -130,25 +140,42 @@ class EasyAnimateController: 
     | 
|
| 130 | 
         
             
                    self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
         
     | 
| 131 | 
         | 
| 132 | 
         
             
                    # Get pipeline
         
     | 
| 133 | 
         
            -
                    self. 
     | 
| 134 | 
         
            -
                         
     | 
| 135 | 
         
            -
             
     | 
| 136 | 
         
            -
             
     | 
| 137 | 
         
            -
             
     | 
| 138 | 
         
            -
             
     | 
| 139 | 
         
            -
             
     | 
| 140 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 141 | 
         
             
                    print("Update diffusion transformer done")
         
     | 
| 142 | 
         
            -
                    return gr. 
     | 
| 143 | 
         | 
| 144 | 
         
             
                def update_motion_module(self, motion_module_dropdown):
         
     | 
| 145 | 
         
             
                    self.motion_module_path = motion_module_dropdown
         
     | 
| 146 | 
         
             
                    print("Update motion module")
         
     | 
| 147 | 
         
             
                    if motion_module_dropdown == "none":
         
     | 
| 148 | 
         
            -
                        return gr. 
     | 
| 149 | 
         
             
                    if self.transformer is None:
         
     | 
| 150 | 
         
             
                        gr.Info(f"Please select a pretrained model path.")
         
     | 
| 151 | 
         
            -
                        return gr. 
     | 
| 152 | 
         
             
                    else:
         
     | 
| 153 | 
         
             
                        motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
         
     | 
| 154 | 
         
             
                        if motion_module_dropdown.endswith(".safetensors"):
         
     | 
| 
         @@ -160,16 +187,16 @@ class EasyAnimateController: 
     | 
|
| 160 | 
         
             
                            motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
         
     | 
| 161 | 
         
             
                        missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
         
     | 
| 162 | 
         
             
                        print("Update motion module done.")
         
     | 
| 163 | 
         
            -
                        return gr. 
     | 
| 164 | 
         | 
| 165 | 
         
             
                def update_base_model(self, base_model_dropdown):
         
     | 
| 166 | 
         
             
                    self.base_model_path = base_model_dropdown
         
     | 
| 167 | 
         
             
                    print("Update base model")
         
     | 
| 168 | 
         
             
                    if base_model_dropdown == "none":
         
     | 
| 169 | 
         
            -
                        return gr. 
     | 
| 170 | 
         
             
                    if self.transformer is None:
         
     | 
| 171 | 
         
             
                        gr.Info(f"Please select a pretrained model path.")
         
     | 
| 172 | 
         
            -
                        return gr. 
     | 
| 173 | 
         
             
                    else:
         
     | 
| 174 | 
         
             
                        base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
         
     | 
| 175 | 
         
             
                        base_model_state_dict = {}
         
     | 
| 
         @@ -178,16 +205,16 @@ class EasyAnimateController: 
     | 
|
| 178 | 
         
             
                                base_model_state_dict[key] = f.get_tensor(key)
         
     | 
| 179 | 
         
             
                        self.transformer.load_state_dict(base_model_state_dict, strict=False)
         
     | 
| 180 | 
         
             
                        print("Update base done")
         
     | 
| 181 | 
         
            -
                        return gr. 
     | 
| 182 | 
         | 
| 183 | 
         
             
                def update_lora_model(self, lora_model_dropdown):
         
     | 
| 184 | 
         
             
                    print("Update lora model")
         
     | 
| 185 | 
         
             
                    if lora_model_dropdown == "none":
         
     | 
| 186 | 
         
             
                        self.lora_model_path = "none"
         
     | 
| 187 | 
         
            -
                        return gr. 
     | 
| 188 | 
         
             
                    lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
         
     | 
| 189 | 
         
             
                    self.lora_model_path = lora_model_dropdown
         
     | 
| 190 | 
         
            -
                    return gr. 
     | 
| 191 | 
         | 
| 192 | 
         
             
                def generate(
         
     | 
| 193 | 
         
             
                    self,
         
     | 
| 
         @@ -200,15 +227,24 @@ class EasyAnimateController: 
     | 
|
| 200 | 
         
             
                    negative_prompt_textbox, 
         
     | 
| 201 | 
         
             
                    sampler_dropdown, 
         
     | 
| 202 | 
         
             
                    sample_step_slider, 
         
     | 
| 
         | 
|
| 203 | 
         
             
                    width_slider, 
         
     | 
| 204 | 
         
             
                    height_slider, 
         
     | 
| 205 | 
         
            -
                     
     | 
| 
         | 
|
| 206 | 
         
             
                    length_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 207 | 
         
             
                    cfg_scale_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 208 | 
         
             
                    seed_textbox,
         
     | 
| 209 | 
         
             
                    is_api = False,
         
     | 
| 210 | 
         
             
                ):
         
     | 
| 211 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 212 | 
         
             
                    if self.transformer is None:
         
     | 
| 213 | 
         
             
                        raise gr.Error(f"Please select a pretrained model path.")
         
     | 
| 214 | 
         | 
| 
         @@ -221,6 +257,39 @@ class EasyAnimateController: 
     | 
|
| 221 | 
         
             
                    if self.lora_model_path != lora_model_dropdown:
         
     | 
| 222 | 
         
             
                        print("Update lora model")
         
     | 
| 223 | 
         
             
                        self.update_lora_model(lora_model_dropdown)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 224 | 
         | 
| 225 | 
         
             
                    if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
         
     | 
| 226 | 
         | 
| 
         @@ -235,16 +304,98 @@ class EasyAnimateController: 
     | 
|
| 235 | 
         
             
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         
     | 
| 236 | 
         | 
| 237 | 
         
             
                    try:
         
     | 
| 238 | 
         
            -
                         
     | 
| 239 | 
         
            -
                             
     | 
| 240 | 
         
            -
             
     | 
| 241 | 
         
            -
             
     | 
| 242 | 
         
            -
             
     | 
| 243 | 
         
            -
             
     | 
| 244 | 
         
            -
             
     | 
| 245 | 
         
            -
             
     | 
| 246 | 
         
            -
             
     | 
| 247 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 248 | 
         
             
                    except Exception as e:
         
     | 
| 249 | 
         
             
                        gc.collect()
         
     | 
| 250 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 
         @@ -254,7 +405,11 @@ class EasyAnimateController: 
     | 
|
| 254 | 
         
             
                        if is_api:
         
     | 
| 255 | 
         
             
                            return "", f"Error. error information is {str(e)}"
         
     | 
| 256 | 
         
             
                        else:
         
     | 
| 257 | 
         
            -
                            return gr. 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 258 | 
         | 
| 259 | 
         
             
                    # lora part
         
     | 
| 260 | 
         
             
                    if self.lora_model_path != "none":
         
     | 
| 
         @@ -296,7 +451,10 @@ class EasyAnimateController: 
     | 
|
| 296 | 
         
             
                        if is_api:
         
     | 
| 297 | 
         
             
                            return save_sample_path, "Success"
         
     | 
| 298 | 
         
             
                        else:
         
     | 
| 299 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 300 | 
         
             
                    else:
         
     | 
| 301 | 
         
             
                        save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
         
     | 
| 302 | 
         
             
                        save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
         
     | 
| 
         @@ -304,7 +462,10 @@ class EasyAnimateController: 
     | 
|
| 304 | 
         
             
                        if is_api:
         
     | 
| 305 | 
         
             
                            return save_sample_path, "Success"
         
     | 
| 306 | 
         
             
                        else:
         
     | 
| 307 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 308 | 
         | 
| 309 | 
         | 
| 310 | 
         
             
            def ui():
         
     | 
| 
         @@ -325,24 +486,24 @@ def ui(): 
     | 
|
| 325 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 326 | 
         
             
                        gr.Markdown(
         
     | 
| 327 | 
         
             
                            """
         
     | 
| 328 | 
         
            -
                            ### 1. EasyAnimate Edition ( 
     | 
| 329 | 
         
             
                            """
         
     | 
| 330 | 
         
             
                        )
         
     | 
| 331 | 
         
             
                        with gr.Row():
         
     | 
| 332 | 
         
             
                            easyanimate_edition_dropdown = gr.Dropdown(
         
     | 
| 333 | 
         
            -
                                label="The config of EasyAnimate Edition",
         
     | 
| 334 | 
         
            -
                                choices=["v1", "v2"],
         
     | 
| 335 | 
         
            -
                                value=" 
     | 
| 336 | 
         
             
                                interactive=True,
         
     | 
| 337 | 
         
             
                            )
         
     | 
| 338 | 
         
             
                        gr.Markdown(
         
     | 
| 339 | 
         
             
                            """
         
     | 
| 340 | 
         
            -
                            ### 2. Model checkpoints ( 
     | 
| 341 | 
         
             
                            """
         
     | 
| 342 | 
         
             
                        )
         
     | 
| 343 | 
         
             
                        with gr.Row():
         
     | 
| 344 | 
         
             
                            diffusion_transformer_dropdown = gr.Dropdown(
         
     | 
| 345 | 
         
            -
                                label="Pretrained Model Path",
         
     | 
| 346 | 
         
             
                                choices=controller.diffusion_transformer_list,
         
     | 
| 347 | 
         
             
                                value="none",
         
     | 
| 348 | 
         
             
                                interactive=True,
         
     | 
| 
         @@ -356,12 +517,12 @@ def ui(): 
     | 
|
| 356 | 
         
             
                            diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
         
     | 
| 357 | 
         
             
                            def refresh_diffusion_transformer():
         
     | 
| 358 | 
         
             
                                controller.refresh_diffusion_transformer()
         
     | 
| 359 | 
         
            -
                                return gr. 
     | 
| 360 | 
         
             
                            diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
         
     | 
| 361 | 
         | 
| 362 | 
         
             
                        with gr.Row():
         
     | 
| 363 | 
         
             
                            motion_module_dropdown = gr.Dropdown(
         
     | 
| 364 | 
         
            -
                                label="Select motion module",
         
     | 
| 365 | 
         
             
                                choices=controller.motion_module_list,
         
     | 
| 366 | 
         
             
                                value="none",
         
     | 
| 367 | 
         
             
                                interactive=True,
         
     | 
| 
         @@ -371,78 +532,139 @@ def ui(): 
     | 
|
| 371 | 
         
             
                            motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
         
     | 
| 372 | 
         
             
                            def update_motion_module():
         
     | 
| 373 | 
         
             
                                controller.refresh_motion_module()
         
     | 
| 374 | 
         
            -
                                return gr. 
     | 
| 375 | 
         
             
                            motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
         
     | 
| 376 | 
         | 
| 377 | 
         
             
                            base_model_dropdown = gr.Dropdown(
         
     | 
| 378 | 
         
            -
                                label="Select base Dreambooth model ( 
     | 
| 379 | 
         
             
                                choices=controller.personalized_model_list,
         
     | 
| 380 | 
         
             
                                value="none",
         
     | 
| 381 | 
         
             
                                interactive=True,
         
     | 
| 382 | 
         
             
                            )
         
     | 
| 383 | 
         | 
| 384 | 
         
             
                            lora_model_dropdown = gr.Dropdown(
         
     | 
| 385 | 
         
            -
                                label="Select LoRA model ( 
     | 
| 386 | 
         
             
                                choices=["none"] + controller.personalized_model_list,
         
     | 
| 387 | 
         
             
                                value="none",
         
     | 
| 388 | 
         
             
                                interactive=True,
         
     | 
| 389 | 
         
             
                            )
         
     | 
| 390 | 
         | 
| 391 | 
         
            -
                            lora_alpha_slider = gr.Slider(label="LoRA alpha", value=0.55, minimum=0, maximum=2, interactive=True)
         
     | 
| 392 | 
         | 
| 393 | 
         
             
                            personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
         
     | 
| 394 | 
         
             
                            def update_personalized_model():
         
     | 
| 395 | 
         
             
                                controller.refresh_personalized_model()
         
     | 
| 396 | 
         
             
                                return [
         
     | 
| 397 | 
         
            -
                                    gr. 
     | 
| 398 | 
         
            -
                                    gr. 
     | 
| 399 | 
         
             
                                ]
         
     | 
| 400 | 
         
             
                            personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
         
     | 
| 401 | 
         | 
| 402 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 403 | 
         
             
                        gr.Markdown(
         
     | 
| 404 | 
         
             
                            """
         
     | 
| 405 | 
         
            -
                            ### 3. Configs for Generation.
         
     | 
| 406 | 
         
             
                            """
         
     | 
| 407 | 
         
             
                        )
         
     | 
| 408 | 
         | 
| 409 | 
         
            -
                        prompt_textbox = gr.Textbox(label="Prompt", lines=2, value=" 
     | 
| 410 | 
         
            -
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. 
     | 
| 411 | 
         | 
| 412 | 
         
             
                        with gr.Row():
         
     | 
| 413 | 
         
             
                            with gr.Column():
         
     | 
| 414 | 
         
             
                                with gr.Row():
         
     | 
| 415 | 
         
            -
                                    sampler_dropdown   = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         
     | 
| 416 | 
         
            -
                                    sample_step_slider = gr.Slider(label="Sampling steps", value= 
     | 
| 417 | 
         | 
| 418 | 
         
            -
                                 
     | 
| 419 | 
         
            -
             
     | 
| 420 | 
         
            -
             
     | 
| 421 | 
         
            -
                                     
     | 
| 422 | 
         
            -
             
     | 
| 423 | 
         
            -
                                 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 424 | 
         | 
| 425 | 
         
             
                                with gr.Row():
         
     | 
| 426 | 
         
            -
                                    seed_textbox = gr.Textbox(label="Seed", value=43)
         
     | 
| 427 | 
         
             
                                    seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
         
     | 
| 428 | 
         
            -
                                    seed_button.click( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 429 | 
         | 
| 430 | 
         
            -
                                generate_button = gr.Button(value="Generate", variant='primary')
         
     | 
| 431 | 
         | 
| 432 | 
         
             
                            with gr.Column():
         
     | 
| 433 | 
         
            -
                                result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
         
     | 
| 434 | 
         
            -
                                result_video = gr.Video(label="Generated Animation", interactive=False)
         
     | 
| 435 | 
         
             
                                infer_progress = gr.Textbox(
         
     | 
| 436 | 
         
            -
                                    label="Generation Info",
         
     | 
| 437 | 
         
             
                                    value="No task currently",
         
     | 
| 438 | 
         
             
                                    interactive=False
         
     | 
| 439 | 
         
             
                                )
         
     | 
| 440 | 
         | 
| 441 | 
         
            -
                         
     | 
| 442 | 
         
            -
                             
     | 
| 443 | 
         
            -
             
     | 
| 444 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 445 | 
         
             
                        )
         
     | 
| 
         | 
|
| 446 | 
         
             
                        easyanimate_edition_dropdown.change(
         
     | 
| 447 | 
         
             
                            fn=controller.update_edition, 
         
     | 
| 448 | 
         
             
                            inputs=[easyanimate_edition_dropdown], 
         
     | 
| 
         @@ -451,7 +673,6 @@ def ui(): 
     | 
|
| 451 | 
         
             
                                diffusion_transformer_dropdown, 
         
     | 
| 452 | 
         
             
                                motion_module_dropdown, 
         
     | 
| 453 | 
         
             
                                motion_module_refresh_button, 
         
     | 
| 454 | 
         
            -
                                is_image, 
         
     | 
| 455 | 
         
             
                                width_slider, 
         
     | 
| 456 | 
         
             
                                height_slider, 
         
     | 
| 457 | 
         
             
                                length_slider, 
         
     | 
| 
         @@ -469,11 +690,17 @@ def ui(): 
     | 
|
| 469 | 
         
             
                                negative_prompt_textbox, 
         
     | 
| 470 | 
         
             
                                sampler_dropdown, 
         
     | 
| 471 | 
         
             
                                sample_step_slider, 
         
     | 
| 
         | 
|
| 472 | 
         
             
                                width_slider, 
         
     | 
| 473 | 
         
             
                                height_slider, 
         
     | 
| 474 | 
         
            -
                                 
     | 
| 
         | 
|
| 475 | 
         
             
                                length_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 476 | 
         
             
                                cfg_scale_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 477 | 
         
             
                                seed_textbox,
         
     | 
| 478 | 
         
             
                            ],
         
     | 
| 479 | 
         
             
                            outputs=[result_image, result_video, infer_progress]
         
     | 
| 
         @@ -483,11 +710,18 @@ def ui(): 
     | 
|
| 483 | 
         | 
| 484 | 
         
             
            class EasyAnimateController_Modelscope:
         
     | 
| 485 | 
         
             
                def __init__(self, edition, config_path, model_name, savedir_sample):
         
     | 
| 486 | 
         
            -
                    #  
     | 
| 487 | 
         
            -
                    weight_dtype 
     | 
| 488 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 489 | 
         
             
                    os.makedirs(self.savedir_sample, exist_ok=True)
         
     | 
| 490 | 
         | 
| 
         | 
|
| 491 | 
         
             
                    self.edition = edition
         
     | 
| 492 | 
         
             
                    self.inference_config = OmegaConf.load(config_path)
         
     | 
| 493 | 
         
             
                    # Get Transformer
         
     | 
| 
         @@ -513,32 +747,107 @@ class EasyAnimateController_Modelscope: 
     | 
|
| 513 | 
         
             
                        subfolder="text_encoder", 
         
     | 
| 514 | 
         
             
                        torch_dtype=weight_dtype
         
     | 
| 515 | 
         
             
                    )
         
     | 
| 516 | 
         
            -
                     
     | 
| 517 | 
         
            -
             
     | 
| 518 | 
         
            -
                         
     | 
| 519 | 
         
            -
             
     | 
| 520 | 
         
            -
             
     | 
| 521 | 
         
            -
             
     | 
| 522 | 
         
            -
             
     | 
| 523 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 524 | 
         
             
                    print("Update diffusion transformer done")
         
     | 
| 525 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 526 | 
         
             
                def generate(
         
     | 
| 527 | 
         
             
                    self,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 528 | 
         
             
                    prompt_textbox, 
         
     | 
| 529 | 
         
             
                    negative_prompt_textbox, 
         
     | 
| 530 | 
         
             
                    sampler_dropdown, 
         
     | 
| 531 | 
         
             
                    sample_step_slider, 
         
     | 
| 
         | 
|
| 532 | 
         
             
                    width_slider, 
         
     | 
| 533 | 
         
             
                    height_slider, 
         
     | 
| 534 | 
         
            -
                     
     | 
| 
         | 
|
| 535 | 
         
             
                    length_slider, 
         
     | 
| 536 | 
         
             
                    cfg_scale_slider, 
         
     | 
| 537 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 538 | 
         
             
                ):    
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 539 | 
         
             
                    if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
         
     | 
| 540 | 
         | 
| 541 | 
         
             
                    self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 542 | 
         
             
                    self.pipeline.to("cuda")
         
     | 
| 543 | 
         | 
| 544 | 
         
             
                    if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
         
     | 
| 
         @@ -546,21 +855,52 @@ class EasyAnimateController_Modelscope: 
     | 
|
| 546 | 
         
             
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         
     | 
| 547 | 
         | 
| 548 | 
         
             
                    try:
         
     | 
| 549 | 
         
            -
                         
     | 
| 550 | 
         
            -
                             
     | 
| 551 | 
         
            -
             
     | 
| 552 | 
         
            -
                             
     | 
| 553 | 
         
            -
             
     | 
| 554 | 
         
            -
             
     | 
| 555 | 
         
            -
             
     | 
| 556 | 
         
            -
             
     | 
| 557 | 
         
            -
             
     | 
| 558 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 559 | 
         
             
                    except Exception as e:
         
     | 
| 560 | 
         
             
                        gc.collect()
         
     | 
| 561 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 562 | 
         
             
                        torch.cuda.ipc_collect()
         
     | 
| 563 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 564 | 
         | 
| 565 | 
         
             
                    if not os.path.exists(self.savedir_sample):
         
     | 
| 566 | 
         
             
                        os.makedirs(self.savedir_sample, exist_ok=True)
         
     | 
| 
         @@ -578,11 +918,23 @@ class EasyAnimateController_Modelscope: 
     | 
|
| 578 | 
         
             
                        image = (image * 255).numpy().astype(np.uint8)
         
     | 
| 579 | 
         
             
                        image = Image.fromarray(image)
         
     | 
| 580 | 
         
             
                        image.save(save_sample_path)
         
     | 
| 581 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 582 | 
         
             
                    else:
         
     | 
| 583 | 
         
             
                        save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
         
     | 
| 584 | 
         
             
                        save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
         
     | 
| 585 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 586 | 
         | 
| 587 | 
         | 
| 588 | 
         
             
            def ui_modelscope(edition, config_path, model_name, savedir_sample):
         
     | 
| 
         @@ -601,71 +953,197 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): 
     | 
|
| 601 | 
         
             
                        """
         
     | 
| 602 | 
         
             
                    )
         
     | 
| 603 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 604 | 
         
            -
                         
     | 
| 605 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 606 | 
         | 
| 607 | 
         
             
                        with gr.Row():
         
     | 
| 608 | 
         
             
                            with gr.Column():
         
     | 
| 609 | 
         
             
                                with gr.Row():
         
     | 
| 610 | 
         
            -
                                    sampler_dropdown   = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         
     | 
| 611 | 
         
            -
                                    sample_step_slider = gr.Slider(label="Sampling steps", value= 
     | 
| 612 | 
         | 
| 613 | 
         
             
                                if edition == "v1":
         
     | 
| 614 | 
         
            -
                                    width_slider     = gr.Slider(label="Width",            value=512, minimum=384, maximum=704, step=32)
         
     | 
| 615 | 
         
            -
                                    height_slider    = gr.Slider(label="Height",           value=512, minimum=384, maximum=704, step=32)
         
     | 
| 616 | 
         
            -
             
     | 
| 617 | 
         
            -
             
     | 
| 618 | 
         
            -
             
     | 
| 619 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 620 | 
         
             
                                else:
         
     | 
| 621 | 
         
            -
                                     
     | 
| 622 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 623 | 
         
             
                                    with gr.Column():
         
     | 
| 624 | 
         
             
                                        gr.Markdown(
         
     | 
| 625 | 
         
             
                                            """                    
         
     | 
| 626 | 
         
            -
                                             
     | 
| 627 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 628 | 
         
             
                                            """
         
     | 
| 629 | 
         
             
                                        )
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 630 | 
         
             
                                        with gr.Row():
         
     | 
| 631 | 
         
            -
                                             
     | 
| 632 | 
         
            -
             
     | 
| 633 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 634 | 
         | 
| 635 | 
         
             
                                with gr.Row():
         
     | 
| 636 | 
         
            -
                                    seed_textbox = gr.Textbox(label="Seed", value=43)
         
     | 
| 637 | 
         
             
                                    seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
         
     | 
| 638 | 
         
            -
                                    seed_button.click( 
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 639 | 
         | 
| 640 | 
         
            -
                                generate_button = gr.Button(value="Generate", variant='primary')
         
     | 
| 641 | 
         | 
| 642 | 
         
             
                            with gr.Column():
         
     | 
| 643 | 
         
            -
                                result_image = gr.Image(label="Generated Image", interactive=False, visible=False)
         
     | 
| 644 | 
         
            -
                                result_video = gr.Video(label="Generated Animation", interactive=False)
         
     | 
| 645 | 
         
             
                                infer_progress = gr.Textbox(
         
     | 
| 646 | 
         
            -
                                    label="Generation Info",
         
     | 
| 647 | 
         
             
                                    value="No task currently",
         
     | 
| 648 | 
         
             
                                    interactive=False
         
     | 
| 649 | 
         
             
                                )
         
     | 
| 650 | 
         | 
| 651 | 
         
            -
                         
     | 
| 652 | 
         
            -
                             
     | 
| 653 | 
         
            -
             
     | 
| 654 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 655 | 
         
             
                        )
         
     | 
| 656 | 
         | 
| 657 | 
         
             
                        generate_button.click(
         
     | 
| 658 | 
         
             
                            fn=controller.generate,
         
     | 
| 659 | 
         
             
                            inputs=[
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 660 | 
         
             
                                prompt_textbox, 
         
     | 
| 661 | 
         
             
                                negative_prompt_textbox, 
         
     | 
| 662 | 
         
             
                                sampler_dropdown, 
         
     | 
| 663 | 
         
             
                                sample_step_slider, 
         
     | 
| 
         | 
|
| 664 | 
         
             
                                width_slider, 
         
     | 
| 665 | 
         
             
                                height_slider, 
         
     | 
| 666 | 
         
            -
                                 
     | 
| 
         | 
|
| 667 | 
         
             
                                length_slider, 
         
     | 
| 668 | 
         
             
                                cfg_scale_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 669 | 
         
             
                                seed_textbox,
         
     | 
| 670 | 
         
             
                            ],
         
     | 
| 671 | 
         
             
                            outputs=[result_image, result_video, infer_progress]
         
     | 
| 
         @@ -674,31 +1152,51 @@ def ui_modelscope(edition, config_path, model_name, savedir_sample): 
     | 
|
| 674 | 
         | 
| 675 | 
         | 
| 676 | 
         
             
            def post_eas(
         
     | 
| 
         | 
|
| 
         | 
|
| 677 | 
         
             
                prompt_textbox, negative_prompt_textbox, 
         
     | 
| 678 | 
         
            -
                sampler_dropdown, sample_step_slider, width_slider, height_slider,
         
     | 
| 679 | 
         
            -
                 
     | 
| 
         | 
|
| 680 | 
         
             
            ):
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 681 | 
         
             
                datas = {
         
     | 
| 682 | 
         
            -
                    "base_model_path":  
     | 
| 683 | 
         
            -
                    "motion_module_path":  
     | 
| 684 | 
         
            -
                    "lora_model_path":  
     | 
| 685 | 
         
            -
                    "lora_alpha_slider":  
     | 
| 686 | 
         
             
                    "prompt_textbox": prompt_textbox, 
         
     | 
| 687 | 
         
             
                    "negative_prompt_textbox": negative_prompt_textbox, 
         
     | 
| 688 | 
         
             
                    "sampler_dropdown": sampler_dropdown, 
         
     | 
| 689 | 
         
             
                    "sample_step_slider": sample_step_slider, 
         
     | 
| 
         | 
|
| 690 | 
         
             
                    "width_slider": width_slider, 
         
     | 
| 691 | 
         
             
                    "height_slider": height_slider, 
         
     | 
| 692 | 
         
            -
                    " 
     | 
| 
         | 
|
| 693 | 
         
             
                    "length_slider": length_slider,
         
     | 
| 694 | 
         
             
                    "cfg_scale_slider": cfg_scale_slider,
         
     | 
| 
         | 
|
| 
         | 
|
| 695 | 
         
             
                    "seed_textbox": seed_textbox,
         
     | 
| 696 | 
         
             
                }
         
     | 
| 697 | 
         
            -
             
     | 
| 698 | 
         
             
                session = requests.session()
         
     | 
| 699 | 
         
             
                session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
         
     | 
| 700 | 
         | 
| 701 | 
         
            -
                response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas)
         
     | 
| 
         | 
|
| 702 | 
         
             
                outputs = response.json()
         
     | 
| 703 | 
         
             
                return outputs
         
     | 
| 704 | 
         | 
| 
         @@ -710,23 +1208,42 @@ class EasyAnimateController_EAS: 
     | 
|
| 710 | 
         | 
| 711 | 
         
             
                def generate(
         
     | 
| 712 | 
         
             
                    self,
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 713 | 
         
             
                    prompt_textbox, 
         
     | 
| 714 | 
         
             
                    negative_prompt_textbox, 
         
     | 
| 715 | 
         
             
                    sampler_dropdown, 
         
     | 
| 716 | 
         
             
                    sample_step_slider, 
         
     | 
| 
         | 
|
| 717 | 
         
             
                    width_slider, 
         
     | 
| 718 | 
         
             
                    height_slider, 
         
     | 
| 719 | 
         
            -
                     
     | 
| 
         | 
|
| 720 | 
         
             
                    length_slider, 
         
     | 
| 721 | 
         
             
                    cfg_scale_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 722 | 
         
             
                    seed_textbox
         
     | 
| 723 | 
         
             
                ):
         
     | 
| 
         | 
|
| 
         | 
|
| 724 | 
         
             
                    outputs = post_eas(
         
     | 
| 
         | 
|
| 
         | 
|
| 725 | 
         
             
                        prompt_textbox, negative_prompt_textbox, 
         
     | 
| 726 | 
         
            -
                        sampler_dropdown, sample_step_slider, width_slider, height_slider,
         
     | 
| 727 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 728 | 
         
             
                    )
         
     | 
| 729 | 
         
            -
                     
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 730 | 
         
             
                    decoded_data = base64.b64decode(base64_encoding)
         
     | 
| 731 | 
         | 
| 732 | 
         
             
                    if not os.path.exists(self.savedir_sample):
         
     | 
| 
         @@ -768,35 +1285,134 @@ def ui_eas(edition, config_path, model_name, savedir_sample): 
     | 
|
| 768 | 
         
             
                        """
         
     | 
| 769 | 
         
             
                    )
         
     | 
| 770 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 771 | 
         
            -
                         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 772 | 
         
             
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
         
     | 
| 773 | 
         | 
| 774 | 
         
             
                        with gr.Row():
         
     | 
| 775 | 
         
             
                            with gr.Column():
         
     | 
| 776 | 
         
             
                                with gr.Row():
         
     | 
| 777 | 
         
             
                                    sampler_dropdown   = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         
     | 
| 778 | 
         
            -
                                    sample_step_slider = gr.Slider(label="Sampling steps", value= 
     | 
| 779 | 
         | 
| 780 | 
         
             
                                if edition == "v1":
         
     | 
| 781 | 
         
             
                                    width_slider     = gr.Slider(label="Width",            value=512, minimum=384, maximum=704, step=32)
         
     | 
| 782 | 
         
             
                                    height_slider    = gr.Slider(label="Height",           value=512, minimum=384, maximum=704, step=32)
         
     | 
| 783 | 
         
            -
             
     | 
| 784 | 
         
            -
             
     | 
| 785 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 786 | 
         
             
                                    cfg_scale_slider = gr.Slider(label="CFG Scale",        value=6.0, minimum=0,   maximum=20)
         
     | 
| 787 | 
         
             
                                else:
         
     | 
| 788 | 
         
            -
                                     
     | 
| 789 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 790 | 
         
             
                                    with gr.Column():
         
     | 
| 791 | 
         
             
                                        gr.Markdown(
         
     | 
| 792 | 
         
             
                                            """                    
         
     | 
| 793 | 
         
            -
                                             
     | 
| 794 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 795 | 
         
             
                                            """
         
     | 
| 796 | 
         
             
                                        )
         
     | 
| 797 | 
         
            -
             
     | 
| 798 | 
         
            -
             
     | 
| 799 | 
         
            -
             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 800 | 
         
             
                                    cfg_scale_slider = gr.Slider(label="CFG Scale",        value=7.0, minimum=0,   maximum=20)
         
     | 
| 801 | 
         | 
| 802 | 
         
             
                                with gr.Row():
         
     | 
| 
         @@ -819,24 +1435,45 @@ def ui_eas(edition, config_path, model_name, savedir_sample): 
     | 
|
| 819 | 
         
             
                                    interactive=False
         
     | 
| 820 | 
         
             
                                )
         
     | 
| 821 | 
         | 
| 822 | 
         
            -
                         
     | 
| 823 | 
         
            -
                             
     | 
| 824 | 
         
            -
             
     | 
| 825 | 
         
            -
                             
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 826 | 
         
             
                        )
         
     | 
| 827 | 
         | 
| 828 | 
         
             
                        generate_button.click(
         
     | 
| 829 | 
         
             
                            fn=controller.generate,
         
     | 
| 830 | 
         
             
                            inputs=[
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 831 | 
         
             
                                prompt_textbox, 
         
     | 
| 832 | 
         
             
                                negative_prompt_textbox, 
         
     | 
| 833 | 
         
             
                                sampler_dropdown, 
         
     | 
| 834 | 
         
             
                                sample_step_slider, 
         
     | 
| 
         | 
|
| 835 | 
         
             
                                width_slider, 
         
     | 
| 836 | 
         
             
                                height_slider, 
         
     | 
| 837 | 
         
            -
                                 
     | 
| 
         | 
|
| 838 | 
         
             
                                length_slider, 
         
     | 
| 839 | 
         
             
                                cfg_scale_slider, 
         
     | 
| 
         | 
|
| 
         | 
|
| 840 | 
         
             
                                seed_textbox,
         
     | 
| 841 | 
         
             
                            ],
         
     | 
| 842 | 
         
             
                            outputs=[result_image, result_video, infer_progress]
         
     | 
| 
         | 
|
| 1 | 
         
             
            """Modified from https://github.com/guoyww/AnimateDiff/blob/main/app.py
         
     | 
| 2 | 
         
             
            """
         
     | 
| 3 | 
         
            +
            import base64
         
     | 
| 4 | 
         
             
            import gc
         
     | 
| 5 | 
         
             
            import json
         
     | 
| 6 | 
         
             
            import os
         
     | 
| 7 | 
         
             
            import random
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 8 | 
         
             
            from datetime import datetime
         
     | 
| 9 | 
         
             
            from glob import glob
         
     | 
| 10 | 
         | 
| 11 | 
         
             
            import gradio as gr
         
     | 
| 
         | 
|
| 12 | 
         
             
            import numpy as np
         
     | 
| 13 | 
         
            +
            import pkg_resources
         
     | 
| 14 | 
         
            +
            import requests
         
     | 
| 15 | 
         
            +
            import torch
         
     | 
| 16 | 
         
             
            from diffusers import (AutoencoderKL, DDIMScheduler,
         
     | 
| 17 | 
         
             
                                   DPMSolverMultistepScheduler,
         
     | 
| 18 | 
         
             
                                   EulerAncestralDiscreteScheduler, EulerDiscreteScheduler,
         
     | 
| 19 | 
         
             
                                   PNDMScheduler)
         
     | 
| 
         | 
|
| 20 | 
         
             
            from diffusers.utils.import_utils import is_xformers_available
         
     | 
| 21 | 
         
             
            from omegaconf import OmegaConf
         
     | 
| 22 | 
         
            +
            from PIL import Image
         
     | 
| 23 | 
         
             
            from safetensors import safe_open
         
     | 
| 24 | 
         
            +
            from transformers import (CLIPImageProcessor, CLIPVisionModelWithProjection,
         
     | 
| 25 | 
         
            +
                                      T5EncoderModel, T5Tokenizer)
         
     | 
| 26 | 
         | 
| 27 | 
         
            +
            from easyanimate.data.bucket_sampler import ASPECT_RATIO_512, get_closest_ratio
         
     | 
| 28 | 
         
            +
            from easyanimate.models.autoencoder_magvit import AutoencoderKLMagvit
         
     | 
| 29 | 
         
             
            from easyanimate.models.transformer3d import Transformer3DModel
         
     | 
| 30 | 
         
             
            from easyanimate.pipeline.pipeline_easyanimate import EasyAnimatePipeline
         
     | 
| 31 | 
         
            +
            from easyanimate.pipeline.pipeline_easyanimate_inpaint import \
         
     | 
| 32 | 
         
            +
                EasyAnimateInpaintPipeline
         
     | 
| 33 | 
         
             
            from easyanimate.utils.lora_utils import merge_lora, unmerge_lora
         
     | 
| 34 | 
         
            +
            from easyanimate.utils.utils import (
         
     | 
| 35 | 
         
            +
                get_image_to_video_latent,
         
     | 
| 36 | 
         
            +
                get_width_and_height_from_image_and_base_resolution, save_videos_grid)
         
     | 
| 37 | 
         | 
| 
         | 
|
| 38 | 
         
             
            scheduler_dict = {
         
     | 
| 39 | 
         
             
                "Euler": EulerDiscreteScheduler,
         
     | 
| 40 | 
         
             
                "Euler A": EulerAncestralDiscreteScheduler,
         
     | 
| 
         | 
|
| 65 | 
         
             
                    self.personalized_model_dir     = os.path.join(self.basedir, "models", "Personalized_Model")
         
     | 
| 66 | 
         
             
                    self.savedir                    = os.path.join(self.basedir, "samples", datetime.now().strftime("Gradio-%Y-%m-%dT%H-%M-%S"))
         
     | 
| 67 | 
         
             
                    self.savedir_sample             = os.path.join(self.savedir, "sample")
         
     | 
| 68 | 
         
            +
                    self.edition                    = "v3"
         
     | 
| 69 | 
         
            +
                    self.inference_config           = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml"))
         
     | 
| 70 | 
         
             
                    os.makedirs(self.savedir, exist_ok=True)
         
     | 
| 71 | 
         | 
| 72 | 
         
             
                    self.diffusion_transformer_list = []
         
     | 
| 
         | 
|
| 90 | 
         
             
                    self.weight_dtype = torch.bfloat16
         
     | 
| 91 | 
         | 
| 92 | 
         
             
                def refresh_diffusion_transformer(self):
         
     | 
| 93 | 
         
            +
                    self.diffusion_transformer_list = sorted(glob(os.path.join(self.diffusion_transformer_dir, "*/")))
         
     | 
| 94 | 
         | 
| 95 | 
         
             
                def refresh_motion_module(self):
         
     | 
| 96 | 
         
            +
                    motion_module_list = sorted(glob(os.path.join(self.motion_module_dir, "*.safetensors")))
         
     | 
| 97 | 
         
             
                    self.motion_module_list = [os.path.basename(p) for p in motion_module_list]
         
     | 
| 98 | 
         | 
| 99 | 
         
             
                def refresh_personalized_model(self):
         
     | 
| 100 | 
         
            +
                    personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
         
     | 
| 101 | 
         
             
                    self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
         
     | 
| 102 | 
         | 
| 103 | 
         
             
                def update_edition(self, edition):
         
     | 
| 
         | 
|
| 105 | 
         
             
                    self.edition = edition
         
     | 
| 106 | 
         
             
                    if edition == "v1":
         
     | 
| 107 | 
         
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_motion_module_v1.yaml"))
         
     | 
| 108 | 
         
            +
                        return gr.update(), gr.update(value="none"), gr.update(visible=True), gr.update(visible=True), \
         
     | 
| 109 | 
         
            +
                            gr.update(value=512, minimum=384, maximum=704, step=32), \
         
     | 
| 110 | 
         
             
                            gr.update(value=512, minimum=384, maximum=704, step=32), gr.update(value=80, minimum=40, maximum=80, step=1)
         
     | 
| 111 | 
         
            +
                    elif edition == "v2":
         
     | 
| 112 | 
         
             
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_magvit_motion_module_v2.yaml"))
         
     | 
| 113 | 
         
            +
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         
     | 
| 114 | 
         
            +
                            gr.update(value=672, minimum=128, maximum=1280, step=16), \
         
     | 
| 115 | 
         
             
                            gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=9, maximum=144, step=9)
         
     | 
| 116 | 
         
            +
                    else:
         
     | 
| 117 | 
         
            +
                        self.inference_config = OmegaConf.load(os.path.join(self.config_dir, "easyanimate_video_slicevae_motion_module_v3.yaml"))
         
     | 
| 118 | 
         
            +
                        return gr.update(), gr.update(value="none"), gr.update(visible=False), gr.update(visible=False), \
         
     | 
| 119 | 
         
            +
                            gr.update(value=672, minimum=128, maximum=1280, step=16), \
         
     | 
| 120 | 
         
            +
                            gr.update(value=384, minimum=128, maximum=1280, step=16), gr.update(value=144, minimum=8, maximum=144, step=8)
         
     | 
| 121 | 
         | 
| 122 | 
         
             
                def update_diffusion_transformer(self, diffusion_transformer_dropdown):
         
     | 
| 123 | 
         
             
                    print("Update diffusion transformer")
         
     | 
| 124 | 
         
             
                    if diffusion_transformer_dropdown == "none":
         
     | 
| 125 | 
         
            +
                        return gr.update()
         
     | 
| 126 | 
         
             
                    if OmegaConf.to_container(self.inference_config['vae_kwargs'])['enable_magvit']:
         
     | 
| 127 | 
         
             
                        Choosen_AutoencoderKL = AutoencoderKLMagvit
         
     | 
| 128 | 
         
             
                    else:
         
     | 
| 
         | 
|
| 140 | 
         
             
                    self.text_encoder = T5EncoderModel.from_pretrained(diffusion_transformer_dropdown, subfolder="text_encoder", torch_dtype=self.weight_dtype)
         
     | 
| 141 | 
         | 
| 142 | 
         
             
                    # Get pipeline
         
     | 
| 143 | 
         
            +
                    if self.transformer.config.in_channels != 12:
         
     | 
| 144 | 
         
            +
                        self.pipeline = EasyAnimatePipeline(
         
     | 
| 145 | 
         
            +
                            vae=self.vae, 
         
     | 
| 146 | 
         
            +
                            text_encoder=self.text_encoder, 
         
     | 
| 147 | 
         
            +
                            tokenizer=self.tokenizer, 
         
     | 
| 148 | 
         
            +
                            transformer=self.transformer,
         
     | 
| 149 | 
         
            +
                            scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
         
     | 
| 150 | 
         
            +
                        )
         
     | 
| 151 | 
         
            +
                    else:
         
     | 
| 152 | 
         
            +
                        clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
         
     | 
| 153 | 
         
            +
                            diffusion_transformer_dropdown, subfolder="image_encoder"
         
     | 
| 154 | 
         
            +
                        ).to("cuda", self.weight_dtype)
         
     | 
| 155 | 
         
            +
                        clip_image_processor = CLIPImageProcessor.from_pretrained(
         
     | 
| 156 | 
         
            +
                            diffusion_transformer_dropdown, subfolder="image_encoder"
         
     | 
| 157 | 
         
            +
                        )
         
     | 
| 158 | 
         
            +
                        self.pipeline = EasyAnimateInpaintPipeline(
         
     | 
| 159 | 
         
            +
                            vae=self.vae, 
         
     | 
| 160 | 
         
            +
                            text_encoder=self.text_encoder, 
         
     | 
| 161 | 
         
            +
                            tokenizer=self.tokenizer, 
         
     | 
| 162 | 
         
            +
                            transformer=self.transformer,
         
     | 
| 163 | 
         
            +
                            scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)),
         
     | 
| 164 | 
         
            +
                            clip_image_encoder=clip_image_encoder,
         
     | 
| 165 | 
         
            +
                            clip_image_processor=clip_image_processor,
         
     | 
| 166 | 
         
            +
                        )
         
     | 
| 167 | 
         
            +
                    
         
     | 
| 168 | 
         
             
                    print("Update diffusion transformer done")
         
     | 
| 169 | 
         
            +
                    return gr.update()
         
     | 
| 170 | 
         | 
| 171 | 
         
             
                def update_motion_module(self, motion_module_dropdown):
         
     | 
| 172 | 
         
             
                    self.motion_module_path = motion_module_dropdown
         
     | 
| 173 | 
         
             
                    print("Update motion module")
         
     | 
| 174 | 
         
             
                    if motion_module_dropdown == "none":
         
     | 
| 175 | 
         
            +
                        return gr.update()
         
     | 
| 176 | 
         
             
                    if self.transformer is None:
         
     | 
| 177 | 
         
             
                        gr.Info(f"Please select a pretrained model path.")
         
     | 
| 178 | 
         
            +
                        return gr.update(value=None)
         
     | 
| 179 | 
         
             
                    else:
         
     | 
| 180 | 
         
             
                        motion_module_dropdown = os.path.join(self.motion_module_dir, motion_module_dropdown)
         
     | 
| 181 | 
         
             
                        if motion_module_dropdown.endswith(".safetensors"):
         
     | 
| 
         | 
|
| 187 | 
         
             
                            motion_module_state_dict = torch.load(motion_module_dropdown, map_location="cpu")
         
     | 
| 188 | 
         
             
                        missing, unexpected = self.transformer.load_state_dict(motion_module_state_dict, strict=False)
         
     | 
| 189 | 
         
             
                        print("Update motion module done.")
         
     | 
| 190 | 
         
            +
                        return gr.update()
         
     | 
| 191 | 
         | 
| 192 | 
         
             
                def update_base_model(self, base_model_dropdown):
         
     | 
| 193 | 
         
             
                    self.base_model_path = base_model_dropdown
         
     | 
| 194 | 
         
             
                    print("Update base model")
         
     | 
| 195 | 
         
             
                    if base_model_dropdown == "none":
         
     | 
| 196 | 
         
            +
                        return gr.update()
         
     | 
| 197 | 
         
             
                    if self.transformer is None:
         
     | 
| 198 | 
         
             
                        gr.Info(f"Please select a pretrained model path.")
         
     | 
| 199 | 
         
            +
                        return gr.update(value=None)
         
     | 
| 200 | 
         
             
                    else:
         
     | 
| 201 | 
         
             
                        base_model_dropdown = os.path.join(self.personalized_model_dir, base_model_dropdown)
         
     | 
| 202 | 
         
             
                        base_model_state_dict = {}
         
     | 
| 
         | 
|
| 205 | 
         
             
                                base_model_state_dict[key] = f.get_tensor(key)
         
     | 
| 206 | 
         
             
                        self.transformer.load_state_dict(base_model_state_dict, strict=False)
         
     | 
| 207 | 
         
             
                        print("Update base done")
         
     | 
| 208 | 
         
            +
                        return gr.update()
         
     | 
| 209 | 
         | 
| 210 | 
         
             
                def update_lora_model(self, lora_model_dropdown):
         
     | 
| 211 | 
         
             
                    print("Update lora model")
         
     | 
| 212 | 
         
             
                    if lora_model_dropdown == "none":
         
     | 
| 213 | 
         
             
                        self.lora_model_path = "none"
         
     | 
| 214 | 
         
            +
                        return gr.update()
         
     | 
| 215 | 
         
             
                    lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
         
     | 
| 216 | 
         
             
                    self.lora_model_path = lora_model_dropdown
         
     | 
| 217 | 
         
            +
                    return gr.update()
         
     | 
| 218 | 
         | 
| 219 | 
         
             
                def generate(
         
     | 
| 220 | 
         
             
                    self,
         
     | 
| 
         | 
|
| 227 | 
         
             
                    negative_prompt_textbox, 
         
     | 
| 228 | 
         
             
                    sampler_dropdown, 
         
     | 
| 229 | 
         
             
                    sample_step_slider, 
         
     | 
| 230 | 
         
            +
                    resize_method,
         
     | 
| 231 | 
         
             
                    width_slider, 
         
     | 
| 232 | 
         
             
                    height_slider, 
         
     | 
| 233 | 
         
            +
                    base_resolution, 
         
     | 
| 234 | 
         
            +
                    generation_method, 
         
     | 
| 235 | 
         
             
                    length_slider, 
         
     | 
| 236 | 
         
            +
                    overlap_video_length, 
         
     | 
| 237 | 
         
            +
                    partial_video_length, 
         
     | 
| 238 | 
         
             
                    cfg_scale_slider, 
         
     | 
| 239 | 
         
            +
                    start_image, 
         
     | 
| 240 | 
         
            +
                    end_image, 
         
     | 
| 241 | 
         
             
                    seed_textbox,
         
     | 
| 242 | 
         
             
                    is_api = False,
         
     | 
| 243 | 
         
             
                ):
         
     | 
| 244 | 
         
            +
                    gc.collect()
         
     | 
| 245 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 246 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 247 | 
         
            +
             
     | 
| 248 | 
         
             
                    if self.transformer is None:
         
     | 
| 249 | 
         
             
                        raise gr.Error(f"Please select a pretrained model path.")
         
     | 
| 250 | 
         | 
| 
         | 
|
| 257 | 
         
             
                    if self.lora_model_path != lora_model_dropdown:
         
     | 
| 258 | 
         
             
                        print("Update lora model")
         
     | 
| 259 | 
         
             
                        self.update_lora_model(lora_model_dropdown)
         
     | 
| 260 | 
         
            +
                    
         
     | 
| 261 | 
         
            +
                    if resize_method == "Resize to the Start Image":
         
     | 
| 262 | 
         
            +
                        if start_image is None:
         
     | 
| 263 | 
         
            +
                            if is_api:
         
     | 
| 264 | 
         
            +
                                return "", f"Please upload an image when using \"Resize to the Start Image\"."
         
     | 
| 265 | 
         
            +
                            else:
         
     | 
| 266 | 
         
            +
                                raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".")
         
     | 
| 267 | 
         
            +
             
     | 
| 268 | 
         
            +
                        aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
         
     | 
| 269 | 
         
            +
                        
         
     | 
| 270 | 
         
            +
                        original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
         
     | 
| 271 | 
         
            +
                        closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
         
     | 
| 272 | 
         
            +
                        height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
         
     | 
| 273 | 
         
            +
             
     | 
| 274 | 
         
            +
                    if self.transformer.config.in_channels != 12 and start_image is not None:
         
     | 
| 275 | 
         
            +
                        if is_api:
         
     | 
| 276 | 
         
            +
                            return "", f"Please select an image to video pretrained model while using image to video."
         
     | 
| 277 | 
         
            +
                        else:
         
     | 
| 278 | 
         
            +
                            raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
         
     | 
| 279 | 
         
            +
             
     | 
| 280 | 
         
            +
                    if self.transformer.config.in_channels != 12 and generation_method == "Long Video Generation":
         
     | 
| 281 | 
         
            +
                        if is_api:
         
     | 
| 282 | 
         
            +
                            return "", f"Please select an image to video pretrained model while using long video generation."
         
     | 
| 283 | 
         
            +
                        else:
         
     | 
| 284 | 
         
            +
                            raise gr.Error(f"Please select an image to video pretrained model while using long video generation.")
         
     | 
| 285 | 
         
            +
                    
         
     | 
| 286 | 
         
            +
                    if start_image is None and end_image is not None:
         
     | 
| 287 | 
         
            +
                        if is_api:
         
     | 
| 288 | 
         
            +
                            return "", f"If specifying the ending image of the video, please specify a starting image of the video."
         
     | 
| 289 | 
         
            +
                        else:
         
     | 
| 290 | 
         
            +
                            raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
         
     | 
| 291 | 
         
            +
             
     | 
| 292 | 
         
            +
                    is_image = True if generation_method == "Image Generation" else False
         
     | 
| 293 | 
         | 
| 294 | 
         
             
                    if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
         
     | 
| 295 | 
         | 
| 
         | 
|
| 304 | 
         
             
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         
     | 
| 305 | 
         | 
| 306 | 
         
             
                    try:
         
     | 
| 307 | 
         
            +
                        if self.transformer.config.in_channels == 12:
         
     | 
| 308 | 
         
            +
                            if generation_method == "Long Video Generation":
         
     | 
| 309 | 
         
            +
                                init_frames = 0
         
     | 
| 310 | 
         
            +
                                last_frames = init_frames + partial_video_length
         
     | 
| 311 | 
         
            +
                                while init_frames < length_slider:
         
     | 
| 312 | 
         
            +
                                    if last_frames >= length_slider:
         
     | 
| 313 | 
         
            +
                                        if self.pipeline.vae.quant_conv.weight.ndim==5:
         
     | 
| 314 | 
         
            +
                                            mini_batch_encoder = self.pipeline.vae.mini_batch_encoder
         
     | 
| 315 | 
         
            +
                                            _partial_video_length = length_slider - init_frames
         
     | 
| 316 | 
         
            +
                                            _partial_video_length = int(_partial_video_length // mini_batch_encoder * mini_batch_encoder)
         
     | 
| 317 | 
         
            +
                                        else:
         
     | 
| 318 | 
         
            +
                                            _partial_video_length = length_slider - init_frames
         
     | 
| 319 | 
         
            +
                                        
         
     | 
| 320 | 
         
            +
                                        if _partial_video_length <= 0:
         
     | 
| 321 | 
         
            +
                                            break
         
     | 
| 322 | 
         
            +
                                    else:
         
     | 
| 323 | 
         
            +
                                        _partial_video_length = partial_video_length
         
     | 
| 324 | 
         
            +
             
     | 
| 325 | 
         
            +
                                    if last_frames >= length_slider:
         
     | 
| 326 | 
         
            +
                                        input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
         
     | 
| 327 | 
         
            +
                                    else:
         
     | 
| 328 | 
         
            +
                                        input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, None, video_length=_partial_video_length, sample_size=(height_slider, width_slider))
         
     | 
| 329 | 
         
            +
             
     | 
| 330 | 
         
            +
                                    with torch.no_grad():
         
     | 
| 331 | 
         
            +
                                        sample = self.pipeline(
         
     | 
| 332 | 
         
            +
                                            prompt_textbox, 
         
     | 
| 333 | 
         
            +
                                            negative_prompt     = negative_prompt_textbox,
         
     | 
| 334 | 
         
            +
                                            num_inference_steps = sample_step_slider,
         
     | 
| 335 | 
         
            +
                                            guidance_scale      = cfg_scale_slider,
         
     | 
| 336 | 
         
            +
                                            width               = width_slider,
         
     | 
| 337 | 
         
            +
                                            height              = height_slider,
         
     | 
| 338 | 
         
            +
                                            video_length        = _partial_video_length,
         
     | 
| 339 | 
         
            +
                                            generator           = generator,
         
     | 
| 340 | 
         
            +
             
     | 
| 341 | 
         
            +
                                            video        = input_video,
         
     | 
| 342 | 
         
            +
                                            mask_video   = input_video_mask,
         
     | 
| 343 | 
         
            +
                                            clip_image   = clip_image, 
         
     | 
| 344 | 
         
            +
                                            strength     = 1,
         
     | 
| 345 | 
         
            +
                                        ).videos
         
     | 
| 346 | 
         
            +
                                    
         
     | 
| 347 | 
         
            +
                                    if init_frames != 0:
         
     | 
| 348 | 
         
            +
                                        mix_ratio = torch.from_numpy(
         
     | 
| 349 | 
         
            +
                                            np.array([float(_index) / float(overlap_video_length) for _index in range(overlap_video_length)], np.float32)
         
     | 
| 350 | 
         
            +
                                        ).unsqueeze(0).unsqueeze(0).unsqueeze(-1).unsqueeze(-1)
         
     | 
| 351 | 
         
            +
                                        
         
     | 
| 352 | 
         
            +
                                        new_sample[:, :, -overlap_video_length:] = new_sample[:, :, -overlap_video_length:] * (1 - mix_ratio) + \
         
     | 
| 353 | 
         
            +
                                            sample[:, :, :overlap_video_length] * mix_ratio
         
     | 
| 354 | 
         
            +
                                        new_sample = torch.cat([new_sample, sample[:, :, overlap_video_length:]], dim = 2)
         
     | 
| 355 | 
         
            +
             
     | 
| 356 | 
         
            +
                                        sample = new_sample
         
     | 
| 357 | 
         
            +
                                    else:
         
     | 
| 358 | 
         
            +
                                        new_sample = sample
         
     | 
| 359 | 
         
            +
             
     | 
| 360 | 
         
            +
                                    if last_frames >= length_slider:
         
     | 
| 361 | 
         
            +
                                        break
         
     | 
| 362 | 
         
            +
             
     | 
| 363 | 
         
            +
                                    start_image = [
         
     | 
| 364 | 
         
            +
                                        Image.fromarray(
         
     | 
| 365 | 
         
            +
                                            (sample[0, :, _index].transpose(0, 1).transpose(1, 2) * 255).numpy().astype(np.uint8)
         
     | 
| 366 | 
         
            +
                                        ) for _index in range(-overlap_video_length, 0)
         
     | 
| 367 | 
         
            +
                                    ]
         
     | 
| 368 | 
         
            +
             
     | 
| 369 | 
         
            +
                                    init_frames = init_frames + _partial_video_length - overlap_video_length
         
     | 
| 370 | 
         
            +
                                    last_frames = init_frames + _partial_video_length
         
     | 
| 371 | 
         
            +
                            else:
         
     | 
| 372 | 
         
            +
                                input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
         
     | 
| 373 | 
         
            +
             
     | 
| 374 | 
         
            +
                                sample = self.pipeline(
         
     | 
| 375 | 
         
            +
                                    prompt_textbox,
         
     | 
| 376 | 
         
            +
                                    negative_prompt     = negative_prompt_textbox,
         
     | 
| 377 | 
         
            +
                                    num_inference_steps = sample_step_slider,
         
     | 
| 378 | 
         
            +
                                    guidance_scale      = cfg_scale_slider,
         
     | 
| 379 | 
         
            +
                                    width               = width_slider,
         
     | 
| 380 | 
         
            +
                                    height              = height_slider,
         
     | 
| 381 | 
         
            +
                                    video_length        = length_slider if not is_image else 1,
         
     | 
| 382 | 
         
            +
                                    generator           = generator,
         
     | 
| 383 | 
         
            +
             
     | 
| 384 | 
         
            +
                                    video        = input_video,
         
     | 
| 385 | 
         
            +
                                    mask_video   = input_video_mask,
         
     | 
| 386 | 
         
            +
                                    clip_image   = clip_image, 
         
     | 
| 387 | 
         
            +
                                ).videos
         
     | 
| 388 | 
         
            +
                        else:
         
     | 
| 389 | 
         
            +
                            sample = self.pipeline(
         
     | 
| 390 | 
         
            +
                                prompt_textbox,
         
     | 
| 391 | 
         
            +
                                negative_prompt     = negative_prompt_textbox,
         
     | 
| 392 | 
         
            +
                                num_inference_steps = sample_step_slider,
         
     | 
| 393 | 
         
            +
                                guidance_scale      = cfg_scale_slider,
         
     | 
| 394 | 
         
            +
                                width               = width_slider,
         
     | 
| 395 | 
         
            +
                                height              = height_slider,
         
     | 
| 396 | 
         
            +
                                video_length        = length_slider if not is_image else 1,
         
     | 
| 397 | 
         
            +
                                generator           = generator
         
     | 
| 398 | 
         
            +
                            ).videos
         
     | 
| 399 | 
         
             
                    except Exception as e:
         
     | 
| 400 | 
         
             
                        gc.collect()
         
     | 
| 401 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 
         | 
|
| 405 | 
         
             
                        if is_api:
         
     | 
| 406 | 
         
             
                            return "", f"Error. error information is {str(e)}"
         
     | 
| 407 | 
         
             
                        else:
         
     | 
| 408 | 
         
            +
                            return gr.update(), gr.update(), f"Error. error information is {str(e)}"
         
     | 
| 409 | 
         
            +
             
     | 
| 410 | 
         
            +
                    gc.collect()
         
     | 
| 411 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 412 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 413 | 
         | 
| 414 | 
         
             
                    # lora part
         
     | 
| 415 | 
         
             
                    if self.lora_model_path != "none":
         
     | 
| 
         | 
|
| 451 | 
         
             
                        if is_api:
         
     | 
| 452 | 
         
             
                            return save_sample_path, "Success"
         
     | 
| 453 | 
         
             
                        else:
         
     | 
| 454 | 
         
            +
                            if gradio_version_is_above_4:
         
     | 
| 455 | 
         
            +
                                return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
         
     | 
| 456 | 
         
            +
                            else:
         
     | 
| 457 | 
         
            +
                                return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
         
     | 
| 458 | 
         
             
                    else:
         
     | 
| 459 | 
         
             
                        save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
         
     | 
| 460 | 
         
             
                        save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
         
     | 
| 
         | 
|
| 462 | 
         
             
                        if is_api:
         
     | 
| 463 | 
         
             
                            return save_sample_path, "Success"
         
     | 
| 464 | 
         
             
                        else:
         
     | 
| 465 | 
         
            +
                            if gradio_version_is_above_4:
         
     | 
| 466 | 
         
            +
                                return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
         
     | 
| 467 | 
         
            +
                            else:
         
     | 
| 468 | 
         
            +
                                return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
         
     | 
| 469 | 
         | 
| 470 | 
         | 
| 471 | 
         
             
            def ui():
         
     | 
| 
         | 
|
| 486 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 487 | 
         
             
                        gr.Markdown(
         
     | 
| 488 | 
         
             
                            """
         
     | 
| 489 | 
         
            +
                            ### 1. EasyAnimate Edition (EasyAnimate版本).
         
     | 
| 490 | 
         
             
                            """
         
     | 
| 491 | 
         
             
                        )
         
     | 
| 492 | 
         
             
                        with gr.Row():
         
     | 
| 493 | 
         
             
                            easyanimate_edition_dropdown = gr.Dropdown(
         
     | 
| 494 | 
         
            +
                                label="The config of EasyAnimate Edition (EasyAnimate版本配置)",
         
     | 
| 495 | 
         
            +
                                choices=["v1", "v2", "v3"],
         
     | 
| 496 | 
         
            +
                                value="v3",
         
     | 
| 497 | 
         
             
                                interactive=True,
         
     | 
| 498 | 
         
             
                            )
         
     | 
| 499 | 
         
             
                        gr.Markdown(
         
     | 
| 500 | 
         
             
                            """
         
     | 
| 501 | 
         
            +
                            ### 2. Model checkpoints (模型路径).
         
     | 
| 502 | 
         
             
                            """
         
     | 
| 503 | 
         
             
                        )
         
     | 
| 504 | 
         
             
                        with gr.Row():
         
     | 
| 505 | 
         
             
                            diffusion_transformer_dropdown = gr.Dropdown(
         
     | 
| 506 | 
         
            +
                                label="Pretrained Model Path (预训练模型路径)",
         
     | 
| 507 | 
         
             
                                choices=controller.diffusion_transformer_list,
         
     | 
| 508 | 
         
             
                                value="none",
         
     | 
| 509 | 
         
             
                                interactive=True,
         
     | 
| 
         | 
|
| 517 | 
         
             
                            diffusion_transformer_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
         
     | 
| 518 | 
         
             
                            def refresh_diffusion_transformer():
         
     | 
| 519 | 
         
             
                                controller.refresh_diffusion_transformer()
         
     | 
| 520 | 
         
            +
                                return gr.update(choices=controller.diffusion_transformer_list)
         
     | 
| 521 | 
         
             
                            diffusion_transformer_refresh_button.click(fn=refresh_diffusion_transformer, inputs=[], outputs=[diffusion_transformer_dropdown])
         
     | 
| 522 | 
         | 
| 523 | 
         
             
                        with gr.Row():
         
     | 
| 524 | 
         
             
                            motion_module_dropdown = gr.Dropdown(
         
     | 
| 525 | 
         
            +
                                label="Select motion module (选择运动模块[非必需])",
         
     | 
| 526 | 
         
             
                                choices=controller.motion_module_list,
         
     | 
| 527 | 
         
             
                                value="none",
         
     | 
| 528 | 
         
             
                                interactive=True,
         
     | 
| 
         | 
|
| 532 | 
         
             
                            motion_module_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton", visible=False)
         
     | 
| 533 | 
         
             
                            def update_motion_module():
         
     | 
| 534 | 
         
             
                                controller.refresh_motion_module()
         
     | 
| 535 | 
         
            +
                                return gr.update(choices=controller.motion_module_list)
         
     | 
| 536 | 
         
             
                            motion_module_refresh_button.click(fn=update_motion_module, inputs=[], outputs=[motion_module_dropdown])
         
     | 
| 537 | 
         | 
| 538 | 
         
             
                            base_model_dropdown = gr.Dropdown(
         
     | 
| 539 | 
         
            +
                                label="Select base Dreambooth model (选择基模型[非必需])",
         
     | 
| 540 | 
         
             
                                choices=controller.personalized_model_list,
         
     | 
| 541 | 
         
             
                                value="none",
         
     | 
| 542 | 
         
             
                                interactive=True,
         
     | 
| 543 | 
         
             
                            )
         
     | 
| 544 | 
         | 
| 545 | 
         
             
                            lora_model_dropdown = gr.Dropdown(
         
     | 
| 546 | 
         
            +
                                label="Select LoRA model (选择LoRA模型[非必需])",
         
     | 
| 547 | 
         
             
                                choices=["none"] + controller.personalized_model_list,
         
     | 
| 548 | 
         
             
                                value="none",
         
     | 
| 549 | 
         
             
                                interactive=True,
         
     | 
| 550 | 
         
             
                            )
         
     | 
| 551 | 
         | 
| 552 | 
         
            +
                            lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
         
     | 
| 553 | 
         | 
| 554 | 
         
             
                            personalized_refresh_button = gr.Button(value="\U0001F503", elem_classes="toolbutton")
         
     | 
| 555 | 
         
             
                            def update_personalized_model():
         
     | 
| 556 | 
         
             
                                controller.refresh_personalized_model()
         
     | 
| 557 | 
         
             
                                return [
         
     | 
| 558 | 
         
            +
                                    gr.update(choices=controller.personalized_model_list),
         
     | 
| 559 | 
         
            +
                                    gr.update(choices=["none"] + controller.personalized_model_list)
         
     | 
| 560 | 
         
             
                                ]
         
     | 
| 561 | 
         
             
                            personalized_refresh_button.click(fn=update_personalized_model, inputs=[], outputs=[base_model_dropdown, lora_model_dropdown])
         
     | 
| 562 | 
         | 
| 563 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 564 | 
         
             
                        gr.Markdown(
         
     | 
| 565 | 
         
             
                            """
         
     | 
| 566 | 
         
            +
                            ### 3. Configs for Generation (生成参数配置).
         
     | 
| 567 | 
         
             
                            """
         
     | 
| 568 | 
         
             
                        )
         
     | 
| 569 | 
         | 
| 570 | 
         
            +
                        prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
         
     | 
| 571 | 
         
            +
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." )
         
     | 
| 572 | 
         | 
| 573 | 
         
             
                        with gr.Row():
         
     | 
| 574 | 
         
             
                            with gr.Column():
         
     | 
| 575 | 
         
             
                                with gr.Row():
         
     | 
| 576 | 
         
            +
                                    sampler_dropdown   = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         
     | 
| 577 | 
         
            +
                                    sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=30, minimum=10, maximum=100, step=1)
         
     | 
| 578 | 
         | 
| 579 | 
         
            +
                                resize_method = gr.Radio(
         
     | 
| 580 | 
         
            +
                                    ["Generate by", "Resize to the Start Image"],
         
     | 
| 581 | 
         
            +
                                    value="Generate by",
         
     | 
| 582 | 
         
            +
                                    show_label=False,
         
     | 
| 583 | 
         
            +
                                )
         
     | 
| 584 | 
         
            +
                                width_slider     = gr.Slider(label="Width (视频宽度)",            value=672, minimum=128, maximum=1280, step=16)
         
     | 
| 585 | 
         
            +
                                height_slider    = gr.Slider(label="Height (视频高度)",           value=384, minimum=128, maximum=1280, step=16)
         
     | 
| 586 | 
         
            +
                                base_resolution  = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], visible=False)
         
     | 
| 587 | 
         
            +
             
     | 
| 588 | 
         
            +
                                with gr.Group():
         
     | 
| 589 | 
         
            +
                                    generation_method = gr.Radio(
         
     | 
| 590 | 
         
            +
                                        ["Video Generation", "Image Generation", "Long Video Generation"],
         
     | 
| 591 | 
         
            +
                                        value="Video Generation",
         
     | 
| 592 | 
         
            +
                                        show_label=False,
         
     | 
| 593 | 
         
            +
                                    )
         
     | 
| 594 | 
         
            +
                                    with gr.Row():
         
     | 
| 595 | 
         
            +
                                        length_slider = gr.Slider(label="Animation length (视频帧数)", value=144, minimum=8,   maximum=144,  step=8)
         
     | 
| 596 | 
         
            +
                                        overlap_video_length = gr.Slider(label="Overlap length (视频续写的重叠帧数)", value=4, minimum=1,   maximum=4,  step=1, visible=False)
         
     | 
| 597 | 
         
            +
                                        partial_video_length = gr.Slider(label="Partial video generation length (每个部分的视频生成帧数)", value=72, minimum=8,   maximum=144,  step=8, visible=False)
         
     | 
| 598 | 
         
            +
                                
         
     | 
| 599 | 
         
            +
                                with gr.Accordion("Image to Video (图片到视频)", open=False):
         
     | 
| 600 | 
         
            +
                                    start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
         
     | 
| 601 | 
         
            +
                                    
         
     | 
| 602 | 
         
            +
                                    template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         
     | 
| 603 | 
         
            +
                                    def select_template(evt: gr.SelectData):
         
     | 
| 604 | 
         
            +
                                        text = {
         
     | 
| 605 | 
         
            +
                                            "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 606 | 
         
            +
                                            "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 607 | 
         
            +
                                            "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 608 | 
         
            +
                                            "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 609 | 
         
            +
                                            "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 610 | 
         
            +
                                        }[template_gallery_path[evt.index]]
         
     | 
| 611 | 
         
            +
                                        return template_gallery_path[evt.index], text
         
     | 
| 612 | 
         
            +
             
     | 
| 613 | 
         
            +
                                    template_gallery = gr.Gallery(
         
     | 
| 614 | 
         
            +
                                        template_gallery_path,
         
     | 
| 615 | 
         
            +
                                        columns=5, rows=1,
         
     | 
| 616 | 
         
            +
                                        height=140,
         
     | 
| 617 | 
         
            +
                                        allow_preview=False,
         
     | 
| 618 | 
         
            +
                                        container=False,
         
     | 
| 619 | 
         
            +
                                        label="Template Examples",
         
     | 
| 620 | 
         
            +
                                    )
         
     | 
| 621 | 
         
            +
                                    template_gallery.select(select_template, None, [start_image, prompt_textbox])
         
     | 
| 622 | 
         
            +
                                    
         
     | 
| 623 | 
         
            +
                                    with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
         
     | 
| 624 | 
         
            +
                                        end_image   = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
         
     | 
| 625 | 
         
            +
             
     | 
| 626 | 
         
            +
                                cfg_scale_slider  = gr.Slider(label="CFG Scale (引导系数)",        value=7.0, minimum=0,   maximum=20)
         
     | 
| 627 | 
         | 
| 628 | 
         
             
                                with gr.Row():
         
     | 
| 629 | 
         
            +
                                    seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
         
     | 
| 630 | 
         
             
                                    seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
         
     | 
| 631 | 
         
            +
                                    seed_button.click(
         
     | 
| 632 | 
         
            +
                                        fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), 
         
     | 
| 633 | 
         
            +
                                        inputs=[], 
         
     | 
| 634 | 
         
            +
                                        outputs=[seed_textbox]
         
     | 
| 635 | 
         
            +
                                    )
         
     | 
| 636 | 
         | 
| 637 | 
         
            +
                                generate_button = gr.Button(value="Generate (生成)", variant='primary')
         
     | 
| 638 | 
         | 
| 639 | 
         
             
                            with gr.Column():
         
     | 
| 640 | 
         
            +
                                result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
         
     | 
| 641 | 
         
            +
                                result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
         
     | 
| 642 | 
         
             
                                infer_progress = gr.Textbox(
         
     | 
| 643 | 
         
            +
                                    label="Generation Info (生成信息)",
         
     | 
| 644 | 
         
             
                                    value="No task currently",
         
     | 
| 645 | 
         
             
                                    interactive=False
         
     | 
| 646 | 
         
             
                                )
         
     | 
| 647 | 
         | 
| 648 | 
         
            +
                        def upload_generation_method(generation_method):
         
     | 
| 649 | 
         
            +
                            if generation_method == "Video Generation":
         
     | 
| 650 | 
         
            +
                                return [gr.update(visible=True, maximum=144, value=144), gr.update(visible=False), gr.update(visible=False)]
         
     | 
| 651 | 
         
            +
                            elif generation_method == "Image Generation":
         
     | 
| 652 | 
         
            +
                                return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=False)]
         
     | 
| 653 | 
         
            +
                            else:
         
     | 
| 654 | 
         
            +
                                return [gr.update(visible=True, maximum=1440), gr.update(visible=True), gr.update(visible=True)]
         
     | 
| 655 | 
         
            +
                        generation_method.change(
         
     | 
| 656 | 
         
            +
                            upload_generation_method, generation_method, [length_slider, overlap_video_length, partial_video_length]
         
     | 
| 657 | 
         
            +
                        )
         
     | 
| 658 | 
         
            +
             
     | 
| 659 | 
         
            +
                        def upload_resize_method(resize_method):
         
     | 
| 660 | 
         
            +
                            if resize_method == "Generate by":
         
     | 
| 661 | 
         
            +
                                return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
         
     | 
| 662 | 
         
            +
                            else:
         
     | 
| 663 | 
         
            +
                                return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
         
     | 
| 664 | 
         
            +
                        resize_method.change(
         
     | 
| 665 | 
         
            +
                            upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
         
     | 
| 666 | 
         
             
                        )
         
     | 
| 667 | 
         
            +
             
     | 
| 668 | 
         
             
                        easyanimate_edition_dropdown.change(
         
     | 
| 669 | 
         
             
                            fn=controller.update_edition, 
         
     | 
| 670 | 
         
             
                            inputs=[easyanimate_edition_dropdown], 
         
     | 
| 
         | 
|
| 673 | 
         
             
                                diffusion_transformer_dropdown, 
         
     | 
| 674 | 
         
             
                                motion_module_dropdown, 
         
     | 
| 675 | 
         
             
                                motion_module_refresh_button, 
         
     | 
| 
         | 
|
| 676 | 
         
             
                                width_slider, 
         
     | 
| 677 | 
         
             
                                height_slider, 
         
     | 
| 678 | 
         
             
                                length_slider, 
         
     | 
| 
         | 
|
| 690 | 
         
             
                                negative_prompt_textbox, 
         
     | 
| 691 | 
         
             
                                sampler_dropdown, 
         
     | 
| 692 | 
         
             
                                sample_step_slider, 
         
     | 
| 693 | 
         
            +
                                resize_method,
         
     | 
| 694 | 
         
             
                                width_slider, 
         
     | 
| 695 | 
         
             
                                height_slider, 
         
     | 
| 696 | 
         
            +
                                base_resolution, 
         
     | 
| 697 | 
         
            +
                                generation_method, 
         
     | 
| 698 | 
         
             
                                length_slider, 
         
     | 
| 699 | 
         
            +
                                overlap_video_length, 
         
     | 
| 700 | 
         
            +
                                partial_video_length, 
         
     | 
| 701 | 
         
             
                                cfg_scale_slider, 
         
     | 
| 702 | 
         
            +
                                start_image, 
         
     | 
| 703 | 
         
            +
                                end_image, 
         
     | 
| 704 | 
         
             
                                seed_textbox,
         
     | 
| 705 | 
         
             
                            ],
         
     | 
| 706 | 
         
             
                            outputs=[result_image, result_video, infer_progress]
         
     | 
| 
         | 
|
| 710 | 
         | 
| 711 | 
         
             
            class EasyAnimateController_Modelscope:
         
     | 
| 712 | 
         
             
                def __init__(self, edition, config_path, model_name, savedir_sample):
         
     | 
| 713 | 
         
            +
                    # Weight Dtype
         
     | 
| 714 | 
         
            +
                    weight_dtype                    = torch.bfloat16
         
     | 
| 715 | 
         
            +
             
     | 
| 716 | 
         
            +
                    # Basic dir
         
     | 
| 717 | 
         
            +
                    self.basedir                    = os.getcwd()
         
     | 
| 718 | 
         
            +
                    self.personalized_model_dir     = os.path.join(self.basedir, "models", "Personalized_Model")
         
     | 
| 719 | 
         
            +
                    self.lora_model_path            = "none"
         
     | 
| 720 | 
         
            +
                    self.savedir_sample             = savedir_sample
         
     | 
| 721 | 
         
            +
                    self.refresh_personalized_model()
         
     | 
| 722 | 
         
             
                    os.makedirs(self.savedir_sample, exist_ok=True)
         
     | 
| 723 | 
         | 
| 724 | 
         
            +
                    # Config and model path
         
     | 
| 725 | 
         
             
                    self.edition = edition
         
     | 
| 726 | 
         
             
                    self.inference_config = OmegaConf.load(config_path)
         
     | 
| 727 | 
         
             
                    # Get Transformer
         
     | 
| 
         | 
|
| 747 | 
         
             
                        subfolder="text_encoder", 
         
     | 
| 748 | 
         
             
                        torch_dtype=weight_dtype
         
     | 
| 749 | 
         
             
                    )
         
     | 
| 750 | 
         
            +
                    # Get pipeline
         
     | 
| 751 | 
         
            +
                    if self.transformer.config.in_channels != 12:
         
     | 
| 752 | 
         
            +
                        self.pipeline = EasyAnimatePipeline(
         
     | 
| 753 | 
         
            +
                            vae=self.vae, 
         
     | 
| 754 | 
         
            +
                            text_encoder=self.text_encoder, 
         
     | 
| 755 | 
         
            +
                            tokenizer=self.tokenizer, 
         
     | 
| 756 | 
         
            +
                            transformer=self.transformer,
         
     | 
| 757 | 
         
            +
                            scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
         
     | 
| 758 | 
         
            +
                        )
         
     | 
| 759 | 
         
            +
                    else:
         
     | 
| 760 | 
         
            +
                        clip_image_encoder = CLIPVisionModelWithProjection.from_pretrained(
         
     | 
| 761 | 
         
            +
                            model_name, subfolder="image_encoder"
         
     | 
| 762 | 
         
            +
                        ).to("cuda", weight_dtype)
         
     | 
| 763 | 
         
            +
                        clip_image_processor = CLIPImageProcessor.from_pretrained(
         
     | 
| 764 | 
         
            +
                            model_name, subfolder="image_encoder"
         
     | 
| 765 | 
         
            +
                        )
         
     | 
| 766 | 
         
            +
                        self.pipeline = EasyAnimateInpaintPipeline(
         
     | 
| 767 | 
         
            +
                            vae=self.vae, 
         
     | 
| 768 | 
         
            +
                            text_encoder=self.text_encoder, 
         
     | 
| 769 | 
         
            +
                            tokenizer=self.tokenizer, 
         
     | 
| 770 | 
         
            +
                            transformer=self.transformer,
         
     | 
| 771 | 
         
            +
                            scheduler=scheduler_dict["Euler"](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs)),
         
     | 
| 772 | 
         
            +
                            clip_image_encoder=clip_image_encoder,
         
     | 
| 773 | 
         
            +
                            clip_image_processor=clip_image_processor,
         
     | 
| 774 | 
         
            +
                        )
         
     | 
| 775 | 
         
            +
                    
         
     | 
| 776 | 
         
             
                    print("Update diffusion transformer done")
         
     | 
| 777 | 
         | 
| 778 | 
         
            +
             
     | 
| 779 | 
         
            +
                def refresh_personalized_model(self):
         
     | 
| 780 | 
         
            +
                    personalized_model_list = sorted(glob(os.path.join(self.personalized_model_dir, "*.safetensors")))
         
     | 
| 781 | 
         
            +
                    self.personalized_model_list = [os.path.basename(p) for p in personalized_model_list]
         
     | 
| 782 | 
         
            +
             
     | 
| 783 | 
         
            +
             
     | 
| 784 | 
         
            +
                def update_lora_model(self, lora_model_dropdown):
         
     | 
| 785 | 
         
            +
                    print("Update lora model")
         
     | 
| 786 | 
         
            +
                    if lora_model_dropdown == "none":
         
     | 
| 787 | 
         
            +
                        self.lora_model_path = "none"
         
     | 
| 788 | 
         
            +
                        return gr.update()
         
     | 
| 789 | 
         
            +
                    lora_model_dropdown = os.path.join(self.personalized_model_dir, lora_model_dropdown)
         
     | 
| 790 | 
         
            +
                    self.lora_model_path = lora_model_dropdown
         
     | 
| 791 | 
         
            +
                    return gr.update()
         
     | 
| 792 | 
         
            +
             
     | 
| 793 | 
         
            +
                
         
     | 
| 794 | 
         
             
                def generate(
         
     | 
| 795 | 
         
             
                    self,
         
     | 
| 796 | 
         
            +
                    diffusion_transformer_dropdown,
         
     | 
| 797 | 
         
            +
                    motion_module_dropdown,
         
     | 
| 798 | 
         
            +
                    base_model_dropdown,
         
     | 
| 799 | 
         
            +
                    lora_model_dropdown, 
         
     | 
| 800 | 
         
            +
                    lora_alpha_slider,
         
     | 
| 801 | 
         
             
                    prompt_textbox, 
         
     | 
| 802 | 
         
             
                    negative_prompt_textbox, 
         
     | 
| 803 | 
         
             
                    sampler_dropdown, 
         
     | 
| 804 | 
         
             
                    sample_step_slider, 
         
     | 
| 805 | 
         
            +
                    resize_method,
         
     | 
| 806 | 
         
             
                    width_slider, 
         
     | 
| 807 | 
         
             
                    height_slider, 
         
     | 
| 808 | 
         
            +
                    base_resolution, 
         
     | 
| 809 | 
         
            +
                    generation_method, 
         
     | 
| 810 | 
         
             
                    length_slider, 
         
     | 
| 811 | 
         
             
                    cfg_scale_slider, 
         
     | 
| 812 | 
         
            +
                    start_image, 
         
     | 
| 813 | 
         
            +
                    end_image, 
         
     | 
| 814 | 
         
            +
                    seed_textbox,
         
     | 
| 815 | 
         
            +
                    is_api = False,
         
     | 
| 816 | 
         
             
                ):    
         
     | 
| 817 | 
         
            +
                    gc.collect()
         
     | 
| 818 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 819 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 820 | 
         
            +
             
     | 
| 821 | 
         
            +
                    if self.transformer is None:
         
     | 
| 822 | 
         
            +
                        raise gr.Error(f"Please select a pretrained model path.")
         
     | 
| 823 | 
         
            +
             
     | 
| 824 | 
         
            +
                    if self.lora_model_path != lora_model_dropdown:
         
     | 
| 825 | 
         
            +
                        print("Update lora model")
         
     | 
| 826 | 
         
            +
                        self.update_lora_model(lora_model_dropdown)
         
     | 
| 827 | 
         
            +
             
     | 
| 828 | 
         
            +
                    if resize_method == "Resize to the Start Image":
         
     | 
| 829 | 
         
            +
                        if start_image is None:
         
     | 
| 830 | 
         
            +
                            raise gr.Error(f"Please upload an image when using \"Resize to the Start Image\".")
         
     | 
| 831 | 
         
            +
                    
         
     | 
| 832 | 
         
            +
                        aspect_ratio_sample_size    = {key : [x / 512 * base_resolution for x in ASPECT_RATIO_512[key]] for key in ASPECT_RATIO_512.keys()}
         
     | 
| 833 | 
         
            +
                        original_width, original_height = start_image[0].size if type(start_image) is list else Image.open(start_image).size
         
     | 
| 834 | 
         
            +
                        closest_size, closest_ratio = get_closest_ratio(original_height, original_width, ratios=aspect_ratio_sample_size)
         
     | 
| 835 | 
         
            +
                        height_slider, width_slider = [int(x / 16) * 16 for x in closest_size]
         
     | 
| 836 | 
         
            +
             
     | 
| 837 | 
         
            +
                    if self.transformer.config.in_channels != 12 and start_image is not None:
         
     | 
| 838 | 
         
            +
                        raise gr.Error(f"Please select an image to video pretrained model while using image to video.")
         
     | 
| 839 | 
         
            +
                    
         
     | 
| 840 | 
         
            +
                    if start_image is None and end_image is not None:
         
     | 
| 841 | 
         
            +
                        raise gr.Error(f"If specifying the ending image of the video, please specify a starting image of the video.")
         
     | 
| 842 | 
         
            +
             
     | 
| 843 | 
         
            +
                    is_image = True if generation_method == "Image Generation" else False
         
     | 
| 844 | 
         
            +
             
     | 
| 845 | 
         
             
                    if is_xformers_available(): self.transformer.enable_xformers_memory_efficient_attention()
         
     | 
| 846 | 
         | 
| 847 | 
         
             
                    self.pipeline.scheduler = scheduler_dict[sampler_dropdown](**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
         
     | 
| 848 | 
         
            +
                    if self.lora_model_path != "none":
         
     | 
| 849 | 
         
            +
                        # lora part
         
     | 
| 850 | 
         
            +
                        self.pipeline = merge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         
     | 
| 851 | 
         
             
                    self.pipeline.to("cuda")
         
     | 
| 852 | 
         | 
| 853 | 
         
             
                    if int(seed_textbox) != -1 and seed_textbox != "": torch.manual_seed(int(seed_textbox))
         
     | 
| 
         | 
|
| 855 | 
         
             
                    generator = torch.Generator(device="cuda").manual_seed(int(seed_textbox))
         
     | 
| 856 | 
         | 
| 857 | 
         
             
                    try:
         
     | 
| 858 | 
         
            +
                        if self.transformer.config.in_channels == 12:
         
     | 
| 859 | 
         
            +
                            input_video, input_video_mask, clip_image = get_image_to_video_latent(start_image, end_image, length_slider if not is_image else 1, sample_size=(height_slider, width_slider))
         
     | 
| 860 | 
         
            +
             
     | 
| 861 | 
         
            +
                            sample = self.pipeline(
         
     | 
| 862 | 
         
            +
                                prompt_textbox,
         
     | 
| 863 | 
         
            +
                                negative_prompt     = negative_prompt_textbox,
         
     | 
| 864 | 
         
            +
                                num_inference_steps = sample_step_slider,
         
     | 
| 865 | 
         
            +
                                guidance_scale      = cfg_scale_slider,
         
     | 
| 866 | 
         
            +
                                width               = width_slider,
         
     | 
| 867 | 
         
            +
                                height              = height_slider,
         
     | 
| 868 | 
         
            +
                                video_length        = length_slider if not is_image else 1,
         
     | 
| 869 | 
         
            +
                                generator           = generator,
         
     | 
| 870 | 
         
            +
             
     | 
| 871 | 
         
            +
                                video        = input_video,
         
     | 
| 872 | 
         
            +
                                mask_video   = input_video_mask,
         
     | 
| 873 | 
         
            +
                                clip_image   = clip_image, 
         
     | 
| 874 | 
         
            +
                            ).videos
         
     | 
| 875 | 
         
            +
                        else:
         
     | 
| 876 | 
         
            +
                            sample = self.pipeline(
         
     | 
| 877 | 
         
            +
                                prompt_textbox,
         
     | 
| 878 | 
         
            +
                                negative_prompt     = negative_prompt_textbox,
         
     | 
| 879 | 
         
            +
                                num_inference_steps = sample_step_slider,
         
     | 
| 880 | 
         
            +
                                guidance_scale      = cfg_scale_slider,
         
     | 
| 881 | 
         
            +
                                width               = width_slider,
         
     | 
| 882 | 
         
            +
                                height              = height_slider,
         
     | 
| 883 | 
         
            +
                                video_length        = length_slider if not is_image else 1,
         
     | 
| 884 | 
         
            +
                                generator           = generator
         
     | 
| 885 | 
         
            +
                            ).videos
         
     | 
| 886 | 
         
             
                    except Exception as e:
         
     | 
| 887 | 
         
             
                        gc.collect()
         
     | 
| 888 | 
         
             
                        torch.cuda.empty_cache()
         
     | 
| 889 | 
         
             
                        torch.cuda.ipc_collect()
         
     | 
| 890 | 
         
            +
                        if self.lora_model_path != "none":
         
     | 
| 891 | 
         
            +
                            self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         
     | 
| 892 | 
         
            +
                        if is_api:
         
     | 
| 893 | 
         
            +
                            return "", f"Error. error information is {str(e)}"
         
     | 
| 894 | 
         
            +
                        else:
         
     | 
| 895 | 
         
            +
                            return gr.update(), gr.update(), f"Error. error information is {str(e)}"
         
     | 
| 896 | 
         
            +
             
     | 
| 897 | 
         
            +
                    gc.collect()
         
     | 
| 898 | 
         
            +
                    torch.cuda.empty_cache()
         
     | 
| 899 | 
         
            +
                    torch.cuda.ipc_collect()
         
     | 
| 900 | 
         
            +
                    
         
     | 
| 901 | 
         
            +
                    # lora part
         
     | 
| 902 | 
         
            +
                    if self.lora_model_path != "none":
         
     | 
| 903 | 
         
            +
                        self.pipeline = unmerge_lora(self.pipeline, self.lora_model_path, multiplier=lora_alpha_slider)
         
     | 
| 904 | 
         | 
| 905 | 
         
             
                    if not os.path.exists(self.savedir_sample):
         
     | 
| 906 | 
         
             
                        os.makedirs(self.savedir_sample, exist_ok=True)
         
     | 
| 
         | 
|
| 918 | 
         
             
                        image = (image * 255).numpy().astype(np.uint8)
         
     | 
| 919 | 
         
             
                        image = Image.fromarray(image)
         
     | 
| 920 | 
         
             
                        image.save(save_sample_path)
         
     | 
| 921 | 
         
            +
                        if is_api:
         
     | 
| 922 | 
         
            +
                            return save_sample_path, "Success"
         
     | 
| 923 | 
         
            +
                        else:
         
     | 
| 924 | 
         
            +
                            if gradio_version_is_above_4:
         
     | 
| 925 | 
         
            +
                                return gr.Image(value=save_sample_path, visible=True), gr.Video(value=None, visible=False), "Success"
         
     | 
| 926 | 
         
            +
                            else:
         
     | 
| 927 | 
         
            +
                                return gr.Image.update(value=save_sample_path, visible=True), gr.Video.update(value=None, visible=False), "Success"
         
     | 
| 928 | 
         
             
                    else:
         
     | 
| 929 | 
         
             
                        save_sample_path = os.path.join(self.savedir_sample, prefix + f".mp4")
         
     | 
| 930 | 
         
             
                        save_videos_grid(sample, save_sample_path, fps=12 if self.edition == "v1" else 24)
         
     | 
| 931 | 
         
            +
                        if is_api:
         
     | 
| 932 | 
         
            +
                            return save_sample_path, "Success"
         
     | 
| 933 | 
         
            +
                        else:
         
     | 
| 934 | 
         
            +
                            if gradio_version_is_above_4:
         
     | 
| 935 | 
         
            +
                                return gr.Image(visible=False, value=None), gr.Video(value=save_sample_path, visible=True), "Success"
         
     | 
| 936 | 
         
            +
                            else:
         
     | 
| 937 | 
         
            +
                                return gr.Image.update(visible=False, value=None), gr.Video.update(value=save_sample_path, visible=True), "Success"
         
     | 
| 938 | 
         | 
| 939 | 
         | 
| 940 | 
         
             
            def ui_modelscope(edition, config_path, model_name, savedir_sample):
         
     | 
| 
         | 
|
| 953 | 
         
             
                        """
         
     | 
| 954 | 
         
             
                    )
         
     | 
| 955 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 956 | 
         
            +
                        gr.Markdown(
         
     | 
| 957 | 
         
            +
                            """
         
     | 
| 958 | 
         
            +
                            ### 1. Model checkpoints (模型路径).
         
     | 
| 959 | 
         
            +
                            """
         
     | 
| 960 | 
         
            +
                        )
         
     | 
| 961 | 
         
            +
                        with gr.Row():
         
     | 
| 962 | 
         
            +
                            diffusion_transformer_dropdown = gr.Dropdown(
         
     | 
| 963 | 
         
            +
                                label="Pretrained Model Path (预训练模型路径)",
         
     | 
| 964 | 
         
            +
                                choices=[model_name],
         
     | 
| 965 | 
         
            +
                                value=model_name,
         
     | 
| 966 | 
         
            +
                                interactive=False,
         
     | 
| 967 | 
         
            +
                            )
         
     | 
| 968 | 
         
            +
                        with gr.Row():
         
     | 
| 969 | 
         
            +
                            motion_module_dropdown = gr.Dropdown(
         
     | 
| 970 | 
         
            +
                                label="Select motion module (选择运动模块[非必需])",
         
     | 
| 971 | 
         
            +
                                choices=["none"],
         
     | 
| 972 | 
         
            +
                                value="none",
         
     | 
| 973 | 
         
            +
                                interactive=False,
         
     | 
| 974 | 
         
            +
                                visible=False
         
     | 
| 975 | 
         
            +
                            )
         
     | 
| 976 | 
         
            +
                            base_model_dropdown = gr.Dropdown(
         
     | 
| 977 | 
         
            +
                                label="Select base Dreambooth model (选择基模型[非必需])",
         
     | 
| 978 | 
         
            +
                                choices=["none"],
         
     | 
| 979 | 
         
            +
                                value="none",
         
     | 
| 980 | 
         
            +
                                interactive=False,
         
     | 
| 981 | 
         
            +
                                visible=False
         
     | 
| 982 | 
         
            +
                            )
         
     | 
| 983 | 
         
            +
                            with gr.Column(visible=False):
         
     | 
| 984 | 
         
            +
                                gr.Markdown(
         
     | 
| 985 | 
         
            +
                                    """
         
     | 
| 986 | 
         
            +
                                    ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/EasyAnimate/wiki/Training-Lora).
         
     | 
| 987 | 
         
            +
                                    """
         
     | 
| 988 | 
         
            +
                                )
         
     | 
| 989 | 
         
            +
                                with gr.Row():
         
     | 
| 990 | 
         
            +
                                    lora_model_dropdown = gr.Dropdown(
         
     | 
| 991 | 
         
            +
                                        label="Select LoRA model",
         
     | 
| 992 | 
         
            +
                                        choices=["none", "easyanimatev2_minimalism_lora.safetensors"],
         
     | 
| 993 | 
         
            +
                                        value="none",
         
     | 
| 994 | 
         
            +
                                        interactive=True,
         
     | 
| 995 | 
         
            +
                                    )
         
     | 
| 996 | 
         
            +
             
     | 
| 997 | 
         
            +
                                    lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
         
     | 
| 998 | 
         
            +
                            
         
     | 
| 999 | 
         
            +
                    with gr.Column(variant="panel"):
         
     | 
| 1000 | 
         
            +
                        gr.Markdown(
         
     | 
| 1001 | 
         
            +
                            """
         
     | 
| 1002 | 
         
            +
                            ### 2. Configs for Generation (生成参数配置).
         
     | 
| 1003 | 
         
            +
                            """
         
     | 
| 1004 | 
         
            +
                        )
         
     | 
| 1005 | 
         
            +
             
     | 
| 1006 | 
         
            +
                        prompt_textbox = gr.Textbox(label="Prompt (正向提示词)", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
         
     | 
| 1007 | 
         
            +
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt (负向提示词)", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion." )
         
     | 
| 1008 | 
         | 
| 1009 | 
         
             
                        with gr.Row():
         
     | 
| 1010 | 
         
             
                            with gr.Column():
         
     | 
| 1011 | 
         
             
                                with gr.Row():
         
     | 
| 1012 | 
         
            +
                                    sampler_dropdown   = gr.Dropdown(label="Sampling method (采样器种类)", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         
     | 
| 1013 | 
         
            +
                                    sample_step_slider = gr.Slider(label="Sampling steps (生成步数)", value=20, minimum=10, maximum=30, step=1, interactive=False)
         
     | 
| 1014 | 
         | 
| 1015 | 
         
             
                                if edition == "v1":
         
     | 
| 1016 | 
         
            +
                                    width_slider     = gr.Slider(label="Width (视频宽度)",            value=512, minimum=384, maximum=704, step=32)
         
     | 
| 1017 | 
         
            +
                                    height_slider    = gr.Slider(label="Height (视频高度)",           value=512, minimum=384, maximum=704, step=32)
         
     | 
| 1018 | 
         
            +
             
     | 
| 1019 | 
         
            +
                                    with gr.Group():
         
     | 
| 1020 | 
         
            +
                                        generation_method = gr.Radio(
         
     | 
| 1021 | 
         
            +
                                            ["Video Generation", "Image Generation"],
         
     | 
| 1022 | 
         
            +
                                            value="Video Generation",
         
     | 
| 1023 | 
         
            +
                                            show_label=False,
         
     | 
| 1024 | 
         
            +
                                            visible=False,
         
     | 
| 1025 | 
         
            +
                                        )
         
     | 
| 1026 | 
         
            +
                                        length_slider = gr.Slider(label="Animation length (视频帧数)", value=80,  minimum=40,  maximum=96,   step=1)
         
     | 
| 1027 | 
         
            +
                                    cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)",        value=6.0, minimum=0,   maximum=20)
         
     | 
| 1028 | 
         
             
                                else:
         
     | 
| 1029 | 
         
            +
                                    resize_method = gr.Radio(
         
     | 
| 1030 | 
         
            +
                                        ["Generate by", "Resize to the Start Image"],
         
     | 
| 1031 | 
         
            +
                                        value="Generate by",
         
     | 
| 1032 | 
         
            +
                                        show_label=False,
         
     | 
| 1033 | 
         
            +
                                    )                        
         
     | 
| 1034 | 
         
             
                                    with gr.Column():
         
     | 
| 1035 | 
         
             
                                        gr.Markdown(
         
     | 
| 1036 | 
         
             
                                            """                    
         
     | 
| 1037 | 
         
            +
                                            We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s). 
         
     | 
| 1038 | 
         
            +
             
     | 
| 1039 | 
         
            +
                                            If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above. 
         
     | 
| 1040 | 
         
            +
                                            
         
     | 
| 1041 | 
         
            +
                                            If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/). 
         
     | 
| 1042 | 
         
             
                                            """
         
     | 
| 1043 | 
         
             
                                        )
         
     | 
| 1044 | 
         
            +
                                    width_slider     = gr.Slider(label="Width (视频宽度)",            value=672, minimum=128, maximum=1280, step=16, interactive=False)
         
     | 
| 1045 | 
         
            +
                                    height_slider    = gr.Slider(label="Height (视频高度)",           value=384, minimum=128, maximum=1280, step=16, interactive=False)
         
     | 
| 1046 | 
         
            +
                                    base_resolution  = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
         
     | 
| 1047 | 
         
            +
             
     | 
| 1048 | 
         
            +
                                    with gr.Group():
         
     | 
| 1049 | 
         
            +
                                        generation_method = gr.Radio(
         
     | 
| 1050 | 
         
            +
                                            ["Video Generation", "Image Generation"],
         
     | 
| 1051 | 
         
            +
                                            value="Video Generation",
         
     | 
| 1052 | 
         
            +
                                            show_label=False,
         
     | 
| 1053 | 
         
            +
                                            visible=True,
         
     | 
| 1054 | 
         
            +
                                        )
         
     | 
| 1055 | 
         
            +
                                        length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8,   maximum=48,  step=8)
         
     | 
| 1056 | 
         
            +
                                    
         
     | 
| 1057 | 
         
            +
                                    with gr.Accordion("Image to Video (图片到视频)", open=True):
         
     | 
| 1058 | 
         
             
                                        with gr.Row():
         
     | 
| 1059 | 
         
            +
                                            start_image = gr.Image(label="The image at the beginning of the video (图片到视频的开始图片)", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
         
     | 
| 1060 | 
         
            +
                                        
         
     | 
| 1061 | 
         
            +
                                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         
     | 
| 1062 | 
         
            +
                                        def select_template(evt: gr.SelectData):
         
     | 
| 1063 | 
         
            +
                                            text = {
         
     | 
| 1064 | 
         
            +
                                                "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1065 | 
         
            +
                                                "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1066 | 
         
            +
                                                "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1067 | 
         
            +
                                                "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1068 | 
         
            +
                                                "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1069 | 
         
            +
                                            }[template_gallery_path[evt.index]]
         
     | 
| 1070 | 
         
            +
                                            return template_gallery_path[evt.index], text
         
     | 
| 1071 | 
         
            +
             
     | 
| 1072 | 
         
            +
                                        template_gallery = gr.Gallery(
         
     | 
| 1073 | 
         
            +
                                            template_gallery_path,
         
     | 
| 1074 | 
         
            +
                                            columns=5, rows=1,
         
     | 
| 1075 | 
         
            +
                                            height=140,
         
     | 
| 1076 | 
         
            +
                                            allow_preview=False,
         
     | 
| 1077 | 
         
            +
                                            container=False,
         
     | 
| 1078 | 
         
            +
                                            label="Template Examples",
         
     | 
| 1079 | 
         
            +
                                        )
         
     | 
| 1080 | 
         
            +
                                        template_gallery.select(select_template, None, [start_image, prompt_textbox])
         
     | 
| 1081 | 
         
            +
             
     | 
| 1082 | 
         
            +
                                        with gr.Accordion("The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", open=False):
         
     | 
| 1083 | 
         
            +
                                            end_image   = gr.Image(label="The image at the ending of the video (图片到视频的结束图片[非必需, Optional])", show_label=False, elem_id="i2v_end", sources="upload", type="filepath")
         
     | 
| 1084 | 
         
            +
             
     | 
| 1085 | 
         
            +
             
     | 
| 1086 | 
         
            +
                                    cfg_scale_slider = gr.Slider(label="CFG Scale (引导系数)",        value=7.0, minimum=0,   maximum=20)
         
     | 
| 1087 | 
         | 
| 1088 | 
         
             
                                with gr.Row():
         
     | 
| 1089 | 
         
            +
                                    seed_textbox = gr.Textbox(label="Seed (随机种子)", value=43)
         
     | 
| 1090 | 
         
             
                                    seed_button  = gr.Button(value="\U0001F3B2", elem_classes="toolbutton")
         
     | 
| 1091 | 
         
            +
                                    seed_button.click(
         
     | 
| 1092 | 
         
            +
                                        fn=lambda: gr.Textbox(value=random.randint(1, 1e8)) if gradio_version_is_above_4 else gr.Textbox.update(value=random.randint(1, 1e8)), 
         
     | 
| 1093 | 
         
            +
                                        inputs=[], 
         
     | 
| 1094 | 
         
            +
                                        outputs=[seed_textbox]
         
     | 
| 1095 | 
         
            +
                                    )
         
     | 
| 1096 | 
         | 
| 1097 | 
         
            +
                                generate_button = gr.Button(value="Generate (生成)", variant='primary')
         
     | 
| 1098 | 
         | 
| 1099 | 
         
             
                            with gr.Column():
         
     | 
| 1100 | 
         
            +
                                result_image = gr.Image(label="Generated Image (生成图片)", interactive=False, visible=False)
         
     | 
| 1101 | 
         
            +
                                result_video = gr.Video(label="Generated Animation (生成视频)", interactive=False)
         
     | 
| 1102 | 
         
             
                                infer_progress = gr.Textbox(
         
     | 
| 1103 | 
         
            +
                                    label="Generation Info (生成信息)",
         
     | 
| 1104 | 
         
             
                                    value="No task currently",
         
     | 
| 1105 | 
         
             
                                    interactive=False
         
     | 
| 1106 | 
         
             
                                )
         
     | 
| 1107 | 
         | 
| 1108 | 
         
            +
                        def upload_generation_method(generation_method):
         
     | 
| 1109 | 
         
            +
                            if generation_method == "Video Generation":
         
     | 
| 1110 | 
         
            +
                                return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True)
         
     | 
| 1111 | 
         
            +
                            elif generation_method == "Image Generation":
         
     | 
| 1112 | 
         
            +
                                return gr.update(minimum=1, maximum=1, value=1, interactive=False)
         
     | 
| 1113 | 
         
            +
                        generation_method.change(
         
     | 
| 1114 | 
         
            +
                            upload_generation_method, generation_method, [length_slider]
         
     | 
| 1115 | 
         
            +
                        )
         
     | 
| 1116 | 
         
            +
             
     | 
| 1117 | 
         
            +
                        def upload_resize_method(resize_method):
         
     | 
| 1118 | 
         
            +
                            if resize_method == "Generate by":
         
     | 
| 1119 | 
         
            +
                                return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
         
     | 
| 1120 | 
         
            +
                            else:
         
     | 
| 1121 | 
         
            +
                                return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
         
     | 
| 1122 | 
         
            +
                        resize_method.change(
         
     | 
| 1123 | 
         
            +
                            upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
         
     | 
| 1124 | 
         
             
                        )
         
     | 
| 1125 | 
         | 
| 1126 | 
         
             
                        generate_button.click(
         
     | 
| 1127 | 
         
             
                            fn=controller.generate,
         
     | 
| 1128 | 
         
             
                            inputs=[
         
     | 
| 1129 | 
         
            +
                                diffusion_transformer_dropdown,
         
     | 
| 1130 | 
         
            +
                                motion_module_dropdown,
         
     | 
| 1131 | 
         
            +
                                base_model_dropdown,
         
     | 
| 1132 | 
         
            +
                                lora_model_dropdown, 
         
     | 
| 1133 | 
         
            +
                                lora_alpha_slider,
         
     | 
| 1134 | 
         
             
                                prompt_textbox, 
         
     | 
| 1135 | 
         
             
                                negative_prompt_textbox, 
         
     | 
| 1136 | 
         
             
                                sampler_dropdown, 
         
     | 
| 1137 | 
         
             
                                sample_step_slider, 
         
     | 
| 1138 | 
         
            +
                                resize_method,
         
     | 
| 1139 | 
         
             
                                width_slider, 
         
     | 
| 1140 | 
         
             
                                height_slider, 
         
     | 
| 1141 | 
         
            +
                                base_resolution, 
         
     | 
| 1142 | 
         
            +
                                generation_method, 
         
     | 
| 1143 | 
         
             
                                length_slider, 
         
     | 
| 1144 | 
         
             
                                cfg_scale_slider, 
         
     | 
| 1145 | 
         
            +
                                start_image, 
         
     | 
| 1146 | 
         
            +
                                end_image, 
         
     | 
| 1147 | 
         
             
                                seed_textbox,
         
     | 
| 1148 | 
         
             
                            ],
         
     | 
| 1149 | 
         
             
                            outputs=[result_image, result_video, infer_progress]
         
     | 
| 
         | 
|
| 1152 | 
         | 
| 1153 | 
         | 
| 1154 | 
         
             
            def post_eas(
         
     | 
| 1155 | 
         
            +
                diffusion_transformer_dropdown, motion_module_dropdown,
         
     | 
| 1156 | 
         
            +
                base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
         
     | 
| 1157 | 
         
             
                prompt_textbox, negative_prompt_textbox, 
         
     | 
| 1158 | 
         
            +
                sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
         
     | 
| 1159 | 
         
            +
                base_resolution, generation_method, length_slider, cfg_scale_slider, 
         
     | 
| 1160 | 
         
            +
                start_image, end_image, seed_textbox,
         
     | 
| 1161 | 
         
             
            ):
         
     | 
| 1162 | 
         
            +
                if start_image is not None:
         
     | 
| 1163 | 
         
            +
                    with open(start_image, 'rb') as file:
         
     | 
| 1164 | 
         
            +
                        file_content = file.read()
         
     | 
| 1165 | 
         
            +
                        start_image_encoded_content = base64.b64encode(file_content)
         
     | 
| 1166 | 
         
            +
                        start_image = start_image_encoded_content.decode('utf-8')
         
     | 
| 1167 | 
         
            +
             
     | 
| 1168 | 
         
            +
                if end_image is not None:
         
     | 
| 1169 | 
         
            +
                    with open(end_image, 'rb') as file:
         
     | 
| 1170 | 
         
            +
                        file_content = file.read()
         
     | 
| 1171 | 
         
            +
                        end_image_encoded_content = base64.b64encode(file_content)
         
     | 
| 1172 | 
         
            +
                        end_image = end_image_encoded_content.decode('utf-8')
         
     | 
| 1173 | 
         
            +
             
     | 
| 1174 | 
         
             
                datas = {
         
     | 
| 1175 | 
         
            +
                    "base_model_path": base_model_dropdown,
         
     | 
| 1176 | 
         
            +
                    "motion_module_path": motion_module_dropdown,
         
     | 
| 1177 | 
         
            +
                    "lora_model_path": lora_model_dropdown, 
         
     | 
| 1178 | 
         
            +
                    "lora_alpha_slider": lora_alpha_slider, 
         
     | 
| 1179 | 
         
             
                    "prompt_textbox": prompt_textbox, 
         
     | 
| 1180 | 
         
             
                    "negative_prompt_textbox": negative_prompt_textbox, 
         
     | 
| 1181 | 
         
             
                    "sampler_dropdown": sampler_dropdown, 
         
     | 
| 1182 | 
         
             
                    "sample_step_slider": sample_step_slider, 
         
     | 
| 1183 | 
         
            +
                    "resize_method": resize_method,
         
     | 
| 1184 | 
         
             
                    "width_slider": width_slider, 
         
     | 
| 1185 | 
         
             
                    "height_slider": height_slider, 
         
     | 
| 1186 | 
         
            +
                    "base_resolution": base_resolution,
         
     | 
| 1187 | 
         
            +
                    "generation_method": generation_method,
         
     | 
| 1188 | 
         
             
                    "length_slider": length_slider,
         
     | 
| 1189 | 
         
             
                    "cfg_scale_slider": cfg_scale_slider,
         
     | 
| 1190 | 
         
            +
                    "start_image": start_image,
         
     | 
| 1191 | 
         
            +
                    "end_image": end_image,
         
     | 
| 1192 | 
         
             
                    "seed_textbox": seed_textbox,
         
     | 
| 1193 | 
         
             
                }
         
     | 
| 1194 | 
         
            +
             
     | 
| 1195 | 
         
             
                session = requests.session()
         
     | 
| 1196 | 
         
             
                session.headers.update({"Authorization": os.environ.get("EAS_TOKEN")})
         
     | 
| 1197 | 
         | 
| 1198 | 
         
            +
                response = session.post(url=f'{os.environ.get("EAS_URL")}/easyanimate/infer_forward', json=datas, timeout=300)
         
     | 
| 1199 | 
         
            +
             
     | 
| 1200 | 
         
             
                outputs = response.json()
         
     | 
| 1201 | 
         
             
                return outputs
         
     | 
| 1202 | 
         | 
| 
         | 
|
| 1208 | 
         | 
| 1209 | 
         
             
                def generate(
         
     | 
| 1210 | 
         
             
                    self,
         
     | 
| 1211 | 
         
            +
                    diffusion_transformer_dropdown,
         
     | 
| 1212 | 
         
            +
                    motion_module_dropdown,
         
     | 
| 1213 | 
         
            +
                    base_model_dropdown,
         
     | 
| 1214 | 
         
            +
                    lora_model_dropdown, 
         
     | 
| 1215 | 
         
            +
                    lora_alpha_slider,
         
     | 
| 1216 | 
         
             
                    prompt_textbox, 
         
     | 
| 1217 | 
         
             
                    negative_prompt_textbox, 
         
     | 
| 1218 | 
         
             
                    sampler_dropdown, 
         
     | 
| 1219 | 
         
             
                    sample_step_slider, 
         
     | 
| 1220 | 
         
            +
                    resize_method,
         
     | 
| 1221 | 
         
             
                    width_slider, 
         
     | 
| 1222 | 
         
             
                    height_slider, 
         
     | 
| 1223 | 
         
            +
                    base_resolution, 
         
     | 
| 1224 | 
         
            +
                    generation_method, 
         
     | 
| 1225 | 
         
             
                    length_slider, 
         
     | 
| 1226 | 
         
             
                    cfg_scale_slider, 
         
     | 
| 1227 | 
         
            +
                    start_image, 
         
     | 
| 1228 | 
         
            +
                    end_image, 
         
     | 
| 1229 | 
         
             
                    seed_textbox
         
     | 
| 1230 | 
         
             
                ):
         
     | 
| 1231 | 
         
            +
                    is_image = True if generation_method == "Image Generation" else False
         
     | 
| 1232 | 
         
            +
             
     | 
| 1233 | 
         
             
                    outputs = post_eas(
         
     | 
| 1234 | 
         
            +
                        diffusion_transformer_dropdown, motion_module_dropdown,
         
     | 
| 1235 | 
         
            +
                        base_model_dropdown, lora_model_dropdown, lora_alpha_slider,
         
     | 
| 1236 | 
         
             
                        prompt_textbox, negative_prompt_textbox, 
         
     | 
| 1237 | 
         
            +
                        sampler_dropdown, sample_step_slider, resize_method, width_slider, height_slider,
         
     | 
| 1238 | 
         
            +
                        base_resolution, generation_method, length_slider, cfg_scale_slider, 
         
     | 
| 1239 | 
         
            +
                        start_image, end_image, 
         
     | 
| 1240 | 
         
            +
                        seed_textbox
         
     | 
| 1241 | 
         
             
                    )
         
     | 
| 1242 | 
         
            +
                    try:
         
     | 
| 1243 | 
         
            +
                        base64_encoding = outputs["base64_encoding"]
         
     | 
| 1244 | 
         
            +
                    except:
         
     | 
| 1245 | 
         
            +
                        return gr.Image(visible=False, value=None), gr.Video(None, visible=True), outputs["message"]
         
     | 
| 1246 | 
         
            +
                        
         
     | 
| 1247 | 
         
             
                    decoded_data = base64.b64decode(base64_encoding)
         
     | 
| 1248 | 
         | 
| 1249 | 
         
             
                    if not os.path.exists(self.savedir_sample):
         
     | 
| 
         | 
|
| 1285 | 
         
             
                        """
         
     | 
| 1286 | 
         
             
                    )
         
     | 
| 1287 | 
         
             
                    with gr.Column(variant="panel"):
         
     | 
| 1288 | 
         
            +
                        gr.Markdown(
         
     | 
| 1289 | 
         
            +
                            """
         
     | 
| 1290 | 
         
            +
                            ### 1. Model checkpoints.
         
     | 
| 1291 | 
         
            +
                            """
         
     | 
| 1292 | 
         
            +
                        )
         
     | 
| 1293 | 
         
            +
                        with gr.Row():
         
     | 
| 1294 | 
         
            +
                            diffusion_transformer_dropdown = gr.Dropdown(
         
     | 
| 1295 | 
         
            +
                                label="Pretrained Model Path",
         
     | 
| 1296 | 
         
            +
                                choices=[model_name],
         
     | 
| 1297 | 
         
            +
                                value=model_name,
         
     | 
| 1298 | 
         
            +
                                interactive=False,
         
     | 
| 1299 | 
         
            +
                            )
         
     | 
| 1300 | 
         
            +
                        with gr.Row():
         
     | 
| 1301 | 
         
            +
                            motion_module_dropdown = gr.Dropdown(
         
     | 
| 1302 | 
         
            +
                                label="Select motion module",
         
     | 
| 1303 | 
         
            +
                                choices=["none"],
         
     | 
| 1304 | 
         
            +
                                value="none",
         
     | 
| 1305 | 
         
            +
                                interactive=False,
         
     | 
| 1306 | 
         
            +
                                visible=False
         
     | 
| 1307 | 
         
            +
                            )
         
     | 
| 1308 | 
         
            +
                            base_model_dropdown = gr.Dropdown(
         
     | 
| 1309 | 
         
            +
                                label="Select base Dreambooth model",
         
     | 
| 1310 | 
         
            +
                                choices=["none"],
         
     | 
| 1311 | 
         
            +
                                value="none",
         
     | 
| 1312 | 
         
            +
                                interactive=False,
         
     | 
| 1313 | 
         
            +
                                visible=False
         
     | 
| 1314 | 
         
            +
                            )
         
     | 
| 1315 | 
         
            +
                            with gr.Column(visible=False):
         
     | 
| 1316 | 
         
            +
                                gr.Markdown(
         
     | 
| 1317 | 
         
            +
                                    """
         
     | 
| 1318 | 
         
            +
                                    ### Minimalism is an example portrait of Lora, triggered by specific prompt words. More details can be found on [Wiki](https://github.com/aigc-apps/EasyAnimate/wiki/Training-Lora).
         
     | 
| 1319 | 
         
            +
                                    """
         
     | 
| 1320 | 
         
            +
                                )
         
     | 
| 1321 | 
         
            +
                                with gr.Row():
         
     | 
| 1322 | 
         
            +
                                    lora_model_dropdown = gr.Dropdown(
         
     | 
| 1323 | 
         
            +
                                        label="Select LoRA model",
         
     | 
| 1324 | 
         
            +
                                        choices=["none", "easyanimatev2_minimalism_lora.safetensors"],
         
     | 
| 1325 | 
         
            +
                                        value="none",
         
     | 
| 1326 | 
         
            +
                                        interactive=True,
         
     | 
| 1327 | 
         
            +
                                    )
         
     | 
| 1328 | 
         
            +
             
     | 
| 1329 | 
         
            +
                                    lora_alpha_slider = gr.Slider(label="LoRA alpha (LoRA权重)", value=0.55, minimum=0, maximum=2, interactive=True)
         
     | 
| 1330 | 
         
            +
                            
         
     | 
| 1331 | 
         
            +
                    with gr.Column(variant="panel"):
         
     | 
| 1332 | 
         
            +
                        gr.Markdown(
         
     | 
| 1333 | 
         
            +
                            """
         
     | 
| 1334 | 
         
            +
                            ### 2. Configs for Generation.
         
     | 
| 1335 | 
         
            +
                            """
         
     | 
| 1336 | 
         
            +
                        )
         
     | 
| 1337 | 
         
            +
                        
         
     | 
| 1338 | 
         
            +
                        prompt_textbox = gr.Textbox(label="Prompt", lines=2, value="A young woman with beautiful and clear eyes and blonde hair standing and white dress in a forest wearing a crown. She seems to be lost in thought, and the camera focuses on her face. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.")
         
     | 
| 1339 | 
         
             
                        negative_prompt_textbox = gr.Textbox(label="Negative prompt", lines=2, value="The video is not of a high quality, it has a low resolution, and the audio quality is not clear. Strange motion trajectory, a poor composition and deformed video, low resolution, duplicate and ugly, strange body structure, long and strange neck, bad teeth, bad eyes, bad limbs, bad hands, rotating camera, blurry camera, shaking camera. Deformation, low-resolution, blurry, ugly, distortion. " )
         
     | 
| 1340 | 
         | 
| 1341 | 
         
             
                        with gr.Row():
         
     | 
| 1342 | 
         
             
                            with gr.Column():
         
     | 
| 1343 | 
         
             
                                with gr.Row():
         
     | 
| 1344 | 
         
             
                                    sampler_dropdown   = gr.Dropdown(label="Sampling method", choices=list(scheduler_dict.keys()), value=list(scheduler_dict.keys())[0])
         
     | 
| 1345 | 
         
            +
                                    sample_step_slider = gr.Slider(label="Sampling steps", value=20, minimum=10, maximum=30, step=1, interactive=False)
         
     | 
| 1346 | 
         | 
| 1347 | 
         
             
                                if edition == "v1":
         
     | 
| 1348 | 
         
             
                                    width_slider     = gr.Slider(label="Width",            value=512, minimum=384, maximum=704, step=32)
         
     | 
| 1349 | 
         
             
                                    height_slider    = gr.Slider(label="Height",           value=512, minimum=384, maximum=704, step=32)
         
     | 
| 1350 | 
         
            +
             
     | 
| 1351 | 
         
            +
                                    with gr.Group():
         
     | 
| 1352 | 
         
            +
                                        generation_method = gr.Radio(
         
     | 
| 1353 | 
         
            +
                                            ["Video Generation", "Image Generation"],
         
     | 
| 1354 | 
         
            +
                                            value="Video Generation",
         
     | 
| 1355 | 
         
            +
                                            show_label=False,
         
     | 
| 1356 | 
         
            +
                                            visible=False,
         
     | 
| 1357 | 
         
            +
                                        )
         
     | 
| 1358 | 
         
            +
                                        length_slider    = gr.Slider(label="Animation length", value=80,  minimum=40,  maximum=96,   step=1)
         
     | 
| 1359 | 
         
             
                                    cfg_scale_slider = gr.Slider(label="CFG Scale",        value=6.0, minimum=0,   maximum=20)
         
     | 
| 1360 | 
         
             
                                else:
         
     | 
| 1361 | 
         
            +
                                    resize_method = gr.Radio(
         
     | 
| 1362 | 
         
            +
                                        ["Generate by", "Resize to the Start Image"],
         
     | 
| 1363 | 
         
            +
                                        value="Generate by",
         
     | 
| 1364 | 
         
            +
                                        show_label=False,
         
     | 
| 1365 | 
         
            +
                                    )                        
         
     | 
| 1366 | 
         
             
                                    with gr.Column():
         
     | 
| 1367 | 
         
             
                                        gr.Markdown(
         
     | 
| 1368 | 
         
             
                                            """                    
         
     | 
| 1369 | 
         
            +
                                            We support video generation up to 720p with 144 frames, but for the trial experience, we have set certain limitations. We fix the max resolution of video to 384x672x48 (2s). 
         
     | 
| 1370 | 
         
            +
             
     | 
| 1371 | 
         
            +
                                            If the start image you uploaded does not match this resolution, you can use the "Resize to the Start Image" option above. 
         
     | 
| 1372 | 
         
            +
                                            
         
     | 
| 1373 | 
         
            +
                                            If you want to experience longer and larger video generation, you can go to our [Github](https://github.com/aigc-apps/EasyAnimate/). 
         
     | 
| 1374 | 
         
             
                                            """
         
     | 
| 1375 | 
         
             
                                        )
         
     | 
| 1376 | 
         
            +
                                    width_slider     = gr.Slider(label="Width (视频宽度)",            value=672, minimum=128, maximum=1280, step=16, interactive=False)
         
     | 
| 1377 | 
         
            +
                                    height_slider    = gr.Slider(label="Height (视频高度)",           value=384, minimum=128, maximum=1280, step=16, interactive=False)
         
     | 
| 1378 | 
         
            +
                                    base_resolution  = gr.Radio(label="Base Resolution of Pretrained Models", value=512, choices=[512, 768, 960], interactive=False, visible=False)
         
     | 
| 1379 | 
         
            +
             
     | 
| 1380 | 
         
            +
                                    with gr.Group():
         
     | 
| 1381 | 
         
            +
                                        generation_method = gr.Radio(
         
     | 
| 1382 | 
         
            +
                                            ["Video Generation", "Image Generation"],
         
     | 
| 1383 | 
         
            +
                                            value="Video Generation",
         
     | 
| 1384 | 
         
            +
                                            show_label=False,
         
     | 
| 1385 | 
         
            +
                                            visible=True,
         
     | 
| 1386 | 
         
            +
                                        )
         
     | 
| 1387 | 
         
            +
                                        length_slider = gr.Slider(label="Animation length (视频帧数)", value=48, minimum=8,   maximum=48,  step=8)
         
     | 
| 1388 | 
         
            +
                                    
         
     | 
| 1389 | 
         
            +
                                    with gr.Accordion("Image to Video", open=True):
         
     | 
| 1390 | 
         
            +
                                        start_image = gr.Image(label="The image at the beginning of the video", show_label=True, elem_id="i2v_start", sources="upload", type="filepath")
         
     | 
| 1391 | 
         
            +
                                        
         
     | 
| 1392 | 
         
            +
                                        template_gallery_path = ["asset/1.png", "asset/2.png", "asset/3.png", "asset/4.png", "asset/5.png"]
         
     | 
| 1393 | 
         
            +
                                        def select_template(evt: gr.SelectData):
         
     | 
| 1394 | 
         
            +
                                            text = {
         
     | 
| 1395 | 
         
            +
                                                "asset/1.png": "The dog is looking at camera and smiling. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1396 | 
         
            +
                                                "asset/2.png": "a sailboat sailing in rough seas with a dramatic sunset. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1397 | 
         
            +
                                                "asset/3.png": "a beautiful woman with long hair and a dress blowing in the wind. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1398 | 
         
            +
                                                "asset/4.png": "a man in an astronaut suit playing a guitar. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1399 | 
         
            +
                                                "asset/5.png": "fireworks display over night city. The video is of high quality, and the view is very clear. High quality, masterpiece, best quality, highres, ultra-detailed, fantastic.", 
         
     | 
| 1400 | 
         
            +
                                            }[template_gallery_path[evt.index]]
         
     | 
| 1401 | 
         
            +
                                            return template_gallery_path[evt.index], text
         
     | 
| 1402 | 
         
            +
             
     | 
| 1403 | 
         
            +
                                        template_gallery = gr.Gallery(
         
     | 
| 1404 | 
         
            +
                                            template_gallery_path,
         
     | 
| 1405 | 
         
            +
                                            columns=5, rows=1,
         
     | 
| 1406 | 
         
            +
                                            height=140,
         
     | 
| 1407 | 
         
            +
                                            allow_preview=False,
         
     | 
| 1408 | 
         
            +
                                            container=False,
         
     | 
| 1409 | 
         
            +
                                            label="Template Examples",
         
     | 
| 1410 | 
         
            +
                                        )
         
     | 
| 1411 | 
         
            +
                                        template_gallery.select(select_template, None, [start_image, prompt_textbox])
         
     | 
| 1412 | 
         
            +
             
     | 
| 1413 | 
         
            +
                                        with gr.Accordion("The image at the ending of the video (Optional)", open=False):
         
     | 
| 1414 | 
         
            +
                                            end_image   = gr.Image(label="The image at the ending of the video (Optional)", show_label=True, elem_id="i2v_end", sources="upload", type="filepath")
         
     | 
| 1415 | 
         
            +
                                    
         
     | 
| 1416 | 
         
             
                                    cfg_scale_slider = gr.Slider(label="CFG Scale",        value=7.0, minimum=0,   maximum=20)
         
     | 
| 1417 | 
         | 
| 1418 | 
         
             
                                with gr.Row():
         
     | 
| 
         | 
|
| 1435 | 
         
             
                                    interactive=False
         
     | 
| 1436 | 
         
             
                                )
         
     | 
| 1437 | 
         | 
| 1438 | 
         
            +
                        def upload_generation_method(generation_method):
         
     | 
| 1439 | 
         
            +
                            if generation_method == "Video Generation":
         
     | 
| 1440 | 
         
            +
                                return gr.update(visible=True, minimum=8, maximum=48, value=48, interactive=True)
         
     | 
| 1441 | 
         
            +
                            elif generation_method == "Image Generation":
         
     | 
| 1442 | 
         
            +
                                return gr.update(minimum=1, maximum=1, value=1, interactive=False)
         
     | 
| 1443 | 
         
            +
                        generation_method.change(
         
     | 
| 1444 | 
         
            +
                            upload_generation_method, generation_method, [length_slider]
         
     | 
| 1445 | 
         
            +
                        )
         
     | 
| 1446 | 
         
            +
             
     | 
| 1447 | 
         
            +
                        def upload_resize_method(resize_method):
         
     | 
| 1448 | 
         
            +
                            if resize_method == "Generate by":
         
     | 
| 1449 | 
         
            +
                                return [gr.update(visible=True), gr.update(visible=True), gr.update(visible=False)]
         
     | 
| 1450 | 
         
            +
                            else:
         
     | 
| 1451 | 
         
            +
                                return [gr.update(visible=False), gr.update(visible=False), gr.update(visible=True)]
         
     | 
| 1452 | 
         
            +
                        resize_method.change(
         
     | 
| 1453 | 
         
            +
                            upload_resize_method, resize_method, [width_slider, height_slider, base_resolution]
         
     | 
| 1454 | 
         
             
                        )
         
     | 
| 1455 | 
         | 
| 1456 | 
         
             
                        generate_button.click(
         
     | 
| 1457 | 
         
             
                            fn=controller.generate,
         
     | 
| 1458 | 
         
             
                            inputs=[
         
     | 
| 1459 | 
         
            +
                                diffusion_transformer_dropdown,
         
     | 
| 1460 | 
         
            +
                                motion_module_dropdown,
         
     | 
| 1461 | 
         
            +
                                base_model_dropdown,
         
     | 
| 1462 | 
         
            +
                                lora_model_dropdown, 
         
     | 
| 1463 | 
         
            +
                                lora_alpha_slider,
         
     | 
| 1464 | 
         
             
                                prompt_textbox, 
         
     | 
| 1465 | 
         
             
                                negative_prompt_textbox, 
         
     | 
| 1466 | 
         
             
                                sampler_dropdown, 
         
     | 
| 1467 | 
         
             
                                sample_step_slider, 
         
     | 
| 1468 | 
         
            +
                                resize_method,
         
     | 
| 1469 | 
         
             
                                width_slider, 
         
     | 
| 1470 | 
         
             
                                height_slider, 
         
     | 
| 1471 | 
         
            +
                                base_resolution, 
         
     | 
| 1472 | 
         
            +
                                generation_method, 
         
     | 
| 1473 | 
         
             
                                length_slider, 
         
     | 
| 1474 | 
         
             
                                cfg_scale_slider, 
         
     | 
| 1475 | 
         
            +
                                start_image, 
         
     | 
| 1476 | 
         
            +
                                end_image, 
         
     | 
| 1477 | 
         
             
                                seed_textbox,
         
     | 
| 1478 | 
         
             
                            ],
         
     | 
| 1479 | 
         
             
                            outputs=[result_image, result_video, infer_progress]
         
     | 
    	
        easyanimate/utils/utils.py
    CHANGED
    
    | 
         @@ -8,6 +8,13 @@ import cv2 
     | 
|
| 8 | 
         
             
            from einops import rearrange
         
     | 
| 9 | 
         
             
            from PIL import Image
         
     | 
| 10 | 
         | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 11 | 
         | 
| 12 | 
         
             
            def color_transfer(sc, dc):
         
     | 
| 13 | 
         
             
                """
         
     | 
| 
         @@ -62,3 +69,103 @@ def save_videos_grid(videos: torch.Tensor, path: str, rescale=False, n_rows=6, f 
     | 
|
| 62 | 
         
             
                    if path.endswith("mp4"):
         
     | 
| 63 | 
         
             
                        path = path.replace('.mp4', '.gif')
         
     | 
| 64 | 
         
             
                    outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
         
     | 
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
|
| 
         | 
| 
         | 
|
| 8 | 
         
             
            from einops import rearrange
         
     | 
| 9 | 
         
             
            from PIL import Image
         
     | 
| 10 | 
         | 
| 11 | 
         
            +
            def get_width_and_height_from_image_and_base_resolution(image, base_resolution):
         
     | 
| 12 | 
         
            +
                target_pixels = int(base_resolution) * int(base_resolution)
         
     | 
| 13 | 
         
            +
                original_width, original_height = Image.open(image).size
         
     | 
| 14 | 
         
            +
                ratio = (target_pixels / (original_width * original_height)) ** 0.5
         
     | 
| 15 | 
         
            +
                width_slider = round(original_width * ratio)
         
     | 
| 16 | 
         
            +
                height_slider = round(original_height * ratio)
         
     | 
| 17 | 
         
            +
                return height_slider, width_slider
         
     | 
| 18 | 
         | 
| 19 | 
         
             
            def color_transfer(sc, dc):
         
     | 
| 20 | 
         
             
                """
         
     | 
| 
         | 
|
| 69 | 
         
             
                    if path.endswith("mp4"):
         
     | 
| 70 | 
         
             
                        path = path.replace('.mp4', '.gif')
         
     | 
| 71 | 
         
             
                    outputs[0].save(path, format='GIF', append_images=outputs, save_all=True, duration=100, loop=0)
         
     | 
| 72 | 
         
            +
             
     | 
| 73 | 
         
            +
            def get_image_to_video_latent(validation_image_start, validation_image_end, video_length, sample_size):
         
     | 
| 74 | 
         
            +
                if validation_image_start is not None and validation_image_end is not None:
         
     | 
| 75 | 
         
            +
                    if type(validation_image_start) is str and os.path.isfile(validation_image_start):
         
     | 
| 76 | 
         
            +
                        image_start = clip_image = Image.open(validation_image_start)
         
     | 
| 77 | 
         
            +
                    else:
         
     | 
| 78 | 
         
            +
                        image_start = clip_image = validation_image_start
         
     | 
| 79 | 
         
            +
                    if type(validation_image_end) is str and os.path.isfile(validation_image_end):
         
     | 
| 80 | 
         
            +
                        image_end = Image.open(validation_image_end)
         
     | 
| 81 | 
         
            +
                    else:
         
     | 
| 82 | 
         
            +
                        image_end = validation_image_end
         
     | 
| 83 | 
         
            +
             
     | 
| 84 | 
         
            +
                    if type(image_start) is list:
         
     | 
| 85 | 
         
            +
                        clip_image = clip_image[0]
         
     | 
| 86 | 
         
            +
                        start_video = torch.cat(
         
     | 
| 87 | 
         
            +
                            [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], 
         
     | 
| 88 | 
         
            +
                            dim=2
         
     | 
| 89 | 
         
            +
                        )
         
     | 
| 90 | 
         
            +
                        input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
         
     | 
| 91 | 
         
            +
                        input_video[:, :, :len(image_start)] = start_video
         
     | 
| 92 | 
         
            +
                        
         
     | 
| 93 | 
         
            +
                        input_video_mask = torch.zeros_like(input_video[:, :1])
         
     | 
| 94 | 
         
            +
                        input_video_mask[:, :, len(image_start):] = 255
         
     | 
| 95 | 
         
            +
                    else:
         
     | 
| 96 | 
         
            +
                        input_video = torch.tile(
         
     | 
| 97 | 
         
            +
                            torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), 
         
     | 
| 98 | 
         
            +
                            [1, 1, video_length, 1, 1]
         
     | 
| 99 | 
         
            +
                        )
         
     | 
| 100 | 
         
            +
                        input_video_mask = torch.zeros_like(input_video[:, :1])
         
     | 
| 101 | 
         
            +
                        input_video_mask[:, :, 1:] = 255
         
     | 
| 102 | 
         
            +
             
     | 
| 103 | 
         
            +
                    if type(image_end) is list:
         
     | 
| 104 | 
         
            +
                        image_end = [_image_end.resize(image_start[0].size if type(image_start) is list else image_start.size) for _image_end in image_end]
         
     | 
| 105 | 
         
            +
                        end_video = torch.cat(
         
     | 
| 106 | 
         
            +
                            [torch.from_numpy(np.array(_image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_end in image_end], 
         
     | 
| 107 | 
         
            +
                            dim=2
         
     | 
| 108 | 
         
            +
                        )
         
     | 
| 109 | 
         
            +
                        input_video[:, :, -len(end_video):] = end_video
         
     | 
| 110 | 
         
            +
                        
         
     | 
| 111 | 
         
            +
                        input_video_mask[:, :, -len(image_end):] = 0
         
     | 
| 112 | 
         
            +
                    else:
         
     | 
| 113 | 
         
            +
                        image_end = image_end.resize(image_start[0].size if type(image_start) is list else image_start.size)
         
     | 
| 114 | 
         
            +
                        input_video[:, :, -1:] = torch.from_numpy(np.array(image_end)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0)
         
     | 
| 115 | 
         
            +
                        input_video_mask[:, :, -1:] = 0
         
     | 
| 116 | 
         
            +
             
     | 
| 117 | 
         
            +
                    input_video = input_video / 255
         
     | 
| 118 | 
         
            +
             
     | 
| 119 | 
         
            +
                elif validation_image_start is not None:
         
     | 
| 120 | 
         
            +
                    if type(validation_image_start) is str and os.path.isfile(validation_image_start):
         
     | 
| 121 | 
         
            +
                        image_start = clip_image = Image.open(validation_image_start).convert("RGB")
         
     | 
| 122 | 
         
            +
                    else:
         
     | 
| 123 | 
         
            +
                        image_start = clip_image = validation_image_start
         
     | 
| 124 | 
         
            +
                    
         
     | 
| 125 | 
         
            +
                    if type(image_start) is list:
         
     | 
| 126 | 
         
            +
                        clip_image = clip_image[0]
         
     | 
| 127 | 
         
            +
                        start_video = torch.cat(
         
     | 
| 128 | 
         
            +
                            [torch.from_numpy(np.array(_image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0) for _image_start in image_start], 
         
     | 
| 129 | 
         
            +
                            dim=2
         
     | 
| 130 | 
         
            +
                        )
         
     | 
| 131 | 
         
            +
                        input_video = torch.tile(start_video[:, :, :1], [1, 1, video_length, 1, 1])
         
     | 
| 132 | 
         
            +
                        input_video[:, :, :len(image_start)] = start_video
         
     | 
| 133 | 
         
            +
                        input_video = input_video / 255
         
     | 
| 134 | 
         
            +
                        
         
     | 
| 135 | 
         
            +
                        input_video_mask = torch.zeros_like(input_video[:, :1])
         
     | 
| 136 | 
         
            +
                        input_video_mask[:, :, len(image_start):] = 255
         
     | 
| 137 | 
         
            +
                    else:
         
     | 
| 138 | 
         
            +
                        input_video = torch.tile(
         
     | 
| 139 | 
         
            +
                            torch.from_numpy(np.array(image_start)).permute(2, 0, 1).unsqueeze(1).unsqueeze(0), 
         
     | 
| 140 | 
         
            +
                            [1, 1, video_length, 1, 1]
         
     | 
| 141 | 
         
            +
                        ) / 255
         
     | 
| 142 | 
         
            +
                        input_video_mask = torch.zeros_like(input_video[:, :1])
         
     | 
| 143 | 
         
            +
                        input_video_mask[:, :, 1:, ] = 255
         
     | 
| 144 | 
         
            +
                else:
         
     | 
| 145 | 
         
            +
                    input_video = torch.zeros([1, 3, video_length, sample_size[0], sample_size[1]])
         
     | 
| 146 | 
         
            +
                    input_video_mask = torch.ones([1, 1, video_length, sample_size[0], sample_size[1]]) * 255
         
     | 
| 147 | 
         
            +
                    clip_image = None
         
     | 
| 148 | 
         
            +
             
     | 
| 149 | 
         
            +
                return  input_video, input_video_mask, clip_image
         
     | 
| 150 | 
         
            +
             
     | 
| 151 | 
         
            +
            def video_frames(input_video_path):
         
     | 
| 152 | 
         
            +
                cap = cv2.VideoCapture(input_video_path)
         
     | 
| 153 | 
         
            +
                frames = []
         
     | 
| 154 | 
         
            +
                while True:
         
     | 
| 155 | 
         
            +
                    ret, frame = cap.read()
         
     | 
| 156 | 
         
            +
                    if not ret:
         
     | 
| 157 | 
         
            +
                        break
         
     | 
| 158 | 
         
            +
                    frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
         
     | 
| 159 | 
         
            +
                cap.release()
         
     | 
| 160 | 
         
            +
                cv2.destroyAllWindows()
         
     | 
| 161 | 
         
            +
                return frames
         
     | 
| 162 | 
         
            +
             
     | 
| 163 | 
         
            +
            def get_video_to_video_latent(validation_videos, video_length):
         
     | 
| 164 | 
         
            +
                input_video = video_frames(validation_videos)
         
     | 
| 165 | 
         
            +
                input_video = torch.from_numpy(np.array(input_video))[:video_length]
         
     | 
| 166 | 
         
            +
                input_video = input_video.permute([3, 0, 1, 2]).unsqueeze(0) / 255
         
     | 
| 167 | 
         
            +
             
     | 
| 168 | 
         
            +
                input_video_mask = torch.zeros_like(input_video[:, :1])
         
     | 
| 169 | 
         
            +
                input_video_mask[:, :, :] = 255
         
     | 
| 170 | 
         
            +
             
     | 
| 171 | 
         
            +
                return  input_video, input_video_mask, None
         
     |