diff --git a/LICENSE b/LICENSE index f6b48afba6083ffc9f7bf5b4897740d5449cef7e..2027815db63bc7d3e744381e64c292b7fa1510ab 100644 --- a/LICENSE +++ b/LICENSE @@ -1,211 +1,211 @@ -Tencent is pleased to support the open source community by making SongGeneration available. - -Copyright (C) 2025 Tencent. All rights reserved. - -SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations. - - -License Terms of SongGeneration: --------------------------------------------------------------------- - -Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances. - -- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - - -Other dependencies and licenses: - - -Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein: --------------------------------------------------------------------- -1. stable_audio_tools -Copyright (c) 2023 Stability AI - - -Terms of the MIT: --------------------------------------------------------------------- -Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: - -The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. - -THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. - -For the license of other third party components, please refer to the following URL: -https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES - - -Open Source Software Licensed under the MIT License: --------------------------------------------------------------------- -1. demucs -Copyright (c) Meta Platforms, Inc. and affiliates. - - -A copy of the MIT is included in this file. - - - -Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: --------------------------------------------------------------------- -1. torch -From PyTorch: - -Copyright (c) 2016- Facebook, Inc (Adam Paszke) -Copyright (c) 2014- Facebook, Inc (Soumith Chintala) -Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) -Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) -Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) -Copyright (c) 2011-2013 NYU (Clement Farabet) -Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) -Copyright (c) 2006 Idiap Research Institute (Samy Bengio) -Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) - -From Caffe2: - -Copyright (c) 2016-present, Facebook Inc. All rights reserved. - -All contributions by Facebook: -Copyright (c) 2016 Facebook Inc. - -All contributions by Google: -Copyright (c) 2015 Google Inc. -All rights reserved. - -All contributions by Yangqing Jia: -Copyright (c) 2015 Yangqing Jia -All rights reserved. - -All contributions by Kakao Brain: -Copyright 2019-2020 Kakao Brain - -All contributions by Cruise LLC: -Copyright (c) 2022 Cruise LLC. -All rights reserved. - -All contributions from Caffe: -Copyright(c) 2013, 2014, 2015, the respective contributors -All rights reserved. - -All other contributions: -Copyright(c) 2015, 2016 the respective contributors -All rights reserved. - -Caffe2 uses a copyright model similar to Caffe: each contributor holds -copyright over their contributions to Caffe2. The project versioning records -all such contribution and copyright details. If a contributor wants to further -mark their specific copyright on a particular contribution, they should -indicate their copyright solely in the commit message of the change when it is -committed. - -All rights reserved. - - -Terms of the BSD 3-Clause: --------------------------------------------------------------------- -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -For the license of other third party components, please refer to the following URL: -https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE - - -Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein: --------------------------------------------------------------------- -1. torchaudio -Copyright (c) 2017 Facebook Inc. (Soumith Chintala), -All rights reserved. - - -Terms of the BSD 2-Clause: --------------------------------------------------------------------- -Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: - -* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. - -* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. - -THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - -For the license of other third party components, please refer to the following URL: -https://github.com/pytorch/audio/blob/v2.0.2/LICENSE - - -Open Source Software License under the Apache License Version 2.0: --------------------------------------------------------------------- -1. huggingface-hub -Copyright (c) huggingface-hub original author and authors - -2. transformers -Copyright 2018- The Hugging Face team. All rights reserved. - - -Terms of the Apache License Version 2.0: --------------------------------------------------------------------- -Apache License - -Version 2.0, January 2004 - -http://www.apache.org/licenses/ - -TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION -1. Definitions. - -"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. - -"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. - -"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. - -"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. - -"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. - -"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. - -"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). - -"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. - -"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." - -"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. - -2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. - -3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. - -4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: - -You must give any other recipients of the Work or Derivative Works a copy of this License; and - -You must cause any modified files to carry prominent notices stating that You changed the files; and - -You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and - -If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. - -You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. - -5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. - -6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. - -7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. - -8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. - -9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. - -END OF TERMS AND CONDITIONS \ No newline at end of file +Tencent is pleased to support the open source community by making SongGeneration available. + +Copyright (C) 2025 Tencent. All rights reserved. + +SongGeneration is licensed under the License Terms of SongGeneration except for the third-party components listed below, which is licensed under different terms. SongGeneration does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations. + + +License Terms of SongGeneration: +-------------------------------------------------------------------- + +Permission is hereby granted, free of charge, to any person obtaining a copy of this Software and associated documentation files, to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, and/or sublicense copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +- You agree to use the SongGeneration only for academic, research and education purposes, and refrain from using it for any commercial or production purposes under any circumstances. + +- The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +For avoidance of doubts, "Software" means the SongGeneration inference-enabling code and the weights made available under this license excluding any pre-trained data and other AI components. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +Other dependencies and licenses: + + +Open Source Software Licensed under the MIT License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. stable_audio_tools +Copyright (c) 2023 Stability AI + + +Terms of the MIT: +-------------------------------------------------------------------- +Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + +For the license of other third party components, please refer to the following URL: +https://github.com/Stability-AI/stable-audio-tools/tree/main/LICENSES + + +Open Source Software Licensed under the MIT License: +-------------------------------------------------------------------- +1. demucs +Copyright (c) Meta Platforms, Inc. and affiliates. + + +A copy of the MIT is included in this file. + + + +Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. torch +From PyTorch: + +Copyright (c) 2016- Facebook, Inc (Adam Paszke) +Copyright (c) 2014- Facebook, Inc (Soumith Chintala) +Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert) +Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu) +Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu) +Copyright (c) 2011-2013 NYU (Clement Farabet) +Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston) +Copyright (c) 2006 Idiap Research Institute (Samy Bengio) +Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz) + +From Caffe2: + +Copyright (c) 2016-present, Facebook Inc. All rights reserved. + +All contributions by Facebook: +Copyright (c) 2016 Facebook Inc. + +All contributions by Google: +Copyright (c) 2015 Google Inc. +All rights reserved. + +All contributions by Yangqing Jia: +Copyright (c) 2015 Yangqing Jia +All rights reserved. + +All contributions by Kakao Brain: +Copyright 2019-2020 Kakao Brain + +All contributions by Cruise LLC: +Copyright (c) 2022 Cruise LLC. +All rights reserved. + +All contributions from Caffe: +Copyright(c) 2013, 2014, 2015, the respective contributors +All rights reserved. + +All other contributions: +Copyright(c) 2015, 2016 the respective contributors +All rights reserved. + +Caffe2 uses a copyright model similar to Caffe: each contributor holds +copyright over their contributions to Caffe2. The project versioning records +all such contribution and copyright details. If a contributor wants to further +mark their specific copyright on a particular contribution, they should +indicate their copyright solely in the commit message of the change when it is +committed. + +All rights reserved. + + +Terms of the BSD 3-Clause: +-------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +For the license of other third party components, please refer to the following URL: +https://github.com/pytorch/pytorch/blob/v2.0.1/NOTICE + + +Open Source Software Licensed under the BSD 2-Clause License and Other Licenses of the Third-Party Components therein: +-------------------------------------------------------------------- +1. torchaudio +Copyright (c) 2017 Facebook Inc. (Soumith Chintala), +All rights reserved. + + +Terms of the BSD 2-Clause: +-------------------------------------------------------------------- +Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: + +* Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. + +* Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +For the license of other third party components, please refer to the following URL: +https://github.com/pytorch/audio/blob/v2.0.2/LICENSE + + +Open Source Software License under the Apache License Version 2.0: +-------------------------------------------------------------------- +1. huggingface-hub +Copyright (c) huggingface-hub original author and authors + +2. transformers +Copyright 2018- The Hugging Face team. All rights reserved. + + +Terms of the Apache License Version 2.0: +-------------------------------------------------------------------- +Apache License + +Version 2.0, January 2004 + +http://www.apache.org/licenses/ + +TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION +1. Definitions. + +"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. + +"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. + +"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. + +"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. + +"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. + +"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. + +"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). + +"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. + +"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." + +"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. + +2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. + +3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. + +4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: + +You must give any other recipients of the Work or Derivative Works a copy of this License; and + +You must cause any modified files to carry prominent notices stating that You changed the files; and + +You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and + +If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. + +You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. + +5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. + +6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. + +7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. + +8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. + +9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. + +END OF TERMS AND CONDITIONS diff --git a/README.md b/README.md index fcd93af3149c494687e9254668cc4346dcb8ef33..8dd38096e04ba111a263dd0373145a56eed5e682 100644 --- a/README.md +++ b/README.md @@ -7,11 +7,9 @@ sdk: docker app_port: 7860 --- -

Demo  |  Paper  |  Code

- This repository is the official weight repository for LeVo: High-Quality Song Generation with Multi-Preference Alignment. In this repository, we provide the SongGeneration model, inference scripts, and the checkpoint that has been trained on the Million Song Dataset. ## Overview diff --git a/app.py b/app.py index 60f77262de245fe06616c4ccf5c01065927d6dc8..c60aafdbcfd44077633f41d382dd4de450af85af 100644 --- a/app.py +++ b/app.py @@ -124,9 +124,9 @@ def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_co # 创建Gradio界面 -with gr.Blocks(title="SongGeration Demo Space") as demo: - gr.Markdown("# 🎵 SongGeration Demo Space") - gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.") +with gr.Blocks(title="SongGeneration Demo Space") as demo: + gr.Markdown("# 🎵 SongGeneration Demo Space") + gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song. The code is in [GIT](https://github.com/tencent-ailab/SongGeneration)") with gr.Row(): with gr.Column(): diff --git a/codeclm/models/codeclm.py b/codeclm/models/codeclm.py index ff61358eab911c3d11ceb4050955af3b41a981d5..e7b18e672e2c4806f8a4d004d8a9f28870c5b081 100644 --- a/codeclm/models/codeclm.py +++ b/codeclm/models/codeclm.py @@ -36,6 +36,10 @@ class CodecLM: max_duration: tp.Optional[float] = None, seperate_tokenizer: AudioTokenizer = None): self.name = name self.audiotokenizer = audiotokenizer + if self.audiotokenizer: + self.frame_rate = self.audiotokenizer.frame_rate + else: + self.frame_rate = 25 self.lm = lm self.seperate_tokenizer = seperate_tokenizer # import pdb; pdb.set_trace() @@ -47,7 +51,7 @@ class CodecLM: assert max_duration is not None self.max_duration: float = max_duration - self.device = next(iter(lm.parameters())).device + self.device = torch.device("cuda") self.generation_params: dict = {} # self.set_generation_params(duration=15) # 15 seconds by default self.set_generation_params(duration=15, extend_stride=self.max_duration // 2) @@ -57,23 +61,6 @@ class CodecLM: else: self.autocast = TorchAutocast(enabled=False) - - - @property - def frame_rate(self) -> float: - """Roughly the number of AR steps per seconds.""" - return self.audiotokenizer.frame_rate - - @property - def sample_rate(self) -> int: - """Sample rate of the generated audio.""" - return self.audiotokenizer.sample_rate - - @property - def audio_channels(self) -> int: - """Audio channels of the generated audio.""" - return self.audiotokenizer.channels - def set_generation_params(self, use_sampling: bool = True, top_k: int = 250, top_p: float = 0.0, temperature: float = 1.0, duration: float = 30.0, cfg_coef: float = 3.0, @@ -185,7 +172,7 @@ class CodecLM: assert len(lyrics) == 1 texts = [lyric for lyric in lyrics] audio_qt_embs = [] - target_melody_token_len = self.lm.cfg.prompt_len * self.audiotokenizer.frame_rate + target_melody_token_len = self.lm.cfg.prompt_len * self.frame_rate # import pdb; pdb.set_trace() if melody_wavs is None: melody_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() @@ -207,39 +194,39 @@ class CodecLM: melody_tokens = melody_tokens[...,:target_melody_token_len] elif melody_tokens.shape[-1] < target_melody_token_len: melody_tokens = torch.cat([melody_tokens, torch.full((1,1,target_melody_token_len - melody_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) - if self.seperate_tokenizer is not None: - if bgm_wavs is None: - assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None" - bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() - vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() + + if bgm_wavs is None: + assert vocal_wavs is None, "vocal_wavs is not None when bgm_wavs is None" + bgm_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() + vocal_tokens = torch.full((1,1,target_melody_token_len), 16385, device=self.device).long() + else: + assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None" + if type(vocal_wavs) == list: + vocal_wavs = torch.stack(vocal_wavs, dim=0) + if type(bgm_wavs) == list: + bgm_wavs = torch.stack(bgm_wavs, dim=0) + vocal_wavs = vocal_wavs.to(self.device) + bgm_wavs = bgm_wavs.to(self.device) + if melody_is_wav: + vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs) else: - assert vocal_wavs is not None, "vocal_wavs is None when bgm_wavs is not None" - if type(vocal_wavs) == list: - vocal_wavs = torch.stack(vocal_wavs, dim=0) - if type(bgm_wavs) == list: - bgm_wavs = torch.stack(bgm_wavs, dim=0) - vocal_wavs = vocal_wavs.to(self.device) - bgm_wavs = bgm_wavs.to(self.device) - if melody_is_wav: - vocal_tokens, bgm_tokens = self.seperate_tokenizer.encode(vocal_wavs, bgm_wavs) - else: - vocal_tokens = vocal_wavs - bgm_tokens = bgm_wavs - assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \ - f"vocal and bgm tokens should have a shape [B, C, T]! " \ - f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}" - assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \ - f"vocal and bgm tokens should have the same length! " \ - f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}" - if bgm_tokens.shape[-1] > target_melody_token_len: - bgm_tokens = bgm_tokens[...,:target_melody_token_len] - elif bgm_tokens.shape[-1] < target_melody_token_len: - bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) - if vocal_tokens.shape[-1] > target_melody_token_len: - vocal_tokens = vocal_tokens[...,:target_melody_token_len] - elif vocal_tokens.shape[-1] < target_melody_token_len: - vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) - melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1) + vocal_tokens = vocal_wavs + bgm_tokens = bgm_wavs + assert len(vocal_tokens.shape) == len(bgm_tokens.shape) == 3, \ + f"vocal and bgm tokens should have a shape [B, C, T]! " \ + f"got vocal len={vocal_tokens.shape}, and bgm len={bgm_tokens.shape}" + assert vocal_tokens.shape[-1] == bgm_tokens.shape[-1], \ + f"vocal and bgm tokens should have the same length! " \ + f"got vocal len={vocal_tokens.shape[-1]}, and bgm len={bgm_tokens.shape[-1]}" + if bgm_tokens.shape[-1] > target_melody_token_len: + bgm_tokens = bgm_tokens[...,:target_melody_token_len] + elif bgm_tokens.shape[-1] < target_melody_token_len: + bgm_tokens = torch.cat([bgm_tokens, torch.full((1,1,target_melody_token_len - bgm_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) + if vocal_tokens.shape[-1] > target_melody_token_len: + vocal_tokens = vocal_tokens[...,:target_melody_token_len] + elif vocal_tokens.shape[-1] < target_melody_token_len: + vocal_tokens = torch.cat([vocal_tokens, torch.full((1,1,target_melody_token_len - vocal_tokens.shape[-1]), 16385, device=self.device).long()], dim=-1) + melody_tokens = torch.cat([melody_tokens, vocal_tokens, bgm_tokens], dim=1) assert melody_tokens.shape[-1] == target_melody_token_len audio_qt_embs = melody_tokens.long() return texts, audio_qt_embs @@ -284,7 +271,7 @@ class CodecLM: return gen_tokens @torch.no_grad() - def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None): + def generate_audio(self, gen_tokens: torch.Tensor, prompt=None, vocal_prompt=None, bgm_prompt=None, chunked=False): """Generate Audio from tokens""" assert gen_tokens.dim() == 3 if self.seperate_tokenizer is not None: @@ -292,7 +279,7 @@ class CodecLM: gen_tokens_vocal = gen_tokens[:, [1], :] gen_tokens_bgm = gen_tokens[:, [2], :] # gen_audio_song = self.audiotokenizer.decode(gen_tokens_song, prompt) - gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt) + gen_audio_seperate = self.seperate_tokenizer.decode([gen_tokens_vocal, gen_tokens_bgm], vocal_prompt, bgm_prompt, chunked=chunked) return gen_audio_seperate else: gen_audio = self.audiotokenizer.decode(gen_tokens, prompt) diff --git a/codeclm/tokenizer/Flow1dVAE/generate_septoken.py b/codeclm/tokenizer/Flow1dVAE/generate_septoken.py index 883e28d0252515321b931bed9114625ce0fbb07a..249838358a96dca29044f142e6585fe6713a20d9 100644 --- a/codeclm/tokenizer/Flow1dVAE/generate_septoken.py +++ b/codeclm/tokenizer/Flow1dVAE/generate_septoken.py @@ -173,7 +173,7 @@ class Tango: return codes_vocal, codes_bgm @torch.no_grad() - def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False): + def code2sound(self, codes, prompt_vocal=None, prompt_bgm=None, duration=40, guidance_scale=1.5, num_steps=20, disable_progress=False, chunked=False): codes_vocal,codes_bgm = codes codes_vocal = codes_vocal.to(self.device) codes_bgm = codes_bgm.to(self.device) @@ -268,11 +268,12 @@ class Tango: min_samples = int(min_samples * self.sample_rate // 1000 * 40) hop_samples = int(hop_samples * self.sample_rate // 1000 * 40) ovlp_samples = min_samples - hop_samples + torch.cuda.empty_cache() with torch.no_grad(): output = None for i in range(len(latent_list)): latent = latent_list[i] - cur_output = self.vae.decode_audio(latent)[0].detach().cpu() + cur_output = self.vae.decode_audio(latent, chunked=chunked)[0].detach().cpu() if output is None: output = cur_output diff --git a/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py b/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py index 1993fb6d00854e5ad749e66e88268c34800d777b..95a18c5b7fd1d5c4031afc27aa7369e35a06134c 100644 --- a/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py +++ b/codeclm/tokenizer/Flow1dVAE/libs/rvq/core_vq.py @@ -1,366 +1,366 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the license found in the -# LICENSE file in the root directory of this source tree. -# -# This implementation is inspired from -# https://github.com/lucidrains/vector-quantize-pytorch -# which is released under MIT License. Hereafter, the original license: -# MIT License -# -# Copyright (c) 2020 Phil Wang -# -# Permission is hereby granted, free of charge, to any person obtaining a copy -# of this software and associated documentation files (the "Software"), to deal -# in the Software without restriction, including without limitation the rights -# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell -# copies of the Software, and to permit persons to whom the Software is -# furnished to do so, subject to the following conditions: -# -# The above copyright notice and this permission notice shall be included in all -# copies or substantial portions of the Software. -# -# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR -# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, -# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE -# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER -# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, -# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE -# SOFTWARE. - -"""Core vector quantization implementation.""" - -import typing as tp - -from einops import rearrange, repeat -import torch -from torch import nn -import torch.nn.functional as F - -# from .. import distrib - - -def default(val: tp.Any, d: tp.Any) -> tp.Any: - return val if val is not None else d - - -def ema_inplace(moving_avg, new, decay: float): - moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) - - -def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): - return (x + epsilon) / (x.sum() + n_categories * epsilon) - - -def uniform_init(*shape: int): - t = torch.empty(shape) - nn.init.kaiming_uniform_(t) - return t - - -def sample_vectors(samples, num: int): - num_samples, device = samples.shape[0], samples.device - - if num_samples >= num: - indices = torch.randperm(num_samples, device=device)[:num] - else: - indices = torch.randint(0, num_samples, (num,), device=device) - - return samples[indices] - - -def kmeans(samples, num_clusters: int, num_iters: int = 10): - dim, dtype = samples.shape[-1], samples.dtype - - means = sample_vectors(samples, num_clusters) - - for _ in range(num_iters): - diffs = rearrange(samples, "n d -> n () d") - rearrange( - means, "c d -> () c d" - ) - dists = -(diffs ** 2).sum(dim=-1) - - buckets = dists.max(dim=-1).indices - bins = torch.bincount(buckets, minlength=num_clusters) - zero_mask = bins == 0 - bins_min_clamped = bins.masked_fill(zero_mask, 1) - - new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) - new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) - new_means = new_means / bins_min_clamped[..., None] - - means = torch.where(zero_mask[..., None], means, new_means) - - return means, bins - - -class EuclideanCodebook(nn.Module): - """Codebook with Euclidean distance. - Args: - dim (int): Dimension. - codebook_size (int): Codebook size. - kmeans_init (bool): Whether to use k-means to initialize the codebooks. - If set to true, run the k-means algorithm on the first training batch and use - the learned centroids as initialization. - kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - """ - def __init__( - self, - dim: int, - codebook_size: int, - kmeans_init: int = False, - kmeans_iters: int = 10, - decay: float = 0.99, - epsilon: float = 1e-5, - threshold_ema_dead_code: int = 2, - ): - super().__init__() - self.decay = decay - init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros - embed = init_fn(codebook_size, dim) - - self.codebook_size = codebook_size - - self.kmeans_iters = kmeans_iters - self.epsilon = epsilon - self.threshold_ema_dead_code = threshold_ema_dead_code - - self.register_buffer("inited", torch.Tensor([not kmeans_init])) - self.register_buffer("cluster_size", torch.zeros(codebook_size)) - self.register_buffer("embed", embed) - self.register_buffer("embed_avg", embed.clone()) - - self.runed_steps = 0 - self.stop_steps = 50_000 - - @torch.jit.ignore - def init_embed_(self, data): - if self.inited: - return - - embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) - self.embed.data.copy_(embed) - self.embed_avg.data.copy_(embed.clone()) - self.cluster_size.data.copy_(cluster_size) - self.inited.data.copy_(torch.Tensor([True])) - # Make sure all buffers across workers are in sync after initialization - distrib.broadcast_tensors(self.buffers()) - - def replace_(self, samples, mask): - modified_codebook = torch.where( - mask[..., None], sample_vectors(samples, self.codebook_size), self.embed - ) - self.embed.data.copy_(modified_codebook) - - def expire_codes_(self, batch_samples): - if self.threshold_ema_dead_code == 0: - return - - expired_codes = self.cluster_size < self.threshold_ema_dead_code - if not torch.any(expired_codes): - return - - batch_samples = rearrange(batch_samples, "... d -> (...) d") - self.replace_(batch_samples, mask=expired_codes) - # distrib.broadcast_tensors(self.buffers()) - - def preprocess(self, x): - x = rearrange(x, "... d -> (...) d") - return x - - def quantize(self, x): - embed = self.embed.t() - dist = -( - x.pow(2).sum(1, keepdim=True) - - 2 * x @ embed - + embed.pow(2).sum(0, keepdim=True) - ) - embed_ind = dist.max(dim=-1).indices - return embed_ind - - def postprocess_emb(self, embed_ind, shape): - return embed_ind.view(*shape[:-1]) - - def dequantize(self, embed_ind): - quantize = F.embedding(embed_ind, self.embed) - return quantize - - def encode(self, x): - shape = x.shape - # pre-process - x = self.preprocess(x) - # quantize - embed_ind = self.quantize(x) - # post-process - embed_ind = self.postprocess_emb(embed_ind, shape) - return embed_ind - - def decode(self, embed_ind): - quantize = self.dequantize(embed_ind) - return quantize - - def forward(self, x): - shape, dtype = x.shape, x.dtype - x = self.preprocess(x) - # self.init_embed_(x) - - embed_ind = self.quantize(x) - embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) - embed_ind = self.postprocess_emb(embed_ind, shape) - quantize = self.dequantize(embed_ind) - self.runed_steps += 1 - - if self.training and self.runed_steps < self.stop_steps: - # We do the expiry of code at that point as buffers are in sync - # and all the workers will take the same decision. - self.expire_codes_(x) - ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) - embed_sum = x.t() @ embed_onehot - ema_inplace(self.embed_avg, embed_sum.t(), self.decay) - cluster_size = ( - laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) - * self.cluster_size.sum() - ) - embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) - self.embed.data.copy_(embed_normalized) - - return quantize, embed_ind - - -class VectorQuantization(nn.Module): - """Vector quantization implementation. - Currently supports only euclidean distance. - Args: - dim (int): Dimension - codebook_size (int): Codebook size - codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. - decay (float): Decay for exponential moving average over the codebooks. - epsilon (float): Epsilon value for numerical stability. - kmeans_init (bool): Whether to use kmeans to initialize the codebooks. - kmeans_iters (int): Number of iterations used for kmeans initialization. - threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes - that have an exponential moving average cluster size less than the specified threshold with - randomly selected vector from the current batch. - commitment_weight (float): Weight for commitment loss. - """ - def __init__( - self, - dim: int, - codebook_size: int, - codebook_dim: tp.Optional[int] = None, - decay: float = 0.99, - epsilon: float = 1e-5, - kmeans_init: bool = True, - kmeans_iters: int = 50, - threshold_ema_dead_code: int = 2, - commitment_weight: float = 1., - ): - super().__init__() - _codebook_dim: int = default(codebook_dim, dim) - - requires_projection = _codebook_dim != dim - self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) - self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) - - self.epsilon = epsilon - self.commitment_weight = commitment_weight - - self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, - kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, - decay=decay, epsilon=epsilon, - threshold_ema_dead_code=threshold_ema_dead_code) - self.codebook_size = codebook_size - - @property - def codebook(self): - return self._codebook.embed - - def encode(self, x): - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - embed_in = self._codebook.encode(x) - return embed_in - - def decode(self, embed_ind): - quantize = self._codebook.decode(embed_ind) - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize - - def forward(self, x, do_debug=False): - device = x.device - x = rearrange(x, "b d n -> b n d") - x = self.project_in(x) - - quantize, embed_ind = self._codebook(x) - - if self.training: - quantize = x + (quantize - x).detach() - - loss = torch.tensor([0.0], device=device, requires_grad=self.training) - - if self.training: - if self.commitment_weight > 0: - commit_loss = F.mse_loss(quantize.detach(), x) - loss = loss + commit_loss * self.commitment_weight - quantize = self.project_out(quantize) - quantize = rearrange(quantize, "b n d -> b d n") - return quantize, embed_ind, loss - - -class ResidualVectorQuantization(nn.Module): - """Residual vector quantization implementation. - Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf - """ - def __init__(self, *, num_quantizers, **kwargs): - super().__init__() - self.layers = nn.ModuleList( - [VectorQuantization(**kwargs) for _ in range(num_quantizers)] - ) - - def forward(self, x, n_q: tp.Optional[int] = None): - quantized_out = 0.0 - residual = x - - all_losses = [] - all_indices = [] - - n_q = n_q or len(self.layers) - - for layerinx, layer in enumerate(self.layers[:n_q]): - print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.)) - quantized, indices, loss = layer(residual) - residual = residual - quantized - quantized_out = quantized_out + quantized - - all_indices.append(indices) - all_losses.append(loss) - - out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) - return quantized_out, out_indices, out_losses - - def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: - residual = x - all_indices = [] - n_q = n_q or len(self.layers) - for layer in self.layers[:n_q]: - indices = layer.encode(residual) - quantized = layer.decode(indices) - residual = residual - quantized - all_indices.append(indices) - out_indices = torch.stack(all_indices) - return out_indices - - def decode(self, q_indices: torch.Tensor) -> torch.Tensor: - quantized_out = torch.tensor(0.0, device=q_indices.device) - for i, indices in enumerate(q_indices): - layer = self.layers[i] - quantized = layer.decode(indices) - quantized_out = quantized_out + quantized - return quantized_out +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. +# +# This implementation is inspired from +# https://github.com/lucidrains/vector-quantize-pytorch +# which is released under MIT License. Hereafter, the original license: +# MIT License +# +# Copyright (c) 2020 Phil Wang +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Core vector quantization implementation.""" + +import typing as tp + +from einops import rearrange, repeat +import torch +from torch import nn +import torch.nn.functional as F + +# from .. import distrib + + +def default(val: tp.Any, d: tp.Any) -> tp.Any: + return val if val is not None else d + + +def ema_inplace(moving_avg, new, decay: float): + moving_avg.data.mul_(decay).add_(new, alpha=(1 - decay)) + + +def laplace_smoothing(x, n_categories: int, epsilon: float = 1e-5): + return (x + epsilon) / (x.sum() + n_categories * epsilon) + + +def uniform_init(*shape: int): + t = torch.empty(shape) + nn.init.kaiming_uniform_(t) + return t + + +def sample_vectors(samples, num: int): + num_samples, device = samples.shape[0], samples.device + + if num_samples >= num: + indices = torch.randperm(num_samples, device=device)[:num] + else: + indices = torch.randint(0, num_samples, (num,), device=device) + + return samples[indices] + + +def kmeans(samples, num_clusters: int, num_iters: int = 10): + dim, dtype = samples.shape[-1], samples.dtype + + means = sample_vectors(samples, num_clusters) + + for _ in range(num_iters): + diffs = rearrange(samples, "n d -> n () d") - rearrange( + means, "c d -> () c d" + ) + dists = -(diffs ** 2).sum(dim=-1) + + buckets = dists.max(dim=-1).indices + bins = torch.bincount(buckets, minlength=num_clusters) + zero_mask = bins == 0 + bins_min_clamped = bins.masked_fill(zero_mask, 1) + + new_means = buckets.new_zeros(num_clusters, dim, dtype=dtype) + new_means.scatter_add_(0, repeat(buckets, "n -> n d", d=dim), samples) + new_means = new_means / bins_min_clamped[..., None] + + means = torch.where(zero_mask[..., None], means, new_means) + + return means, bins + + +class EuclideanCodebook(nn.Module): + """Codebook with Euclidean distance. + Args: + dim (int): Dimension. + codebook_size (int): Codebook size. + kmeans_init (bool): Whether to use k-means to initialize the codebooks. + If set to true, run the k-means algorithm on the first training batch and use + the learned centroids as initialization. + kmeans_iters (int): Number of iterations used for k-means algorithm at initialization. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + """ + def __init__( + self, + dim: int, + codebook_size: int, + kmeans_init: int = False, + kmeans_iters: int = 10, + decay: float = 0.99, + epsilon: float = 1e-5, + threshold_ema_dead_code: int = 2, + ): + super().__init__() + self.decay = decay + init_fn: tp.Union[tp.Callable[..., torch.Tensor], tp.Any] = uniform_init if not kmeans_init else torch.zeros + embed = init_fn(codebook_size, dim) + + self.codebook_size = codebook_size + + self.kmeans_iters = kmeans_iters + self.epsilon = epsilon + self.threshold_ema_dead_code = threshold_ema_dead_code + + self.register_buffer("inited", torch.Tensor([not kmeans_init])) + self.register_buffer("cluster_size", torch.zeros(codebook_size)) + self.register_buffer("embed", embed) + self.register_buffer("embed_avg", embed.clone()) + + self.runed_steps = 0 + self.stop_steps = 50_000 + + @torch.jit.ignore + def init_embed_(self, data): + if self.inited: + return + + embed, cluster_size = kmeans(data, self.codebook_size, self.kmeans_iters) + self.embed.data.copy_(embed) + self.embed_avg.data.copy_(embed.clone()) + self.cluster_size.data.copy_(cluster_size) + self.inited.data.copy_(torch.Tensor([True])) + # Make sure all buffers across workers are in sync after initialization + distrib.broadcast_tensors(self.buffers()) + + def replace_(self, samples, mask): + modified_codebook = torch.where( + mask[..., None], sample_vectors(samples, self.codebook_size), self.embed + ) + self.embed.data.copy_(modified_codebook) + + def expire_codes_(self, batch_samples): + if self.threshold_ema_dead_code == 0: + return + + expired_codes = self.cluster_size < self.threshold_ema_dead_code + if not torch.any(expired_codes): + return + + batch_samples = rearrange(batch_samples, "... d -> (...) d") + self.replace_(batch_samples, mask=expired_codes) + # distrib.broadcast_tensors(self.buffers()) + + def preprocess(self, x): + x = rearrange(x, "... d -> (...) d") + return x + + def quantize(self, x): + embed = self.embed.t() + dist = -( + x.pow(2).sum(1, keepdim=True) + - 2 * x @ embed + + embed.pow(2).sum(0, keepdim=True) + ) + embed_ind = dist.max(dim=-1).indices + return embed_ind + + def postprocess_emb(self, embed_ind, shape): + return embed_ind.view(*shape[:-1]) + + def dequantize(self, embed_ind): + quantize = F.embedding(embed_ind, self.embed) + return quantize + + def encode(self, x): + shape = x.shape + # pre-process + x = self.preprocess(x) + # quantize + embed_ind = self.quantize(x) + # post-process + embed_ind = self.postprocess_emb(embed_ind, shape) + return embed_ind + + def decode(self, embed_ind): + quantize = self.dequantize(embed_ind) + return quantize + + def forward(self, x): + shape, dtype = x.shape, x.dtype + x = self.preprocess(x) + # self.init_embed_(x) + + embed_ind = self.quantize(x) + embed_onehot = F.one_hot(embed_ind, self.codebook_size).type(dtype) + embed_ind = self.postprocess_emb(embed_ind, shape) + quantize = self.dequantize(embed_ind) + self.runed_steps += 1 + + if self.training and self.runed_steps < self.stop_steps: + # We do the expiry of code at that point as buffers are in sync + # and all the workers will take the same decision. + self.expire_codes_(x) + ema_inplace(self.cluster_size, embed_onehot.sum(0), self.decay) + embed_sum = x.t() @ embed_onehot + ema_inplace(self.embed_avg, embed_sum.t(), self.decay) + cluster_size = ( + laplace_smoothing(self.cluster_size, self.codebook_size, self.epsilon) + * self.cluster_size.sum() + ) + embed_normalized = self.embed_avg / cluster_size.unsqueeze(1) + self.embed.data.copy_(embed_normalized) + + return quantize, embed_ind + + +class VectorQuantization(nn.Module): + """Vector quantization implementation. + Currently supports only euclidean distance. + Args: + dim (int): Dimension + codebook_size (int): Codebook size + codebook_dim (int): Codebook dimension. If not defined, uses the specified dimension in dim. + decay (float): Decay for exponential moving average over the codebooks. + epsilon (float): Epsilon value for numerical stability. + kmeans_init (bool): Whether to use kmeans to initialize the codebooks. + kmeans_iters (int): Number of iterations used for kmeans initialization. + threshold_ema_dead_code (int): Threshold for dead code expiration. Replace any codes + that have an exponential moving average cluster size less than the specified threshold with + randomly selected vector from the current batch. + commitment_weight (float): Weight for commitment loss. + """ + def __init__( + self, + dim: int, + codebook_size: int, + codebook_dim: tp.Optional[int] = None, + decay: float = 0.99, + epsilon: float = 1e-5, + kmeans_init: bool = True, + kmeans_iters: int = 50, + threshold_ema_dead_code: int = 2, + commitment_weight: float = 1., + ): + super().__init__() + _codebook_dim: int = default(codebook_dim, dim) + + requires_projection = _codebook_dim != dim + self.project_in = (nn.Linear(dim, _codebook_dim) if requires_projection else nn.Identity()) + self.project_out = (nn.Linear(_codebook_dim, dim) if requires_projection else nn.Identity()) + + self.epsilon = epsilon + self.commitment_weight = commitment_weight + + self._codebook = EuclideanCodebook(dim=_codebook_dim, codebook_size=codebook_size, + kmeans_init=kmeans_init, kmeans_iters=kmeans_iters, + decay=decay, epsilon=epsilon, + threshold_ema_dead_code=threshold_ema_dead_code) + self.codebook_size = codebook_size + + @property + def codebook(self): + return self._codebook.embed + + def encode(self, x): + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + embed_in = self._codebook.encode(x) + return embed_in + + def decode(self, embed_ind): + quantize = self._codebook.decode(embed_ind) + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize + + def forward(self, x, do_debug=False): + device = x.device + x = rearrange(x, "b d n -> b n d") + x = self.project_in(x) + + quantize, embed_ind = self._codebook(x) + + if self.training: + quantize = x + (quantize - x).detach() + + loss = torch.tensor([0.0], device=device, requires_grad=self.training) + + if self.training: + if self.commitment_weight > 0: + commit_loss = F.mse_loss(quantize.detach(), x) + loss = loss + commit_loss * self.commitment_weight + quantize = self.project_out(quantize) + quantize = rearrange(quantize, "b n d -> b d n") + return quantize, embed_ind, loss + + +class ResidualVectorQuantization(nn.Module): + """Residual vector quantization implementation. + Follows Algorithm 1. in https://arxiv.org/pdf/2107.03312.pdf + """ + def __init__(self, *, num_quantizers, **kwargs): + super().__init__() + self.layers = nn.ModuleList( + [VectorQuantization(**kwargs) for _ in range(num_quantizers)] + ) + + def forward(self, x, n_q: tp.Optional[int] = None): + quantized_out = 0.0 + residual = x + + all_losses = [] + all_indices = [] + + n_q = n_q or len(self.layers) + + for layerinx, layer in enumerate(self.layers[:n_q]): + print("Layer {} Used ratio {:.1f}".format(layerinx, (layer._codebook.cluster_size > 1.0).sum() / layer._codebook.cluster_size.shape[0] * 100.)) + quantized, indices, loss = layer(residual) + residual = residual - quantized + quantized_out = quantized_out + quantized + + all_indices.append(indices) + all_losses.append(loss) + + out_losses, out_indices = map(torch.stack, (all_losses, all_indices)) + return quantized_out, out_indices, out_losses + + def encode(self, x: torch.Tensor, n_q: tp.Optional[int] = None) -> torch.Tensor: + residual = x + all_indices = [] + n_q = n_q or len(self.layers) + for layer in self.layers[:n_q]: + indices = layer.encode(residual) + quantized = layer.decode(indices) + residual = residual - quantized + all_indices.append(indices) + out_indices = torch.stack(all_indices) + return out_indices + + def decode(self, q_indices: torch.Tensor) -> torch.Tensor: + quantized_out = torch.tensor(0.0, device=q_indices.device) + for i, indices in enumerate(q_indices): + layer = self.layers[i] + quantized = layer.decode(indices) + quantized_out = quantized_out + quantized + return quantized_out diff --git a/codeclm/tokenizer/Flow1dVAE/model_1rvq.py b/codeclm/tokenizer/Flow1dVAE/model_1rvq.py index 7687447204ec7337870411fa8889a113a49e6265..80cb633394e358ed807819885dd9820d26b86d90 100644 --- a/codeclm/tokenizer/Flow1dVAE/model_1rvq.py +++ b/codeclm/tokenizer/Flow1dVAE/model_1rvq.py @@ -1,710 +1,710 @@ -import yaml -import random -import inspect -import numpy as np -from tqdm import tqdm -import typing as tp -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio - -from tools.torch_tools import wav_to_fbank - -from diffusers.utils.torch_utils import randn_tensor -from transformers import HubertModel -from libs.rvq.descript_quantize3 import ResidualVectorQuantize - -from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model -from models_gpt.models.gpt2_config import GPT2Config - -from torch.cuda.amp import autocast - - -from our_MERT_BESTRQ.test import load_model - -class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - - # The final projection layer is only used for backward compatibility. - # Following https://github.com/auspicious3000/contentvec/issues/6 - # Remove this layer is necessary to achieve the desired outcome. - print("hidden_size:",config.hidden_size) - print("classifier_proj_size:",config.classifier_proj_size) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - """Project back from diffusion space to the actual sample space.""" - return z - -class Feature1DProcessor(SampleProcessor): - def __init__(self, dim: int = 100, power_std = 1., \ - num_samples: int = 100_000, cal_num_frames: int = 600): - super().__init__() - - self.num_samples = num_samples - self.dim = dim - self.power_std = power_std - self.cal_num_frames = cal_num_frames - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(dim)) - self.register_buffer('sum_x2', torch.zeros(dim)) - self.register_buffer('sum_target_x2', torch.zeros(dim)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - if(self.counts < 10): - mean = torch.zeros_like(mean) - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - if(self.counts < 10): - std = torch.ones_like(std) - return std - - @property - def target_std(self): - return 1 - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - if self.counts.item() < self.num_samples: - self.counts += len(x) - self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) - self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) - return x - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - rescale = (self.std / self.target_std) ** self.power_std - # print(rescale, self.mean) - x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) - return x - -def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): - if(prior_text_encoder_hidden_states.shape[1] 1.0): - - model_input = torch.cat([ \ - torch.cat([latent_mask_input, latent_mask_input], 0), \ - torch.cat([incontext_x, incontext_x], 0), \ - torch.cat([torch.zeros_like(mu), mu], 0), \ - torch.cat([x, x], 0), \ - ], 2) - timestep=t.unsqueeze(-1).repeat(2) - - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) - dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) - else: - model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) - timestep=t.unsqueeze(-1) - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - - dphi_dt = dphi_dt[: ,:, -x.shape[2]:] - # print("dphi_dt.shape:",dphi_dt.shape) - # print("x.shape:",x.shape) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def projection_loss(self,hidden_proj, bestrq_emb): - bsz = hidden_proj.shape[0] - - hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) - bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) - - proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) - proj_loss = 1+proj_loss.mean() - - return proj_loss - - def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_channels, mel_timesteps, n_feats) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_channels, mel_timesteps, n_feats) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_channels, mel_timesteps, n_feats) - """ - b = mu[0].shape[0] - len_x = x1.shape[2] - # random timestep - if(validation_mode): - t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 - else: - t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - # print("y.shape:",y.shape) - #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state - model_input = torch.cat([*mu,y], 2) - t=t.squeeze(-1).squeeze(-1) - # print("model_input.shape:",model_input.shape) - # print("attention_mask.shape:",attention_mask.shape) - out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) - hidden_layer = out.hidden_states[self.ssl_layer] - hidden_proj = self.mlp(hidden_layer) - # print("hidden_proj.shape:",hidden_proj.shape) - # print("mert_emb.shape:",mert_emb.shape) - # exit() - - - out = out.last_hidden_state - - out=out[:,:,-len_x:] - # out=self.proj_out(out) - - weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 - # print("out.shape",out.shape) - # print("u.shape",u.shape) - loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() - # print("hidden_proj.shape:",hidden_proj.shape) - # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) - loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) - loss = loss_re + loss_cos * 0.5 - # print("loss_cos:",loss_cos,loss_cos.device) - print("loss:",loss,loss.device) - # exit() - return loss, loss_re, loss_cos - -class PromptCondAudioDiffusion(nn.Module): - def __init__( - self, - num_channels, - unet_model_name=None, - unet_model_config_path=None, - snr_gamma=None, - hubert_layer=None, - ssl_layer=None, - uncondition=True, - out_paint=False, - ): - super().__init__() - - assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" - - self.unet_model_name = unet_model_name - self.unet_model_config_path = unet_model_config_path - self.snr_gamma = snr_gamma - self.uncondition = uncondition - self.num_channels = num_channels - self.hubert_layer = hubert_layer - self.ssl_layer = ssl_layer - - # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview - self.normfeat = Feature1DProcessor(dim=64) - - self.sample_rate = 48000 - self.num_samples_perseg = self.sample_rate * 20 // 1000 - self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) - self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) - # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - self.bestrq = load_model( - model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq', - checkpoint_dir='ckpt/encode-s12k.pt', - ) - self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) - self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) - for v in self.bestrq.parameters():v.requires_grad = False - self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False - self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") - for v in self.hubert.parameters():v.requires_grad = False - self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) - # self.xvecmodel = XVECModel() - config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) - unet = GPT2Model(config) - mlp = nn.Sequential( - nn.Linear(1200, 1024), - nn.SiLU(), - nn.Linear(1024, 1024), - nn.SiLU(), - nn.Linear(1024, 768) - ) - self.set_from = "random" - self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) - self.mask_emb = torch.nn.Embedding(3, 48) - print("Transformer initialized from pretrain.") - torch.cuda.empty_cache() - # self.unet.set_attn_processor(AttnProcessor2_0()) - # self.unet.set_use_memory_efficient_attention_xformers(True) - - # self.start_embedding = nn.Parameter(torch.randn(1,1024)) - # self.end_embedding = nn.Parameter(torch.randn(1,1024)) - - def compute_snr(self, timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = self.noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - def preprocess_audio(self, input_audios, threshold=0.9): - assert len(input_audios.shape) == 2, input_audios.shape - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios/norm_value.unsqueeze(-1) - - def extract_wav2vec_embeds(self, input_audios,output_len): - wav2vec_stride = 2 - - wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 - # print(wav2vec_embeds) - # print("audio.shape:",input_audios.shape) - wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] - # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) - wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) - return wav2vec_embeds_last - - def extract_mert_embeds(self, input_audios): - prompt_stride = 3 - inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 - mert_emb= prompt_embeds[-1] - mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) - - return mert_emb - - def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): - self.bestrq.eval() - # print("audio shape:",input_audio_0.shape) - input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 - # print("input_wav_mean.shape:",input_wav_mean.shape) - # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) - input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) - layer_results = input_wav_mean['layer_results'] - # print("layer_results.shape:",layer_results[layer].shape) - bestrq_emb = layer_results[layer] - bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() - #[b,t,1024] t=t/960 - #35.84s->batch,896,1024 - return bestrq_emb - - - def extract_spk_embeds(self, input_audios): - spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) - spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) - return spk_embeds - - def extract_lyric_feats(self, lyric): - with torch.no_grad(): - try: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) - except: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) - text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) - text_mask = text_mask.to(self.device) - text_encoder_hidden_states, text_mask, text_prompt_embeds = \ - pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() - return text_encoder_hidden_states, text_mask - - def extract_energy_bar(self, input_audios): - if(input_audios.shape[-1] % self.num_samples_perseg > 0): - energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) - else: - energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) - energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T - energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() - energy_embedding = self.energy_embedding(energy_bar) - energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t - return energy_embedding - - def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ - additional_feats = ['spk', 'lyric'], \ - train_rvq=True, train_ssl=False,layer=5): - if not hasattr(self,"device"): - self.device = input_audios.device - if not hasattr(self,"dtype"): - self.dtype = input_audios.dtype - device = self.device - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 - # energy_embedding = self.extract_energy_bar(input_audios) - # print("energy_embedding.shape:",energy_embedding.shape) - # with autocast(enabled=False): - if(train_ssl): - self.wav2vec.train() - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) - self.clap_embd_extractor.train() - prompt_embeds = self.extract_mert_embeds(input_audios) - if('spk' in additional_feats): - self.xvecmodel.train() - spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) - else: - with torch.no_grad(): - with autocast(enabled=False): - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - # mert_emb = self.extract_mert_embeds(input_audios_mert) - - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) - - bestrq_emb = bestrq_emb.detach() - if('lyric' in additional_feats): - text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) - else: - text_encoder_hidden_states, text_mask = None, None - - # prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1) - # print("prompt_embes.shape:",prompt_embeds.shape) - #prompt_embes.shape: torch.Size([3, 1088, 896]) - # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) - #wav2vec_embeds.shape:torch.Size([3, 1024, 896]) - if(train_rvq): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - else: - bestrq_emb = bestrq_emb.float() - self.rvq_bestrq_emb.eval() - # with autocast(enabled=False): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() - codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() - quantized_bestrq_emb = quantized_bestrq_emb.detach() - - commitment_loss = commitment_loss_bestrq_emb - codebook_loss = codebook_loss_bestrq_emb - - - alpha=1 - quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) - - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - scenario = np.random.choice(['start_seg', 'other_seg']) - if(scenario == 'other_seg'): - for binx in range(input_audios.shape[0]): - # latent_masks[binx,0:64] = 1 - latent_masks[binx,0:random.randint(64,128)] = 1 - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - - - - - if self.uncondition: - mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] - if len(mask_indices) > 0: - quantized_bestrq_emb[mask_indices] = 0 - # print("latents.shape:",latents.shape) - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.project_sample(latents) - latents = latents.permute(0,2,1).contiguous() - incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - # print("incontext_latents.shape:",incontext_latents.shape) - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - latent_mask_input = self.mask_emb(latent_masks) - #64+48+64+1024 - loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) - return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() - - def init_device_dtype(self, device, dtype): - self.device = device - self.dtype = dtype - - @torch.no_grad() - def fetch_codes(self, input_audios, additional_feats,layer): - input_audio_0 = input_audios[[0],:] - input_audio_1 = input_audios[[1],:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - - @torch.no_grad() - def fetch_codes_batch(self, input_audios, additional_feats,layer): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - - @torch.no_grad() - def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, - guidance_scale=2, num_steps=20, - disable_progress=True, scenario='start_seg'): - classifier_free_guidance = guidance_scale > 1.0 - device = self.device - dtype = self.dtype - # codes_bestrq_middle, codes_bestrq_last = codes - codes_bestrq_emb = codes[0] - - - batch_size = codes_bestrq_emb.shape[0] - - - quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - - if('spk' in additional_feats): - spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() - - num_frames = quantized_bestrq_emb.shape[1] - - num_channels_latents = self.num_channels - shape = (batch_size, num_frames, 64) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - - - - latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) - latent_masks[:,0:latent_length] = 2 - if(scenario=='other_seg'): - latent_masks[:,0:incontext_length] = 1 - - - - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - true_latents = true_latents.permute(0,2,1).contiguous() - true_latents = self.normfeat.project_sample(true_latents) - true_latents = true_latents.permute(0,2,1).contiguous() - incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] - - - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - - if('spk' in additional_feats): - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) - additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) - else: - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) - additional_model_input = torch.cat([quantized_bestrq_emb],1) - - temperature = 1.0 - t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) - latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) - - latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.return_sample(latents) - # latents = latents.permute(0,2,1).contiguous() - return latents - - @torch.no_grad() - def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) - - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents - - @torch.no_grad() - def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) - import time - start = time.time() - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents,time.time()-start - - def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): - divisor = 4 - shape = (batch_size, num_channels_latents, num_frames, 32) - if(num_frames%divisor>0): - num_frames = round(num_frames/float(divisor))*divisor - shape = (batch_size, num_channels_latents, num_frames, 32) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - return latents - - +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from tools.torch_tools import wav_to_fbank + +from diffusers.utils.torch_utils import randn_tensor +from transformers import HubertModel +from libs.rvq.descript_quantize3 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast + + +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + # print(rescale, self.mean) + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + # print("dphi_dt.shape:",dphi_dt.shape) + # print("x.shape:",x.shape) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # print("y.shape:",y.shape) + #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + # print("model_input.shape:",model_input.shape) + # print("attention_mask.shape:",attention_mask.shape) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer = out.hidden_states[self.ssl_layer] + hidden_proj = self.mlp(hidden_layer) + # print("hidden_proj.shape:",hidden_proj.shape) + # print("mert_emb.shape:",mert_emb.shape) + # exit() + + + out = out.last_hidden_state + + out=out[:,:,-len_x:] + # out=self.proj_out(out) + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + # print("out.shape",out.shape) + # print("u.shape",u.shape) + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + # print("hidden_proj.shape:",hidden_proj.shape) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + # print("loss_cos:",loss_cos,loss_cos.device) + print("loss:",loss,loss.device) + # exit() + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + hubert_layer=None, + ssl_layer=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + self.hubert_layer = hubert_layer + self.ssl_layer = ssl_layer + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq', + checkpoint_dir='ckpt/encode-s12k.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + for v in self.rvq_bestrq_emb.parameters():v.requires_grad = False + self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(1200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) + self.mask_emb = torch.nn.Embedding(3, 48) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + # self.unet.set_attn_processor(AttnProcessor2_0()) + # self.unet.set_use_memory_efficient_attention_xformers(True) + + # self.start_embedding = nn.Parameter(torch.randn(1,1024)) + # self.end_embedding = nn.Parameter(torch.randn(1,1024)) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + # print(wav2vec_embeds) + # print("audio.shape:",input_audios.shape) + wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] + # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): + self.bestrq.eval() + # print("audio shape:",input_audio_0.shape) + input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 + # print("input_wav_mean.shape:",input_wav_mean.shape) + # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + # print("layer_results.shape:",layer_results[layer].shape) + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + #[b,t,1024] t=t/960 + #35.84s->batch,896,1024 + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer=5): + if not hasattr(self,"device"): + self.device = input_audios.device + if not hasattr(self,"dtype"): + self.dtype = input_audios.dtype + device = self.device + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 + # energy_embedding = self.extract_energy_bar(input_audios) + # print("energy_embedding.shape:",energy_embedding.shape) + # with autocast(enabled=False): + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) + + bestrq_emb = bestrq_emb.detach() + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + # prompt_embeds_13 = torch.cat([mert_emb_13, energy_embedding], 1) + # print("prompt_embes.shape:",prompt_embeds.shape) + #prompt_embes.shape: torch.Size([3, 1088, 896]) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + #wav2vec_embeds.shape:torch.Size([3, 1024, 896]) + if(train_rvq): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb + codebook_loss = codebook_loss_bestrq_emb + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + # print("latents.shape:",latents.shape) + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + # print("incontext_latents.shape:",incontext_latents.shape) + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + latent_mask_input = self.mask_emb(latent_masks) + #64+48+64+1024 + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios, additional_feats,layer): + input_audio_0 = input_audios[[0],:] + input_audio_1 = input_audios[[1],:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + + @torch.no_grad() + def fetch_codes_batch(self, input_audios, additional_feats,layer): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb = codes[0] + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb],1) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + @torch.no_grad() + def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + import time + start = time.time() + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents,time.time()-start + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/model_2rvq.py b/codeclm/tokenizer/Flow1dVAE/model_2rvq.py index f1ac8c206dd6579af434b67ba0d0aa73c671dc5c..d9f3644d88a28d798527a1b6de19ee81a2d24ddb 100644 --- a/codeclm/tokenizer/Flow1dVAE/model_2rvq.py +++ b/codeclm/tokenizer/Flow1dVAE/model_2rvq.py @@ -1,774 +1,774 @@ -import yaml -import random -import inspect -import numpy as np -from tqdm import tqdm -import typing as tp -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio - -from einops import repeat -from tools.torch_tools import wav_to_fbank - -import diffusers -from diffusers.utils.torch_utils import randn_tensor -from diffusers import DDPMScheduler -from models.transformer_2d_flow import Transformer2DModel -from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel -# from tools.get_mulan import get_mulan -from third_party.wespeaker.extract_embd import XVECModel -# from libs.rvq2 import RVQEmbedding -from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize - -from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model -from models_gpt.models.gpt2_config import GPT2Config - -from torch.cuda.amp import autocast - - -from our_MERT_BESTRQ.test import load_model - -class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - - # The final projection layer is only used for backward compatibility. - # Following https://github.com/auspicious3000/contentvec/issues/6 - # Remove this layer is necessary to achieve the desired outcome. - print("hidden_size:",config.hidden_size) - print("classifier_proj_size:",config.classifier_proj_size) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - """Project back from diffusion space to the actual sample space.""" - return z - -class Feature1DProcessor(SampleProcessor): - def __init__(self, dim: int = 100, power_std = 1., \ - num_samples: int = 100_000, cal_num_frames: int = 600): - super().__init__() - - self.num_samples = num_samples - self.dim = dim - self.power_std = power_std - self.cal_num_frames = cal_num_frames - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(dim)) - self.register_buffer('sum_x2', torch.zeros(dim)) - self.register_buffer('sum_target_x2', torch.zeros(dim)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - if(self.counts < 10): - mean = torch.zeros_like(mean) - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - if(self.counts < 10): - std = torch.ones_like(std) - return std - - @property - def target_std(self): - return 1 - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - if self.counts.item() < self.num_samples: - self.counts += len(x) - self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) - self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) - return x - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - rescale = (self.std / self.target_std) ** self.power_std - # print(rescale, self.mean) - x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) - return x - -def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): - if(prior_text_encoder_hidden_states.shape[1] 1.0): - - model_input = torch.cat([ \ - torch.cat([latent_mask_input, latent_mask_input], 0), \ - torch.cat([incontext_x, incontext_x], 0), \ - torch.cat([torch.zeros_like(mu), mu], 0), \ - torch.cat([x, x], 0), \ - ], 2) - timestep=t.unsqueeze(-1).repeat(2) - - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) - dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) - else: - model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) - timestep=t.unsqueeze(-1) - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - - dphi_dt = dphi_dt[: ,:, -x.shape[2]:] - # print("dphi_dt.shape:",dphi_dt.shape) - # print("x.shape:",x.shape) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def projection_loss(self,hidden_proj, bestrq_emb): - bsz = hidden_proj.shape[0] - - hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) - bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) - - proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) - proj_loss = 1+proj_loss.mean() - - return proj_loss - - def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_channels, mel_timesteps, n_feats) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_channels, mel_timesteps, n_feats) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_channels, mel_timesteps, n_feats) - """ - b = mu[0].shape[0] - len_x = x1.shape[2] - # random timestep - if(validation_mode): - t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 - else: - t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - # print("y.shape:",y.shape) - #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state - model_input = torch.cat([*mu,y], 2) - t=t.squeeze(-1).squeeze(-1) - # print("model_input.shape:",model_input.shape) - # print("attention_mask.shape:",attention_mask.shape) - out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) - hidden_layer = out.hidden_states[self.ssl_layer] - hidden_proj = self.mlp(hidden_layer) - # print("hidden_proj.shape:",hidden_proj.shape) - # print("mert_emb.shape:",mert_emb.shape) - # exit() - - - out = out.last_hidden_state - - out=out[:,:,-len_x:] - # out=self.proj_out(out) - - weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 - # print("out.shape",out.shape) - # print("u.shape",u.shape) - loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() - # print("hidden_proj.shape:",hidden_proj.shape) - # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) - loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) - loss = loss_re + loss_cos * 0.5 - # print("loss_cos:",loss_cos,loss_cos.device) - print("loss:",loss,loss.device) - # exit() - return loss, loss_re, loss_cos - -class PromptCondAudioDiffusion(nn.Module): - def __init__( - self, - num_channels, - unet_model_name=None, - unet_model_config_path=None, - snr_gamma=None, - hubert_layer=None, - ssl_layer=None, - uncondition=True, - out_paint=False, - ): - super().__init__() - - assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" - - self.unet_model_name = unet_model_name - self.unet_model_config_path = unet_model_config_path - self.snr_gamma = snr_gamma - self.uncondition = uncondition - self.num_channels = num_channels - self.hubert_layer = hubert_layer - self.ssl_layer = ssl_layer - - # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview - self.normfeat = Feature1DProcessor(dim=64) - - self.sample_rate = 48000 - self.num_samples_perseg = self.sample_rate * 20 // 1000 - self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) - self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) - # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - self.bestrq = load_model( - model_dir='path/to/our-MERT/mert_fairseq', - checkpoint_dir='checkpoint-120000.pt', - ) - self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) - self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) - for v in self.bestrq.parameters():v.requires_grad = False - self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - # for v in self.rvq_bestrq_emb.parameters(): - # print(v) - freeze_parameters='quantizers.0' - for name, param in self.rvq_bestrq_emb.named_parameters(): - if freeze_parameters in name: - param.requires_grad = False - print("Freezing RVQ parameters:", name) - self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") - for v in self.hubert.parameters():v.requires_grad = False - self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) - # self.xvecmodel = XVECModel() - config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) - unet = GPT2Model(config) - mlp = nn.Sequential( - nn.Linear(1200, 1024), - nn.SiLU(), - nn.Linear(1024, 1024), - nn.SiLU(), - nn.Linear(1024, 768) - ) - self.set_from = "random" - self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) - self.mask_emb = torch.nn.Embedding(3, 48) - print("Transformer initialized from pretrain.") - torch.cuda.empty_cache() - # self.unet.set_attn_processor(AttnProcessor2_0()) - # self.unet.set_use_memory_efficient_attention_xformers(True) - - # self.start_embedding = nn.Parameter(torch.randn(1,1024)) - # self.end_embedding = nn.Parameter(torch.randn(1,1024)) - - def compute_snr(self, timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = self.noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - def preprocess_audio(self, input_audios, threshold=0.9): - assert len(input_audios.shape) == 2, input_audios.shape - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios/norm_value.unsqueeze(-1) - - def extract_wav2vec_embeds(self, input_audios,output_len): - wav2vec_stride = 2 - - wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 - # print(wav2vec_embeds) - # print("audio.shape:",input_audios.shape) - wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] - # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) - wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) - return wav2vec_embeds_last - - def extract_mert_embeds(self, input_audios): - prompt_stride = 3 - inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 - mert_emb= prompt_embeds[-1] - mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) - - return mert_emb - - def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): - self.bestrq.eval() - # print("audio shape:",input_audio_0.shape) - input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 - # print("input_wav_mean.shape:",input_wav_mean.shape) - # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) - input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) - layer_results = input_wav_mean['layer_results'] - # print("layer_results.shape:",layer_results[layer].shape) - bestrq_emb = layer_results[layer] - bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() - #[b,t,1024] t=t/960 - #35.84s->batch,896,1024 - return bestrq_emb - - - def extract_spk_embeds(self, input_audios): - spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) - spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) - return spk_embeds - - def extract_lyric_feats(self, lyric): - with torch.no_grad(): - try: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) - except: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) - text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) - text_mask = text_mask.to(self.device) - text_encoder_hidden_states, text_mask, text_prompt_embeds = \ - pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() - return text_encoder_hidden_states, text_mask - - def extract_energy_bar(self, input_audios): - if(input_audios.shape[-1] % self.num_samples_perseg > 0): - energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) - else: - energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) - energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T - energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() - energy_embedding = self.energy_embedding(energy_bar) - energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t - return energy_embedding - - def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ - additional_feats = ['spk', 'lyric'], \ - train_rvq=True, train_ssl=False,layer=5): - if not hasattr(self,"device"): - self.device = input_audios.device - if not hasattr(self,"dtype"): - self.dtype = input_audios.dtype - device = self.device - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 - # energy_embedding = self.extract_energy_bar(input_audios) - # print("energy_embedding.shape:",energy_embedding.shape) - # with autocast(enabled=False): - if(train_ssl): - self.wav2vec.train() - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) - self.clap_embd_extractor.train() - prompt_embeds = self.extract_mert_embeds(input_audios) - if('spk' in additional_feats): - self.xvecmodel.train() - spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) - else: - with torch.no_grad(): - with autocast(enabled=False): - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - # mert_emb = self.extract_mert_embeds(input_audios_mert) - - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) - - bestrq_emb = bestrq_emb.detach() - if('lyric' in additional_feats): - text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) - else: - text_encoder_hidden_states, text_mask = None, None - - - if(train_rvq): - random_num=random.random() - if(random_num<0.6): - rvq_layer = 1 - elif(random_num<0.8): - rvq_layer = 2 - else: - rvq_layer = 4 - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t - else: - bestrq_emb = bestrq_emb.float() - self.rvq_bestrq_emb.eval() - # with autocast(enabled=False): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() - codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() - quantized_bestrq_emb = quantized_bestrq_emb.detach() - - commitment_loss = commitment_loss_bestrq_emb - codebook_loss = codebook_loss_bestrq_emb - - - alpha=1 - quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) - - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - scenario = np.random.choice(['start_seg', 'other_seg']) - if(scenario == 'other_seg'): - for binx in range(input_audios.shape[0]): - # latent_masks[binx,0:64] = 1 - latent_masks[binx,0:random.randint(64,128)] = 1 - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - - - - - if self.uncondition: - mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] - if len(mask_indices) > 0: - quantized_bestrq_emb[mask_indices] = 0 - # print("latents.shape:",latents.shape) - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.project_sample(latents) - latents = latents.permute(0,2,1).contiguous() - incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - # print("incontext_latents.shape:",incontext_latents.shape) - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - latent_mask_input = self.mask_emb(latent_masks) - #64+48+64+1024 - loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) - return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() - - def init_device_dtype(self, device, dtype): - self.device = device - self.dtype = dtype - - @torch.no_grad() - def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[[0],:] - input_audio_1 = input_audios[[1],:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, - guidance_scale=2, num_steps=20, - disable_progress=True, scenario='start_seg'): - classifier_free_guidance = guidance_scale > 1.0 - device = self.device - dtype = self.dtype - # codes_bestrq_middle, codes_bestrq_last = codes - codes_bestrq_emb = codes[0] - - - batch_size = codes_bestrq_emb.shape[0] - - - quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - - if('spk' in additional_feats): - spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() - - num_frames = quantized_bestrq_emb.shape[1] - - num_channels_latents = self.num_channels - shape = (batch_size, num_frames, 64) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - - - - latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) - latent_masks[:,0:latent_length] = 2 - if(scenario=='other_seg'): - latent_masks[:,0:incontext_length] = 1 - - - - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - true_latents = true_latents.permute(0,2,1).contiguous() - true_latents = self.normfeat.project_sample(true_latents) - true_latents = true_latents.permute(0,2,1).contiguous() - incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] - - - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - - if('spk' in additional_feats): - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) - additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) - else: - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) - additional_model_input = torch.cat([quantized_bestrq_emb],1) - - temperature = 1.0 - t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) - latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) - - latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.return_sample(latents) - # latents = latents.permute(0,2,1).contiguous() - return latents - - @torch.no_grad() - def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) - - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents - - @torch.no_grad() - def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) - import time - start = time.time() - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents,time.time()-start - - def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): - divisor = 4 - shape = (batch_size, num_channels_latents, num_frames, 32) - if(num_frames%divisor>0): - num_frames = round(num_frames/float(divisor))*divisor - shape = (batch_size, num_channels_latents, num_frames, 32) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - return latents - - +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from einops import repeat +from tools.torch_tools import wav_to_fbank + +import diffusers +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DDPMScheduler +from models.transformer_2d_flow import Transformer2DModel +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel +# from tools.get_mulan import get_mulan +from third_party.wespeaker.extract_embd import XVECModel +# from libs.rvq2 import RVQEmbedding +from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast + + +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + # print(rescale, self.mean) + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + # print("dphi_dt.shape:",dphi_dt.shape) + # print("x.shape:",x.shape) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # print("y.shape:",y.shape) + #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + # print("model_input.shape:",model_input.shape) + # print("attention_mask.shape:",attention_mask.shape) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer = out.hidden_states[self.ssl_layer] + hidden_proj = self.mlp(hidden_layer) + # print("hidden_proj.shape:",hidden_proj.shape) + # print("mert_emb.shape:",mert_emb.shape) + # exit() + + + out = out.last_hidden_state + + out=out[:,:,-len_x:] + # out=self.proj_out(out) + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + # print("out.shape",out.shape) + # print("u.shape",u.shape) + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + # print("hidden_proj.shape:",hidden_proj.shape) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + # print("loss_cos:",loss_cos,loss_cos.device) + print("loss:",loss,loss.device) + # exit() + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + hubert_layer=None, + ssl_layer=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + self.hubert_layer = hubert_layer + self.ssl_layer = ssl_layer + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='path/to/our-MERT/mert_fairseq', + checkpoint_dir='checkpoint-120000.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 2, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + # for v in self.rvq_bestrq_emb.parameters(): + # print(v) + freeze_parameters='quantizers.0' + for name, param in self.rvq_bestrq_emb.named_parameters(): + if freeze_parameters in name: + param.requires_grad = False + print("Freezing RVQ parameters:", name) + self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(1200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) + self.mask_emb = torch.nn.Embedding(3, 48) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + # self.unet.set_attn_processor(AttnProcessor2_0()) + # self.unet.set_use_memory_efficient_attention_xformers(True) + + # self.start_embedding = nn.Parameter(torch.randn(1,1024)) + # self.end_embedding = nn.Parameter(torch.randn(1,1024)) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + # print(wav2vec_embeds) + # print("audio.shape:",input_audios.shape) + wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] + # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): + self.bestrq.eval() + # print("audio shape:",input_audio_0.shape) + input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 + # print("input_wav_mean.shape:",input_wav_mean.shape) + # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + # print("layer_results.shape:",layer_results[layer].shape) + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + #[b,t,1024] t=t/960 + #35.84s->batch,896,1024 + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer=5): + if not hasattr(self,"device"): + self.device = input_audios.device + if not hasattr(self,"dtype"): + self.dtype = input_audios.dtype + device = self.device + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 + # energy_embedding = self.extract_energy_bar(input_audios) + # print("energy_embedding.shape:",energy_embedding.shape) + # with autocast(enabled=False): + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) + + bestrq_emb = bestrq_emb.detach() + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + + if(train_rvq): + random_num=random.random() + if(random_num<0.6): + rvq_layer = 1 + elif(random_num<0.8): + rvq_layer = 2 + else: + rvq_layer = 4 + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb + codebook_loss = codebook_loss_bestrq_emb + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + # print("latents.shape:",latents.shape) + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + # print("incontext_latents.shape:",incontext_latents.shape) + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + latent_mask_input = self.mask_emb(latent_masks) + #64+48+64+1024 + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[[0],:] + input_audio_1 = input_audios[[1],:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb = codes[0] + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb],1) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + @torch.no_grad() + def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + import time + start = time.time() + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents,time.time()-start + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/model_4rvq.py b/codeclm/tokenizer/Flow1dVAE/model_4rvq.py index 1cb3ea89f0b3f91b56e734d0f42257d8b48199ad..09f61d5f589a51853110504c9ebb396093836ef9 100644 --- a/codeclm/tokenizer/Flow1dVAE/model_4rvq.py +++ b/codeclm/tokenizer/Flow1dVAE/model_4rvq.py @@ -1,774 +1,774 @@ -import yaml -import random -import inspect -import numpy as np -from tqdm import tqdm -import typing as tp -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio - -from einops import repeat -from tools.torch_tools import wav_to_fbank - -import diffusers -from diffusers.utils.torch_utils import randn_tensor -from diffusers import DDPMScheduler -from models.transformer_2d_flow import Transformer2DModel -from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel -# from tools.get_mulan import get_mulan -from third_party.wespeaker.extract_embd import XVECModel -# from libs.rvq2 import RVQEmbedding -from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize - -from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model -from models_gpt.models.gpt2_config import GPT2Config - -from torch.cuda.amp import autocast - - -from our_MERT_BESTRQ.test import load_model - -class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - - # The final projection layer is only used for backward compatibility. - # Following https://github.com/auspicious3000/contentvec/issues/6 - # Remove this layer is necessary to achieve the desired outcome. - print("hidden_size:",config.hidden_size) - print("classifier_proj_size:",config.classifier_proj_size) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - """Project back from diffusion space to the actual sample space.""" - return z - -class Feature1DProcessor(SampleProcessor): - def __init__(self, dim: int = 100, power_std = 1., \ - num_samples: int = 100_000, cal_num_frames: int = 600): - super().__init__() - - self.num_samples = num_samples - self.dim = dim - self.power_std = power_std - self.cal_num_frames = cal_num_frames - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(dim)) - self.register_buffer('sum_x2', torch.zeros(dim)) - self.register_buffer('sum_target_x2', torch.zeros(dim)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - if(self.counts < 10): - mean = torch.zeros_like(mean) - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - if(self.counts < 10): - std = torch.ones_like(std) - return std - - @property - def target_std(self): - return 1 - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - if self.counts.item() < self.num_samples: - self.counts += len(x) - self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) - self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) - return x - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - rescale = (self.std / self.target_std) ** self.power_std - # print(rescale, self.mean) - x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) - return x - -def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): - if(prior_text_encoder_hidden_states.shape[1] 1.0): - - model_input = torch.cat([ \ - torch.cat([latent_mask_input, latent_mask_input], 0), \ - torch.cat([incontext_x, incontext_x], 0), \ - torch.cat([torch.zeros_like(mu), mu], 0), \ - torch.cat([x, x], 0), \ - ], 2) - timestep=t.unsqueeze(-1).repeat(2) - - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) - dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) - else: - model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) - timestep=t.unsqueeze(-1) - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - - dphi_dt = dphi_dt[: ,:, -x.shape[2]:] - print("dphi_dt.shape:",dphi_dt.shape) - print("x.shape:",x.shape) - - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def projection_loss(self,hidden_proj, bestrq_emb): - bsz = hidden_proj.shape[0] - - hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) - bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) - - proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) - proj_loss = 1+proj_loss.mean() - - return proj_loss - - def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_channels, mel_timesteps, n_feats) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_channels, mel_timesteps, n_feats) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_channels, mel_timesteps, n_feats) - """ - b = mu[0].shape[0] - len_x = x1.shape[2] - # random timestep - if(validation_mode): - t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 - else: - t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - # print("y.shape:",y.shape) - #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state - model_input = torch.cat([*mu,y], 2) - t=t.squeeze(-1).squeeze(-1) - # print("model_input.shape:",model_input.shape) - # print("attention_mask.shape:",attention_mask.shape) - out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) - hidden_layer = out.hidden_states[self.ssl_layer] - hidden_proj = self.mlp(hidden_layer) - # print("hidden_proj.shape:",hidden_proj.shape) - # print("mert_emb.shape:",mert_emb.shape) - # exit() - - - out = out.last_hidden_state - - out=out[:,:,-len_x:] - # out=self.proj_out(out) - - weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 - # print("out.shape",out.shape) - # print("u.shape",u.shape) - loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() - # print("hidden_proj.shape:",hidden_proj.shape) - # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) - loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) - loss = loss_re + loss_cos * 0.5 - # print("loss_cos:",loss_cos,loss_cos.device) - print("loss:",loss,loss.device) - # exit() - return loss, loss_re, loss_cos - -class PromptCondAudioDiffusion(nn.Module): - def __init__( - self, - num_channels, - unet_model_name=None, - unet_model_config_path=None, - snr_gamma=None, - hubert_layer=None, - ssl_layer=None, - uncondition=True, - out_paint=False, - ): - super().__init__() - - assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" - - self.unet_model_name = unet_model_name - self.unet_model_config_path = unet_model_config_path - self.snr_gamma = snr_gamma - self.uncondition = uncondition - self.num_channels = num_channels - self.hubert_layer = hubert_layer - self.ssl_layer = ssl_layer - - # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview - self.normfeat = Feature1DProcessor(dim=64) - - self.sample_rate = 48000 - self.num_samples_perseg = self.sample_rate * 20 // 1000 - self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) - self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) - # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - self.bestrq = load_model( - model_dir='path/to/our-MERT/mert_fairseq', - checkpoint_dir='checkpoint-120000.pt', - ) - self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) - self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) - for v in self.bestrq.parameters():v.requires_grad = False - self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - # for v in self.rvq_bestrq_emb.parameters(): - # print(v) - freeze_parameters='quantizers.0' - for name, param in self.rvq_bestrq_emb.named_parameters(): - if freeze_parameters in name: - param.requires_grad = False - print("Freezing RVQ parameters:", name) - self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") - for v in self.hubert.parameters():v.requires_grad = False - self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) - # self.xvecmodel = XVECModel() - config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) - unet = GPT2Model(config) - mlp = nn.Sequential( - nn.Linear(1200, 1024), - nn.SiLU(), - nn.Linear(1024, 1024), - nn.SiLU(), - nn.Linear(1024, 768) - ) - self.set_from = "random" - self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) - self.mask_emb = torch.nn.Embedding(3, 48) - print("Transformer initialized from pretrain.") - torch.cuda.empty_cache() - # self.unet.set_attn_processor(AttnProcessor2_0()) - # self.unet.set_use_memory_efficient_attention_xformers(True) - - # self.start_embedding = nn.Parameter(torch.randn(1,1024)) - # self.end_embedding = nn.Parameter(torch.randn(1,1024)) - - def compute_snr(self, timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = self.noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - def preprocess_audio(self, input_audios, threshold=0.9): - assert len(input_audios.shape) == 2, input_audios.shape - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios/norm_value.unsqueeze(-1) - - def extract_wav2vec_embeds(self, input_audios,output_len): - wav2vec_stride = 2 - - wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 - # print(wav2vec_embeds) - # print("audio.shape:",input_audios.shape) - wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] - # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) - wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) - return wav2vec_embeds_last - - def extract_mert_embeds(self, input_audios): - prompt_stride = 3 - inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 - mert_emb= prompt_embeds[-1] - mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) - - return mert_emb - - def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): - self.bestrq.eval() - # print("audio shape:",input_audio_0.shape) - input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 - # print("input_wav_mean.shape:",input_wav_mean.shape) - # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) - input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) - layer_results = input_wav_mean['layer_results'] - # print("layer_results.shape:",layer_results[layer].shape) - bestrq_emb = layer_results[layer] - bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() - #[b,t,1024] t=t/960 - #35.84s->batch,896,1024 - return bestrq_emb - - - def extract_spk_embeds(self, input_audios): - spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) - spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) - return spk_embeds - - def extract_lyric_feats(self, lyric): - with torch.no_grad(): - try: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) - except: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) - text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) - text_mask = text_mask.to(self.device) - text_encoder_hidden_states, text_mask, text_prompt_embeds = \ - pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() - return text_encoder_hidden_states, text_mask - - def extract_energy_bar(self, input_audios): - if(input_audios.shape[-1] % self.num_samples_perseg > 0): - energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) - else: - energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) - energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T - energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() - energy_embedding = self.energy_embedding(energy_bar) - energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t - return energy_embedding - - def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ - additional_feats = ['spk', 'lyric'], \ - train_rvq=True, train_ssl=False,layer=5): - if not hasattr(self,"device"): - self.device = input_audios.device - if not hasattr(self,"dtype"): - self.dtype = input_audios.dtype - device = self.device - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 - # energy_embedding = self.extract_energy_bar(input_audios) - # print("energy_embedding.shape:",energy_embedding.shape) - # with autocast(enabled=False): - if(train_ssl): - self.wav2vec.train() - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) - self.clap_embd_extractor.train() - prompt_embeds = self.extract_mert_embeds(input_audios) - if('spk' in additional_feats): - self.xvecmodel.train() - spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) - else: - with torch.no_grad(): - with autocast(enabled=False): - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - # mert_emb = self.extract_mert_embeds(input_audios_mert) - - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) - - bestrq_emb = bestrq_emb.detach() - if('lyric' in additional_feats): - text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) - else: - text_encoder_hidden_states, text_mask = None, None - - - if(train_rvq): - random_num=random.random() - if(random_num<0.6): - rvq_layer = 1 - elif(random_num<0.8): - rvq_layer = 2 - else: - rvq_layer = 4 - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t - else: - bestrq_emb = bestrq_emb.float() - self.rvq_bestrq_emb.eval() - # with autocast(enabled=False): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() - codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() - quantized_bestrq_emb = quantized_bestrq_emb.detach() - - commitment_loss = commitment_loss_bestrq_emb - codebook_loss = codebook_loss_bestrq_emb - - - alpha=1 - quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) - - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - scenario = np.random.choice(['start_seg', 'other_seg']) - if(scenario == 'other_seg'): - for binx in range(input_audios.shape[0]): - # latent_masks[binx,0:64] = 1 - latent_masks[binx,0:random.randint(64,128)] = 1 - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) - # print("latent_masks.shape:",latent_masks.shape) - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - - - - - if self.uncondition: - mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] - if len(mask_indices) > 0: - quantized_bestrq_emb[mask_indices] = 0 - # print("latents.shape:",latents.shape) - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.project_sample(latents) - latents = latents.permute(0,2,1).contiguous() - incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - # print("incontext_latents.shape:",incontext_latents.shape) - # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - latent_mask_input = self.mask_emb(latent_masks) - #64+48+64+1024 - loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) - return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() - - def init_device_dtype(self, device, dtype): - self.device = device - self.dtype = dtype - - @torch.no_grad() - def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[[0],:] - input_audio_1 = input_audios[[1],:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): - input_audio_0 = input_audios[:,0,:] - input_audio_1 = input_audios[:,1,:] - input_audio_0 = self.preprocess_audio(input_audio_0) - input_audio_1 = self.preprocess_audio(input_audio_1) - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) - bestrq_emb = bestrq_emb.detach() - - # self.rvq_bestrq_middle.eval() - # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t - # self.rvq_bestrq_last.eval() - # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t - - self.rvq_bestrq_emb.eval() - bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] - # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) - # exit() - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb], [bestrq_emb], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, - guidance_scale=2, num_steps=20, - disable_progress=True, scenario='start_seg'): - classifier_free_guidance = guidance_scale > 1.0 - device = self.device - dtype = self.dtype - # codes_bestrq_middle, codes_bestrq_last = codes - codes_bestrq_emb = codes[0] - - - batch_size = codes_bestrq_emb.shape[0] - - - quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) - # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) - - - - - if('spk' in additional_feats): - spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() - - num_frames = quantized_bestrq_emb.shape[1] - - num_channels_latents = self.num_channels - shape = (batch_size, num_frames, 64) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - - - - latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) - latent_masks[:,0:latent_length] = 2 - if(scenario=='other_seg'): - latent_masks[:,0:incontext_length] = 1 - - - - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - true_latents = true_latents.permute(0,2,1).contiguous() - true_latents = self.normfeat.project_sample(true_latents) - true_latents = true_latents.permute(0,2,1).contiguous() - incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] - - - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - - if('spk' in additional_feats): - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) - additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) - else: - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) - additional_model_input = torch.cat([quantized_bestrq_emb],1) - - temperature = 1.0 - t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) - latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) - - latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.return_sample(latents) - # latents = latents.permute(0,2,1).contiguous() - return latents - - @torch.no_grad() - def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) - - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents - - @torch.no_grad() - def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer=5,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) - import time - start = time.time() - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents,time.time()-start - - def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): - divisor = 4 - shape = (batch_size, num_channels_latents, num_frames, 32) - if(num_frames%divisor>0): - num_frames = round(num_frames/float(divisor))*divisor - shape = (batch_size, num_channels_latents, num_frames, 32) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - return latents - - +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from einops import repeat +from tools.torch_tools import wav_to_fbank + +import diffusers +from diffusers.utils.torch_utils import randn_tensor +from diffusers import DDPMScheduler +from models.transformer_2d_flow import Transformer2DModel +from transformers import AutoFeatureExtractor, Wav2Vec2BertModel,HubertModel +# from tools.get_mulan import get_mulan +from third_party.wespeaker.extract_embd import XVECModel +# from libs.rvq2 import RVQEmbedding +from libs.rvq.descript_quantize3_4layer_freezelayer1 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast + + +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + # print(rescale, self.mean) + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + print("dphi_dt.shape:",dphi_dt.shape) + print("x.shape:",x.shape) + + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + # print("y.shape:",y.shape) + #self.unet(inputs_embeds=model_input, attention_mask=attention_mask,encoder_hidden_states=text_embedding,encoder_attention_mask=txt_attn_mask,time_step=timesteps).last_hidden_state + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + # print("model_input.shape:",model_input.shape) + # print("attention_mask.shape:",attention_mask.shape) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer = out.hidden_states[self.ssl_layer] + hidden_proj = self.mlp(hidden_layer) + # print("hidden_proj.shape:",hidden_proj.shape) + # print("mert_emb.shape:",mert_emb.shape) + # exit() + + + out = out.last_hidden_state + + out=out[:,:,-len_x:] + # out=self.proj_out(out) + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + # print("out.shape",out.shape) + # print("u.shape",u.shape) + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + # print("hidden_proj.shape:",hidden_proj.shape) + # print("wav2vec_embeds.shape:",wav2vec_embeds.shape) + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + # print("loss_cos:",loss_cos,loss_cos.device) + print("loss:",loss,loss.device) + # exit() + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + hubert_layer=None, + ssl_layer=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + self.hubert_layer = hubert_layer + self.ssl_layer = ssl_layer + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='path/to/our-MERT/mert_fairseq', + checkpoint_dir='checkpoint-120000.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 4, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + # for v in self.rvq_bestrq_emb.parameters(): + # print(v) + freeze_parameters='quantizers.0' + for name, param in self.rvq_bestrq_emb.named_parameters(): + if freeze_parameters in name: + param.requires_grad = False + print("Freezing RVQ parameters:", name) + self.hubert = HubertModelWithFinalProj.from_pretrained("huggingface_cache/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=39,n_head=30,n_embd=1200) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(1200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp,self.ssl_layer) + self.mask_emb = torch.nn.Embedding(3, 48) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + # self.unet.set_attn_processor(AttnProcessor2_0()) + # self.unet.set_use_memory_efficient_attention_xformers(True) + + # self.start_embedding = nn.Parameter(torch.randn(1,1024)) + # self.end_embedding = nn.Parameter(torch.randn(1,1024)) + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + # print(wav2vec_embeds) + # print("audio.shape:",input_audios.shape) + wav2vec_embeds_last=wav2vec_embeds[self.hubert_layer] + # print("wav2vec_embeds_last.shape:",wav2vec_embeds_last.shape) + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=500, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_0,input_audio_1,layer): + self.bestrq.eval() + # print("audio shape:",input_audio_0.shape) + input_wav_mean = (input_audio_0 + input_audio_1) / 2.0 + # print("input_wav_mean.shape:",input_wav_mean.shape) + # input_wav_mean = torch.randn(2,1720320*2).to(input_audio_0.device) + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + # print("layer_results.shape:",layer_results[layer].shape) + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + #[b,t,1024] t=t/960 + #35.84s->batch,896,1024 + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer=5): + if not hasattr(self,"device"): + self.device = input_audios.device + if not hasattr(self,"dtype"): + self.dtype = input_audios.dtype + device = self.device + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + input_audios_wav2vec = (input_audio_0 + input_audio_1) / 2.0 + # energy_embedding = self.extract_energy_bar(input_audios) + # print("energy_embedding.shape:",energy_embedding.shape) + # with autocast(enabled=False): + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_wav2vec,bestrq_emb.shape[2]) + + bestrq_emb = bestrq_emb.detach() + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + + if(train_rvq): + random_num=random.random() + if(random_num<0.6): + rvq_layer = 1 + elif(random_num<0.8): + rvq_layer = 2 + else: + rvq_layer = 4 + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb,n_quantizers=rvq_layer) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb + codebook_loss = codebook_loss_bestrq_emb + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # print("quantized_bestrq_emb1.shape:",quantized_bestrq_emb.shape) + # print("latent_masks.shape:",latent_masks.shape) + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + # print("latents.shape:",latents.shape) + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + # print("incontext_latents.shape:",incontext_latents.shape) + # print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + latent_mask_input = self.mask_emb(latent_masks) + #64+48+64+1024 + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[[0],:] + input_audio_1 = input_audios[[1],:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch(self, input_audios, additional_feats,layer,rvq_num=1): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch_ds(self, input_audios, additional_feats, layer, rvq_num=1, ds=250): + input_audio_0 = input_audios[:,0,:] + input_audio_1 = input_audios[:,1,:] + input_audio_0 = self.preprocess_audio(input_audio_0) + input_audio_1 = self.preprocess_audio(input_audio_1) + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_0,input_audio_1,layer) + bestrq_emb = bestrq_emb.detach() + + # self.rvq_bestrq_middle.eval() + # quantized_bestrq_middle, codes_bestrq_middle, *_ = self.rvq_bestrq_middle(bestrq_middle) # b,d,t + # self.rvq_bestrq_last.eval() + # quantized_bestrq_last, codes_bestrq_last, *_ = self.rvq_bestrq_last(bestrq_last) # b,d,t + + self.rvq_bestrq_emb.eval() + bestrq_emb = torch.nn.functional.avg_pool1d(bestrq_emb, kernel_size=ds, stride=ds) + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + codes_bestrq_emb = codes_bestrq_emb[:,:rvq_num,:] + # print("codes_bestrq_emb.shape:",codes_bestrq_emb.shape) + # exit() + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb], [bestrq_emb], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats, incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb = codes[0] + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + print("quantized_bestrq_emb.shape:",quantized_bestrq_emb.shape) + # quantized_bestrq_emb = torch.nn.functional.interpolate(quantized_bestrq_emb, size=(int(quantized_bestrq_emb.shape[-1]/999*937),), mode='linear', align_corners=True) + + + + + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb, spk_embeds],1) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb],1) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg',rvq_num=1): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer,rvq_num) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + @torch.no_grad() + def inference_rtf(self, input_audios, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer=5,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios, additional_feats,layer) + import time + start = time.time() + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents,time.time()-start + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/model_septoken.py b/codeclm/tokenizer/Flow1dVAE/model_septoken.py index 331f2a3fa23e5f8873c62b532ef41d0519e93d11..22c56db5385e17a86252e440e8f7f7849ce344a1 100644 --- a/codeclm/tokenizer/Flow1dVAE/model_septoken.py +++ b/codeclm/tokenizer/Flow1dVAE/model_septoken.py @@ -1,670 +1,670 @@ -import yaml -import random -import inspect -import numpy as np -from tqdm import tqdm -import typing as tp -from abc import ABC - -import torch -import torch.nn as nn -import torch.nn.functional as F -import torchaudio - -from einops import repeat -from tools.torch_tools import wav_to_fbank - -from diffusers.utils.torch_utils import randn_tensor -from transformers import HubertModel -from libs.rvq.descript_quantize3 import ResidualVectorQuantize - -from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model -from models_gpt.models.gpt2_config import GPT2Config - -from torch.cuda.amp import autocast -from our_MERT_BESTRQ.test import load_model - -class HubertModelWithFinalProj(HubertModel): - def __init__(self, config): - super().__init__(config) - - # The final projection layer is only used for backward compatibility. - # Following https://github.com/auspicious3000/contentvec/issues/6 - # Remove this layer is necessary to achieve the desired outcome. - print("hidden_size:",config.hidden_size) - print("classifier_proj_size:",config.classifier_proj_size) - self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) - - -class SampleProcessor(torch.nn.Module): - def project_sample(self, x: torch.Tensor): - """Project the original sample to the 'space' where the diffusion will happen.""" - """Project back from diffusion space to the actual sample space.""" - return z - -class Feature1DProcessor(SampleProcessor): - def __init__(self, dim: int = 100, power_std = 1., \ - num_samples: int = 100_000, cal_num_frames: int = 600): - super().__init__() - - self.num_samples = num_samples - self.dim = dim - self.power_std = power_std - self.cal_num_frames = cal_num_frames - self.register_buffer('counts', torch.zeros(1)) - self.register_buffer('sum_x', torch.zeros(dim)) - self.register_buffer('sum_x2', torch.zeros(dim)) - self.register_buffer('sum_target_x2', torch.zeros(dim)) - self.counts: torch.Tensor - self.sum_x: torch.Tensor - self.sum_x2: torch.Tensor - - @property - def mean(self): - mean = self.sum_x / self.counts - if(self.counts < 10): - mean = torch.zeros_like(mean) - return mean - - @property - def std(self): - std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() - if(self.counts < 10): - std = torch.ones_like(std) - return std - - @property - def target_std(self): - return 1 - - def project_sample(self, x: torch.Tensor): - assert x.dim() == 3 - if self.counts.item() < self.num_samples: - self.counts += len(x) - self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) - self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) - rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size - x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) - return x - - def return_sample(self, x: torch.Tensor): - assert x.dim() == 3 - rescale = (self.std / self.target_std) ** self.power_std - x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) - return x - -def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): - if(prior_text_encoder_hidden_states.shape[1] 1.0): - - model_input = torch.cat([ \ - torch.cat([latent_mask_input, latent_mask_input], 0), \ - torch.cat([incontext_x, incontext_x], 0), \ - torch.cat([torch.zeros_like(mu), mu], 0), \ - torch.cat([x, x], 0), \ - ], 2) - timestep=t.unsqueeze(-1).repeat(2) - - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) - dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) - else: - model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) - timestep=t.unsqueeze(-1) - dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state - - dphi_dt = dphi_dt[: ,:, -x.shape[2]:] - x = x + dt * dphi_dt - t = t + dt - sol.append(x) - if step < len(t_span) - 1: - dt = t_span[step + 1] - t - - return sol[-1] - - def projection_loss(self,hidden_proj, bestrq_emb): - bsz = hidden_proj.shape[0] - - hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) - bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) - - proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) - proj_loss = 1+proj_loss.mean() - - return proj_loss - - def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): - """Computes diffusion loss - - Args: - x1 (torch.Tensor): Target - shape: (batch_size, n_channels, mel_timesteps, n_feats) - mu (torch.Tensor): output of encoder - shape: (batch_size, n_channels, mel_timesteps, n_feats) - - Returns: - loss: conditional flow matching loss - y: conditional flow - shape: (batch_size, n_channels, mel_timesteps, n_feats) - """ - b = mu[0].shape[0] - len_x = x1.shape[2] - # random timestep - if(validation_mode): - t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 - else: - t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) - # sample noise p(x_0) - z = torch.randn_like(x1) - - y = (1 - (1 - self.sigma_min) * t) * z + t * x1 - u = x1 - (1 - self.sigma_min) * z - model_input = torch.cat([*mu,y], 2) - t=t.squeeze(-1).squeeze(-1) - out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) - hidden_layer_7 = out.hidden_states[7] - hidden_proj = self.mlp(hidden_layer_7) - out = out.last_hidden_state - out=out[:,:,-len_x:] - - weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 - loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() - loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) - loss = loss_re + loss_cos * 0.5 - return loss, loss_re, loss_cos - -class PromptCondAudioDiffusion(nn.Module): - def __init__( - self, - num_channels, - unet_model_name=None, - unet_model_config_path=None, - snr_gamma=None, - uncondition=True, - out_paint=False, - ): - super().__init__() - - assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" - - self.unet_model_name = unet_model_name - self.unet_model_config_path = unet_model_config_path - self.snr_gamma = snr_gamma - self.uncondition = uncondition - self.num_channels = num_channels - - # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview - self.normfeat = Feature1DProcessor(dim=64) - - self.sample_rate = 48000 - self.num_samples_perseg = self.sample_rate * 20 // 1000 - self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) - self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) - # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) - self.bestrq = load_model( - model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq', - checkpoint_dir='ckpt/encode-s12k.pt', - ) - self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) - self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) - for v in self.bestrq.parameters():v.requires_grad = False - self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) - self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") - for v in self.hubert.parameters():v.requires_grad = False - self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) - # self.xvecmodel = XVECModel() - config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400) - unet = GPT2Model(config) - mlp = nn.Sequential( - nn.Linear(2200, 1024), - nn.SiLU(), - nn.Linear(1024, 1024), - nn.SiLU(), - nn.Linear(1024, 768) - ) - self.set_from = "random" - self.cfm_wrapper = BASECFM(unet, mlp) - self.mask_emb = torch.nn.Embedding(3, 24) - print("Transformer initialized from pretrain.") - torch.cuda.empty_cache() - - def compute_snr(self, timesteps): - """ - Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 - """ - alphas_cumprod = self.noise_scheduler.alphas_cumprod - sqrt_alphas_cumprod = alphas_cumprod**0.5 - sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 - - # Expand the tensors. - # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 - sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] - alpha = sqrt_alphas_cumprod.expand(timesteps.shape) - - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() - while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): - sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] - sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) - - # Compute SNR. - snr = (alpha / sigma) ** 2 - return snr - - def preprocess_audio(self, input_audios, threshold=0.9): - assert len(input_audios.shape) == 2, input_audios.shape - norm_value = torch.ones_like(input_audios[:,0]) - max_volume = input_audios.abs().max(dim=-1)[0] - norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold - return input_audios/norm_value.unsqueeze(-1) - - def extract_wav2vec_embeds(self, input_audios,output_len): - wav2vec_stride = 2 - - wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 - wav2vec_embeds_last=wav2vec_embeds[-1] - wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) - return wav2vec_embeds_last - - def extract_mert_embeds(self, input_audios): - prompt_stride = 3 - inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") - input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) - prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 - mert_emb= prompt_embeds[-1] - mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1) - - return mert_emb - - def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer): - input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 - input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) - layer_results = input_wav_mean['layer_results'] - bestrq_emb = layer_results[layer] - bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() - return bestrq_emb - - - def extract_spk_embeds(self, input_audios): - spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) - spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) - return spk_embeds - - def extract_lyric_feats(self, lyric): - with torch.no_grad(): - try: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) - except: - text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) - text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) - text_mask = text_mask.to(self.device) - text_encoder_hidden_states, text_mask, text_prompt_embeds = \ - pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) - text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() - return text_encoder_hidden_states, text_mask - - def extract_energy_bar(self, input_audios): - if(input_audios.shape[-1] % self.num_samples_perseg > 0): - energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) - else: - energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) - energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T - energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() - energy_embedding = self.energy_embedding(energy_bar) - energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t - return energy_embedding - - def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \ - additional_feats = ['spk', 'lyric'], \ - train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7): - if not hasattr(self,"device"): - self.device = input_audios_vocal.device - if not hasattr(self,"dtype"): - self.dtype = input_audios_vocal.dtype - device = self.device - input_audio_vocal_0 = input_audios_vocal[:,0,:] - input_audio_vocal_1 = input_audios_vocal[:,1,:] - input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) - input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) - input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 - - input_audio_bgm_0 = input_audios_bgm[:,0,:] - input_audio_bgm_1 = input_audios_bgm[:,1,:] - input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) - input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) - input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 - - if(train_ssl): - self.wav2vec.train() - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) - self.clap_embd_extractor.train() - prompt_embeds = self.extract_mert_embeds(input_audios) - if('spk' in additional_feats): - self.xvecmodel.train() - spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) - else: - with torch.no_grad(): - with autocast(enabled=False): - bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) - bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) - # mert_emb = self.extract_mert_embeds(input_audios_mert) - output_len = bestrq_emb.shape[2] - wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len) - - - bestrq_emb = bestrq_emb.detach() - bestrq_emb_bgm = bestrq_emb_bgm.detach() - - if('lyric' in additional_feats): - text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) - else: - text_encoder_hidden_states, text_mask = None, None - - - if(train_rvq): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t - else: - bestrq_emb = bestrq_emb.float() - self.rvq_bestrq_emb.eval() - # with autocast(enabled=False): - quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() - codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() - quantized_bestrq_emb = quantized_bestrq_emb.detach() - - commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm - codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm - - - alpha=1 - quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) - quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha) - - - - - scenario = np.random.choice(['start_seg', 'other_seg']) - if(scenario == 'other_seg'): - for binx in range(input_audios_vocal.shape[0]): - # latent_masks[binx,0:64] = 1 - latent_masks[binx,0:random.randint(64,128)] = 1 - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - - - - - if self.uncondition: - mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] - if len(mask_indices) > 0: - quantized_bestrq_emb[mask_indices] = 0 - quantized_bestrq_emb_bgm[mask_indices] = 0 - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.project_sample(latents) - latents = latents.permute(0,2,1).contiguous() - incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) - return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() - - def init_device_dtype(self, device, dtype): - self.device = device - self.dtype = dtype - - @torch.no_grad() - def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): - input_audio_vocal_0 = input_audios_vocal[[0],:] - input_audio_vocal_1 = input_audios_vocal[[1],:] - input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) - input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) - input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 - - input_audio_bgm_0 = input_audios_bgm[[0],:] - input_audio_bgm_1 = input_audios_bgm[[1],:] - input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) - input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) - input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) - bestrq_emb = bestrq_emb.detach() - - bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) - bestrq_emb_bgm = bestrq_emb_bgm.detach() - - - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - - self.rvq_bestrq_bgm_emb.eval() - quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - @torch.no_grad() - def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): - input_audio_vocal_0 = input_audios_vocal[:,0,:] - input_audio_vocal_1 = input_audios_vocal[:,1,:] - input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) - input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) - input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 - - input_audio_bgm_0 = input_audios_bgm[:,0,:] - input_audio_bgm_1 = input_audios_bgm[:,1,:] - input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) - input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) - input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 - - self.bestrq.eval() - - # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) - # bestrq_middle = bestrq_middle.detach() - # bestrq_last = bestrq_last.detach() - bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) - bestrq_emb = bestrq_emb.detach() - - bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) - bestrq_emb_bgm = bestrq_emb_bgm.detach() - - - - self.rvq_bestrq_emb.eval() - quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t - - self.rvq_bestrq_bgm_emb.eval() - quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t - - - if('spk' in additional_feats): - self.xvecmodel.eval() - spk_embeds = self.extract_spk_embeds(input_audios) - else: - spk_embeds = None - - # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds - # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds - # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds - return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds - # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds - - - @torch.no_grad() - def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127, - guidance_scale=2, num_steps=20, - disable_progress=True, scenario='start_seg'): - classifier_free_guidance = guidance_scale > 1.0 - device = self.device - dtype = self.dtype - # codes_bestrq_middle, codes_bestrq_last = codes - codes_bestrq_emb,codes_bestrq_emb_bgm = codes - - - batch_size = codes_bestrq_emb.shape[0] - - - quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) - quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm) - quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() - quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() - if('spk' in additional_feats): - spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() - - num_frames = quantized_bestrq_emb.shape[1] - - num_channels_latents = self.num_channels - shape = (batch_size, num_frames, 64) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - - - - latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) - latent_masks[:,0:latent_length] = 2 - if(scenario=='other_seg'): - latent_masks[:,0:incontext_length] = 1 - - - - quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ - + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) - true_latents = true_latents.permute(0,2,1).contiguous() - true_latents = self.normfeat.project_sample(true_latents) - true_latents = true_latents.permute(0,2,1).contiguous() - incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() - incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] - - - attention_mask=(latent_masks > 0.5) - B, L = attention_mask.size() - attention_mask = attention_mask.view(B, 1, L) - attention_mask = attention_mask * attention_mask.transpose(-1, -2) - attention_mask = attention_mask.unsqueeze(1) - latent_mask_input = self.mask_emb(latent_masks) - - if('spk' in additional_feats): - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) - additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2) - else: - # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) - additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2) - - temperature = 1.0 - t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) - latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) - - latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] - latents = latents.permute(0,2,1).contiguous() - latents = self.normfeat.return_sample(latents) - # latents = latents.permute(0,2,1).contiguous() - return latents - - @torch.no_grad() - def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, - disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'): - codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm) - - latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ - guidance_scale=guidance_scale, num_steps=num_steps, \ - disable_progress=disable_progress,scenario=scenario) - return latents - - def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): - divisor = 4 - shape = (batch_size, num_channels_latents, num_frames, 32) - if(num_frames%divisor>0): - num_frames = round(num_frames/float(divisor))*divisor - shape = (batch_size, num_channels_latents, num_frames, 32) - latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) - return latents - - +import yaml +import random +import inspect +import numpy as np +from tqdm import tqdm +import typing as tp +from abc import ABC + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchaudio + +from einops import repeat +from tools.torch_tools import wav_to_fbank + +from diffusers.utils.torch_utils import randn_tensor +from transformers import HubertModel +from libs.rvq.descript_quantize3 import ResidualVectorQuantize + +from models_gpt.models.gpt2_rope2_time_new_correct_mask_noncasual_reflow import GPT2Model +from models_gpt.models.gpt2_config import GPT2Config + +from torch.cuda.amp import autocast +from our_MERT_BESTRQ.test import load_model + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + print("hidden_size:",config.hidden_size) + print("classifier_proj_size:",config.classifier_proj_size) + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class SampleProcessor(torch.nn.Module): + def project_sample(self, x: torch.Tensor): + """Project the original sample to the 'space' where the diffusion will happen.""" + """Project back from diffusion space to the actual sample space.""" + return z + +class Feature1DProcessor(SampleProcessor): + def __init__(self, dim: int = 100, power_std = 1., \ + num_samples: int = 100_000, cal_num_frames: int = 600): + super().__init__() + + self.num_samples = num_samples + self.dim = dim + self.power_std = power_std + self.cal_num_frames = cal_num_frames + self.register_buffer('counts', torch.zeros(1)) + self.register_buffer('sum_x', torch.zeros(dim)) + self.register_buffer('sum_x2', torch.zeros(dim)) + self.register_buffer('sum_target_x2', torch.zeros(dim)) + self.counts: torch.Tensor + self.sum_x: torch.Tensor + self.sum_x2: torch.Tensor + + @property + def mean(self): + mean = self.sum_x / self.counts + if(self.counts < 10): + mean = torch.zeros_like(mean) + return mean + + @property + def std(self): + std = (self.sum_x2 / self.counts - self.mean**2).clamp(min=0).sqrt() + if(self.counts < 10): + std = torch.ones_like(std) + return std + + @property + def target_std(self): + return 1 + + def project_sample(self, x: torch.Tensor): + assert x.dim() == 3 + if self.counts.item() < self.num_samples: + self.counts += len(x) + self.sum_x += x[:,:,0:self.cal_num_frames].mean(dim=(2,)).sum(dim=0) + self.sum_x2 += x[:,:,0:self.cal_num_frames].pow(2).mean(dim=(2,)).sum(dim=0) + rescale = (self.target_std / self.std.clamp(min=1e-12)) ** self.power_std # same output size + x = (x - self.mean.view(1, -1, 1)) * rescale.view(1, -1, 1) + return x + + def return_sample(self, x: torch.Tensor): + assert x.dim() == 3 + rescale = (self.std / self.target_std) ** self.power_std + x = x * rescale.view(1, -1, 1) + self.mean.view(1, -1, 1) + return x + +def pad_or_tunc_tolen(prior_text_encoder_hidden_states, prior_text_mask, prior_prompt_embeds, len_size=77): + if(prior_text_encoder_hidden_states.shape[1] 1.0): + + model_input = torch.cat([ \ + torch.cat([latent_mask_input, latent_mask_input], 0), \ + torch.cat([incontext_x, incontext_x], 0), \ + torch.cat([torch.zeros_like(mu), mu], 0), \ + torch.cat([x, x], 0), \ + ], 2) + timestep=t.unsqueeze(-1).repeat(2) + + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + dphi_dt_uncond, dhpi_dt_cond = dphi_dt.chunk(2,0) + dphi_dt = dphi_dt_uncond + guidance_scale * (dhpi_dt_cond - dphi_dt_uncond) + else: + model_input = torch.cat([latent_mask_input, incontext_x, mu, x], 2) + timestep=t.unsqueeze(-1) + dphi_dt = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=timestep).last_hidden_state + + dphi_dt = dphi_dt[: ,:, -x.shape[2]:] + x = x + dt * dphi_dt + t = t + dt + sol.append(x) + if step < len(t_span) - 1: + dt = t_span[step + 1] - t + + return sol[-1] + + def projection_loss(self,hidden_proj, bestrq_emb): + bsz = hidden_proj.shape[0] + + hidden_proj_normalized = F.normalize(hidden_proj, dim=-1) + bestrq_emb_normalized = F.normalize(bestrq_emb, dim=-1) + + proj_loss = -(hidden_proj_normalized * bestrq_emb_normalized).sum(dim=-1) + proj_loss = 1+proj_loss.mean() + + return proj_loss + + def compute_loss(self, x1, mu, latent_masks,attention_mask,wav2vec_embeds, validation_mode=False): + """Computes diffusion loss + + Args: + x1 (torch.Tensor): Target + shape: (batch_size, n_channels, mel_timesteps, n_feats) + mu (torch.Tensor): output of encoder + shape: (batch_size, n_channels, mel_timesteps, n_feats) + + Returns: + loss: conditional flow matching loss + y: conditional flow + shape: (batch_size, n_channels, mel_timesteps, n_feats) + """ + b = mu[0].shape[0] + len_x = x1.shape[2] + # random timestep + if(validation_mode): + t = torch.ones([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) * 0.5 + else: + t = torch.rand([b, 1, 1], device=mu[0].device, dtype=mu[0].dtype) + # sample noise p(x_0) + z = torch.randn_like(x1) + + y = (1 - (1 - self.sigma_min) * t) * z + t * x1 + u = x1 - (1 - self.sigma_min) * z + model_input = torch.cat([*mu,y], 2) + t=t.squeeze(-1).squeeze(-1) + out = self.estimator(inputs_embeds=model_input, attention_mask=attention_mask,time_step=t,output_hidden_states=True) + hidden_layer_7 = out.hidden_states[7] + hidden_proj = self.mlp(hidden_layer_7) + out = out.last_hidden_state + out=out[:,:,-len_x:] + + weight = (latent_masks > 1.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() + (latent_masks < 0.5).unsqueeze(-1).repeat(1, 1, out.shape[-1]).float() * 0.01 + loss_re = F.mse_loss(out * weight, u * weight, reduction="sum") / weight.sum() + loss_cos = self.projection_loss(hidden_proj, wav2vec_embeds) + loss = loss_re + loss_cos * 0.5 + return loss, loss_re, loss_cos + +class PromptCondAudioDiffusion(nn.Module): + def __init__( + self, + num_channels, + unet_model_name=None, + unet_model_config_path=None, + snr_gamma=None, + uncondition=True, + out_paint=False, + ): + super().__init__() + + assert unet_model_name is not None or unet_model_config_path is not None, "Either UNet pretrain model name or a config file path is required" + + self.unet_model_name = unet_model_name + self.unet_model_config_path = unet_model_config_path + self.snr_gamma = snr_gamma + self.uncondition = uncondition + self.num_channels = num_channels + + # https://huggingface.co/docs/diffusers/v0.14.0/en/api/schedulers/overview + self.normfeat = Feature1DProcessor(dim=64) + + self.sample_rate = 48000 + self.num_samples_perseg = self.sample_rate * 20 // 1000 + self.rsp48toclap = torchaudio.transforms.Resample(48000, 24000) + self.rsq48towav2vec = torchaudio.transforms.Resample(48000, 16000) + # self.wav2vec = Wav2Vec2BertModel.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + # self.wav2vec_processor = AutoFeatureExtractor.from_pretrained("facebook/w2v-bert-2.0", trust_remote_code=True) + self.bestrq = load_model( + model_dir='codeclm/tokenizer/Flow1dVAE/our_MERT_BESTRQ/mert_fairseq', + checkpoint_dir='ckpt/encode-s12k.pt', + ) + self.rsq48tobestrq = torchaudio.transforms.Resample(48000, 24000) + self.rsq48tohubert = torchaudio.transforms.Resample(48000, 16000) + for v in self.bestrq.parameters():v.requires_grad = False + self.rvq_bestrq_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + self.rvq_bestrq_bgm_emb = ResidualVectorQuantize(input_dim = 1024, n_codebooks = 1, codebook_size = 16_384, codebook_dim = 32, quantizer_dropout = 0.0, stale_tolerance=200) + self.hubert = HubertModelWithFinalProj.from_pretrained("ckpt/models--lengyue233--content-vec-best/snapshots/c0b9ba13db21beaa4053faae94c102ebe326fd68") + for v in self.hubert.parameters():v.requires_grad = False + self.zero_cond_embedding1 = nn.Parameter(torch.randn(32*32,)) + # self.xvecmodel = XVECModel() + config = GPT2Config(n_positions=1000,n_layer=16,n_head=20,n_embd=2200,n_inner=4400) + unet = GPT2Model(config) + mlp = nn.Sequential( + nn.Linear(2200, 1024), + nn.SiLU(), + nn.Linear(1024, 1024), + nn.SiLU(), + nn.Linear(1024, 768) + ) + self.set_from = "random" + self.cfm_wrapper = BASECFM(unet, mlp) + self.mask_emb = torch.nn.Embedding(3, 24) + print("Transformer initialized from pretrain.") + torch.cuda.empty_cache() + + def compute_snr(self, timesteps): + """ + Computes SNR as per https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = self.noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + def preprocess_audio(self, input_audios, threshold=0.9): + assert len(input_audios.shape) == 2, input_audios.shape + norm_value = torch.ones_like(input_audios[:,0]) + max_volume = input_audios.abs().max(dim=-1)[0] + norm_value[max_volume>threshold] = max_volume[max_volume>threshold] / threshold + return input_audios/norm_value.unsqueeze(-1) + + def extract_wav2vec_embeds(self, input_audios,output_len): + wav2vec_stride = 2 + + wav2vec_embeds = self.hubert(self.rsq48tohubert(input_audios), output_hidden_states=True).hidden_states # 1, 4096, 1024 + wav2vec_embeds_last=wav2vec_embeds[-1] + wav2vec_embeds_last=torch.nn.functional.interpolate(wav2vec_embeds_last.permute(0, 2, 1), size=output_len, mode='linear', align_corners=False).permute(0, 2, 1) + return wav2vec_embeds_last + + def extract_mert_embeds(self, input_audios): + prompt_stride = 3 + inputs = self.clap_embd_extractor.mulan.audio.processor(self.rsp48toclap(input_audios), sampling_rate=self.clap_embd_extractor.mulan.audio.sr, return_tensors="pt") + input_values = inputs['input_values'].squeeze(0).to(input_audios.device, dtype = input_audios.dtype) + prompt_embeds = self.clap_embd_extractor.mulan.audio.model(input_values, output_hidden_states=True).hidden_states # batch_size, Time steps, 1024 + mert_emb= prompt_embeds[-1] + mert_emb = torch.nn.functional.interpolate(mert_emb.permute(0, 2, 1), size=375, mode='linear', align_corners=False).permute(0, 2, 1) + + return mert_emb + + def extract_bestrq_embeds(self, input_audio_vocal_0,input_audio_vocal_1,layer): + input_wav_mean = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + input_wav_mean = self.bestrq(self.rsq48tobestrq(input_wav_mean), features_only = True) + layer_results = input_wav_mean['layer_results'] + bestrq_emb = layer_results[layer] + bestrq_emb = bestrq_emb.permute(0,2,1).contiguous() + return bestrq_emb + + + def extract_spk_embeds(self, input_audios): + spk_embeds = self.xvecmodel(self.rsq48towav2vec(input_audios)) + spk_embeds = self.spk_linear(spk_embeds).reshape(spk_embeds.shape[0], 16, 1, 32) + return spk_embeds + + def extract_lyric_feats(self, lyric): + with torch.no_grad(): + try: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = lyric, return_one=False) + except: + text_encoder_hidden_states, text_mask, text_prompt_embeds = self.clap_embd_extractor(texts = [""] * len(lyric), return_one=False) + text_encoder_hidden_states = text_encoder_hidden_states.to(self.device) + text_mask = text_mask.to(self.device) + text_encoder_hidden_states, text_mask, text_prompt_embeds = \ + pad_or_tunc_tolen(text_encoder_hidden_states, text_mask, text_prompt_embeds) + text_encoder_hidden_states = text_encoder_hidden_states.permute(0,2,1).contiguous() + return text_encoder_hidden_states, text_mask + + def extract_energy_bar(self, input_audios): + if(input_audios.shape[-1] % self.num_samples_perseg > 0): + energy_bar = input_audios[:,:-1 * (input_audios.shape[-1] % self.num_samples_perseg)].reshape(input_audios.shape[0],-1,self.num_samples_perseg) + else: + energy_bar = input_audios.reshape(input_audios.shape[0],-1,self.num_samples_perseg) + energy_bar = (energy_bar.pow(2.0).mean(-1).sqrt() + 1e-6).log10() * 20 # B T + energy_bar = (energy_bar / 2.0 + 16).clamp(0,16).int() + energy_embedding = self.energy_embedding(energy_bar) + energy_embedding = energy_embedding.view(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 2, 32).reshape(energy_embedding.shape[0], energy_embedding.shape[1] // 2, 64).permute(0,2,1) # b 128 t + return energy_embedding + + def forward(self, input_audios_vocal,input_audios_bgm, lyric, latents, latent_masks, validation_mode=False, \ + additional_feats = ['spk', 'lyric'], \ + train_rvq=True, train_ssl=False,layer_vocal=7,layer_bgm=7): + if not hasattr(self,"device"): + self.device = input_audios_vocal.device + if not hasattr(self,"dtype"): + self.dtype = input_audios_vocal.dtype + device = self.device + input_audio_vocal_0 = input_audios_vocal[:,0,:] + input_audio_vocal_1 = input_audios_vocal[:,1,:] + input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) + input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) + input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + + input_audio_bgm_0 = input_audios_bgm[:,0,:] + input_audio_bgm_1 = input_audios_bgm[:,1,:] + input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) + input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) + input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 + + if(train_ssl): + self.wav2vec.train() + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios) + self.clap_embd_extractor.train() + prompt_embeds = self.extract_mert_embeds(input_audios) + if('spk' in additional_feats): + self.xvecmodel.train() + spk_embeds = self.extract_spk_embeds(input_audios).repeat(1,1,prompt_embeds.shape[-1]//2,1) + else: + with torch.no_grad(): + with autocast(enabled=False): + bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) + bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) + # mert_emb = self.extract_mert_embeds(input_audios_mert) + output_len = bestrq_emb.shape[2] + wav2vec_embeds = self.extract_wav2vec_embeds(input_audios_vocal_wav2vec+input_audios_bgm_wav2vec,output_len) + + + bestrq_emb = bestrq_emb.detach() + bestrq_emb_bgm = bestrq_emb_bgm.detach() + + if('lyric' in additional_feats): + text_encoder_hidden_states, text_mask = self.extract_lyric_feats(lyric) + else: + text_encoder_hidden_states, text_mask = None, None + + + if(train_rvq): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + quantized_bestrq_emb_bgm, _, _, commitment_loss_bestrq_emb_bgm, codebook_loss_bestrq_emb_bgm,_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t + else: + bestrq_emb = bestrq_emb.float() + self.rvq_bestrq_emb.eval() + # with autocast(enabled=False): + quantized_bestrq_emb, _, _, commitment_loss_bestrq_emb, codebook_loss_bestrq_emb,_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + commitment_loss_bestrq_emb = commitment_loss_bestrq_emb.detach() + codebook_loss_bestrq_emb = codebook_loss_bestrq_emb.detach() + quantized_bestrq_emb = quantized_bestrq_emb.detach() + + commitment_loss = commitment_loss_bestrq_emb+commitment_loss_bestrq_emb_bgm + codebook_loss = codebook_loss_bestrq_emb+codebook_loss_bestrq_emb_bgm + + + alpha=1 + quantized_bestrq_emb = quantized_bestrq_emb * alpha + bestrq_emb * (1-alpha) + quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm * alpha + bestrq_emb_bgm * (1-alpha) + + + + + scenario = np.random.choice(['start_seg', 'other_seg']) + if(scenario == 'other_seg'): + for binx in range(input_audios_vocal.shape[0]): + # latent_masks[binx,0:64] = 1 + latent_masks[binx,0:random.randint(64,128)] = 1 + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + + + + + if self.uncondition: + mask_indices = [k for k in range(quantized_bestrq_emb.shape[0]) if random.random() < 0.1] + if len(mask_indices) > 0: + quantized_bestrq_emb[mask_indices] = 0 + quantized_bestrq_emb_bgm[mask_indices] = 0 + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.project_sample(latents) + latents = latents.permute(0,2,1).contiguous() + incontext_latents = latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + loss,loss_re, loss_cos = self.cfm_wrapper.compute_loss(latents, [latent_mask_input,incontext_latents, quantized_bestrq_emb,quantized_bestrq_emb_bgm], latent_masks,attention_mask,wav2vec_embeds, validation_mode=validation_mode) + return loss,loss_re, loss_cos, commitment_loss.mean(), codebook_loss.mean() + + def init_device_dtype(self, device, dtype): + self.device = device + self.dtype = dtype + + @torch.no_grad() + def fetch_codes(self, input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): + input_audio_vocal_0 = input_audios_vocal[[0],:] + input_audio_vocal_1 = input_audios_vocal[[1],:] + input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) + input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) + input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + + input_audio_bgm_0 = input_audios_bgm[[0],:] + input_audio_bgm_1 = input_audios_bgm[[1],:] + input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) + input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) + input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) + bestrq_emb = bestrq_emb.detach() + + bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) + bestrq_emb_bgm = bestrq_emb_bgm.detach() + + + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + self.rvq_bestrq_bgm_emb.eval() + quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + @torch.no_grad() + def fetch_codes_batch(self, input_audios_vocal, input_audios_bgm, additional_feats,layer_vocal=7,layer_bgm=7): + input_audio_vocal_0 = input_audios_vocal[:,0,:] + input_audio_vocal_1 = input_audios_vocal[:,1,:] + input_audio_vocal_0 = self.preprocess_audio(input_audio_vocal_0) + input_audio_vocal_1 = self.preprocess_audio(input_audio_vocal_1) + input_audios_vocal_wav2vec = (input_audio_vocal_0 + input_audio_vocal_1) / 2.0 + + input_audio_bgm_0 = input_audios_bgm[:,0,:] + input_audio_bgm_1 = input_audios_bgm[:,1,:] + input_audio_bgm_0 = self.preprocess_audio(input_audio_bgm_0) + input_audio_bgm_1 = self.preprocess_audio(input_audio_bgm_1) + input_audios_bgm_wav2vec = (input_audio_bgm_0 + input_audio_bgm_1) / 2.0 + + self.bestrq.eval() + + # bestrq_middle,bestrq_last = self.extract_bestrq_embeds(input_audios) + # bestrq_middle = bestrq_middle.detach() + # bestrq_last = bestrq_last.detach() + bestrq_emb = self.extract_bestrq_embeds(input_audio_vocal_0,input_audio_vocal_1,layer_vocal) + bestrq_emb = bestrq_emb.detach() + + bestrq_emb_bgm = self.extract_bestrq_embeds(input_audio_bgm_0,input_audio_bgm_1,layer_bgm) + bestrq_emb_bgm = bestrq_emb_bgm.detach() + + + + self.rvq_bestrq_emb.eval() + quantized_bestrq_emb, codes_bestrq_emb, *_ = self.rvq_bestrq_emb(bestrq_emb) # b,d,t + + self.rvq_bestrq_bgm_emb.eval() + quantized_bestrq_emb_bgm, codes_bestrq_emb_bgm, *_ = self.rvq_bestrq_bgm_emb(bestrq_emb_bgm) # b,d,t + + + if('spk' in additional_feats): + self.xvecmodel.eval() + spk_embeds = self.extract_spk_embeds(input_audios) + else: + spk_embeds = None + + # return [codes_prompt, codes_wav2vec], [prompt_embeds, wav2vec_embeds], spk_embeds + # return [codes_prompt_7, codes_prompt_13, codes_prompt_20, codes_wav2vec_half, codes_wav2vec_last], [prompt_embeds_7, prompt_embeds_13, prompt_embeds_20, wav2vec_embeds_half, wav2vec_embeds_last], spk_embeds + # return [codes_bestrq_middle, codes_bestrq_last], [bestrq_middle, bestrq_last], spk_embeds + return [codes_bestrq_emb,codes_bestrq_emb_bgm], [bestrq_emb,bestrq_emb_bgm], spk_embeds + # return [codes_prompt_13, codes_wav2vec_last], [prompt_embeds_13, wav2vec_embeds_last], spk_embeds + + + @torch.no_grad() + def inference_codes(self, codes, spk_embeds, true_latents, latent_length, additional_feats,incontext_length=127, + guidance_scale=2, num_steps=20, + disable_progress=True, scenario='start_seg'): + classifier_free_guidance = guidance_scale > 1.0 + device = self.device + dtype = self.dtype + # codes_bestrq_middle, codes_bestrq_last = codes + codes_bestrq_emb,codes_bestrq_emb_bgm = codes + + + batch_size = codes_bestrq_emb.shape[0] + + + quantized_bestrq_emb,_,_=self.rvq_bestrq_emb.from_codes(codes_bestrq_emb) + quantized_bestrq_emb_bgm,_,_=self.rvq_bestrq_bgm_emb.from_codes(codes_bestrq_emb_bgm) + quantized_bestrq_emb = quantized_bestrq_emb.permute(0,2,1).contiguous() + quantized_bestrq_emb_bgm = quantized_bestrq_emb_bgm.permute(0,2,1).contiguous() + if('spk' in additional_feats): + spk_embeds = spk_embeds.repeat(1,1,quantized_bestrq_emb.shape[-2],1).detach() + + num_frames = quantized_bestrq_emb.shape[1] + + num_channels_latents = self.num_channels + shape = (batch_size, num_frames, 64) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + + + + latent_masks = torch.zeros(latents.shape[0], latents.shape[1], dtype=torch.int64, device=latents.device) + latent_masks[:,0:latent_length] = 2 + if(scenario=='other_seg'): + latent_masks[:,0:incontext_length] = 1 + + + + quantized_bestrq_emb = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + quantized_bestrq_emb_bgm = (latent_masks > 0.5).unsqueeze(-1) * quantized_bestrq_emb_bgm \ + + (latent_masks < 0.5).unsqueeze(-1) * self.zero_cond_embedding1.reshape(1,1,1024) + true_latents = true_latents.permute(0,2,1).contiguous() + true_latents = self.normfeat.project_sample(true_latents) + true_latents = true_latents.permute(0,2,1).contiguous() + incontext_latents = true_latents * ((latent_masks > 0.5) * (latent_masks < 1.5)).unsqueeze(-1).float() + incontext_length = ((latent_masks > 0.5) * (latent_masks < 1.5)).sum(-1)[0] + + + attention_mask=(latent_masks > 0.5) + B, L = attention_mask.size() + attention_mask = attention_mask.view(B, 1, L) + attention_mask = attention_mask * attention_mask.transpose(-1, -2) + attention_mask = attention_mask.unsqueeze(1) + latent_mask_input = self.mask_emb(latent_masks) + + if('spk' in additional_feats): + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last, spk_embeds],1) + additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm, spk_embeds],2) + else: + # additional_model_input = torch.cat([quantized_bestrq_middle, quantized_bestrq_last],1) + additional_model_input = torch.cat([quantized_bestrq_emb,quantized_bestrq_emb_bgm],2) + + temperature = 1.0 + t_span = torch.linspace(0, 1, num_steps + 1, device=quantized_bestrq_emb.device) + latents = self.cfm_wrapper.solve_euler(latents * temperature, latent_mask_input,incontext_latents, incontext_length, t_span, additional_model_input,attention_mask, guidance_scale) + + latents[:,0:incontext_length,:] = incontext_latents[:,0:incontext_length,:] + latents = latents.permute(0,2,1).contiguous() + latents = self.normfeat.return_sample(latents) + # latents = latents.permute(0,2,1).contiguous() + return latents + + @torch.no_grad() + def inference(self, input_audios_vocal,input_audios_bgm, lyric, true_latents, latent_length, additional_feats, guidance_scale=2, num_steps=20, + disable_progress=True,layer_vocal=7,layer_bgm=3,scenario='start_seg'): + codes, embeds, spk_embeds = self.fetch_codes(input_audios_vocal,input_audios_bgm, additional_feats,layer_vocal,layer_bgm) + + latents = self.inference_codes(codes, spk_embeds, true_latents, latent_length, additional_feats, \ + guidance_scale=guidance_scale, num_steps=num_steps, \ + disable_progress=disable_progress,scenario=scenario) + return latents + + def prepare_latents(self, batch_size, num_frames, num_channels_latents, dtype, device): + divisor = 4 + shape = (batch_size, num_channels_latents, num_frames, 32) + if(num_frames%divisor>0): + num_frames = round(num_frames/float(divisor))*divisor + shape = (batch_size, num_channels_latents, num_frames, 32) + latents = randn_tensor(shape, generator=None, device=device, dtype=dtype) + return latents + + diff --git a/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py index 1c733ab9ee1092da1ebc0330a65892d68b608a34..acdf10894e786400e5d81a9106c5e45fbd9849aa 100644 --- a/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py +++ b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_additionalemb.py @@ -1,996 +1,996 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers -from diffusers.models.activations import get_activation -from diffusers.models.embeddings import ( - GaussianFourierProjection, - GLIGENTextBoundingboxProjection, - ImageHintTimeEmbedding, - ImageProjection, - ImageTimeEmbedding, - TextImageProjection, - TextImageTimeEmbedding, - TextTimeEmbedding, - TimestepEmbedding, - Timesteps, -) -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import ( - UNetMidBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - get_down_block, - get_up_block, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - The output of [`UNet2DConditionModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor = None - - -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): - r""" - A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or - `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, normalization and activation layers is skipped in post-processing. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling - blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - addition_time_embed_dim: (`int`, *optional*, defaults to `None`): - Dimension for the timestep embeddings. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, defaults to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, defaults to `None`): - An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, defaults to `None`): - Optional activation function to use only once on the time embeddings before they are passed to the rest of - the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str`, *optional*, defaults to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, defaults to `None`): - The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, - *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, - *optional*): The dimension of the `class_labels` input when - `class_embed_type="projection"`. Required when `class_embed_type="projection"`. - class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. - mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): - Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the - `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` - otherwise. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: Union[int, Tuple[int]] = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - dropout: float = 0.0, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: int = 1.0, - time_embedding_type: str = "positional", - time_embedding_dim: Optional[int] = None, - time_embedding_act_fn: Optional[str] = None, - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, - attention_type: str = "default", - class_embeddings_concat: bool = False, - mid_block_only_cross_attention: Optional[bool] = None, - cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads=64, - ): - super().__init__() - - self.sample_size = sample_size - - if num_attention_heads is not None: - raise ValueError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." - ) - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: - for layer_number_per_block in transformer_layers_per_block: - if isinstance(layer_number_per_block, list): - raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - if time_embedding_type == "fourier": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." - ) - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - self.encoder_hid_proj = ImageProjection( - image_embed_dim=encoder_hid_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif addition_embed_type == "image": - # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - - if time_embedding_act_fn is None: - self.time_embed_act = None - else: - self.time_embed_act = get_activation(time_embedding_act_fn) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = only_cross_attention - - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = False - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - if class_embeddings_concat: - # The time embeddings are concatenated with the class embeddings. The dimension of the - # time embeddings passed to the down, middle, and up blocks is twice the dimension of the - # regular time embeddings - blocks_time_embed_dim = time_embed_dim * 2 - else: - blocks_time_embed_dim = time_embed_dim - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim[-1], - attention_head_dim=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - only_cross_attention=mid_block_only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif mid_block_type == "UNetMidBlock2D": - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - num_layers=0, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - add_attention=False, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = ( - list(reversed(transformer_layers_per_block)) - if reverse_transformer_layers_per_block is None - else reverse_transformer_layers_per_block - ) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resolution_idx=i, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - - self.conv_act = get_activation(act_fn) - - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - if attention_type in ["gated", "gated-text-image"]: - positive_len = 768 - if isinstance(cross_attention_dim, int): - positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): - positive_len = cross_attention_dim[0] - - feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = GLIGENTextBoundingboxProjection( - positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type - ) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - def forward( - self, - sample: torch.FloatTensor, - additional_embd: torch.Tensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - The [`UNet2DConditionModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.FloatTensor`): - The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): - Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed - through the `self.time_embedding` layer to obtain the timestep embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): - A tuple of tensors that if specified are added to the residuals of down unet blocks. - mid_block_additional_residual: (`torch.Tensor`, *optional*): - A tensor that if specified is added to the residual of the middle unet block. - encoder_attention_mask (`torch.Tensor`): - A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If - `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added to UNet long skip connections from down blocks to up blocks for - example from ControlNet side model(s) - mid_block_additional_residual (`torch.Tensor`, *optional*): - additional residual to be added to UNet mid block output, for example from ControlNet side model - down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) - - Returns: - [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - for dim in sample.shape[-2:]: - if dim % default_overall_up_factor != 0: - # Forward upsample size to force interpolation output size. - forward_upsample_size = True - break - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - t_emb = self.time_proj(timesteps) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # but time_embedding might actually be running in fp16. so we need to cast here. - # there might be better ways to encapsulate this. - t_emb = t_emb.to(dtype=sample.dtype) - - emb = self.time_embedding(t_emb, timestep_cond) - aug_emb = None - - if self.class_embedding is not None: - if class_labels is None: - raise ValueError("class_labels should be provided when num_class_embeds > 0") - - if self.config.class_embed_type == "timestep": - class_labels = self.time_proj(class_labels) - - # `Timesteps` does not contain any weights and will always return f32 tensors - # there might be better ways to encapsulate this. - class_labels = class_labels.to(dtype=sample.dtype) - - class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) - - if self.config.class_embeddings_concat: - emb = torch.cat([emb, class_emb], dim=-1) - else: - emb = emb + class_emb - - if self.config.addition_embed_type == "text": - aug_emb = self.add_embedding(encoder_hidden_states) - elif self.config.addition_embed_type == "text_image": - # Kandinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - - image_embs = added_cond_kwargs.get("image_embeds") - text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) - aug_emb = self.add_embedding(text_embs, image_embs) - elif self.config.addition_embed_type == "text_time": - # SDXL - style - if "text_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" - ) - text_embeds = added_cond_kwargs.get("text_embeds") - if "time_ids" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" - ) - time_ids = added_cond_kwargs.get("time_ids") - time_embeds = self.add_time_proj(time_ids.flatten()) - time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) - add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) - add_embeds = add_embeds.to(emb.dtype) - aug_emb = self.add_embedding(add_embeds) - elif self.config.addition_embed_type == "image": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - aug_emb = self.add_embedding(image_embs) - elif self.config.addition_embed_type == "image_hint": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" - ) - image_embs = added_cond_kwargs.get("image_embeds") - hint = added_cond_kwargs.get("hint") - aug_emb, hint = self.add_embedding(image_embs, hint) - sample = torch.cat([sample, hint], dim=1) - - emb = emb + aug_emb if aug_emb is not None else emb - - if self.time_embed_act is not None: - emb = self.time_embed_act(emb) - - emb = emb + additional_embd - - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": - # Kadinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds) - encoder_hidden_states = (encoder_hidden_states, image_embeds) - - # 2. pre-process - sample = self.conv_in(sample) - - # 2.5 GLIGEN position net - if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: - cross_attention_kwargs = cross_attention_kwargs.copy() - gligen_args = cross_attention_kwargs.pop("gligen") - cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} - - # 3. down - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets - is_adapter = down_intrablock_additional_residuals is not None - # maintain backward compatibility for legacy usage, where - # T2I-Adapter and ControlNet both use down_block_additional_residuals arg - # but can only use one or the other - if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: - deprecate( - "T2I should not use down_block_additional_residuals", - "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ - and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ - for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", - standard_warn=False, - ) - down_intrablock_additional_residuals = down_block_additional_residuals - is_adapter = True - - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - # For t2i-adapter CrossAttnDownBlock2D - additional_residuals = {} - if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) - - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - **additional_residuals, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - if is_adapter and len(down_intrablock_additional_residuals) > 0: - sample += down_intrablock_additional_residuals.pop(0) - - down_block_res_samples += res_samples - - if is_controlnet: - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = self.mid_block(sample, emb) - - # To support T2I-Adapter-XL - if ( - is_adapter - and len(down_intrablock_additional_residuals) > 0 - and sample.shape == down_intrablock_additional_residuals[0].shape - ): - sample += down_intrablock_additional_residuals.pop(0) - - if is_controlnet: - sample = sample + mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - scale=lora_scale, - ) - - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + additional_embd: torch.Tensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + emb = emb + additional_embd + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py index 71eb958b9188df5f6fef76bf3d7d19dd329bbc16..aff46c720b6440905c1ee78e176feb7e4f12f283 100644 --- a/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py +++ b/codeclm/tokenizer/Flow1dVAE/models/unet_2d_condition_flow.py @@ -1,934 +1,934 @@ -# Copyright 2023 The HuggingFace Team. All rights reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -from dataclasses import dataclass -from typing import Any, Dict, List, Optional, Tuple, Union -import math - -import torch -import torch.nn as nn -import torch.utils.checkpoint - -from diffusers.configuration_utils import ConfigMixin, register_to_config -from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin -from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers -from diffusers.models.activations import get_activation -from diffusers.models.embeddings import ( - GaussianFourierProjection, - GLIGENTextBoundingboxProjection, - ImageHintTimeEmbedding, - ImageProjection, - ImageTimeEmbedding, - TextImageProjection, - TextImageTimeEmbedding, - TextTimeEmbedding, - TimestepEmbedding, - Timesteps, -) -from diffusers.models.modeling_utils import ModelMixin -from diffusers.models.unet_2d_blocks import ( - UNetMidBlock2D, - UNetMidBlock2DCrossAttn, - UNetMidBlock2DSimpleCrossAttn, - get_down_block, - get_up_block, -) - - -logger = logging.get_logger(__name__) # pylint: disable=invalid-name - - -@dataclass -class UNet2DConditionOutput(BaseOutput): - """ - The output of [`UNet2DConditionModel`]. - - Args: - sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): - The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. - """ - - sample: torch.FloatTensor = None - - -class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): - r""" - A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample - shaped output. - - This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented - for all models (such as downloading or saving). - - Parameters: - sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): - Height and width of input/output sample. - in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. - out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. - center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. - flip_sin_to_cos (`bool`, *optional*, defaults to `False`): - Whether to flip the sin to cos in the time embedding. - freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. - down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): - The tuple of downsample blocks to use. - mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): - Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or - `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. - up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): - The tuple of upsample blocks to use. - only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): - Whether to include self-attention in the basic transformer blocks, see - [`~models.attention.BasicTransformerBlock`]. - block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): - The tuple of output channels for each block. - layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. - downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. - mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. - dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. - act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. - norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. - If `None`, normalization and activation layers is skipped in post-processing. - norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. - cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): - The dimension of the cross attention features. - transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): - The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling - blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for - [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], - [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. - encoder_hid_dim (`int`, *optional*, defaults to None): - If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` - dimension to `cross_attention_dim`. - encoder_hid_dim_type (`str`, *optional*, defaults to `None`): - If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text - embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. - attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. - num_attention_heads (`int`, *optional*): - The number of attention heads. If not defined, defaults to `attention_head_dim` - resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config - for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. - class_embed_type (`str`, *optional*, defaults to `None`): - The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, - `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. - addition_embed_type (`str`, *optional*, defaults to `None`): - Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or - "text". "text" will use the `TextTimeEmbedding` layer. - addition_time_embed_dim: (`int`, *optional*, defaults to `None`): - Dimension for the timestep embeddings. - num_class_embeds (`int`, *optional*, defaults to `None`): - Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing - class conditioning with `class_embed_type` equal to `None`. - time_embedding_type (`str`, *optional*, defaults to `positional`): - The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. - time_embedding_dim (`int`, *optional*, defaults to `None`): - An optional override for the dimension of the projected time embedding. - time_embedding_act_fn (`str`, *optional*, defaults to `None`): - Optional activation function to use only once on the time embeddings before they are passed to the rest of - the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. - timestep_post_act (`str`, *optional*, defaults to `None`): - The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. - time_cond_proj_dim (`int`, *optional*, defaults to `None`): - The dimension of `cond_proj` layer in the timestep embedding. - conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, - *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, - *optional*): The dimension of the `class_labels` input when - `class_embed_type="projection"`. Required when `class_embed_type="projection"`. - class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time - embeddings with the class embeddings. - mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): - Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If - `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the - `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` - otherwise. - """ - - _supports_gradient_checkpointing = True - - @register_to_config - def __init__( - self, - sample_size: Optional[int] = None, - in_channels: int = 4, - out_channels: int = 4, - center_input_sample: bool = False, - flip_sin_to_cos: bool = True, - freq_shift: int = 0, - down_block_types: Tuple[str] = ( - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "CrossAttnDownBlock2D", - "DownBlock2D", - ), - mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", - up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), - only_cross_attention: Union[bool, Tuple[bool]] = False, - block_out_channels: Tuple[int] = (320, 640, 1280, 1280), - layers_per_block: Union[int, Tuple[int]] = 2, - downsample_padding: int = 1, - mid_block_scale_factor: float = 1, - dropout: float = 0.0, - act_fn: str = "silu", - norm_num_groups: Optional[int] = 32, - norm_eps: float = 1e-5, - cross_attention_dim: Union[int, Tuple[int]] = 1280, - transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, - reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, - encoder_hid_dim: Optional[int] = None, - encoder_hid_dim_type: Optional[str] = None, - attention_head_dim: Union[int, Tuple[int]] = 8, - num_attention_heads: Optional[Union[int, Tuple[int]]] = None, - dual_cross_attention: bool = False, - use_linear_projection: bool = False, - class_embed_type: Optional[str] = None, - addition_embed_type: Optional[str] = None, - addition_time_embed_dim: Optional[int] = None, - num_class_embeds: Optional[int] = None, - upcast_attention: bool = False, - resnet_time_scale_shift: str = "default", - resnet_skip_time_act: bool = False, - resnet_out_scale_factor: int = 1.0, - time_embedding_type: str = "positional", - time_embedding_dim: Optional[int] = None, - time_embedding_act_fn: Optional[str] = None, - timestep_post_act: Optional[str] = None, - time_cond_proj_dim: Optional[int] = None, - conv_in_kernel: int = 3, - conv_out_kernel: int = 3, - projection_class_embeddings_input_dim: Optional[int] = None, - attention_type: str = "default", - class_embeddings_concat: bool = False, - mid_block_only_cross_attention: Optional[bool] = None, - cross_attention_norm: Optional[str] = None, - addition_embed_type_num_heads=64, - ): - super().__init__() - - self.sample_size = sample_size - self.block_out_channels = block_out_channels - - if num_attention_heads is not None: - raise ValueError( - "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." - ) - - # If `num_attention_heads` is not defined (which is the case for most models) - # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. - # The reason for this behavior is to correct for incorrectly named variables that were introduced - # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 - # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking - # which is why we correct for the naming here. - num_attention_heads = num_attention_heads or attention_head_dim - - # Check inputs - if len(down_block_types) != len(up_block_types): - raise ValueError( - f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." - ) - - if len(block_out_channels) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." - ) - - if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." - ) - - if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): - raise ValueError( - f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." - ) - if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: - for layer_number_per_block in transformer_layers_per_block: - if isinstance(layer_number_per_block, list): - raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") - - # input - conv_in_padding = (conv_in_kernel - 1) // 2 - self.conv_in = nn.Conv2d( - in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding - ) - - # time - if time_embedding_type == "fourier": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 - if time_embed_dim % 2 != 0: - raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") - self.time_proj = GaussianFourierProjection( - time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos - ) - timestep_input_dim = time_embed_dim - elif time_embedding_type == "positional": - time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 - - self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) - timestep_input_dim = block_out_channels[0] - else: - raise ValueError( - f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." - ) - - self.time_embedding = TimestepEmbedding( - timestep_input_dim, - time_embed_dim, - act_fn=act_fn, - post_act_fn=timestep_post_act, - cond_proj_dim=time_cond_proj_dim, - ) - - if encoder_hid_dim_type is None and encoder_hid_dim is not None: - encoder_hid_dim_type = "text_proj" - self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) - logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") - - if encoder_hid_dim is None and encoder_hid_dim_type is not None: - raise ValueError( - f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." - ) - - if encoder_hid_dim_type == "text_proj": - self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) - elif encoder_hid_dim_type == "text_image_proj": - # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` - self.encoder_hid_proj = TextImageProjection( - text_embed_dim=encoder_hid_dim, - image_embed_dim=cross_attention_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - self.encoder_hid_proj = ImageProjection( - image_embed_dim=encoder_hid_dim, - cross_attention_dim=cross_attention_dim, - ) - elif encoder_hid_dim_type is not None: - raise ValueError( - f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." - ) - else: - self.encoder_hid_proj = None - - # class embedding - if class_embed_type is None and num_class_embeds is not None: - self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) - elif class_embed_type == "timestep": - self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) - elif class_embed_type == "identity": - self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) - elif class_embed_type == "projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" - ) - # The projection `class_embed_type` is the same as the timestep `class_embed_type` except - # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings - # 2. it projects from an arbitrary input dimension. - # - # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. - # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. - # As a result, `TimestepEmbedding` can be passed arbitrary vectors. - self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif class_embed_type == "simple_projection": - if projection_class_embeddings_input_dim is None: - raise ValueError( - "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" - ) - self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) - else: - self.class_embedding = None - - if addition_embed_type == "text": - if encoder_hid_dim is not None: - text_time_embedding_from_dim = encoder_hid_dim - else: - text_time_embedding_from_dim = cross_attention_dim - - self.add_embedding = TextTimeEmbedding( - text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads - ) - elif addition_embed_type == "text_image": - # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much - # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use - # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` - self.add_embedding = TextImageTimeEmbedding( - text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim - ) - elif addition_embed_type == "text_time": - self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) - self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) - elif addition_embed_type == "image": - # Kandinsky 2.2 - self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type == "image_hint": - # Kandinsky 2.2 ControlNet - self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) - elif addition_embed_type is not None: - raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") - - if time_embedding_act_fn is None: - self.time_embed_act = None - else: - self.time_embed_act = get_activation(time_embedding_act_fn) - - self.down_blocks = nn.ModuleList([]) - self.up_blocks = nn.ModuleList([]) - - if isinstance(only_cross_attention, bool): - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = only_cross_attention - - only_cross_attention = [only_cross_attention] * len(down_block_types) - - if mid_block_only_cross_attention is None: - mid_block_only_cross_attention = False - - if isinstance(num_attention_heads, int): - num_attention_heads = (num_attention_heads,) * len(down_block_types) - - if isinstance(attention_head_dim, int): - attention_head_dim = (attention_head_dim,) * len(down_block_types) - - if isinstance(cross_attention_dim, int): - cross_attention_dim = (cross_attention_dim,) * len(down_block_types) - - if isinstance(layers_per_block, int): - layers_per_block = [layers_per_block] * len(down_block_types) - - if isinstance(transformer_layers_per_block, int): - transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) - - if class_embeddings_concat: - # The time embeddings are concatenated with the class embeddings. The dimension of the - # time embeddings passed to the down, middle, and up blocks is twice the dimension of the - # regular time embeddings - blocks_time_embed_dim = time_embed_dim * 2 - else: - blocks_time_embed_dim = time_embed_dim - - # down - output_channel = block_out_channels[0] - for i, down_block_type in enumerate(down_block_types): - input_channel = output_channel - output_channel = block_out_channels[i] - is_final_block = i == len(block_out_channels) - 1 - - down_block = get_down_block( - down_block_type, - num_layers=layers_per_block[i], - transformer_layers_per_block=transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - temb_channels=blocks_time_embed_dim, - add_downsample=not is_final_block, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resnet_groups=norm_num_groups, - cross_attention_dim=cross_attention_dim[i], - num_attention_heads=num_attention_heads[i], - downsample_padding=downsample_padding, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.down_blocks.append(down_block) - - # mid - if mid_block_type == "UNetMidBlock2DCrossAttn": - self.mid_block = UNetMidBlock2DCrossAttn( - transformer_layers_per_block=transformer_layers_per_block[-1], - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_time_scale_shift=resnet_time_scale_shift, - cross_attention_dim=cross_attention_dim[-1], - num_attention_heads=num_attention_heads[-1], - resnet_groups=norm_num_groups, - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - upcast_attention=upcast_attention, - attention_type=attention_type, - ) - elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": - self.mid_block = UNetMidBlock2DSimpleCrossAttn( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - cross_attention_dim=cross_attention_dim[-1], - attention_head_dim=attention_head_dim[-1], - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - skip_time_act=resnet_skip_time_act, - only_cross_attention=mid_block_only_cross_attention, - cross_attention_norm=cross_attention_norm, - ) - elif mid_block_type == "UNetMidBlock2D": - self.mid_block = UNetMidBlock2D( - in_channels=block_out_channels[-1], - temb_channels=blocks_time_embed_dim, - dropout=dropout, - num_layers=0, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - output_scale_factor=mid_block_scale_factor, - resnet_groups=norm_num_groups, - resnet_time_scale_shift=resnet_time_scale_shift, - add_attention=False, - ) - elif mid_block_type is None: - self.mid_block = None - else: - raise ValueError(f"unknown mid_block_type : {mid_block_type}") - - # count how many layers upsample the images - self.num_upsamplers = 0 - - # up - reversed_block_out_channels = list(reversed(block_out_channels)) - reversed_num_attention_heads = list(reversed(num_attention_heads)) - reversed_layers_per_block = list(reversed(layers_per_block)) - reversed_cross_attention_dim = list(reversed(cross_attention_dim)) - reversed_transformer_layers_per_block = ( - list(reversed(transformer_layers_per_block)) - if reverse_transformer_layers_per_block is None - else reverse_transformer_layers_per_block - ) - only_cross_attention = list(reversed(only_cross_attention)) - - output_channel = reversed_block_out_channels[0] - for i, up_block_type in enumerate(up_block_types): - is_final_block = i == len(block_out_channels) - 1 - - prev_output_channel = output_channel - output_channel = reversed_block_out_channels[i] - input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] - - # add upsample block for all BUT final layer - if not is_final_block: - add_upsample = True - self.num_upsamplers += 1 - else: - add_upsample = False - - up_block = get_up_block( - up_block_type, - num_layers=reversed_layers_per_block[i] + 1, - transformer_layers_per_block=reversed_transformer_layers_per_block[i], - in_channels=input_channel, - out_channels=output_channel, - prev_output_channel=prev_output_channel, - temb_channels=blocks_time_embed_dim, - add_upsample=add_upsample, - resnet_eps=norm_eps, - resnet_act_fn=act_fn, - resolution_idx=i, - resnet_groups=norm_num_groups, - cross_attention_dim=reversed_cross_attention_dim[i], - num_attention_heads=reversed_num_attention_heads[i], - dual_cross_attention=dual_cross_attention, - use_linear_projection=use_linear_projection, - only_cross_attention=only_cross_attention[i], - upcast_attention=upcast_attention, - resnet_time_scale_shift=resnet_time_scale_shift, - attention_type=attention_type, - resnet_skip_time_act=resnet_skip_time_act, - resnet_out_scale_factor=resnet_out_scale_factor, - cross_attention_norm=cross_attention_norm, - attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, - dropout=dropout, - ) - self.up_blocks.append(up_block) - prev_output_channel = output_channel - - # out - if norm_num_groups is not None: - self.conv_norm_out = nn.GroupNorm( - num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps - ) - - self.conv_act = get_activation(act_fn) - - else: - self.conv_norm_out = None - self.conv_act = None - - conv_out_padding = (conv_out_kernel - 1) // 2 - self.conv_out = nn.Conv2d( - block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding - ) - - if attention_type in ["gated", "gated-text-image"]: - positive_len = 768 - if isinstance(cross_attention_dim, int): - positive_len = cross_attention_dim - elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): - positive_len = cross_attention_dim[0] - - feature_type = "text-only" if attention_type == "gated" else "text-image" - self.position_net = GLIGENTextBoundingboxProjection( - positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type - ) - - def _set_gradient_checkpointing(self, module, value=False): - if hasattr(module, "gradient_checkpointing"): - module.gradient_checkpointing = value - - # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 - def timestep_embedding(self, timesteps, max_period=10000, scale=1000): - """Create sinusoidal timestep embeddings. - - :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. - :param dim: the dimension of the output. - :param max_period: controls the minimum frequency of the embeddings. - :return: an [N x dim] Tensor of positional embeddings. - """ - dim = self.block_out_channels[-1] - half = dim // 2 - freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) - args = timesteps[:, None] * freqs[None] * scale - embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) - if dim % 2: - embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) - return embedding - - def forward( - self, - sample: torch.FloatTensor, - timestep: Union[torch.Tensor, float, int], - encoder_hidden_states: torch.Tensor, - class_labels: Optional[torch.Tensor] = None, - timestep_cond: Optional[torch.Tensor] = None, - attention_mask: Optional[torch.Tensor] = None, - cross_attention_kwargs: Optional[Dict[str, Any]] = None, - added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, - down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - mid_block_additional_residual: Optional[torch.Tensor] = None, - down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, - encoder_attention_mask: Optional[torch.Tensor] = None, - return_dict: bool = True, - ) -> Union[UNet2DConditionOutput, Tuple]: - r""" - The [`UNet2DConditionModel`] forward method. - - Args: - sample (`torch.FloatTensor`): - The noisy input tensor with the following shape `(batch, channel, height, width)`. - timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. - encoder_hidden_states (`torch.FloatTensor`): - The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. - class_labels (`torch.Tensor`, *optional*, defaults to `None`): - Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. - timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): - Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed - through the `self.time_embedding` layer to obtain the timestep embeddings. - attention_mask (`torch.Tensor`, *optional*, defaults to `None`): - An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask - is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large - negative values to the attention scores corresponding to "discard" tokens. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under - `self.processor` in - [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): - A tuple of tensors that if specified are added to the residuals of down unet blocks. - mid_block_additional_residual: (`torch.Tensor`, *optional*): - A tensor that if specified is added to the residual of the middle unet block. - encoder_attention_mask (`torch.Tensor`): - A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If - `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, - which adds large negative values to the attention scores corresponding to "discard" tokens. - return_dict (`bool`, *optional*, defaults to `True`): - Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain - tuple. - cross_attention_kwargs (`dict`, *optional*): - A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. - added_cond_kwargs: (`dict`, *optional*): - A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that - are passed along to the UNet blocks. - down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added to UNet long skip connections from down blocks to up blocks for - example from ControlNet side model(s) - mid_block_additional_residual (`torch.Tensor`, *optional*): - additional residual to be added to UNet mid block output, for example from ControlNet side model - down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): - additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) - - Returns: - [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: - If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise - a `tuple` is returned where the first element is the sample tensor. - """ - # By default samples have to be AT least a multiple of the overall upsampling factor. - # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). - # However, the upsampling interpolation output size can be forced to fit any upsampling size - # on the fly if necessary. - default_overall_up_factor = 2**self.num_upsamplers - - # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` - forward_upsample_size = False - upsample_size = None - - for dim in sample.shape[-2:]: - if dim % default_overall_up_factor != 0: - # Forward upsample size to force interpolation output size. - forward_upsample_size = True - break - - # ensure attention_mask is a bias, and give it a singleton query_tokens dimension - # expects mask of shape: - # [batch, key_tokens] - # adds singleton query_tokens dimension: - # [batch, 1, key_tokens] - # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: - # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) - # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) - if attention_mask is not None: - # assume that mask is expressed as: - # (1 = keep, 0 = discard) - # convert mask into a bias that can be added to attention scores: - # (keep = +0, discard = -10000.0) - attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 - attention_mask = attention_mask.unsqueeze(1) - - # convert encoder_attention_mask to a bias the same way we do for attention_mask - if encoder_attention_mask is not None: - encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 - encoder_attention_mask = encoder_attention_mask.unsqueeze(1) - - # 0. center input if necessary - if self.config.center_input_sample: - sample = 2 * sample - 1.0 - - # 1. time - timesteps = timestep - if not torch.is_tensor(timesteps): - # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can - # This would be a good case for the `match` statement (Python 3.10+) - is_mps = sample.device.type == "mps" - if isinstance(timestep, float): - dtype = torch.float32 if is_mps else torch.float64 - else: - dtype = torch.int32 if is_mps else torch.int64 - timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) - elif len(timesteps.shape) == 0: - timesteps = timesteps[None].to(sample.device) - - # broadcast to batch dimension in a way that's compatible with ONNX/Core ML - timesteps = timesteps.expand(sample.shape[0]) - - emb = self.timestep_embedding(timesteps) - - if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": - # Kadinsky 2.1 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": - # Kandinsky 2.2 - style - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - encoder_hidden_states = self.encoder_hid_proj(image_embeds) - elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": - if "image_embeds" not in added_cond_kwargs: - raise ValueError( - f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" - ) - image_embeds = added_cond_kwargs.get("image_embeds") - image_embeds = self.encoder_hid_proj(image_embeds) - encoder_hidden_states = (encoder_hidden_states, image_embeds) - - # 2. pre-process - sample = self.conv_in(sample) - - # 2.5 GLIGEN position net - if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: - cross_attention_kwargs = cross_attention_kwargs.copy() - gligen_args = cross_attention_kwargs.pop("gligen") - cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} - - # 3. down - lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 - if USE_PEFT_BACKEND: - # weight the lora layers by setting `lora_scale` for each PEFT layer - scale_lora_layers(self, lora_scale) - - is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None - # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets - is_adapter = down_intrablock_additional_residuals is not None - # maintain backward compatibility for legacy usage, where - # T2I-Adapter and ControlNet both use down_block_additional_residuals arg - # but can only use one or the other - if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: - deprecate( - "T2I should not use down_block_additional_residuals", - "1.3.0", - "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ - and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ - for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", - standard_warn=False, - ) - down_intrablock_additional_residuals = down_block_additional_residuals - is_adapter = True - - down_block_res_samples = (sample,) - for downsample_block in self.down_blocks: - if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: - # For t2i-adapter CrossAttnDownBlock2D - additional_residuals = {} - if is_adapter and len(down_intrablock_additional_residuals) > 0: - additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) - - sample, res_samples = downsample_block( - hidden_states=sample, - temb=emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - **additional_residuals, - ) - else: - sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) - if is_adapter and len(down_intrablock_additional_residuals) > 0: - sample += down_intrablock_additional_residuals.pop(0) - - down_block_res_samples += res_samples - - if is_controlnet: - new_down_block_res_samples = () - - for down_block_res_sample, down_block_additional_residual in zip( - down_block_res_samples, down_block_additional_residuals - ): - down_block_res_sample = down_block_res_sample + down_block_additional_residual - new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) - - down_block_res_samples = new_down_block_res_samples - - # 4. mid - if self.mid_block is not None: - if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: - sample = self.mid_block( - sample, - emb, - encoder_hidden_states=encoder_hidden_states, - attention_mask=attention_mask, - cross_attention_kwargs=cross_attention_kwargs, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = self.mid_block(sample, emb) - - # To support T2I-Adapter-XL - if ( - is_adapter - and len(down_intrablock_additional_residuals) > 0 - and sample.shape == down_intrablock_additional_residuals[0].shape - ): - sample += down_intrablock_additional_residuals.pop(0) - - if is_controlnet: - sample = sample + mid_block_additional_residual - - # 5. up - for i, upsample_block in enumerate(self.up_blocks): - is_final_block = i == len(self.up_blocks) - 1 - - res_samples = down_block_res_samples[-len(upsample_block.resnets) :] - down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] - - # if we have not reached the final block and need to forward the - # upsample size, we do it here - if not is_final_block and forward_upsample_size: - upsample_size = down_block_res_samples[-1].shape[2:] - - if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - encoder_hidden_states=encoder_hidden_states, - cross_attention_kwargs=cross_attention_kwargs, - upsample_size=upsample_size, - attention_mask=attention_mask, - encoder_attention_mask=encoder_attention_mask, - ) - else: - sample = upsample_block( - hidden_states=sample, - temb=emb, - res_hidden_states_tuple=res_samples, - upsample_size=upsample_size, - scale=lora_scale, - ) - - # 6. post-process - if self.conv_norm_out: - sample = self.conv_norm_out(sample) - sample = self.conv_act(sample) - sample = self.conv_out(sample) - - if USE_PEFT_BACKEND: - # remove `lora_scale` from each PEFT layer - unscale_lora_layers(self, lora_scale) - - if not return_dict: - return (sample,) - - return UNet2DConditionOutput(sample=sample) +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Tuple, Union +import math + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin +from diffusers.utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers +from diffusers.models.activations import get_activation +from diffusers.models.embeddings import ( + GaussianFourierProjection, + GLIGENTextBoundingboxProjection, + ImageHintTimeEmbedding, + ImageProjection, + ImageTimeEmbedding, + TextImageProjection, + TextImageTimeEmbedding, + TextTimeEmbedding, + TimestepEmbedding, + Timesteps, +) +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.unet_2d_blocks import ( + UNetMidBlock2D, + UNetMidBlock2DCrossAttn, + UNetMidBlock2DSimpleCrossAttn, + get_down_block, + get_up_block, +) + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet2DConditionOutput(BaseOutput): + """ + The output of [`UNet2DConditionModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`): + The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model. + """ + + sample: torch.FloatTensor = None + + +class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin): + r""" + A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample + shaped output. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented + for all models (such as downloading or saving). + + Parameters: + sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`): + Height and width of input/output sample. + in_channels (`int`, *optional*, defaults to 4): Number of channels in the input sample. + out_channels (`int`, *optional*, defaults to 4): Number of channels in the output. + center_input_sample (`bool`, *optional*, defaults to `False`): Whether to center the input sample. + flip_sin_to_cos (`bool`, *optional*, defaults to `False`): + Whether to flip the sin to cos in the time embedding. + freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding. + down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`): + The tuple of downsample blocks to use. + mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`): + Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`, `UNetMidBlock2D`, or + `UNetMidBlock2DSimpleCrossAttn`. If `None`, the mid block layer is skipped. + up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`): + The tuple of upsample blocks to use. + only_cross_attention(`bool` or `Tuple[bool]`, *optional*, default to `False`): + Whether to include self-attention in the basic transformer blocks, see + [`~models.attention.BasicTransformerBlock`]. + block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`): + The tuple of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block. + downsample_padding (`int`, *optional*, defaults to 1): The padding to use for the downsampling convolution. + mid_block_scale_factor (`float`, *optional*, defaults to 1.0): The scale factor to use for the mid block. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization. + If `None`, normalization and activation layers is skipped in post-processing. + norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization. + cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280): + The dimension of the cross attention features. + transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + reverse_transformer_layers_per_block : (`Tuple[Tuple]`, *optional*, defaults to None): + The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`], in the upsampling + blocks of the U-Net. Only relevant if `transformer_layers_per_block` is of type `Tuple[Tuple]` and for + [`~models.unet_2d_blocks.CrossAttnDownBlock2D`], [`~models.unet_2d_blocks.CrossAttnUpBlock2D`], + [`~models.unet_2d_blocks.UNetMidBlock2DCrossAttn`]. + encoder_hid_dim (`int`, *optional*, defaults to None): + If `encoder_hid_dim_type` is defined, `encoder_hidden_states` will be projected from `encoder_hid_dim` + dimension to `cross_attention_dim`. + encoder_hid_dim_type (`str`, *optional*, defaults to `None`): + If given, the `encoder_hidden_states` and potentially other embeddings are down-projected to text + embeddings of dimension `cross_attention` according to `encoder_hid_dim_type`. + attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads. + num_attention_heads (`int`, *optional*): + The number of attention heads. If not defined, defaults to `attention_head_dim` + resnet_time_scale_shift (`str`, *optional*, defaults to `"default"`): Time scale shift config + for ResNet blocks (see [`~models.resnet.ResnetBlock2D`]). Choose from `default` or `scale_shift`. + class_embed_type (`str`, *optional*, defaults to `None`): + The type of class embedding to use which is ultimately summed with the time embeddings. Choose from `None`, + `"timestep"`, `"identity"`, `"projection"`, or `"simple_projection"`. + addition_embed_type (`str`, *optional*, defaults to `None`): + Configures an optional embedding which will be summed with the time embeddings. Choose from `None` or + "text". "text" will use the `TextTimeEmbedding` layer. + addition_time_embed_dim: (`int`, *optional*, defaults to `None`): + Dimension for the timestep embeddings. + num_class_embeds (`int`, *optional*, defaults to `None`): + Input dimension of the learnable embedding matrix to be projected to `time_embed_dim`, when performing + class conditioning with `class_embed_type` equal to `None`. + time_embedding_type (`str`, *optional*, defaults to `positional`): + The type of position embedding to use for timesteps. Choose from `positional` or `fourier`. + time_embedding_dim (`int`, *optional*, defaults to `None`): + An optional override for the dimension of the projected time embedding. + time_embedding_act_fn (`str`, *optional*, defaults to `None`): + Optional activation function to use only once on the time embeddings before they are passed to the rest of + the UNet. Choose from `silu`, `mish`, `gelu`, and `swish`. + timestep_post_act (`str`, *optional*, defaults to `None`): + The second activation function to use in timestep embedding. Choose from `silu`, `mish` and `gelu`. + time_cond_proj_dim (`int`, *optional*, defaults to `None`): + The dimension of `cond_proj` layer in the timestep embedding. + conv_in_kernel (`int`, *optional*, default to `3`): The kernel size of `conv_in` layer. conv_out_kernel (`int`, + *optional*, default to `3`): The kernel size of `conv_out` layer. projection_class_embeddings_input_dim (`int`, + *optional*): The dimension of the `class_labels` input when + `class_embed_type="projection"`. Required when `class_embed_type="projection"`. + class_embeddings_concat (`bool`, *optional*, defaults to `False`): Whether to concatenate the time + embeddings with the class embeddings. + mid_block_only_cross_attention (`bool`, *optional*, defaults to `None`): + Whether to use cross attention with the mid block when using the `UNetMidBlock2DSimpleCrossAttn`. If + `only_cross_attention` is given as a single boolean and `mid_block_only_cross_attention` is `None`, the + `only_cross_attention` value is used as the value for `mid_block_only_cross_attention`. Default to `False` + otherwise. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "CrossAttnDownBlock2D", + "DownBlock2D", + ), + mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn", + up_block_types: Tuple[str] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D"), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: Union[int, Tuple[int]] = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + dropout: float = 0.0, + act_fn: str = "silu", + norm_num_groups: Optional[int] = 32, + norm_eps: float = 1e-5, + cross_attention_dim: Union[int, Tuple[int]] = 1280, + transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1, + reverse_transformer_layers_per_block: Optional[Tuple[Tuple[int]]] = None, + encoder_hid_dim: Optional[int] = None, + encoder_hid_dim_type: Optional[str] = None, + attention_head_dim: Union[int, Tuple[int]] = 8, + num_attention_heads: Optional[Union[int, Tuple[int]]] = None, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + addition_embed_type: Optional[str] = None, + addition_time_embed_dim: Optional[int] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + resnet_skip_time_act: bool = False, + resnet_out_scale_factor: int = 1.0, + time_embedding_type: str = "positional", + time_embedding_dim: Optional[int] = None, + time_embedding_act_fn: Optional[str] = None, + timestep_post_act: Optional[str] = None, + time_cond_proj_dim: Optional[int] = None, + conv_in_kernel: int = 3, + conv_out_kernel: int = 3, + projection_class_embeddings_input_dim: Optional[int] = None, + attention_type: str = "default", + class_embeddings_concat: bool = False, + mid_block_only_cross_attention: Optional[bool] = None, + cross_attention_norm: Optional[str] = None, + addition_embed_type_num_heads=64, + ): + super().__init__() + + self.sample_size = sample_size + self.block_out_channels = block_out_channels + + if num_attention_heads is not None: + raise ValueError( + "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue as described in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131. Passing `num_attention_heads` will only be supported in diffusers v0.19." + ) + + # If `num_attention_heads` is not defined (which is the case for most models) + # it will default to `attention_head_dim`. This looks weird upon first reading it and it is. + # The reason for this behavior is to correct for incorrectly named variables that were introduced + # when this library was created. The incorrect naming was only discovered much later in https://github.com/huggingface/diffusers/issues/2011#issuecomment-1547958131 + # Changing `attention_head_dim` to `num_attention_heads` for 40,000+ configurations is too backwards breaking + # which is why we correct for the naming here. + num_attention_heads = num_attention_heads or attention_head_dim + + # Check inputs + if len(down_block_types) != len(up_block_types): + raise ValueError( + f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}." + ) + + if len(block_out_channels) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(only_cross_attention, bool) and len(only_cross_attention) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `only_cross_attention` as `down_block_types`. `only_cross_attention`: {only_cross_attention}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(attention_head_dim, int) and len(attention_head_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `attention_head_dim` as `down_block_types`. `attention_head_dim`: {attention_head_dim}. `down_block_types`: {down_block_types}." + ) + + if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}." + ) + + if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types): + raise ValueError( + f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}." + ) + if isinstance(transformer_layers_per_block, list) and reverse_transformer_layers_per_block is None: + for layer_number_per_block in transformer_layers_per_block: + if isinstance(layer_number_per_block, list): + raise ValueError("Must provide 'reverse_transformer_layers_per_block` if using asymmetrical UNet.") + + # input + conv_in_padding = (conv_in_kernel - 1) // 2 + self.conv_in = nn.Conv2d( + in_channels, block_out_channels[0], kernel_size=conv_in_kernel, padding=conv_in_padding + ) + + # time + if time_embedding_type == "fourier": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 2 + if time_embed_dim % 2 != 0: + raise ValueError(f"`time_embed_dim` should be divisible by 2, but is {time_embed_dim}.") + self.time_proj = GaussianFourierProjection( + time_embed_dim // 2, set_W_to_weight=False, log=False, flip_sin_to_cos=flip_sin_to_cos + ) + timestep_input_dim = time_embed_dim + elif time_embedding_type == "positional": + time_embed_dim = time_embedding_dim or block_out_channels[0] * 4 + + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + else: + raise ValueError( + f"{time_embedding_type} does not exist. Please make sure to use one of `fourier` or `positional`." + ) + + self.time_embedding = TimestepEmbedding( + timestep_input_dim, + time_embed_dim, + act_fn=act_fn, + post_act_fn=timestep_post_act, + cond_proj_dim=time_cond_proj_dim, + ) + + if encoder_hid_dim_type is None and encoder_hid_dim is not None: + encoder_hid_dim_type = "text_proj" + self.register_to_config(encoder_hid_dim_type=encoder_hid_dim_type) + logger.info("encoder_hid_dim_type defaults to 'text_proj' as `encoder_hid_dim` is defined.") + + if encoder_hid_dim is None and encoder_hid_dim_type is not None: + raise ValueError( + f"`encoder_hid_dim` has to be defined when `encoder_hid_dim_type` is set to {encoder_hid_dim_type}." + ) + + if encoder_hid_dim_type == "text_proj": + self.encoder_hid_proj = nn.Linear(encoder_hid_dim, cross_attention_dim) + elif encoder_hid_dim_type == "text_image_proj": + # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)` + self.encoder_hid_proj = TextImageProjection( + text_embed_dim=encoder_hid_dim, + image_embed_dim=cross_attention_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 + self.encoder_hid_proj = ImageProjection( + image_embed_dim=encoder_hid_dim, + cross_attention_dim=cross_attention_dim, + ) + elif encoder_hid_dim_type is not None: + raise ValueError( + f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'." + ) + else: + self.encoder_hid_proj = None + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim, act_fn=act_fn) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + elif class_embed_type == "projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'projection' requires `projection_class_embeddings_input_dim` be set" + ) + # The projection `class_embed_type` is the same as the timestep `class_embed_type` except + # 1. the `class_labels` inputs are not first converted to sinusoidal embeddings + # 2. it projects from an arbitrary input dimension. + # + # Note that `TimestepEmbedding` is quite general, being mainly linear layers and activations. + # When used for embedding actual timesteps, the timesteps are first converted to sinusoidal embeddings. + # As a result, `TimestepEmbedding` can be passed arbitrary vectors. + self.class_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif class_embed_type == "simple_projection": + if projection_class_embeddings_input_dim is None: + raise ValueError( + "`class_embed_type`: 'simple_projection' requires `projection_class_embeddings_input_dim` be set" + ) + self.class_embedding = nn.Linear(projection_class_embeddings_input_dim, time_embed_dim) + else: + self.class_embedding = None + + if addition_embed_type == "text": + if encoder_hid_dim is not None: + text_time_embedding_from_dim = encoder_hid_dim + else: + text_time_embedding_from_dim = cross_attention_dim + + self.add_embedding = TextTimeEmbedding( + text_time_embedding_from_dim, time_embed_dim, num_heads=addition_embed_type_num_heads + ) + elif addition_embed_type == "text_image": + # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much + # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use + # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)` + self.add_embedding = TextImageTimeEmbedding( + text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim + ) + elif addition_embed_type == "text_time": + self.add_time_proj = Timesteps(addition_time_embed_dim, flip_sin_to_cos, freq_shift) + self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim) + elif addition_embed_type == "image": + # Kandinsky 2.2 + self.add_embedding = ImageTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type == "image_hint": + # Kandinsky 2.2 ControlNet + self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim) + elif addition_embed_type is not None: + raise ValueError(f"addition_embed_type: {addition_embed_type} must be None, 'text' or 'text_image'.") + + if time_embedding_act_fn is None: + self.time_embed_act = None + else: + self.time_embed_act = get_activation(time_embedding_act_fn) + + self.down_blocks = nn.ModuleList([]) + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = only_cross_attention + + only_cross_attention = [only_cross_attention] * len(down_block_types) + + if mid_block_only_cross_attention is None: + mid_block_only_cross_attention = False + + if isinstance(num_attention_heads, int): + num_attention_heads = (num_attention_heads,) * len(down_block_types) + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + if isinstance(cross_attention_dim, int): + cross_attention_dim = (cross_attention_dim,) * len(down_block_types) + + if isinstance(layers_per_block, int): + layers_per_block = [layers_per_block] * len(down_block_types) + + if isinstance(transformer_layers_per_block, int): + transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types) + + if class_embeddings_concat: + # The time embeddings are concatenated with the class embeddings. The dimension of the + # time embeddings passed to the down, middle, and up blocks is twice the dimension of the + # regular time embeddings + blocks_time_embed_dim = time_embed_dim * 2 + else: + blocks_time_embed_dim = time_embed_dim + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block[i], + transformer_layers_per_block=transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + temb_channels=blocks_time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim[i], + num_attention_heads=num_attention_heads[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock2DCrossAttn": + self.mid_block = UNetMidBlock2DCrossAttn( + transformer_layers_per_block=transformer_layers_per_block[-1], + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim[-1], + num_attention_heads=num_attention_heads[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + attention_type=attention_type, + ) + elif mid_block_type == "UNetMidBlock2DSimpleCrossAttn": + self.mid_block = UNetMidBlock2DSimpleCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + cross_attention_dim=cross_attention_dim[-1], + attention_head_dim=attention_head_dim[-1], + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + skip_time_act=resnet_skip_time_act, + only_cross_attention=mid_block_only_cross_attention, + cross_attention_norm=cross_attention_norm, + ) + elif mid_block_type == "UNetMidBlock2D": + self.mid_block = UNetMidBlock2D( + in_channels=block_out_channels[-1], + temb_channels=blocks_time_embed_dim, + dropout=dropout, + num_layers=0, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_groups=norm_num_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + add_attention=False, + ) + elif mid_block_type is None: + self.mid_block = None + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the images + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_num_attention_heads = list(reversed(num_attention_heads)) + reversed_layers_per_block = list(reversed(layers_per_block)) + reversed_cross_attention_dim = list(reversed(cross_attention_dim)) + reversed_transformer_layers_per_block = ( + list(reversed(transformer_layers_per_block)) + if reverse_transformer_layers_per_block is None + else reverse_transformer_layers_per_block + ) + only_cross_attention = list(reversed(only_cross_attention)) + + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=reversed_layers_per_block[i] + 1, + transformer_layers_per_block=reversed_transformer_layers_per_block[i], + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=blocks_time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resolution_idx=i, + resnet_groups=norm_num_groups, + cross_attention_dim=reversed_cross_attention_dim[i], + num_attention_heads=reversed_num_attention_heads[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + attention_type=attention_type, + resnet_skip_time_act=resnet_skip_time_act, + resnet_out_scale_factor=resnet_out_scale_factor, + cross_attention_norm=cross_attention_norm, + attention_head_dim=attention_head_dim[i] if attention_head_dim[i] is not None else output_channel, + dropout=dropout, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + if norm_num_groups is not None: + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps + ) + + self.conv_act = get_activation(act_fn) + + else: + self.conv_norm_out = None + self.conv_act = None + + conv_out_padding = (conv_out_kernel - 1) // 2 + self.conv_out = nn.Conv2d( + block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding + ) + + if attention_type in ["gated", "gated-text-image"]: + positive_len = 768 + if isinstance(cross_attention_dim, int): + positive_len = cross_attention_dim + elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list): + positive_len = cross_attention_dim[0] + + feature_type = "text-only" if attention_type == "gated" else "text-image" + self.position_net = GLIGENTextBoundingboxProjection( + positive_len=positive_len, out_dim=cross_attention_dim, feature_type=feature_type + ) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + # https://github.com/atong01/conditional-flow-matching/blob/main/torchcfm/models/unet/nn.py#L87 + def timestep_embedding(self, timesteps, max_period=10000, scale=1000): + """Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + dim = self.block_out_channels[-1] + half = dim // 2 + freqs = torch.exp(-math.log(max_period) * torch.arange(start=0, end=half, device=timesteps.device) / half).type(timesteps.type()) + args = timesteps[:, None] * freqs[None] * scale + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor, + class_labels: Optional[torch.Tensor] = None, + timestep_cond: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + mid_block_additional_residual: Optional[torch.Tensor] = None, + down_intrablock_additional_residuals: Optional[Tuple[torch.Tensor]] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet2DConditionOutput, Tuple]: + r""" + The [`UNet2DConditionModel`] forward method. + + Args: + sample (`torch.FloatTensor`): + The noisy input tensor with the following shape `(batch, channel, height, width)`. + timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input. + encoder_hidden_states (`torch.FloatTensor`): + The encoder hidden states with shape `(batch, sequence_length, feature_dim)`. + class_labels (`torch.Tensor`, *optional*, defaults to `None`): + Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings. + timestep_cond: (`torch.Tensor`, *optional*, defaults to `None`): + Conditional embeddings for timestep. If provided, the embeddings will be summed with the samples passed + through the `self.time_embedding` layer to obtain the timestep embeddings. + attention_mask (`torch.Tensor`, *optional*, defaults to `None`): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containing additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals: (`tuple` of `torch.Tensor`, *optional*): + A tuple of tensors that if specified are added to the residuals of down unet blocks. + mid_block_additional_residual: (`torch.Tensor`, *optional*): + A tensor that if specified is added to the residual of the middle unet block. + encoder_attention_mask (`torch.Tensor`): + A cross-attention mask of shape `(batch, sequence_length)` is applied to `encoder_hidden_states`. If + `True` the mask is kept, otherwise if `False` it is discarded. Mask will be converted into a bias, + which adds large negative values to the attention scores corresponding to "discard" tokens. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the [`AttnProcessor`]. + added_cond_kwargs: (`dict`, *optional*): + A kwargs dictionary containin additional embeddings that if specified are added to the embeddings that + are passed along to the UNet blocks. + down_block_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added to UNet long skip connections from down blocks to up blocks for + example from ControlNet side model(s) + mid_block_additional_residual (`torch.Tensor`, *optional*): + additional residual to be added to UNet mid block output, for example from ControlNet side model + down_intrablock_additional_residuals (`tuple` of `torch.Tensor`, *optional*): + additional residuals to be added within UNet down blocks, for example from T2I-Adapter side model(s) + + Returns: + [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise + a `tuple` is returned where the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layers). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + for dim in sample.shape[-2:]: + if dim % default_overall_up_factor != 0: + # Forward upsample size to force interpolation output size. + forward_upsample_size = True + break + + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 0. center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # 1. time + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + emb = self.timestep_embedding(timesteps) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj": + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + image_embeds = self.encoder_hid_proj(image_embeds) + encoder_hidden_states = (encoder_hidden_states, image_embeds) + + # 2. pre-process + sample = self.conv_in(sample) + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + if USE_PEFT_BACKEND: + # weight the lora layers by setting `lora_scale` for each PEFT layer + scale_lora_layers(self, lora_scale) + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + # using new arg down_intrablock_additional_residuals for T2I-Adapters, to distinguish from controlnets + is_adapter = down_intrablock_additional_residuals is not None + # maintain backward compatibility for legacy usage, where + # T2I-Adapter and ControlNet both use down_block_additional_residuals arg + # but can only use one or the other + if not is_adapter and mid_block_additional_residual is None and down_block_additional_residuals is not None: + deprecate( + "T2I should not use down_block_additional_residuals", + "1.3.0", + "Passing intrablock residual connections with `down_block_additional_residuals` is deprecated \ + and will be removed in diffusers 1.3.0. `down_block_additional_residuals` should only be used \ + for ControlNet. Please make sure use `down_intrablock_additional_residuals` instead. ", + standard_warn=False, + ) + down_intrablock_additional_residuals = down_block_additional_residuals + is_adapter = True + + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_intrablock_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_intrablock_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + if is_adapter and len(down_intrablock_additional_residuals) > 0: + sample += down_intrablock_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + if hasattr(self.mid_block, "has_cross_attention") and self.mid_block.has_cross_attention: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = self.mid_block(sample, emb) + + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_intrablock_additional_residuals) > 0 + and sample.shape == down_intrablock_additional_residuals[0].shape + ): + sample += down_intrablock_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if USE_PEFT_BACKEND: + # remove `lora_scale` from each PEFT layer + unscale_lora_layers(self, lora_scale) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample) diff --git a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py index c387dfe285fa0a10624e9bb1d4aebc248bea8d78..3d82af5540da2db3e8e532107ba8328b7ca14700 100644 --- a/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py +++ b/codeclm/tokenizer/Flow1dVAE/models_gpt/models/tokenizer/pinyin/symbols.py @@ -1,71 +1,71 @@ -_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] - -_initials = [ - "^", - "b", - "c", - "ch", - "d", - "f", - "g", - "h", - "j", - "k", - "l", - "m", - "n", - "p", - "q", - "r", - "s", - "sh", - "t", - "x", - "z", - "zh", -] - -_tones = ["1", "2", "3", "4", "5"] - -_finals = [ - "a", - "ai", - "an", - "ang", - "ao", - "e", - "ei", - "en", - "eng", - "er", - "i", - "ia", - "ian", - "iang", - "iao", - "ie", - "ii", - "iii", - "in", - "ing", - "iong", - "iou", - "o", - "ong", - "ou", - "u", - "ua", - "uai", - "uan", - "uang", - "uei", - "uen", - "ueng", - "uo", - "v", - "van", - "ve", - "vn", -] - -symbols = _pause + _initials + [i + j for i in _finals for j in _tones] +_pause = ["sil", "eos", "sp", "#0", "#1", "#2", "#3"] + +_initials = [ + "^", + "b", + "c", + "ch", + "d", + "f", + "g", + "h", + "j", + "k", + "l", + "m", + "n", + "p", + "q", + "r", + "s", + "sh", + "t", + "x", + "z", + "zh", +] + +_tones = ["1", "2", "3", "4", "5"] + +_finals = [ + "a", + "ai", + "an", + "ang", + "ao", + "e", + "ei", + "en", + "eng", + "er", + "i", + "ia", + "ian", + "iang", + "iao", + "ie", + "ii", + "iii", + "in", + "ing", + "iong", + "iou", + "o", + "ong", + "ou", + "u", + "ua", + "uai", + "uan", + "uang", + "uei", + "uen", + "ueng", + "uo", + "v", + "van", + "ve", + "vn", +] + +symbols = _pause + _initials + [i + j for i in _finals for j in _tones] diff --git a/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py b/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py index 4ae71d252f589e9756efd706aab8b69ec391a28b..ec8a54ff9c66bb99e646b60758c363fc10c0aebe 100644 --- a/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py +++ b/codeclm/tokenizer/Flow1dVAE/tools/infer_bsrnnvae441k.py @@ -1,47 +1,47 @@ -import json -import torch -from tqdm import tqdm -import torchaudio -import librosa -import os -import math -import numpy as np -from tools.get_bsrnnvae import get_bsrnnvae -import tools.torch_tools as torch_tools - -class Tango: - def __init__(self, \ - device="cuda:0"): - - self.sample_rate = 44100 - self.device = device - - self.vae = get_bsrnnvae() - self.vae = self.vae.eval().to(device) - - def sound2sound_generate_longterm(self, fname, batch_size=1, duration=15.36, steps=200, disable_progress=False): - """ Genrate audio without condition. """ - num_frames = math.ceil(duration * 100. / 8) - with torch.no_grad(): - orig_samples, fs = torchaudio.load(fname) - if(fs!=44100): - orig_samples = torchaudio.functional.resample(orig_samples, fs, 44100) - fs = 44100 - if(orig_samples.shape[-1] segment_length: - return waveform[:segment_length] - else: - pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) - waveform = torch.cat([waveform, pad_wav]) - return waveform - - -def _pad_spec(fbank, target_length=1024): - batch, n_frames, channels = fbank.shape - p = target_length - n_frames - if p > 0: - pad = torch.zeros(batch, p, channels).to(fbank.device) - fbank = torch.cat([fbank, pad], 1) - elif p < 0: - fbank = fbank[:, :target_length, :] - - if channels % 2 != 0: - fbank = fbank[:, :, :-1] - - return fbank - - -def read_wav_file(filename, segment_length): - waveform, sr = torchaudio.load(filename) # Faster!!! - waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0] - try: - waveform = normalize_wav(waveform) - except: - print ("Exception normalizing:", filename) - waveform = torch.ones(160000) - waveform = pad_wav(waveform, segment_length).unsqueeze(0) - waveform = waveform / torch.max(torch.abs(waveform)) - waveform = 0.5 * waveform - return waveform - - -def get_mel_from_wav(audio, _stft): - audio = torch.nan_to_num(torch.clip(audio, -1, 1)) - audio = torch.autograd.Variable(audio, requires_grad=False) - melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) - return melspec, log_magnitudes_stft, energy - - -def wav_to_fbank(paths, target_length=1024, fn_STFT=None): - assert fn_STFT is not None - - waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160 - - fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) - fbank = fbank.transpose(1, 2) - log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) - - fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( - log_magnitudes_stft, target_length - ) - - return fbank, log_magnitudes_stft, waveform - -def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None): - assert fn_STFT is not None - - fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) - fbank = fbank.transpose(1, 2) - log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) - # print(fbank.shape, log_magnitudes_stft.shape) - - if(target_length>0): - fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( - log_magnitudes_stft, target_length - ) - - return fbank, log_magnitudes_stft, waveform - - -def uncapitalize(s): - if s: - return s[:1].lower() + s[1:] - else: - return "" - - -def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024): - sound1 = read_wav_file(path1, target_length * 160)[0].numpy() - sound2 = read_wav_file(path2, target_length * 160)[0].numpy() - mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1) - mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) - return mixed_sound, mixed_caption - - -def augment(paths, texts, num_items=4, target_length=1024): - mixed_sounds, mixed_captions = [], [] - combinations = list(itertools.combinations(list(range(len(texts))), 2)) - random.shuffle(combinations) - if len(combinations) < num_items: - selected_combinations = combinations - else: - selected_combinations = combinations[:num_items] - - for (i, j) in selected_combinations: - new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length) - mixed_sounds.append(new_sound) - mixed_captions.append(new_caption) - - waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) - waveform = waveform / torch.max(torch.abs(waveform)) - waveform = 0.5 * waveform - - return waveform, mixed_captions - - -def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None): - assert fn_STFT is not None - - waveform, captions = augment(paths, texts) - fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) - fbank = fbank.transpose(1, 2) - log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) - - fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( - log_magnitudes_stft, target_length - ) - +import torch +import torchaudio +import random +import itertools +import numpy as np +from tools.mix import mix + + +def normalize_wav(waveform): + waveform = waveform - torch.mean(waveform) + waveform = waveform / (torch.max(torch.abs(waveform)) + 1e-8) + return waveform * 0.5 + + +def pad_wav(waveform, segment_length): + waveform_length = len(waveform) + + if segment_length is None or waveform_length == segment_length: + return waveform + elif waveform_length > segment_length: + return waveform[:segment_length] + else: + pad_wav = torch.zeros(segment_length - waveform_length).to(waveform.device) + waveform = torch.cat([waveform, pad_wav]) + return waveform + + +def _pad_spec(fbank, target_length=1024): + batch, n_frames, channels = fbank.shape + p = target_length - n_frames + if p > 0: + pad = torch.zeros(batch, p, channels).to(fbank.device) + fbank = torch.cat([fbank, pad], 1) + elif p < 0: + fbank = fbank[:, :target_length, :] + + if channels % 2 != 0: + fbank = fbank[:, :, :-1] + + return fbank + + +def read_wav_file(filename, segment_length): + waveform, sr = torchaudio.load(filename) # Faster!!! + waveform = torchaudio.functional.resample(waveform, orig_freq=sr, new_freq=16000)[0] + try: + waveform = normalize_wav(waveform) + except: + print ("Exception normalizing:", filename) + waveform = torch.ones(160000) + waveform = pad_wav(waveform, segment_length).unsqueeze(0) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + return waveform + + +def get_mel_from_wav(audio, _stft): + audio = torch.nan_to_num(torch.clip(audio, -1, 1)) + audio = torch.autograd.Variable(audio, requires_grad=False) + melspec, log_magnitudes_stft, energy = _stft.mel_spectrogram(audio) + return melspec, log_magnitudes_stft, energy + + +def wav_to_fbank(paths, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + waveform = torch.cat([read_wav_file(path, target_length * 160) for path in paths], 0) # hop size is 160 + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform + +def wav_to_fbank2(waveform, target_length=-1, fn_STFT=None): + assert fn_STFT is not None + + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + # print(fbank.shape, log_magnitudes_stft.shape) + + if(target_length>0): + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + + return fbank, log_magnitudes_stft, waveform + + +def uncapitalize(s): + if s: + return s[:1].lower() + s[1:] + else: + return "" + + +def mix_wavs_and_captions(path1, path2, caption1, caption2, target_length=1024): + sound1 = read_wav_file(path1, target_length * 160)[0].numpy() + sound2 = read_wav_file(path2, target_length * 160)[0].numpy() + mixed_sound = mix(sound1, sound2, 0.5, 16000).reshape(1, -1) + mixed_caption = "{} and {}".format(caption1, uncapitalize(caption2)) + return mixed_sound, mixed_caption + + +def augment(paths, texts, num_items=4, target_length=1024): + mixed_sounds, mixed_captions = [], [] + combinations = list(itertools.combinations(list(range(len(texts))), 2)) + random.shuffle(combinations) + if len(combinations) < num_items: + selected_combinations = combinations + else: + selected_combinations = combinations[:num_items] + + for (i, j) in selected_combinations: + new_sound, new_caption = mix_wavs_and_captions(paths[i], paths[j], texts[i], texts[j], target_length) + mixed_sounds.append(new_sound) + mixed_captions.append(new_caption) + + waveform = torch.tensor(np.concatenate(mixed_sounds, 0)) + waveform = waveform / torch.max(torch.abs(waveform)) + waveform = 0.5 * waveform + + return waveform, mixed_captions + + +def augment_wav_to_fbank(paths, texts, num_items=4, target_length=1024, fn_STFT=None): + assert fn_STFT is not None + + waveform, captions = augment(paths, texts) + fbank, log_magnitudes_stft, energy = get_mel_from_wav(waveform, fn_STFT) + fbank = fbank.transpose(1, 2) + log_magnitudes_stft = log_magnitudes_stft.transpose(1, 2) + + fbank, log_magnitudes_stft = _pad_spec(fbank, target_length), _pad_spec( + log_magnitudes_stft, target_length + ) + return fbank, log_magnitudes_stft, waveform, captions \ No newline at end of file diff --git a/codeclm/tokenizer/audio_tokenizer.py b/codeclm/tokenizer/audio_tokenizer.py index aa448a00e73907b2733406034b73e5676b00cafe..00d9af132dcd24d0ffcbca9c16c852ea6ee41385 100755 --- a/codeclm/tokenizer/audio_tokenizer.py +++ b/codeclm/tokenizer/audio_tokenizer.py @@ -208,9 +208,9 @@ class Flow1dVAESeparate(AudioTokenizer): return codes_vocal, codes_bgm @torch.no_grad() - def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None): + def decode(self, codes: torch.Tensor, prompt_vocal = None, prompt_bgm = None, chunked=False): wav = self.model.code2sound(codes, prompt_vocal=prompt_vocal, prompt_bgm=prompt_bgm, guidance_scale=1.5, - num_steps=50, disable_progress=False) # [B,N,T] -> [B,T] + num_steps=50, disable_progress=False, chunked=chunked) # [B,N,T] -> [B,T] return wav[None] diff --git a/generate_lowmem.py b/generate_lowmem.py new file mode 100644 index 0000000000000000000000000000000000000000..ee5c77f204f60cd1ab51575fbe39f9fdfe74fb97 --- /dev/null +++ b/generate_lowmem.py @@ -0,0 +1,240 @@ +import sys +import os + +import time +import json +import torch +import torchaudio +import numpy as np +from omegaconf import OmegaConf +from codeclm.models import builders + +from codeclm.trainer.codec_song_pl import CodecLM_PL +from codeclm.models import CodecLM +from third_party.demucs.models.pretrained import get_model_from_yaml + +auto_prompt_type = ['Pop', 'R&B', 'Dance', 'Jazz', 'Folk', 'Rock', 'Chinese Style', 'Chinese Tradition', 'Metal', 'Reggae', 'Chinese Opera', 'Auto'] + +class Separator: + def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: + if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): + self.device = torch.device(f"cuda:{gpu_id}") + else: + self.device = torch.device("cpu") + self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) + + def init_demucs_model(self, model_path, config_path): + model = get_model_from_yaml(config_path, model_path) + model.to(self.device) + model.eval() + return model + + def load_audio(self, f): + a, fs = torchaudio.load(f) + if (fs != 48000): + a = torchaudio.functional.resample(a, fs, 48000) + if a.shape[-1] >= 48000*10: + a = a[..., :48000*10] + else: + a = torch.cat([a, a], -1) + return a[:, 0:48000*10] + + def run(self, audio_path, output_dir='tmp', ext=".flac"): + os.makedirs(output_dir, exist_ok=True) + name, _ = os.path.splitext(os.path.split(audio_path)[-1]) + output_paths = [] + + for stem in self.demucs_model.sources: + output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") + if os.path.exists(output_path): + output_paths.append(output_path) + if len(output_paths) == 1: # 4 + vocal_path = output_paths[0] + else: + drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) + for path in [drums_path, bass_path, other_path]: + os.remove(path) + full_audio = self.load_audio(audio_path) + vocal_audio = self.load_audio(vocal_path) + bgm_audio = full_audio - vocal_audio + return full_audio, vocal_audio, bgm_audio + + + +if __name__ == "__main__": + torch.backends.cudnn.enabled = False + OmegaConf.register_new_resolver("eval", lambda x: eval(x)) + OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) + OmegaConf.register_new_resolver("get_fname", lambda: os.path.splitext(os.path.basename(sys.argv[1]))[0]) + OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) + np.random.seed(int(time.time())) + ckpt_path = sys.argv[1] + input_jsonl = sys.argv[2] + save_dir = sys.argv[3] + cfg_path = os.path.join(ckpt_path, 'config.yaml') + ckpt_path = os.path.join(ckpt_path, 'model.pt') + cfg = OmegaConf.load(cfg_path) + cfg.mode = 'inference' + max_duration = cfg.max_dur + + separator = Separator() + auto_prompt = torch.load('ckpt/prompt.pt') + audio_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint, cfg) + if "audio_tokenizer_checkpoint_sep" in cfg.keys(): + seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg) + else: + seperate_tokenizer = None + audio_tokenizer = audio_tokenizer.eval().cuda() + if seperate_tokenizer is not None: + seperate_tokenizer = seperate_tokenizer.eval().cuda() + + merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] + with open(input_jsonl, "r") as fp: + lines = fp.readlines() + new_items = [] + for line in lines: + item = json.loads(line) + target_wav_name = f"{save_dir}/audios/{item['idx']}.flac" + # get prompt audio + if "prompt_audio_path" in item: + assert os.path.exists(item['prompt_audio_path']), f"prompt_audio_path {item['prompt_audio_path']} not found" + assert 'auto_prompt_audio_type' not in item, f"auto_prompt_audio_type and prompt_audio_path cannot be used together" + pmt_wav, vocal_wav, bgm_wav = separator.run(item['prompt_audio_path']) + item['raw_pmt_wav'] = pmt_wav + item['raw_vocal_wav'] = vocal_wav + item['raw_bgm_wav'] = bgm_wav + if pmt_wav.dim() == 2: + pmt_wav = pmt_wav[None] + if pmt_wav.dim() != 3: + raise ValueError("Melody wavs should have a shape [B, C, T].") + pmt_wav = list(pmt_wav) + if vocal_wav.dim() == 2: + vocal_wav = vocal_wav[None] + if vocal_wav.dim() != 3: + raise ValueError("Vocal wavs should have a shape [B, C, T].") + vocal_wav = list(vocal_wav) + if bgm_wav.dim() == 2: + bgm_wav = bgm_wav[None] + if bgm_wav.dim() != 3: + raise ValueError("BGM wavs should have a shape [B, C, T].") + bgm_wav = list(bgm_wav) + if type(pmt_wav) == list: + pmt_wav = torch.stack(pmt_wav, dim=0) + if type(vocal_wav) == list: + vocal_wav = torch.stack(vocal_wav, dim=0) + if type(bgm_wav) == list: + bgm_wav = torch.stack(bgm_wav, dim=0) + pmt_wav = pmt_wav.cuda() + vocal_wav = vocal_wav.cuda() + bgm_wav = bgm_wav.cuda() + pmt_wav, _ = audio_tokenizer.encode(pmt_wav) + vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav) + melody_is_wav = False + elif "auto_prompt_audio_type" in item: + assert item["auto_prompt_audio_type"] in auto_prompt_type, f"auto_prompt_audio_type {item['auto_prompt_audio_type']} not found" + if item["auto_prompt_audio_type"] == "Auto": + prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] + else: + prompt_token = auto_prompt[item["auto_prompt_audio_type"]][np.random.randint(0, len(auto_prompt[item["auto_prompt_audio_type"]]))] + pmt_wav = prompt_token[:,[0],:] + vocal_wav = prompt_token[:,[1],:] + bgm_wav = prompt_token[:,[2],:] + melody_is_wav = False + else: + pmt_wav = None + vocal_wav = None + bgm_wav = None + melody_is_wav = True + item['pmt_wav'] = pmt_wav + item['vocal_wav'] = vocal_wav + item['bgm_wav'] = bgm_wav + item['melody_is_wav'] = melody_is_wav + item["idx"] = f"{item['idx']}" + item["wav_path"] = target_wav_name + new_items.append(item) + + del audio_tokenizer + del seperate_tokenizer + del separator + + # Define model or load pretrained model + model_light = CodecLM_PL(cfg, ckpt_path) + model_light = model_light.eval() + model_light.audiolm.cfg = cfg + model = CodecLM(name = "tmp", + lm = model_light.audiolm, + audiotokenizer = None, + max_duration = max_duration, + seperate_tokenizer = None, + ) + del model_light + model.lm = model.lm.cuda().to(torch.float16) + + cfg_coef = 1.5 #25 + temp = 0.9 + top_k = 50 + top_p = 0.0 + record_tokens = True + record_window = 50 + + model.set_generation_params(duration=max_duration, extend_stride=5, temperature=temp, cfg_coef=cfg_coef, + top_k=top_k, top_p=top_p, record_tokens=record_tokens, record_window=record_window) + os.makedirs(save_dir, exist_ok=True) + os.makedirs(save_dir + "/audios", exist_ok=True) + os.makedirs(save_dir + "/jsonl", exist_ok=True) + + + for item in new_items: + lyric = item["gt_lyric"] + descriptions = item["descriptions"] if "descriptions" in item else None + pmt_wav = item['pmt_wav'] + vocal_wav = item['vocal_wav'] + bgm_wav = item['bgm_wav'] + melody_is_wav = item['melody_is_wav'] + + generate_inp = { + 'lyrics': [lyric.replace(" ", " ")], + 'descriptions': [descriptions], + 'melody_wavs': pmt_wav, + 'vocal_wavs': vocal_wav, + 'bgm_wavs': bgm_wav, + 'melody_is_wav': melody_is_wav, + } + with torch.autocast(device_type="cuda", dtype=torch.float16): + tokens = model.generate(**generate_inp, return_tokens=True) + item['tokens'] = tokens + + del model + torch.cuda.empty_cache() + + + seperate_tokenizer = builders.get_audio_tokenizer_model(cfg.audio_tokenizer_checkpoint_sep, cfg) + seperate_tokenizer = seperate_tokenizer.eval().cuda() + + model = CodecLM(name = "tmp", + lm = None, + audiotokenizer = None, + max_duration = max_duration, + seperate_tokenizer = seperate_tokenizer, + ) + for item in new_items: + with torch.no_grad(): + if 'raw_pmt_wav' in item: + wav_seperate = model.generate_audio(item['tokens'], item['raw_pmt_wav'], item['raw_vocal_wav'], item['raw_bgm_wav'], chunked=True) + del item['raw_pmt_wav'] + del item['raw_vocal_wav'] + del item['raw_bgm_wav'] + else: + wav_seperate = model.generate_audio(item['tokens'], chunked=True) + torchaudio.save(item['wav_path'], wav_seperate[0].cpu().float(), cfg.sample_rate) + del item['tokens'] + del item['pmt_wav'] + del item['vocal_wav'] + del item['bgm_wav'] + del item['melody_is_wav'] + + torch.cuda.empty_cache() + src_jsonl_name = os.path.split(input_jsonl)[-1] + with open(f"{save_dir}/jsonl/{src_jsonl_name}.jsonl", "w", encoding='utf-8') as fw: + for item in new_items: + fw.writelines(json.dumps(item, ensure_ascii=False)+"\n") diff --git a/generate_lowmem.sh b/generate_lowmem.sh new file mode 100644 index 0000000000000000000000000000000000000000..0578aada78060c290ee71e1cd276ba190627b80c --- /dev/null +++ b/generate_lowmem.sh @@ -0,0 +1,10 @@ +export USER=root +export PYTHONDONTWRITEBYTECODE=1 +export TRANSFORMERS_CACHE="$(pwd)/third_party/hub" +export NCCL_HOME=/usr/local/tccl +export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH + +CKPT_PATH=$1 +JSONL=$2 +SAVE_DIR=$3 +python3 generate_lowmem.py $CKPT_PATH $JSONL $SAVE_DIR diff --git a/requirements.txt b/requirements.txt index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..bd555da80b8a06f2b8fa5f91959ace34f79140a4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -0,0 +1,24 @@ +alias-free-torch>=0.0.6 +descript-audio-codec>=1.0.0 +diffusers==0.27.2 +einops>=0.8.1 +einops-exts==0.0.4 +flashy>=0.0.2 +huggingface-hub==0.25.2 +julius>=0.2.7 +k-diffusion==0.1.1 +kaldiio>=2.18.1 +lameenc>=1.8.1 +librosa>=0.11.0 +lightning>=2.5.2 +ninja>=1.11.1.4 +nnAudio>=0.3.3 +openunmix>=1.3.0 +peft==0.10.0 +torch==2.6.0 +torchaudio==2.6.0 +torchvision==0.21.0 +transformers==4.37.2 +vector-quantize-pytorch>=1.22.17 +wheel>=0.45.1 +x-transformers>=2.3.25 \ No newline at end of file diff --git a/requirements_nodeps.txt b/requirements_nodeps.txt new file mode 100644 index 0000000000000000000000000000000000000000..38032c0f7e850b329d6072bd27eaa3b7bda66e4d --- /dev/null +++ b/requirements_nodeps.txt @@ -0,0 +1,13 @@ +fairseq==0.12.2 +antlr4-python3-runtime==4.8 +bitarray==3.4.3 +cffi==1.17.1 +colorama==0.4.6 +cython==3.1.2 +hydra-core==1.0.7 +lxml==5.4.0 +omegaconf==2.2.0 +portalocker==3.2.0 +pycparser==2.22 +sacrebleu==2.5.1 +tabulate==0.9.0 \ No newline at end of file diff --git a/sample/lyrics.jsonl b/sample/lyrics.jsonl index 43c5fa278237d72e598aa5477f4c52c909d8cf85..8401ca36e025f8accaa68d507dd603b775421648 100644 --- a/sample/lyrics.jsonl +++ b/sample/lyrics.jsonl @@ -1,4 +1,4 @@ {"idx": "sample_01_autoprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "auto_prompt_audio_type": "Auto"} {"idx": "sample_01_noprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"} {"idx": "sample_01_textprompt", "descriptions": "female, dark, pop, sad, piano and drums, the bpm is 125.", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]"} -{"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "sample/sample_prompt_audio.wav"} +{"idx": "sample_01_audioprompt", "gt_lyric": "[intro-short] ; [verse] 雪花舞动在无尽的天际.情缘如同雪花般轻轻逝去.希望与真挚.永不磨灭.你的忧虑.随风而逝 ; [chorus] 我怀抱着守护这片梦境.在这世界中寻找爱与虚幻.苦辣酸甜.我们一起品尝.在雪的光芒中.紧紧相拥 ; [inst-short] ; [verse] 雪花再次在风中飘扬.情愿如同雪花般消失无踪.希望与真挚.永不消失.在痛苦与喧嚣中.你找到解脱 ; [chorus] 我环绕着守护这片梦境.在这世界中感受爱与虚假.苦辣酸甜.我们一起分享.在白银的光芒中.我们同在 ; [outro-short]", "prompt_audio_path": "input/sample_prompt_audio.wav"} diff --git a/tools/gradio/app.py b/tools/gradio/app.py new file mode 100644 index 0000000000000000000000000000000000000000..47719b37475839f72112b064569f6731c41cdb77 --- /dev/null +++ b/tools/gradio/app.py @@ -0,0 +1,236 @@ +import sys +import gradio as gr +import json +from datetime import datetime +import yaml +import time +import re +import os.path as op +from levo_inference_lowmem import LeVoInference + +EXAMPLE_LYRICS = """ +[intro-short] + +[verse] +夜晚的街灯闪烁 +我漫步在熟悉的角落 +回忆像潮水般涌来 +你的笑容如此清晰 +在心头无法抹去 +那些曾经的甜蜜 +如今只剩我独自回忆 + +[verse] +手机屏幕亮起 +是你发来的消息 +简单的几个字 +却让我泪流满面 +曾经的拥抱温暖 +如今却变得遥远 +我多想回到从前 +重新拥有你的陪伴 + +[chorus] +回忆的温度还在 +你却已不在 +我的心被爱填满 +却又被思念刺痛 +音乐的节奏奏响 +我的心却在流浪 +没有你的日子 +我该如何继续向前 + +[outro-short] +""".strip() + +APP_DIR = op.dirname(op.dirname(op.dirname(op.abspath(__file__)))) +MODEL = LeVoInference(sys.argv[1]) +with open(op.join(APP_DIR, 'conf/vocab.yaml'), 'r', encoding='utf-8') as file: + STRUCTS = yaml.safe_load(file) + + +def generate_song(lyric, description=None, prompt_audio=None, genre=None, cfg_coef=None, temperature=None, top_k=None, progress=gr.Progress(track_tqdm=True)): + global MODEL + global STRUCTS + params = {'cfg_coef':cfg_coef, 'temperature':temperature, 'top_k':top_k} + params = {k:v for k,v in params.items() if v is not None} + vocal_structs = ['[verse]', '[chorus]', '[bridge]'] + sample_rate = MODEL.cfg.sample_rate + + # format lyric + lyric = lyric.replace("[intro]", "[intro-short]").replace("[inst]", "[inst-short]").replace("[outro]", "[outro-short]") + paragraphs = [p.strip() for p in lyric.strip().split('\n\n') if p.strip()] + if len(paragraphs) < 1: + return None, json.dumps("Lyrics can not be left blank") + paragraphs_norm = [] + vocal_flag = False + for para in paragraphs: + lines = para.splitlines() + struct_tag = lines[0].strip().lower() + if struct_tag not in STRUCTS: + return None, json.dumps(f"Segments should start with a structure tag in {STRUCTS}") + if struct_tag in vocal_structs: + vocal_flag = True + if len(lines) < 2 or not [line.strip() for line in lines[1:] if line.strip()]: + return None, json.dumps("The following segments require lyrics: [verse], [chorus], [bridge]") + else: + new_para_list = [] + for line in lines[1:]: + new_para_list.append(re.sub(r"[^\w\s\[\]\-\u4e00-\u9fff\u3040-\u309f\u30a0-\u30ff\uac00-\ud7af\u00c0-\u017f]", "", line)) + new_para_str = f"{struct_tag} {'.'.join(new_para_list)}" + else: + if len(lines) > 1: + return None, json.dumps("The following segments should not contain lyrics: [intro], [intro-short], [intro-medium], [inst], [inst-short], [inst-medium], [outro], [outro-short], [outro-medium]") + else: + new_para_str = struct_tag + paragraphs_norm.append(new_para_str) + if not vocal_flag: + return None, json.dumps(f"The lyric must contain at least one of the following structures: {vocal_structs}") + lyric_norm = " ; ".join(paragraphs_norm) + + # format prompt + if prompt_audio is not None: + genre = None + description = None + elif description is not None and description != "": + genre = None + + progress(0.0, "Start Generation") + start = time.time() + + audio_data = MODEL(lyric_norm, description, prompt_audio, genre, op.join(APP_DIR, "ckpt/prompt.pt"), params).cpu().permute(1, 0).float().numpy() + + end = time.time() + + # 创建输入配置的JSON + input_config = { + "lyric": lyric_norm, + "genre": genre, + "prompt_audio": prompt_audio, + "description": description, + "params": params, + "inference_duration": end - start, + "timestamp": datetime.now().isoformat(), + } + + return (sample_rate, audio_data), json.dumps(input_config, indent=2) + + +# 创建Gradio界面 +with gr.Blocks(title="SongGeneration Demo Space") as demo: + gr.Markdown("# 🎵 SongGeneration Demo Space") + gr.Markdown("Demo interface for the song generation model. Provide a lyrics, and optionally an audio or text prompt, to generate a custom song.") + + with gr.Row(): + with gr.Column(): + lyric = gr.Textbox( + label="Lyrics", + lines=5, + max_lines=15, + value=EXAMPLE_LYRICS, + info="Each paragraph represents a segment starting with a structure tag and ending with a blank line, each line is a sentence without punctuation, segments [intro], [inst], [outro] should not contain lyrics, while [verse], [chorus], and [bridge] require lyrics.", + placeholder="""Lyric Format +''' +[structure tag] +lyrics + +[structure tag] +lyrics +''' +1. One paragraph represents one segments, starting with a structure tag and ending with a blank line +2. One line represents one sentence, punctuation is not recommended inside the sentence +3. The following segments should not contain lyrics: [intro-short], [intro-medium], [inst-short], [inst-medium], [outro-short], [outro-medium] +4. The following segments require lyrics: [verse], [chorus], [bridge] +""" + ) + + with gr.Tabs(elem_id="extra-tabs"): + with gr.Tab("Genre Select"): + genre = gr.Radio( + choices=["Pop", "R&B", "Dance", "Jazz", "Folk", "Rock", "Chinese Style", "Chinese Tradition", "Metal", "Reggae", "Chinese Opera", "Auto"], + label="Genre Select(Optional)", + value="Pop", + interactive=True, + elem_id="single-select-radio" + ) + with gr.Tab("Audio Prompt"): + prompt_audio = gr.Audio( + label="Prompt Audio (Optional)", + type="filepath", + elem_id="audio-prompt" + ) + with gr.Tab("Text Prompt"): + gr.Markdown("For detailed usage, please refer to [here](https://github.com/tencent-ailab/SongGeneration?tab=readme-ov-file#-description-input-format)") + description = gr.Textbox( + label="Song Description (Optional)", + info="Describe the gender, timbre, genre, emotion, instrument and bpm of the song. Only English is supported currently.​", + placeholder="female, dark, pop, sad, piano and drums, the bpm is 125.", + lines=1, + max_lines=2 + ) + + with gr.Accordion("Advanced Config", open=False): + cfg_coef = gr.Slider( + label="CFG Coefficient", + minimum=0.1, + maximum=3.0, + step=0.1, + value=1.5, + interactive=True, + elem_id="cfg-coef", + ) + temperature = gr.Slider( + label="Temperature", + minimum=0.1, + maximum=2.0, + step=0.1, + value=0.9, + interactive=True, + elem_id="temperature", + ) + top_k = gr.Slider( + label="Top-K", + minimum=1, + maximum=100, + step=1, + value=50, + interactive=True, + elem_id="top_k", + ) + generate_btn = gr.Button("Generate Song", variant="primary") + + with gr.Column(): + output_audio = gr.Audio(label="Generated Song", type="numpy") + output_json = gr.JSON(label="Generated Info") + + # # 示例按钮 + # examples = gr.Examples( + # examples=[ + # ["male, bright, rock, happy, electric guitar and drums, the bpm is 150."], + # ["female, warm, jazz, romantic, synthesizer and piano, the bpm is 100."] + # ], + # inputs=[description], + # label="Text Prompt examples" + # ) + + # examples = gr.Examples( + # examples=[ + # "[intro-medium]\n\n[verse]\n在这个疯狂的世界里\n谁不渴望一点改变\n在爱情面前\n我们都显得那么不安全\n你紧紧抱着我\n告诉我再靠近一点\n别让这璀璨的夜晚白白浪费\n我那迷茫的眼睛\n看不见未来的路\n在情感消散之前\n我们对爱的渴望永不熄灭\n你给我留下一句誓言\n想知道我们的爱是否能持续到永远\n[chorus]\n\n约定在那最后的夜晚\n不管命运如何摆布\n我们的心是否依然如初\n我会穿上红衬衫\n带着摇滚的激情\n回到我们初遇的地方\n约定在那最后的夜晚\n就算全世界都变了样\n我依然坚守诺言\n铭记这一天\n你永远是我心中的爱恋\n\n[outro-medium]\n", + # "[intro-short]\n\n[verse]\nThrough emerald canyons where fireflies dwell\nCerulean berries kiss morning's first swell\nCrystalline dew crowns each Vitamin Dawn's confection dissolves slowly on me\nAmbrosia breezes through honeycomb vines\nNature's own candy in Fibonacci lines\n[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n [verse] Resin of sunlight in candied retreat\nMarmalade moonbeams melt under bare feet\nNectar spirals bloom chloroplast champagne\nPhotosynthesis sings through my veins\nChlorophyll rhythms pulse warm in my blood\nThe forest's green pharmacy floods every bud[chorus] Blueberry fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n You're under its spell\n feel the buzz\n ride the wave\n Limey me\n blueberry\n your mind's enslaved\n In the haze\n lose all time\n floating free\n feeling fine\n Blueberry\n fruit so sweet\n takes you higher\n can't be beat\n In your lungs\n it starts to swell\n cry\n You're under its spell\n\n[outro-short]\n", + # ], + # inputs=[lyric], + # label="Lyrics examples", + # ) + + # 生成按钮点击事件 + generate_btn.click( + fn=generate_song, + inputs=[lyric, description, prompt_audio, genre, cfg_coef, temperature, top_k], + outputs=[output_audio, output_json] + ) + + +# 启动应用 +if __name__ == "__main__": + demo.launch(server_name="0.0.0.0", server_port=8081) + diff --git a/tools/gradio/levo_inference.py b/tools/gradio/levo_inference.py new file mode 100644 index 0000000000000000000000000000000000000000..73c2dd90d5cf2fffb28aad968e72fff5ccc5229f --- /dev/null +++ b/tools/gradio/levo_inference.py @@ -0,0 +1,110 @@ +import os +import sys + +import torch + +import json +import numpy as np +from omegaconf import OmegaConf + +from codeclm.trainer.codec_song_pl import CodecLM_PL +from codeclm.models import CodecLM + +from separator import Separator + + +class LeVoInference(torch.nn.Module): + def __init__(self, ckpt_path): + super().__init__() + + torch.backends.cudnn.enabled = False + OmegaConf.register_new_resolver("eval", lambda x: eval(x)) + OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) + OmegaConf.register_new_resolver("get_fname", lambda: 'default') + OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) + + cfg_path = os.path.join(ckpt_path, 'config.yaml') + pt_path = os.path.join(ckpt_path, 'model.pt') + + self.cfg = OmegaConf.load(cfg_path) + self.cfg.mode = 'inference' + self.max_duration = self.cfg.max_dur + + # Define model or load pretrained model + model_light = CodecLM_PL(self.cfg, pt_path) + + model_light = model_light.eval().cuda() + model_light.audiolm.cfg = self.cfg + + self.model_lm = model_light.audiolm + self.model_audio_tokenizer = model_light.audio_tokenizer + self.model_seperate_tokenizer = model_light.seperate_tokenizer + + self.model = CodecLM(name = "tmp", + lm = self.model_lm, + audiotokenizer = self.model_audio_tokenizer, + max_duration = self.max_duration, + seperate_tokenizer = self.model_seperate_tokenizer, + ) + self.separator = Separator() + + + self.default_params = dict( + cfg_coef = 1.5, + temperature = 1.0, + top_k = 50, + top_p = 0.0, + record_tokens = True, + record_window = 50, + extend_stride = 5, + duration = self.max_duration, + ) + + self.model.set_generation_params(**self.default_params) + + def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()): + params = {**self.default_params, **params} + self.model.set_generation_params(**params) + + if prompt_audio_path is not None and os.path.exists(prompt_audio_path): + pmt_wav, vocal_wav, bgm_wav = self.separator.run(prompt_audio_path) + melody_is_wav = True + elif genre is not None and auto_prompt_path is not None: + auto_prompt = torch.load(auto_prompt_path) + merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] + if genre == "Auto": + prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] + else: + prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] + pmt_wav = prompt_token[:,[0],:] + vocal_wav = prompt_token[:,[1],:] + bgm_wav = prompt_token[:,[2],:] + melody_is_wav = False + else: + pmt_wav = None + vocal_wav = None + bgm_wav = None + melody_is_wav = True + + generate_inp = { + 'lyrics': [lyric.replace(" ", " ")], + 'descriptions': [description], + 'melody_wavs': pmt_wav, + 'vocal_wavs': vocal_wav, + 'bgm_wavs': bgm_wav, + 'melody_is_wav': melody_is_wav, + } + + with torch.autocast(device_type="cuda", dtype=torch.float16): + tokens = self.model.generate(**generate_inp, return_tokens=True) + + if tokens.shape[-1] > 3000: + tokens = tokens[..., :3000] + + with torch.no_grad(): + if melody_is_wav: + wav_seperate = self.model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav) + else: + wav_seperate = self.model.generate_audio(tokens) + + return wav_seperate[0] diff --git a/tools/gradio/levo_inference_lowmem.py b/tools/gradio/levo_inference_lowmem.py new file mode 100644 index 0000000000000000000000000000000000000000..83baa65107aa5ed2798718300d86df708d804733 --- /dev/null +++ b/tools/gradio/levo_inference_lowmem.py @@ -0,0 +1,129 @@ +import os +import sys + +import torch + +import json +import numpy as np +from omegaconf import OmegaConf + +from codeclm.trainer.codec_song_pl import CodecLM_PL +from codeclm.models import CodecLM +from codeclm.models import builders + +from separator import Separator + + +class LeVoInference(torch.nn.Module): + def __init__(self, ckpt_path): + super().__init__() + + torch.backends.cudnn.enabled = False + OmegaConf.register_new_resolver("eval", lambda x: eval(x)) + OmegaConf.register_new_resolver("concat", lambda *x: [xxx for xx in x for xxx in xx]) + OmegaConf.register_new_resolver("get_fname", lambda: 'default') + OmegaConf.register_new_resolver("load_yaml", lambda x: list(OmegaConf.load(x))) + + cfg_path = os.path.join(ckpt_path, 'config.yaml') + self.pt_path = os.path.join(ckpt_path, 'model.pt') + + self.cfg = OmegaConf.load(cfg_path) + self.cfg.mode = 'inference' + self.max_duration = self.cfg.max_dur + + self.default_params = dict( + top_p = 0.0, + record_tokens = True, + record_window = 50, + extend_stride = 5, + duration = self.max_duration, + ) + + + def forward(self, lyric: str, description: str = None, prompt_audio_path: os.PathLike = None, genre: str = None, auto_prompt_path: os.PathLike = None, params = dict()): + if prompt_audio_path is not None and os.path.exists(prompt_audio_path): + separator = Separator() + audio_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint, self.cfg) + audio_tokenizer = audio_tokenizer.eval().cuda() + seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) + seperate_tokenizer = seperate_tokenizer.eval().cuda() + pmt_wav, vocal_wav, bgm_wav = separator.run(prompt_audio_path) + pmt_wav = pmt_wav.cuda() + vocal_wav = vocal_wav.cuda() + bgm_wav = bgm_wav.cuda() + pmt_wav, _ = audio_tokenizer.encode(pmt_wav) + vocal_wav, bgm_wav = seperate_tokenizer.encode(vocal_wav, bgm_wav) + melody_is_wav = False + melody_is_wav = False + del audio_tokenizer + del seperate_tokenizer + del separator + elif genre is not None and auto_prompt_path is not None: + auto_prompt = torch.load(auto_prompt_path) + merge_prompt = [item for sublist in auto_prompt.values() for item in sublist] + if genre == "Auto": + prompt_token = merge_prompt[np.random.randint(0, len(merge_prompt))] + else: + prompt_token = auto_prompt[genre][np.random.randint(0, len(auto_prompt[genre]))] + pmt_wav = prompt_token[:,[0],:] + vocal_wav = prompt_token[:,[1],:] + bgm_wav = prompt_token[:,[2],:] + melody_is_wav = False + else: + pmt_wav = None + vocal_wav = None + bgm_wav = None + melody_is_wav = True + + model_light = CodecLM_PL(self.cfg, self.pt_path) + model_light = model_light.eval() + model_light.audiolm.cfg = self.cfg + model = CodecLM(name = "tmp", + lm = model_light.audiolm, + audiotokenizer = None, + max_duration = self.max_duration, + seperate_tokenizer = None, + ) + del model_light + model.lm = model.lm.cuda().to(torch.float16) + params = {**self.default_params, **params} + model.set_generation_params(**params) + + generate_inp = { + 'lyrics': [lyric.replace(" ", " ")], + 'descriptions': [description], + 'melody_wavs': pmt_wav, + 'vocal_wavs': vocal_wav, + 'bgm_wavs': bgm_wav, + 'melody_is_wav': melody_is_wav, + } + + with torch.autocast(device_type="cuda", dtype=torch.float16): + tokens = model.generate(**generate_inp, return_tokens=True) + + del model + torch.cuda.empty_cache() + + seperate_tokenizer = builders.get_audio_tokenizer_model(self.cfg.audio_tokenizer_checkpoint_sep, self.cfg) + seperate_tokenizer = seperate_tokenizer.eval().cuda() + model = CodecLM(name = "tmp", + lm = None, + audiotokenizer = None, + max_duration = self.max_duration, + seperate_tokenizer = seperate_tokenizer, + ) + + if tokens.shape[-1] > 3000: + tokens = tokens[..., :3000] + + with torch.no_grad(): + if melody_is_wav: + wav_seperate = model.generate_audio(tokens, pmt_wav, vocal_wav, bgm_wav) + else: + wav_seperate = model.generate_audio(tokens) + + del seperate_tokenizer + del model + torch.cuda.empty_cache() + + return wav_seperate[0] diff --git a/tools/gradio/run.sh b/tools/gradio/run.sh new file mode 100644 index 0000000000000000000000000000000000000000..6961792b7f78eb22b6df8b34a05ec0b810c1f131 --- /dev/null +++ b/tools/gradio/run.sh @@ -0,0 +1,9 @@ +export USER=root +export PYTHONDONTWRITEBYTECODE=1 +export TRANSFORMERS_CACHE="$(pwd)/third_party/hub" +export NCCL_HOME=/usr/local/tccl +export PYTHONPATH="$(pwd)/codeclm/tokenizer/":"$(pwd)":"$(pwd)/codeclm/tokenizer/Flow1dVAE/":"$(pwd)/codeclm/tokenizer/":$PYTHONPATH + + +CKPT_PATH=$1 +python3 tools/gradio/app.py $CKPT_PATH diff --git a/tools/gradio/separator.py b/tools/gradio/separator.py new file mode 100644 index 0000000000000000000000000000000000000000..f6444d897eb09a515d4318bfc85879959b39d801 --- /dev/null +++ b/tools/gradio/separator.py @@ -0,0 +1,50 @@ +import torchaudio +import os +import torch +from third_party.demucs.models.pretrained import get_model_from_yaml + + +class Separator(torch.nn.Module): + def __init__(self, dm_model_path='third_party/demucs/ckpt/htdemucs.pth', dm_config_path='third_party/demucs/ckpt/htdemucs.yaml', gpu_id=0) -> None: + super().__init__() + if torch.cuda.is_available() and gpu_id < torch.cuda.device_count(): + self.device = torch.device(f"cuda:{gpu_id}") + else: + self.device = torch.device("cpu") + self.demucs_model = self.init_demucs_model(dm_model_path, dm_config_path) + + def init_demucs_model(self, model_path, config_path): + model = get_model_from_yaml(config_path, model_path) + model.to(self.device) + model.eval() + return model + + def load_audio(self, f): + a, fs = torchaudio.load(f) + if (fs != 48000): + a = torchaudio.functional.resample(a, fs, 48000) + if a.shape[-1] >= 48000*10: + a = a[..., :48000*10] + else: + a = torch.cat([a, a], -1) + return a[:, 0:48000*10] + + def run(self, audio_path, output_dir='tmp', ext=".flac"): + os.makedirs(output_dir, exist_ok=True) + name, _ = os.path.splitext(os.path.split(audio_path)[-1]) + output_paths = [] + + for stem in self.demucs_model.sources: + output_path = os.path.join(output_dir, f"{name}_{stem}{ext}") + if os.path.exists(output_path): + output_paths.append(output_path) + if len(output_paths) == 1: # 4 + vocal_path = output_paths[0] + else: + drums_path, bass_path, other_path, vocal_path = self.demucs_model.separate(audio_path, output_dir, device=self.device) + for path in [drums_path, bass_path, other_path]: + os.remove(path) + full_audio = self.load_audio(audio_path) + vocal_audio = self.load_audio(vocal_path) + bgm_audio = full_audio - vocal_audio + return full_audio, vocal_audio, bgm_audio