diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000000000000000000000000000000000000..7ba3ccc36df91a590346529677507d707a6562b7 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,4 @@ +*.jpg filter=lfs diff=lfs merge=lfs -text +*.jpeg filter=lfs diff=lfs merge=lfs -text +*.png filter=lfs diff=lfs merge=lfs -text +*.gif filter=lfs diff=lfs merge=lfs -text diff --git a/.github/workflows/pylint.yml b/.github/workflows/pylint.yml new file mode 100644 index 0000000000000000000000000000000000000000..a07ba7b7d61913a2ea95ee85f7a0b68151e8f602 --- /dev/null +++ b/.github/workflows/pylint.yml @@ -0,0 +1,27 @@ +name: Ruff + +on: [push] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: ["3.10"] + steps: + - name: Checkout repository and submodules + uses: actions/checkout@v3 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v3 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install ruff==0.2.2 black==24.2.0 + - name: Analyzing the code with ruff + run: | + ruff $(git ls-files '*.py') + - name: Verify that no Black changes are required + run: | + black --check $(git ls-files '*.py') diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..8e91c5bbfcc7b236fe045a9e5ba71892da88e7bb --- /dev/null +++ b/.gitignore @@ -0,0 +1,164 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/latest/usage/project/#working-with-version-control +.pdm.toml +.pdm-python +.pdm-build/ + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# From inference.py \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e0981140b422f092bfa889ac91528d4b2cfa152d --- /dev/null +++ b/.pre-commit-config.yaml @@ -0,0 +1,16 @@ +repos: + - repo: https://github.com/astral-sh/ruff-pre-commit + # Ruff version. + rev: v0.2.2 + hooks: + # Run the linter. + - id: ruff + args: [--fix] # Automatically fix issues if possible. + types: [python] # Ensure it only runs on .py files. + + - repo: https://github.com/psf/black + rev: 24.2.0 # Specify the version of Black you want + hooks: + - id: black + name: Black code formatter + language_version: python3 # Use the Python version you're targeting (e.g., 3.10) \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md new file mode 100644 index 0000000000000000000000000000000000000000..970e15687f01650a2121143c952afbaa499373bc --- /dev/null +++ b/README.md @@ -0,0 +1,49 @@ +--- +title: Chuyển Văn Bản Thành Video AI +emoji: 🎬 +colorFrom: blue +colorTo: purple +sdk: gradio +sdk_version: 4.19.2 +app_file: app.py +pinned: false +--- + +# Chuyển Văn Bản Thành Video AI + +Ứng dụng này sử dụng mô hình LTX-Video để tạo video từ văn bản mô tả. Bạn có thể dễ dàng tạo video ngắn bằng cách nhập mô tả bằng văn bản. + +## Tính năng + +- 🎯 Tạo video từ mô tả văn bản +- 🎨 Tùy chỉnh từ khóa loại trừ để cải thiện chất lượng +- 📱 Giao diện thân thiện, dễ sử dụng +- 🖼️ Xem lại các video đã tạo trước đó + +## Cách sử dụng + +1. Nhập mô tả cho video bạn muốn tạo vào ô "Nhập nội dung video" +2. (Tùy chọn) Điều chỉnh từ khóa loại trừ nếu cần +3. Nhấn nút "Tạo Video" +4. Đợi trong giây lát để hệ thống tạo video +5. Video được tạo sẽ hiển thị ở khung bên phải +6. Xem lại các video đã tạo trước đó ở phần gallery bên dưới + +## Thông số kỹ thuật + +- Model: Lightricks/LTX-Video +- Độ phân giải: 704x480 +- Số frame: 161 +- FPS: 24 +- Số bước inference: 50 + +## Lưu ý + +- Thời gian tạo video có thể mất vài phút tùy thuộc vào độ phức tạp của mô tả +- Chất lượng video phụ thuộc vào độ chi tiết của mô tả văn bản +- Nên sử dụng mô tả rõ ràng, chi tiết để có kết quả tốt nhất + +## Credits + +- Model: [Lightricks/LTX-Video](https://huggingface.co/Lightricks/LTX-Video) +- Framework: [Gradio](https://gradio.app/) diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..5d5f1c8984bb11ffb039a723fd144c4bf372fb66 --- /dev/null +++ b/app.py @@ -0,0 +1,71 @@ +import gradio as gr +import os +import subprocess +from ttv import generate_video # We'll modify ttv.py to make it a function + +def install_dependencies(): + try: + # Cài đặt diffusers từ git + subprocess.run(["pip", "install", "-U", "git+https://github.com/huggingface/diffusers"], + check=True, + capture_output=True) + + # Cài đặt inference-script + subprocess.run(["pip", "install", "-e", ".[inference-script]"], + check=True, + capture_output=True) + + return "✅ Đã cài đặt thành công các gói phụ thuộc!" + except subprocess.CalledProcessError as e: + return f"❌ Lỗi khi cài đặt: {str(e)}" + +def text_to_video(prompt, negative_prompt): + # Generate video from text + output_path = generate_video(prompt, negative_prompt) + return output_path + +def list_videos(): + # List all MP4 files in the current directory + videos = [f for f in os.listdir('.') if f.endswith('.mp4')] + return videos + +# Create Gradio interface +with gr.Blocks() as demo: + gr.Markdown("# Chuyển Văn Bản Thành Video") + + # Add installation button at the top + install_btn = gr.Button("Cài đặt Dependencies") + install_output = gr.Textbox(label="Trạng thái cài đặt", interactive=False) + + with gr.Row(): + with gr.Column(): + # Input components + text_input = gr.Textbox(label="Nhập nội dung video", lines=3) + neg_prompt = gr.Textbox(label="Từ khóa loại trừ", + value="chất lượng kém, chuyển động không đồng nhất, mờ, giật, biến dạng") + generate_btn = gr.Button("Tạo Video") + + with gr.Column(): + # Output video display + video_output = gr.Video(label="Video đã tạo") + + # Gallery of existing videos + gr.Markdown("### Video đã tạo trước đó") + gallery = gr.Gallery(value=list_videos(), label="Video có sẵn", + show_label=True, elem_id="gallery").style(grid=2) + + # Connect components + generate_btn.click( + fn=text_to_video, + inputs=[text_input, neg_prompt], + outputs=[video_output] + ) + + # Connect install button + install_btn.click( + fn=install_dependencies, + inputs=[], + outputs=[install_output] + ) + +demo.launch() diff --git a/docs/_static/ltx-video_example_00001.gif b/docs/_static/ltx-video_example_00001.gif new file mode 100644 index 0000000000000000000000000000000000000000..19ff700c5b0bed1764d4a9afe27345312bdbb36c --- /dev/null +++ b/docs/_static/ltx-video_example_00001.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b679f14a09d2321b7e34b3ecd23bc01c2cfa75c8d4214a1e59af09826003e2ec +size 7963919 diff --git a/docs/_static/ltx-video_example_00002.gif b/docs/_static/ltx-video_example_00002.gif new file mode 100644 index 0000000000000000000000000000000000000000..03892c267af59cfaf99867ca3f675f5efc03440a --- /dev/null +++ b/docs/_static/ltx-video_example_00002.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:336f4baec79c1bd754c7c1bf3ac0792910cc85b6a3bde15fabeb0fb0f33299ff +size 7897781 diff --git a/docs/_static/ltx-video_example_00003.gif b/docs/_static/ltx-video_example_00003.gif new file mode 100644 index 0000000000000000000000000000000000000000..ddd889db7363bccc569e5cdb91bd20db873a0409 --- /dev/null +++ b/docs/_static/ltx-video_example_00003.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:ab2cb063b872d487fbbab821de7fe8157e7f87af03bd780d55116cb98fc8fc45 +size 4429543 diff --git a/docs/_static/ltx-video_example_00004.gif b/docs/_static/ltx-video_example_00004.gif new file mode 100644 index 0000000000000000000000000000000000000000..53f2d435d0ac9f8fc6a73530dc25f18b52c404cc --- /dev/null +++ b/docs/_static/ltx-video_example_00004.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:0a599a641cc3367fab5a6dd75fc89be63208cc708a1173b2ce7bfeac7208f831 +size 6713603 diff --git a/docs/_static/ltx-video_example_00005.gif b/docs/_static/ltx-video_example_00005.gif new file mode 100644 index 0000000000000000000000000000000000000000..58a0a4ad9847e37fae7979070972dd84450370c8 --- /dev/null +++ b/docs/_static/ltx-video_example_00005.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:87fdb9556c1218db4b929994e9b807d1d63f4676defef5b418a4edb1ddaa8422 +size 5732587 diff --git a/docs/_static/ltx-video_example_00006.gif b/docs/_static/ltx-video_example_00006.gif new file mode 100644 index 0000000000000000000000000000000000000000..788ab9411150a051b39f9e1141559041247f958d --- /dev/null +++ b/docs/_static/ltx-video_example_00006.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f56f3dcc84a871ab4ef1510120f7a4586c7044c5609a897d8177ae8d52eb3eae +size 4239543 diff --git a/docs/_static/ltx-video_example_00007.gif b/docs/_static/ltx-video_example_00007.gif new file mode 100644 index 0000000000000000000000000000000000000000..d34d86ff54960073315d8ef3789ef84f3fb36cf7 --- /dev/null +++ b/docs/_static/ltx-video_example_00007.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a08a06681334856db516e969a9ae4290acfd7550f7b970331e87d0223e282bcc +size 7829259 diff --git a/docs/_static/ltx-video_example_00008.gif b/docs/_static/ltx-video_example_00008.gif new file mode 100644 index 0000000000000000000000000000000000000000..45f5c9e27d1b256c06e4868cacfdd05a7954fd3c --- /dev/null +++ b/docs/_static/ltx-video_example_00008.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3242c65e11a40177c91b48d8ee18084dc4f907ffe5f11217c5f3e5aa2ca3fe36 +size 6229734 diff --git a/docs/_static/ltx-video_example_00009.gif b/docs/_static/ltx-video_example_00009.gif new file mode 100644 index 0000000000000000000000000000000000000000..0a2a4878ecfa35572a20bae70eb2f6af310e9696 --- /dev/null +++ b/docs/_static/ltx-video_example_00009.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa1e0a2ba75c6bda530a798e8aaeb3edc19413970b99d2a67b79839cd14f2fe5 +size 6389700 diff --git a/docs/_static/ltx-video_example_00010.gif b/docs/_static/ltx-video_example_00010.gif new file mode 100644 index 0000000000000000000000000000000000000000..999b04171b2b4087728498cf940b7664cd9c3829 --- /dev/null +++ b/docs/_static/ltx-video_example_00010.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bcf1e084e936a75eaae73a29f60935c469b1fc34eb3f5ad89483e88b3a2eaffe +size 6193172 diff --git a/docs/_static/ltx-video_example_00011.gif b/docs/_static/ltx-video_example_00011.gif new file mode 100644 index 0000000000000000000000000000000000000000..d0274a7bd9d9f1528f89192a773dca8948b9ee03 --- /dev/null +++ b/docs/_static/ltx-video_example_00011.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3e3d04f5763ecb416b3b80c3488e48c49991d80661c94e8f08dddd7b890b1b75 +size 5345673 diff --git a/docs/_static/ltx-video_example_00012.gif b/docs/_static/ltx-video_example_00012.gif new file mode 100644 index 0000000000000000000000000000000000000000..f9f6dc712d6faa0579d9c6b7347045c6b5d0f6fb --- /dev/null +++ b/docs/_static/ltx-video_example_00012.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:39790832fd9bff62c99a799eb4843cf99c9ab73c3f181656acbbd0d4ebf7f471 +size 7474091 diff --git a/docs/_static/ltx-video_example_00013.gif b/docs/_static/ltx-video_example_00013.gif new file mode 100644 index 0000000000000000000000000000000000000000..a32a9ca7f6d30b6759af80d52d46a0a027dc278e --- /dev/null +++ b/docs/_static/ltx-video_example_00013.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:aa7eb790b43f8a55c01d1fbed4c7a7f657fb2ca78a9685833cf9cb558d2002c1 +size 9024843 diff --git a/docs/_static/ltx-video_example_00014.gif b/docs/_static/ltx-video_example_00014.gif new file mode 100644 index 0000000000000000000000000000000000000000..adcf7c2387c11cb8efe6cd14577afe22e31a8e63 --- /dev/null +++ b/docs/_static/ltx-video_example_00014.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:4f7afc4b498a927dcc4e1492548db5c32fa76d117e0410d11e1e0b1929153e54 +size 7434241 diff --git a/docs/_static/ltx-video_example_00015.gif b/docs/_static/ltx-video_example_00015.gif new file mode 100644 index 0000000000000000000000000000000000000000..e70df3caa759cc2e7d7e33c4fd9a957eff907e99 --- /dev/null +++ b/docs/_static/ltx-video_example_00015.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:d897c9656e0cba89512ab9d2cbe2d2c0f2ddf907dcab5f7eadab4b96b1cb1efe +size 6556457 diff --git a/docs/_static/ltx-video_example_00016.gif b/docs/_static/ltx-video_example_00016.gif new file mode 100644 index 0000000000000000000000000000000000000000..34474f880704fa15910790f95f3d04decf13dadf --- /dev/null +++ b/docs/_static/ltx-video_example_00016.gif @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:c74f35e37bba01817ca4ac01dd9195863100eb83e7cb73bbea2b53e0f69a8628 +size 7412915 diff --git a/inference.py b/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..704ba6ab42ec401ec90b33797bdae594cb6945d2 --- /dev/null +++ b/inference.py @@ -0,0 +1,488 @@ +import argparse +import os +import random +from datetime import datetime +from pathlib import Path +from diffusers.utils import logging + +import imageio +import numpy as np +import torch +import torch.nn.functional as F +from PIL import Image +from transformers import T5EncoderModel, T5Tokenizer + +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.models.transformers.symmetric_patchifier import SymmetricPatchifier +from ltx_video.models.transformers.transformer3d import Transformer3DModel +from ltx_video.pipelines.pipeline_ltx_video import LTXVideoPipeline +from ltx_video.schedulers.rf import RectifiedFlowScheduler +from ltx_video.utils.conditioning_method import ConditioningMethod +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +MAX_HEIGHT = 720 +MAX_WIDTH = 1280 +MAX_NUM_FRAMES = 257 + + +def get_total_gpu_memory(): + if torch.cuda.is_available(): + total_memory = torch.cuda.get_device_properties(0).total_memory / (1024**3) + return total_memory + return None + + +def load_image_to_tensor_with_resize_and_crop( + image_path, target_height=512, target_width=768 +): + image = Image.open(image_path).convert("RGB") + input_width, input_height = image.size + aspect_ratio_target = target_width / target_height + aspect_ratio_frame = input_width / input_height + if aspect_ratio_frame > aspect_ratio_target: + new_width = int(input_height * aspect_ratio_target) + new_height = input_height + x_start = (input_width - new_width) // 2 + y_start = 0 + else: + new_width = input_width + new_height = int(input_width / aspect_ratio_target) + x_start = 0 + y_start = (input_height - new_height) // 2 + + image = image.crop((x_start, y_start, x_start + new_width, y_start + new_height)) + image = image.resize((target_width, target_height)) + frame_tensor = torch.tensor(np.array(image)).permute(2, 0, 1).float() + frame_tensor = (frame_tensor / 127.5) - 1.0 + # Create 5D tensor: (batch_size=1, channels=3, num_frames=1, height, width) + return frame_tensor.unsqueeze(0).unsqueeze(2) + + +def calculate_padding( + source_height: int, source_width: int, target_height: int, target_width: int +) -> tuple[int, int, int, int]: + + # Calculate total padding needed + pad_height = target_height - source_height + pad_width = target_width - source_width + + # Calculate padding for each side + pad_top = pad_height // 2 + pad_bottom = pad_height - pad_top # Handles odd padding + pad_left = pad_width // 2 + pad_right = pad_width - pad_left # Handles odd padding + + # Return padded tensor + # Padding format is (left, right, top, bottom) + padding = (pad_left, pad_right, pad_top, pad_bottom) + return padding + + +def convert_prompt_to_filename(text: str, max_len: int = 20) -> str: + # Remove non-letters and convert to lowercase + clean_text = "".join( + char.lower() for char in text if char.isalpha() or char.isspace() + ) + + # Split into words + words = clean_text.split() + + # Build result string keeping track of length + result = [] + current_length = 0 + + for word in words: + # Add word length plus 1 for underscore (except for first word) + new_length = current_length + len(word) + + if new_length <= max_len: + result.append(word) + current_length += len(word) + else: + break + + return "-".join(result) + + +# Generate output video name +def get_unique_filename( + base: str, + ext: str, + prompt: str, + seed: int, + resolution: tuple[int, int, int], + dir: Path, + endswith=None, + index_range=1000, +) -> Path: + base_filename = f"{base}_{convert_prompt_to_filename(prompt, max_len=30)}_{seed}_{resolution[0]}x{resolution[1]}x{resolution[2]}" + for i in range(index_range): + filename = dir / f"{base_filename}_{i}{endswith if endswith else ''}{ext}" + if not os.path.exists(filename): + return filename + raise FileExistsError( + f"Could not find a unique filename after {index_range} attempts." + ) + + +def seed_everething(seed: int): + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + if torch.cuda.is_available(): + torch.cuda.manual_seed(seed) + + +def main(): + parser = argparse.ArgumentParser( + description="Load models from separate directories and run the pipeline." + ) + + # Directories + parser.add_argument( + "--ckpt_path", + type=str, + required=True, + help="Path to a safetensors file that contains all model parts.", + ) + parser.add_argument( + "--input_video_path", + type=str, + help="Path to the input video file (first frame used)", + ) + parser.add_argument( + "--input_image_path", type=str, help="Path to the input image file" + ) + parser.add_argument( + "--output_path", + type=str, + default=None, + help="Path to the folder to save output video, if None will save in outputs/ directory.", + ) + parser.add_argument("--seed", type=int, default="171198") + + # Pipeline parameters + parser.add_argument( + "--num_inference_steps", type=int, default=40, help="Number of inference steps" + ) + parser.add_argument( + "--num_images_per_prompt", + type=int, + default=1, + help="Number of images per prompt", + ) + parser.add_argument( + "--guidance_scale", + type=float, + default=3, + help="Guidance scale for the pipeline", + ) + parser.add_argument( + "--stg_scale", + type=float, + default=1, + help="Spatiotemporal guidance scale for the pipeline. 0 to disable STG.", + ) + parser.add_argument( + "--stg_rescale", + type=float, + default=0.7, + help="Spatiotemporal guidance rescaling scale for the pipeline. 1 to disable rescale.", + ) + parser.add_argument( + "--stg_mode", + type=str, + default="stg_a", + help="Spatiotemporal guidance mode for the pipeline. Can be either stg_a or stg_r.", + ) + parser.add_argument( + "--stg_skip_layers", + type=str, + default="19", + help="Attention layers to skip for spatiotemporal guidance. Comma separated list of integers.", + ) + parser.add_argument( + "--image_cond_noise_scale", + type=float, + default=0.15, + help="Amount of noise to add to the conditioned image", + ) + parser.add_argument( + "--height", + type=int, + default=480, + help="Height of the output video frames. Optional if an input image provided.", + ) + parser.add_argument( + "--width", + type=int, + default=704, + help="Width of the output video frames. If None will infer from input image.", + ) + parser.add_argument( + "--num_frames", + type=int, + default=121, + help="Number of frames to generate in the output video", + ) + parser.add_argument( + "--frame_rate", type=int, default=25, help="Frame rate for the output video" + ) + + parser.add_argument( + "--precision", + choices=["bfloat16", "mixed_precision"], + default="bfloat16", + help="Sets the precision for the transformer and tokenizer. Default is bfloat16. If 'mixed_precision' is enabled, it moves to mixed-precision.", + ) + + # VAE noise augmentation + parser.add_argument( + "--decode_timestep", + type=float, + default=0.05, + help="Timestep for decoding noise", + ) + parser.add_argument( + "--decode_noise_scale", + type=float, + default=0.025, + help="Noise level for decoding noise", + ) + + # Prompts + parser.add_argument( + "--prompt", + type=str, + help="Text prompt to guide generation", + ) + parser.add_argument( + "--negative_prompt", + type=str, + default="worst quality, inconsistent motion, blurry, jittery, distorted", + help="Negative prompt for undesired features", + ) + + parser.add_argument( + "--offload_to_cpu", + action="store_true", + help="Offloading unnecessary computations to CPU.", + ) + + logger = logging.get_logger(__name__) + + args = parser.parse_args() + + logger.warning(f"Running generation with arguments: {args}") + + seed_everething(args.seed) + + offload_to_cpu = False if not args.offload_to_cpu else get_total_gpu_memory() < 30 + + output_dir = ( + Path(args.output_path) + if args.output_path + else Path(f"outputs/{datetime.today().strftime('%Y-%m-%d')}") + ) + output_dir.mkdir(parents=True, exist_ok=True) + + # Load image + if args.input_image_path: + media_items_prepad = load_image_to_tensor_with_resize_and_crop( + args.input_image_path, args.height, args.width + ) + else: + media_items_prepad = None + + height = args.height if args.height else media_items_prepad.shape[-2] + width = args.width if args.width else media_items_prepad.shape[-1] + num_frames = args.num_frames + + if height > MAX_HEIGHT or width > MAX_WIDTH or num_frames > MAX_NUM_FRAMES: + logger.warning( + f"Input resolution or number of frames {height}x{width}x{num_frames} is too big, it is suggested to use the resolution below {MAX_HEIGHT}x{MAX_WIDTH}x{MAX_NUM_FRAMES}." + ) + + # Adjust dimensions to be divisible by 32 and num_frames to be (N * 8 + 1) + height_padded = ((height - 1) // 32 + 1) * 32 + width_padded = ((width - 1) // 32 + 1) * 32 + num_frames_padded = ((num_frames - 2) // 8 + 1) * 8 + 1 + + padding = calculate_padding(height, width, height_padded, width_padded) + + logger.warning( + f"Padded dimensions: {height_padded}x{width_padded}x{num_frames_padded}" + ) + + if media_items_prepad is not None: + media_items = F.pad( + media_items_prepad, padding, mode="constant", value=-1 + ) # -1 is the value for padding since the image is normalized to -1, 1 + else: + media_items = None + + ckpt_path = Path(args.ckpt_path) + vae = CausalVideoAutoencoder.from_pretrained(ckpt_path) + transformer = Transformer3DModel.from_pretrained(ckpt_path) + scheduler = RectifiedFlowScheduler.from_pretrained(ckpt_path) + + text_encoder = T5EncoderModel.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="text_encoder" + ) + patchifier = SymmetricPatchifier(patch_size=1) + tokenizer = T5Tokenizer.from_pretrained( + "PixArt-alpha/PixArt-XL-2-1024-MS", subfolder="tokenizer" + ) + + if torch.cuda.is_available(): + transformer = transformer.cuda() + vae = vae.cuda() + text_encoder = text_encoder.cuda() + + vae = vae.to(torch.bfloat16) + if args.precision == "bfloat16" and transformer.dtype != torch.bfloat16: + transformer = transformer.to(torch.bfloat16) + text_encoder = text_encoder.to(torch.bfloat16) + + # Set spatiotemporal guidance + skip_block_list = [int(x.strip()) for x in args.stg_skip_layers.split(",")] + skip_layer_strategy = ( + SkipLayerStrategy.Attention + if args.stg_mode.lower() == "stg_a" + else SkipLayerStrategy.Residual + ) + + # Use submodels for the pipeline + submodel_dict = { + "transformer": transformer, + "patchifier": patchifier, + "text_encoder": text_encoder, + "tokenizer": tokenizer, + "scheduler": scheduler, + "vae": vae, + } + + pipeline = LTXVideoPipeline(**submodel_dict) + if torch.cuda.is_available(): + pipeline = pipeline.to("cuda") + + # Prepare input for the pipeline + sample = { + "prompt": args.prompt, + "prompt_attention_mask": None, + "negative_prompt": args.negative_prompt, + "negative_prompt_attention_mask": None, + "media_items": media_items, + } + + generator = torch.Generator( + device="cuda" if torch.cuda.is_available() else "cpu" + ).manual_seed(args.seed) + + images = pipeline( + num_inference_steps=args.num_inference_steps, + num_images_per_prompt=args.num_images_per_prompt, + guidance_scale=args.guidance_scale, + skip_layer_strategy=skip_layer_strategy, + skip_block_list=skip_block_list, + stg_scale=args.stg_scale, + do_rescaling=args.stg_rescale != 1, + rescaling_scale=args.stg_rescale, + generator=generator, + output_type="pt", + callback_on_step_end=None, + height=height_padded, + width=width_padded, + num_frames=num_frames_padded, + frame_rate=args.frame_rate, + **sample, + is_video=True, + vae_per_channel_normalize=True, + conditioning_method=( + ConditioningMethod.FIRST_FRAME + if media_items is not None + else ConditioningMethod.UNCONDITIONAL + ), + image_cond_noise_scale=args.image_cond_noise_scale, + decode_timestep=args.decode_timestep, + decode_noise_scale=args.decode_noise_scale, + mixed_precision=(args.precision == "mixed_precision"), + offload_to_cpu=offload_to_cpu, + ).images + + # Crop the padded images to the desired resolution and number of frames + (pad_left, pad_right, pad_top, pad_bottom) = padding + pad_bottom = -pad_bottom + pad_right = -pad_right + if pad_bottom == 0: + pad_bottom = images.shape[3] + if pad_right == 0: + pad_right = images.shape[4] + images = images[:, :, :num_frames, pad_top:pad_bottom, pad_left:pad_right] + + for i in range(images.shape[0]): + # Gathering from B, C, F, H, W to C, F, H, W and then permuting to F, H, W, C + video_np = images[i].permute(1, 2, 3, 0).cpu().float().numpy() + # Unnormalizing images to [0, 255] range + video_np = (video_np * 255).astype(np.uint8) + fps = args.frame_rate + height, width = video_np.shape[1:3] + # In case a single image is generated + if video_np.shape[0] == 1: + output_filename = get_unique_filename( + f"image_output_{i}", + ".png", + prompt=args.prompt, + seed=args.seed, + resolution=(height, width, num_frames), + dir=output_dir, + ) + imageio.imwrite(output_filename, video_np[0]) + else: + if args.input_image_path: + base_filename = f"img_to_vid_{i}" + else: + base_filename = f"text_to_vid_{i}" + output_filename = get_unique_filename( + base_filename, + ".mp4", + prompt=args.prompt, + seed=args.seed, + resolution=(height, width, num_frames), + dir=output_dir, + ) + + # Write video + with imageio.get_writer(output_filename, fps=fps) as video: + for frame in video_np: + video.append_data(frame) + + # Write condition image + if args.input_image_path: + reference_image = ( + ( + media_items_prepad[0, :, 0].permute(1, 2, 0).cpu().data.numpy() + + 1.0 + ) + / 2.0 + * 255 + ) + imageio.imwrite( + get_unique_filename( + base_filename, + ".png", + prompt=args.prompt, + seed=args.seed, + resolution=(height, width, num_frames), + dir=output_dir, + endswith="_condition", + ), + reference_image.astype(np.uint8), + ) + logger.warning(f"Output saved to {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/itv.py b/itv.py new file mode 100644 index 0000000000000000000000000000000000000000..1fa6fa055a28550bc7a28a9a5d5660d0052f606a --- /dev/null +++ b/itv.py @@ -0,0 +1,23 @@ +import torch +from diffusers import LTXImageToVideoPipeline +from diffusers.utils import export_to_video, load_image + +pipe = LTXImageToVideoPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) +pipe.to("cuda") + +image = load_image( + "https://huggingface.co/datasets/a-r-r-o-w/tiny-meme-dataset-captioned/resolve/main/images/8.png" +) +prompt = "A young girl stands calmly in the foreground, looking directly at the camera, as a house fire rages in the background. Flames engulf the structure, with smoke billowing into the air. Firefighters in protective gear rush to the scene, a fire truck labeled '38' visible behind them. The girl's neutral expression contrasts sharply with the chaos of the fire, creating a poignant and emotionally charged scene." +negative_prompt = "worst quality, inconsistent motion, blurry, jittery, distorted" + +video = pipe( + image=image, + prompt=prompt, + negative_prompt=negative_prompt, + width=704, + height=480, + num_frames=161, + num_inference_steps=50, +).frames[0] +export_to_video(video, "output.mp4", fps=24) diff --git a/ltx_video/__init__.py b/ltx_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/__init__.py b/ltx_video/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/autoencoders/__init__.py b/ltx_video/models/autoencoders/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/autoencoders/causal_conv3d.py b/ltx_video/models/autoencoders/causal_conv3d.py new file mode 100644 index 0000000000000000000000000000000000000000..146dea19b704f6c2a9f1ecc505cbab29ac09914b --- /dev/null +++ b/ltx_video/models/autoencoders/causal_conv3d.py @@ -0,0 +1,62 @@ +from typing import Tuple, Union + +import torch +import torch.nn as nn + + +class CausalConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size: int = 3, + stride: Union[int, Tuple[int]] = 1, + dilation: int = 1, + groups: int = 1, + **kwargs, + ): + super().__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + + kernel_size = (kernel_size, kernel_size, kernel_size) + self.time_kernel_size = kernel_size[0] + + dilation = (dilation, 1, 1) + + height_pad = kernel_size[1] // 2 + width_pad = kernel_size[2] // 2 + padding = (0, height_pad, width_pad) + + self.conv = nn.Conv3d( + in_channels, + out_channels, + kernel_size, + stride=stride, + dilation=dilation, + padding=padding, + padding_mode="zeros", + groups=groups, + ) + + def forward(self, x, causal: bool = True): + if causal: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, self.time_kernel_size - 1, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x), dim=2) + else: + first_frame_pad = x[:, :, :1, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + last_frame_pad = x[:, :, -1:, :, :].repeat( + (1, 1, (self.time_kernel_size - 1) // 2, 1, 1) + ) + x = torch.concatenate((first_frame_pad, x, last_frame_pad), dim=2) + x = self.conv(x) + return x + + @property + def weight(self): + return self.conv.weight diff --git a/ltx_video/models/autoencoders/causal_video_autoencoder.py b/ltx_video/models/autoencoders/causal_video_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..c059d4abedf49213564c97ede617e8147e5f1d22 --- /dev/null +++ b/ltx_video/models/autoencoders/causal_video_autoencoder.py @@ -0,0 +1,1263 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union, List +from pathlib import Path + +import torch +import numpy as np +from einops import rearrange +from torch import nn +from diffusers.utils import logging +import torch.nn.functional as F +from diffusers.models.embeddings import PixArtAlphaCombinedTimestepSizeEmbeddings +from safetensors import safe_open + + +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ltx_video.models.autoencoders.pixel_norm import PixelNorm +from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper +from ltx_video.models.transformers.attention import Attention +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + VAE_KEYS_RENAME_DICT, +) + +PER_CHANNEL_STATISTICS_PREFIX = "per_channel_statistics." +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +class CausalVideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_name_or_path = Path(pretrained_model_name_or_path) + if ( + pretrained_model_name_or_path.is_dir() + and (pretrained_model_name_or_path / "autoencoder.pth").exists() + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + state_dict = torch.load(model_local_path, map_location=torch.device("cpu")) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + std_of_means = data_dict["std-of-means"] + mean_of_means = data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}std-of-means"] = ( + std_of_means + ) + state_dict[f"{PER_CHANNEL_STATISTICS_PREFIX}mean-of-means"] = ( + mean_of_means + ) + + elif pretrained_model_name_or_path.is_dir(): + config_path = pretrained_model_name_or_path / "vae" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for VAE is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + + state_dict_path = ( + pretrained_model_name_or_path + / "vae" + / "diffusion_pytorch_model.safetensors" + ) + + state_dict = {} + with safe_open(state_dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in VAE_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + + state_dict[new_key] = state_dict.pop(key) + + elif pretrained_model_name_or_path.is_file() and str( + pretrained_model_name_or_path + ).endswith(".safetensors"): + state_dict = {} + with safe_open( + pretrained_model_name_or_path, framework="pt", device="cpu" + ) as f: + metadata = f.metadata() + for k in f.keys(): + state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["vae"] + + video_vae = cls.from_config(config) + if "torch_dtype" in kwargs: + video_vae.to(kwargs["torch_dtype"]) + video_vae.load_state_dict(state_dict) + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "CausalVideoAutoencoder" + ), "config must have _class_name=CausalVideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + blocks=config.get("encoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + blocks=config.get("decoder_blocks", config.get("blocks")), + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + causal=config.get("causal_decoder", False), + timestep_conditioning=config.get("timestep_conditioning", False), + ) + + dims = config["dims"] + return CausalVideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="CausalVideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels // self.encoder.patch_size**2, + out_channels=self.decoder.conv_out.out_channels + // self.decoder.patch_size**2, + latent_channels=self.decoder.conv_in.in_channels, + encoder_blocks=self.encoder.blocks_desc, + decoder_blocks=self.decoder.blocks_desc, + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + causal_decoder=self.decoder.causal, + timestep_conditioning=self.decoder.timestep_conditioning, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def spatial_downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] in ["compress_space", "compress_all"] + ] + ) + * self.encoder.patch_size + ) + + @property + def temporal_downscale_factor(self): + return 2 ** len( + [ + block + for block in self.encoder.blocks_desc + if block[0] in ["compress_time", "compress_all"] + ] + ) + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + if any([key.startswith("vae.") for key in state_dict.keys()]): + state_dict = { + key.replace("vae.", ""): value + for key, value in state_dict.items() + if key.startswith("vae.") + } + ckpt_state_dict = { + key: value + for key, value in state_dict.items() + if not key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + + model_keys = set(name for name, _ in self.named_parameters()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + converted_state_dict = {} + for key, value in ckpt_state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + data_dict = { + key.removeprefix(PER_CHANNEL_STATISTICS_PREFIX): value + for key, value in state_dict.items() + if key.startswith(PER_CHANNEL_STATISTICS_PREFIX) + } + if len(data_dict) > 0: + self.register_buffer("std_of_means", data_dict["std-of-means"]) + self.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + def set_use_tpu_flash_attention(self): + for block in self.decoder.up_blocks: + if isinstance(block, UNetMidBlock3D) and block.attention_blocks: + for attention_block in block.attention_blocks: + attention_block.set_use_tpu_flash_attention() + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + ): + super().__init__() + self.patch_size = patch_size + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + self.blocks_desc = blocks + + in_channels = in_channels * patch_size**2 + output_channel = base_channels + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.down_blocks = nn.ModuleList([]) + + for block_name, block_params in blocks: + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "res_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + ) + elif block_name == "compress_time": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 1, 1), + causal=True, + ) + elif block_name == "compress_space": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(1, 2, 2), + causal=True, + ) + elif block_name == "compress_all": + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + ) + elif block_name == "compress_all_x_y": + output_channel = block_params.get("multiplier", 2) * output_channel + block = make_conv_nd( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + kernel_size=3, + stride=(2, 2, 2), + causal=True, + ) + else: + raise ValueError(f"unknown block: {block_name}") + + self.down_blocks.append(block) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, output_channel, conv_out_channels, 3, padding=1, causal=True + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)(sample) + + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + dims (`int` or `Tuple[int, int]`, *optional*, defaults to 3): + The number of dimensions to use in convolutions. + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + blocks (`List[Tuple[str, int]]`, *optional*, defaults to `[("res_x", 1)]`): + The blocks to use. Each block is a tuple of the block name and the number of layers. + base_channels (`int`, *optional*, defaults to 128): + The number of output channels for the first convolutional layer. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + causal (`bool`, *optional*, defaults to `True`): + Whether to use causal convolutions or not. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + blocks: List[Tuple[str, int | dict]] = [("res_x", 1)], + base_channels: int = 128, + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + causal: bool = True, + timestep_conditioning: bool = False, + ): + super().__init__() + self.patch_size = patch_size + self.layers_per_block = layers_per_block + out_channels = out_channels * patch_size**2 + self.causal = causal + self.blocks_desc = blocks + + # Compute output channel to be product of all channel-multiplier blocks + output_channel = base_channels + for block_name, block_params in list(reversed(blocks)): + block_params = block_params if isinstance(block_params, dict) else {} + if block_name == "res_x_y": + output_channel = output_channel * block_params.get("multiplier", 2) + if block_name == "compress_all": + output_channel = output_channel * block_params.get("multiplier", 1) + + self.conv_in = make_conv_nd( + dims, + in_channels, + output_channel, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + self.up_blocks = nn.ModuleList([]) + + for block_name, block_params in list(reversed(blocks)): + input_channel = output_channel + if isinstance(block_params, int): + block_params = {"num_layers": block_params} + + if block_name == "res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + ) + elif block_name == "attn_res_x": + block = UNetMidBlock3D( + dims=dims, + in_channels=input_channel, + num_layers=block_params["num_layers"], + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=timestep_conditioning, + attention_head_dim=block_params["attention_head_dim"], + ) + elif block_name == "res_x_y": + output_channel = output_channel // block_params.get("multiplier", 2) + block = ResnetBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + eps=1e-6, + groups=norm_num_groups, + norm_layer=norm_layer, + inject_noise=block_params.get("inject_noise", False), + timestep_conditioning=False, + ) + elif block_name == "compress_time": + block = DepthToSpaceUpsample( + dims=dims, in_channels=input_channel, stride=(2, 1, 1) + ) + elif block_name == "compress_space": + block = DepthToSpaceUpsample( + dims=dims, in_channels=input_channel, stride=(1, 2, 2) + ) + elif block_name == "compress_all": + output_channel = output_channel // block_params.get("multiplier", 1) + block = DepthToSpaceUpsample( + dims=dims, + in_channels=input_channel, + stride=(2, 2, 2), + residual=block_params.get("residual", False), + out_channels_reduction_factor=block_params.get("multiplier", 1), + ) + else: + raise ValueError(f"unknown layer: {block_name}") + + self.up_blocks.append(block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=output_channel, num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + elif norm_layer == "layer_norm": + self.conv_norm_out = LayerNorm(output_channel, eps=1e-6) + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, output_channel, out_channels, 3, padding=1, causal=True + ) + + self.gradient_checkpointing = False + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.timestep_scale_multiplier = nn.Parameter( + torch.tensor(1000.0, dtype=torch.float32) + ) + self.last_time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + output_channel * 2, 0 + ) + self.last_scale_shift_table = nn.Parameter( + torch.randn(2, output_channel) / output_channel**0.5 + ) + + def forward( + self, + sample: torch.FloatTensor, + target_shape, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + batch_size = sample.shape[0] + + sample = self.conv_in(sample, causal=self.causal) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = sample.to(upscale_dtype) + + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + scaled_timestep = timestep * self.timestep_scale_multiplier + + for up_block in self.up_blocks: + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + sample = self.conv_norm_out(sample) + + if self.timestep_conditioning: + embedded_timestep = self.last_time_embedder( + timestep=scaled_timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=sample.shape[0], + hidden_dtype=sample.dtype, + ) + embedded_timestep = embedded_timestep.view( + batch_size, embedded_timestep.shape[-1], 1, 1, 1 + ) + ada_values = self.last_scale_shift_table[ + None, ..., None, None, None + ] + embedded_timestep.reshape( + batch_size, + 2, + -1, + embedded_timestep.shape[-3], + embedded_timestep.shape[-2], + embedded_timestep.shape[-1], + ) + shift, scale = ada_values.unbind(dim=1) + sample = sample * (1 + scale) + shift + + sample = self.conv_act(sample) + sample = self.conv_out(sample, causal=self.causal) + + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + + return sample + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + inject_noise (`bool`, *optional*, defaults to `False`): + Whether to inject noise into the hidden states. + timestep_conditioning (`bool`, *optional*, defaults to `False`): + Whether to condition the hidden states on the timestep. + attention_head_dim (`int`, *optional*, defaults to -1): + The dimension of the attention head. If -1, no attention is used. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + attention_head_dim: int = -1, + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.time_embedder = PixArtAlphaCombinedTimestepSizeEmbeddings( + in_channels * 4, 0 + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + inject_noise=inject_noise, + timestep_conditioning=timestep_conditioning, + ) + for _ in range(num_layers) + ] + ) + + self.attention_blocks = None + + if attention_head_dim > 0: + if attention_head_dim > in_channels: + raise ValueError( + "attention_head_dim must be less than or equal to in_channels" + ) + + self.attention_blocks = nn.ModuleList( + [ + Attention( + query_dim=in_channels, + heads=in_channels // attention_head_dim, + dim_head=attention_head_dim, + bias=True, + out_bias=True, + qk_norm="rms_norm", + residual_connection=True, + ) + for _ in range(num_layers) + ] + ) + + def forward( + self, + hidden_states: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + timestep_embed = None + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + batch_size = hidden_states.shape[0] + timestep_embed = self.time_embedder( + timestep=timestep.flatten(), + resolution=None, + aspect_ratio=None, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + timestep_embed = timestep_embed.view( + batch_size, timestep_embed.shape[-1], 1, 1, 1 + ) + + if self.attention_blocks: + for resnet, attention in zip(self.res_blocks, self.attention_blocks): + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + # Reshape the hidden states to be (batch_size, frames * height * width, channel) + batch_size, channel, frames, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, frames * height * width + ).transpose(1, 2) + + if attention.use_tpu_flash_attention: + # Pad the second dimension to be divisible by block_k_major (block in flash attention) + seq_len = hidden_states.shape[1] + block_k_major = 512 + pad_len = (block_k_major - seq_len % block_k_major) % block_k_major + if pad_len > 0: + hidden_states = F.pad( + hidden_states, (0, 0, 0, pad_len), "constant", 0 + ) + + # Create a mask with ones for the original sequence length and zeros for the padded indexes + mask = torch.ones( + (hidden_states.shape[0], seq_len), + device=hidden_states.device, + dtype=hidden_states.dtype, + ) + if pad_len > 0: + mask = F.pad(mask, (0, pad_len), "constant", 0) + + hidden_states = attention( + hidden_states, + attention_mask=( + None if not attention.use_tpu_flash_attention else mask + ), + ) + + if attention.use_tpu_flash_attention: + # Remove the padding + if pad_len > 0: + hidden_states = hidden_states[:, :-pad_len, :] + + # Reshape the hidden states back to (batch_size, channel, frames, height, width, channel) + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, frames, height, width + ) + else: + for resnet in self.res_blocks: + hidden_states = resnet( + hidden_states, causal=causal, timestep=timestep_embed + ) + + return hidden_states + + +class DepthToSpaceUpsample(nn.Module): + def __init__( + self, dims, in_channels, stride, residual=False, out_channels_reduction_factor=1 + ): + super().__init__() + self.stride = stride + self.out_channels = ( + np.prod(stride) * in_channels // out_channels_reduction_factor + ) + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=self.out_channels, + kernel_size=3, + stride=1, + causal=True, + ) + self.residual = residual + self.out_channels_reduction_factor = out_channels_reduction_factor + + def forward(self, x, causal: bool = True): + if self.residual: + # Reshape and duplicate the input to match the output shape + x_in = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + num_repeat = np.prod(self.stride) // self.out_channels_reduction_factor + x_in = x_in.repeat(1, num_repeat, 1, 1, 1) + if self.stride[0] == 2: + x_in = x_in[:, :, 1:, :, :] + x = self.conv(x, causal=causal) + x = rearrange( + x, + "b (c p1 p2 p3) d h w -> b c (d p1) (h p2) (w p3)", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + if self.stride[0] == 2: + x = x[:, :, 1:, :, :] + if self.residual: + x = x + x_in + return x + + +class LayerNorm(nn.Module): + def __init__(self, dim, eps, elementwise_affine=True) -> None: + super().__init__() + self.norm = nn.LayerNorm(dim, eps=eps, elementwise_affine=elementwise_affine) + + def forward(self, x): + x = rearrange(x, "b c d h w -> b d h w c") + x = self.norm(x) + x = rearrange(x, "b d h w c -> b c d h w") + return x + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + inject_noise: bool = False, + timestep_conditioning: bool = False, + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.inject_noise = inject_noise + + if norm_layer == "group_norm": + self.norm1 = nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm1 = LayerNorm(in_channels, eps=eps, elementwise_affine=True) + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, + in_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + if inject_noise: + self.per_channel_scale1 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + if norm_layer == "group_norm": + self.norm2 = nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + elif norm_layer == "layer_norm": + self.norm2 = LayerNorm(out_channels, eps=eps, elementwise_affine=True) + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, + out_channels, + out_channels, + kernel_size=3, + stride=1, + padding=1, + causal=True, + ) + + if inject_noise: + self.per_channel_scale2 = nn.Parameter(torch.zeros((in_channels, 1, 1))) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + self.norm3 = ( + LayerNorm(in_channels, eps=eps, elementwise_affine=True) + if in_channels != out_channels + else nn.Identity() + ) + + self.timestep_conditioning = timestep_conditioning + + if timestep_conditioning: + self.scale_shift_table = nn.Parameter( + torch.randn(4, in_channels) / in_channels**0.5 + ) + + def _feed_spatial_noise( + self, hidden_states: torch.FloatTensor, per_channel_scale: torch.FloatTensor + ) -> torch.FloatTensor: + spatial_shape = hidden_states.shape[-2:] + device = hidden_states.device + dtype = hidden_states.dtype + + # similar to the "explicit noise inputs" method in style-gan + spatial_noise = torch.randn(spatial_shape, device=device, dtype=dtype)[None] + scaled_noise = (spatial_noise * per_channel_scale)[None, :, None, ...] + hidden_states = hidden_states + scaled_noise + + return hidden_states + + def forward( + self, + input_tensor: torch.FloatTensor, + causal: bool = True, + timestep: Optional[torch.Tensor] = None, + ) -> torch.FloatTensor: + hidden_states = input_tensor + batch_size = hidden_states.shape[0] + + hidden_states = self.norm1(hidden_states) + if self.timestep_conditioning: + assert ( + timestep is not None + ), "should pass timestep with timestep_conditioning=True" + ada_values = self.scale_shift_table[ + None, ..., None, None, None + ] + timestep.reshape( + batch_size, + 4, + -1, + timestep.shape[-3], + timestep.shape[-2], + timestep.shape[-1], + ) + shift1, scale1, shift2, scale2 = ada_values.unbind(dim=1) + + hidden_states = hidden_states * (1 + scale1) + shift1 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale1 + ) + + hidden_states = self.norm2(hidden_states) + + if self.timestep_conditioning: + hidden_states = hidden_states * (1 + scale2) + shift2 + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states, causal=causal) + + if self.inject_noise: + hidden_states = self._feed_spatial_noise( + hidden_states, self.per_channel_scale2 + ) + + input_tensor = self.norm3(input_tensor) + + batch_size = input_tensor.shape[0] + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +def patchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 64, +): + encoder_blocks = [ + ("res_x", {"num_layers": 4}), + ("compress_all_x_y", {"multiplier": 3}), + ("res_x", {"num_layers": 4}), + ("compress_all_x_y", {"multiplier": 2}), + ("res_x", {"num_layers": 4}), + ("compress_all", {}), + ("res_x", {"num_layers": 3}), + ("res_x", {"num_layers": 4}), + ] + decoder_blocks = [ + ("res_x", {"num_layers": 4}), + ("compress_all", {"residual": True}), + ("res_x_y", {"multiplier": 3}), + ("res_x", {"num_layers": 3}), + ("compress_all", {"residual": True}), + ("res_x_y", {"multiplier": 2}), + ("res_x", {"num_layers": 3}), + ("compress_all", {"residual": True}), + ("res_x", {"num_layers": 3}), + ("res_x", {"num_layers": 4}), + ] + return { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "encoder_blocks": encoder_blocks, + "decoder_blocks": decoder_blocks, + "latent_channels": latent_channels, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, + "timestep_conditioning": True, + } + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = CausalVideoAutoencoder.from_config(config) + + print(video_autoencoder) + video_autoencoder.eval() + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 17, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + + timestep = torch.ones(input_videos.shape[0]) * 0.1 + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape, timestep=timestep + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Validate that single image gets treated the same way as first frame + input_image = input_videos[:, :, :1, :, :] + image_latent = video_autoencoder.encode(input_image).latent_dist.mode() + _ = video_autoencoder.decode( + image_latent, target_shape=image_latent.shape, timestep=timestep + ).sample + + # first_frame_latent = latent[:, :, :1, :, :] + + # assert torch.allclose(image_latent, first_frame_latent, atol=1e-6) + # assert torch.allclose(reconstructed_image, reconstructed_videos[:, :, :1, :, :], atol=1e-6) + # assert (image_latent == first_frame_latent).all() + # assert (reconstructed_image == reconstructed_videos[:, :, :1, :, :]).all() + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/ltx_video/models/autoencoders/conv_nd_factory.py b/ltx_video/models/autoencoders/conv_nd_factory.py new file mode 100644 index 0000000000000000000000000000000000000000..5bc0c2e3882f3bd70c294ea7832fa1cb5501dd3c --- /dev/null +++ b/ltx_video/models/autoencoders/conv_nd_factory.py @@ -0,0 +1,82 @@ +from typing import Tuple, Union + +import torch + +from ltx_video.models.autoencoders.dual_conv3d import DualConv3d +from ltx_video.models.autoencoders.causal_conv3d import CausalConv3d + + +def make_conv_nd( + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + kernel_size: int, + stride=1, + padding=0, + dilation=1, + groups=1, + bias=True, + causal=False, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + elif dims == 3: + if causal: + return CausalConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + return torch.nn.Conv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + dilation=dilation, + groups=groups, + bias=bias, + ) + elif dims == (2, 1): + return DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=bias, + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") + + +def make_linear_nd( + dims: int, + in_channels: int, + out_channels: int, + bias=True, +): + if dims == 2: + return torch.nn.Conv2d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + elif dims == 3 or dims == (2, 1): + return torch.nn.Conv3d( + in_channels=in_channels, out_channels=out_channels, kernel_size=1, bias=bias + ) + else: + raise ValueError(f"unsupported dimensions: {dims}") diff --git a/ltx_video/models/autoencoders/dual_conv3d.py b/ltx_video/models/autoencoders/dual_conv3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6bd54c0a6712857e5f9e62d26144d3a450b58571 --- /dev/null +++ b/ltx_video/models/autoencoders/dual_conv3d.py @@ -0,0 +1,195 @@ +import math +from typing import Tuple, Union + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + + +class DualConv3d(nn.Module): + def __init__( + self, + in_channels, + out_channels, + kernel_size, + stride: Union[int, Tuple[int, int, int]] = 1, + padding: Union[int, Tuple[int, int, int]] = 0, + dilation: Union[int, Tuple[int, int, int]] = 1, + groups=1, + bias=True, + ): + super(DualConv3d, self).__init__() + + self.in_channels = in_channels + self.out_channels = out_channels + # Ensure kernel_size, stride, padding, and dilation are tuples of length 3 + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size, kernel_size) + if kernel_size == (1, 1, 1): + raise ValueError( + "kernel_size must be greater than 1. Use make_linear_nd instead." + ) + if isinstance(stride, int): + stride = (stride, stride, stride) + if isinstance(padding, int): + padding = (padding, padding, padding) + if isinstance(dilation, int): + dilation = (dilation, dilation, dilation) + + # Set parameters for convolutions + self.groups = groups + self.bias = bias + + # Define the size of the channels after the first convolution + intermediate_channels = ( + out_channels if in_channels < out_channels else in_channels + ) + + # Define parameters for the first convolution + self.weight1 = nn.Parameter( + torch.Tensor( + intermediate_channels, + in_channels // groups, + 1, + kernel_size[1], + kernel_size[2], + ) + ) + self.stride1 = (1, stride[1], stride[2]) + self.padding1 = (0, padding[1], padding[2]) + self.dilation1 = (1, dilation[1], dilation[2]) + if bias: + self.bias1 = nn.Parameter(torch.Tensor(intermediate_channels)) + else: + self.register_parameter("bias1", None) + + # Define parameters for the second convolution + self.weight2 = nn.Parameter( + torch.Tensor( + out_channels, intermediate_channels // groups, kernel_size[0], 1, 1 + ) + ) + self.stride2 = (stride[0], 1, 1) + self.padding2 = (padding[0], 0, 0) + self.dilation2 = (dilation[0], 1, 1) + if bias: + self.bias2 = nn.Parameter(torch.Tensor(out_channels)) + else: + self.register_parameter("bias2", None) + + # Initialize weights and biases + self.reset_parameters() + + def reset_parameters(self): + nn.init.kaiming_uniform_(self.weight1, a=math.sqrt(5)) + nn.init.kaiming_uniform_(self.weight2, a=math.sqrt(5)) + if self.bias: + fan_in1, _ = nn.init._calculate_fan_in_and_fan_out(self.weight1) + bound1 = 1 / math.sqrt(fan_in1) + nn.init.uniform_(self.bias1, -bound1, bound1) + fan_in2, _ = nn.init._calculate_fan_in_and_fan_out(self.weight2) + bound2 = 1 / math.sqrt(fan_in2) + nn.init.uniform_(self.bias2, -bound2, bound2) + + def forward(self, x, use_conv3d=False, skip_time_conv=False): + if use_conv3d: + return self.forward_with_3d(x=x, skip_time_conv=skip_time_conv) + else: + return self.forward_with_2d(x=x, skip_time_conv=skip_time_conv) + + def forward_with_3d(self, x, skip_time_conv): + # First convolution + x = F.conv3d( + x, + self.weight1, + self.bias1, + self.stride1, + self.padding1, + self.dilation1, + self.groups, + ) + + if skip_time_conv: + return x + + # Second convolution + x = F.conv3d( + x, + self.weight2, + self.bias2, + self.stride2, + self.padding2, + self.dilation2, + self.groups, + ) + + return x + + def forward_with_2d(self, x, skip_time_conv): + b, c, d, h, w = x.shape + + # First 2D convolution + x = rearrange(x, "b c d h w -> (b d) c h w") + # Squeeze the depth dimension out of weight1 since it's 1 + weight1 = self.weight1.squeeze(2) + # Select stride, padding, and dilation for the 2D convolution + stride1 = (self.stride1[1], self.stride1[2]) + padding1 = (self.padding1[1], self.padding1[2]) + dilation1 = (self.dilation1[1], self.dilation1[2]) + x = F.conv2d(x, weight1, self.bias1, stride1, padding1, dilation1, self.groups) + + _, _, h, w = x.shape + + if skip_time_conv: + x = rearrange(x, "(b d) c h w -> b c d h w", b=b) + return x + + # Second convolution which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c d", b=b) + + # Reshape weight2 to match the expected dimensions for conv1d + weight2 = self.weight2.squeeze(-1).squeeze(-1) + # Use only the relevant dimension for stride, padding, and dilation for the 1D convolution + stride2 = self.stride2[0] + padding2 = self.padding2[0] + dilation2 = self.dilation2[0] + x = F.conv1d(x, weight2, self.bias2, stride2, padding2, dilation2, self.groups) + x = rearrange(x, "(b h w) c d -> b c d h w", b=b, h=h, w=w) + + return x + + @property + def weight(self): + return self.weight2 + + +def test_dual_conv3d_consistency(): + # Initialize parameters + in_channels = 3 + out_channels = 5 + kernel_size = (3, 3, 3) + stride = (2, 2, 2) + padding = (1, 1, 1) + + # Create an instance of the DualConv3d class + dual_conv3d = DualConv3d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + bias=True, + ) + + # Example input tensor + test_input = torch.randn(1, 3, 10, 10, 10) + + # Perform forward passes with both 3D and 2D settings + output_conv3d = dual_conv3d(test_input, use_conv3d=True) + output_2d = dual_conv3d(test_input, use_conv3d=False) + + # Assert that the outputs from both methods are sufficiently close + assert torch.allclose( + output_conv3d, output_2d, atol=1e-6 + ), "Outputs are not consistent between 3D and 2D convolutions." diff --git a/ltx_video/models/autoencoders/pixel_norm.py b/ltx_video/models/autoencoders/pixel_norm.py new file mode 100644 index 0000000000000000000000000000000000000000..9bc3ea60e8a6453e7e12a7fb5aca4de3958a2567 --- /dev/null +++ b/ltx_video/models/autoencoders/pixel_norm.py @@ -0,0 +1,12 @@ +import torch +from torch import nn + + +class PixelNorm(nn.Module): + def __init__(self, dim=1, eps=1e-8): + super(PixelNorm, self).__init__() + self.dim = dim + self.eps = eps + + def forward(self, x): + return x / torch.sqrt(torch.mean(x**2, dim=self.dim, keepdim=True) + self.eps) diff --git a/ltx_video/models/autoencoders/vae.py b/ltx_video/models/autoencoders/vae.py new file mode 100644 index 0000000000000000000000000000000000000000..0bcab5214fdea6126e9b3f67d712a064f24d3b9c --- /dev/null +++ b/ltx_video/models/autoencoders/vae.py @@ -0,0 +1,343 @@ +from typing import Optional, Union + +import torch +import inspect +import math +import torch.nn as nn +from diffusers import ConfigMixin, ModelMixin +from diffusers.models.autoencoders.vae import ( + DecoderOutput, + DiagonalGaussianDistribution, +) +from diffusers.models.modeling_outputs import AutoencoderKLOutput +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd + + +class AutoencoderKLWrapper(ModelMixin, ConfigMixin): + """Variational Autoencoder (VAE) model with KL loss. + + VAE from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma and Max Welling. + This model is a wrapper around an encoder and a decoder, and it adds a KL loss term to the reconstruction loss. + + Args: + encoder (`nn.Module`): + Encoder module. + decoder (`nn.Module`): + Decoder module. + latent_channels (`int`, *optional*, defaults to 4): + Number of latent channels. + """ + + def __init__( + self, + encoder: nn.Module, + decoder: nn.Module, + latent_channels: int = 4, + dims: int = 2, + sample_size=512, + use_quant_conv: bool = True, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = encoder + self.use_quant_conv = use_quant_conv + + # pass init params to Decoder + quant_dims = 2 if dims == 2 else 3 + self.decoder = decoder + if use_quant_conv: + self.quant_conv = make_conv_nd( + quant_dims, 2 * latent_channels, 2 * latent_channels, 1 + ) + self.post_quant_conv = make_conv_nd( + quant_dims, latent_channels, latent_channels, 1 + ) + else: + self.quant_conv = nn.Identity() + self.post_quant_conv = nn.Identity() + self.use_z_tiling = False + self.use_hw_tiling = False + self.dims = dims + self.z_sample_size = 1 + + self.decoder_params = inspect.signature(self.decoder.forward).parameters + + # only relevant if vae tiling is enabled + self.set_tiling_params(sample_size=sample_size, overlap_factor=0.25) + + def set_tiling_params(self, sample_size: int = 512, overlap_factor: float = 0.25): + self.tile_sample_min_size = sample_size + num_blocks = len(self.encoder.down_blocks) + self.tile_latent_min_size = int(sample_size / (2 ** (num_blocks - 1))) + self.tile_overlap_factor = overlap_factor + + def enable_z_tiling(self, z_sample_size: int = 8): + r""" + Enable tiling during VAE decoding. + + When this option is enabled, the VAE will split the input tensor in tiles to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_z_tiling = z_sample_size > 1 + self.z_sample_size = z_sample_size + assert ( + z_sample_size % 8 == 0 or z_sample_size == 1 + ), f"z_sample_size must be a multiple of 8 or 1. Got {z_sample_size}." + + def disable_z_tiling(self): + r""" + Disable tiling during VAE decoding. If `use_tiling` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_z_tiling = False + + def enable_hw_tiling(self): + r""" + Enable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = True + + def disable_hw_tiling(self): + r""" + Disable tiling during VAE decoding along the height and width dimension. + """ + self.use_hw_tiling = False + + def _hw_tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True): + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[3], overlap_size): + row = [] + for j in range(0, x.shape[4], overlap_size): + tile = x[ + :, + :, + :, + i : i + self.tile_sample_min_size, + j : j + self.tile_sample_min_size, + ] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + moments = torch.cat(result_rows, dim=3) + return moments + + def blend_z( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[2], b.shape[2], blend_extent) + for z in range(blend_extent): + b[:, :, z, :, :] = a[:, :, -blend_extent + z, :, :] * ( + 1 - z / blend_extent + ) + b[:, :, z, :, :] * (z / blend_extent) + return b + + def blend_v( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[3], b.shape[3], blend_extent) + for y in range(blend_extent): + b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * ( + 1 - y / blend_extent + ) + b[:, :, :, y, :] * (y / blend_extent) + return b + + def blend_h( + self, a: torch.Tensor, b: torch.Tensor, blend_extent: int + ) -> torch.Tensor: + blend_extent = min(a.shape[4], b.shape[4], blend_extent) + for x in range(blend_extent): + b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * ( + 1 - x / blend_extent + ) + b[:, :, :, :, x] * (x / blend_extent) + return b + + def _hw_tiled_decode(self, z: torch.FloatTensor, target_shape): + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + tile_target_shape = ( + *target_shape[:3], + self.tile_sample_min_size, + self.tile_sample_min_size, + ) + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[3], overlap_size): + row = [] + for j in range(0, z.shape[4], overlap_size): + tile = z[ + :, + :, + :, + i : i + self.tile_latent_min_size, + j : j + self.tile_latent_min_size, + ] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile, target_shape=tile_target_shape) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=4)) + + dec = torch.cat(result_rows, dim=3) + return dec + + def encode( + self, z: torch.FloatTensor, return_dict: bool = True + ) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + num_splits = z.shape[2] // self.z_sample_size + sizes = [self.z_sample_size] * num_splits + sizes = ( + sizes + [z.shape[2] - sum(sizes)] + if z.shape[2] - sum(sizes) > 0 + else sizes + ) + tiles = z.split(sizes, dim=2) + moments_tiles = [ + ( + self._hw_tiled_encode(z_tile, return_dict) + if self.use_hw_tiling + else self._encode(z_tile) + ) + for z_tile in tiles + ] + moments = torch.cat(moments_tiles, dim=2) + + else: + moments = ( + self._hw_tiled_encode(z, return_dict) + if self.use_hw_tiling + else self._encode(z) + ) + + posterior = DiagonalGaussianDistribution(moments) + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _encode(self, x: torch.FloatTensor) -> AutoencoderKLOutput: + h = self.encoder(x) + moments = self.quant_conv(h) + return moments + + def _decode( + self, + z: torch.FloatTensor, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + z = self.post_quant_conv(z) + if "timestep" in self.decoder_params: + dec = self.decoder(z, target_shape=target_shape, timestep=timestep) + else: + dec = self.decoder(z, target_shape=target_shape) + return dec + + def decode( + self, + z: torch.FloatTensor, + return_dict: bool = True, + target_shape=None, + timestep: Optional[torch.Tensor] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + assert target_shape is not None, "target_shape must be provided for decoding" + if self.use_z_tiling and z.shape[2] > self.z_sample_size > 1: + reduction_factor = int( + self.encoder.patch_size_t + * 2 + ** ( + len(self.encoder.down_blocks) + - 1 + - math.sqrt(self.encoder.patch_size) + ) + ) + split_size = self.z_sample_size // reduction_factor + num_splits = z.shape[2] // split_size + + # copy target shape, and divide frame dimension (=2) by the context size + target_shape_split = list(target_shape) + target_shape_split[2] = target_shape[2] // num_splits + + decoded_tiles = [ + ( + self._hw_tiled_decode(z_tile, target_shape_split) + if self.use_hw_tiling + else self._decode(z_tile, target_shape=target_shape_split) + ) + for z_tile in torch.tensor_split(z, num_splits, dim=2) + ] + decoded = torch.cat(decoded_tiles, dim=2) + else: + decoded = ( + self._hw_tiled_decode(z, target_shape) + if self.use_hw_tiling + else self._decode(z, target_shape=target_shape, timestep=timestep) + ) + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether to return a [`DecoderOutput`] instead of a plain tuple. + generator (`torch.Generator`, *optional*): + Generator used to sample from the posterior. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z, target_shape=sample.shape).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/ltx_video/models/autoencoders/vae_encode.py b/ltx_video/models/autoencoders/vae_encode.py new file mode 100644 index 0000000000000000000000000000000000000000..f584ec04747015dbaa60385c621f524906fe6115 --- /dev/null +++ b/ltx_video/models/autoencoders/vae_encode.py @@ -0,0 +1,208 @@ +import torch +from diffusers import AutoencoderKL +from einops import rearrange +from torch import Tensor + + +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.models.autoencoders.video_autoencoder import ( + Downsample3D, + VideoAutoencoder, +) + +try: + import torch_xla.core.xla_model as xm +except ImportError: + xm = None + + +def vae_encode( + media_items: Tensor, + vae: AutoencoderKL, + split_size: int = 1, + vae_per_channel_normalize=False, +) -> Tensor: + """ + Encodes media items (images or videos) into latent representations using a specified VAE model. + The function supports processing batches of images or video frames and can handle the processing + in smaller sub-batches if needed. + + Args: + media_items (Tensor): A torch Tensor containing the media items to encode. The expected + shape is (batch_size, channels, height, width) for images or (batch_size, channels, + frames, height, width) for videos. + vae (AutoencoderKL): An instance of the `AutoencoderKL` class from the `diffusers` library, + pre-configured and loaded with the appropriate model weights. + split_size (int, optional): The number of sub-batches to split the input batch into for encoding. + If set to more than 1, the input media items are processed in smaller batches according to + this value. Defaults to 1, which processes all items in a single batch. + + Returns: + Tensor: A torch Tensor of the encoded latent representations. The shape of the tensor is adjusted + to match the input shape, scaled by the model's configuration. + + Examples: + >>> import torch + >>> from diffusers import AutoencoderKL + >>> vae = AutoencoderKL.from_pretrained('your-model-name') + >>> images = torch.rand(10, 3, 8 256, 256) # Example tensor with 10 videos of 8 frames. + >>> latents = vae_encode(images, vae) + >>> print(latents.shape) # Output shape will depend on the model's latent configuration. + + Note: + In case of a video, the function encodes the media item frame-by frame. + """ + is_video_shaped = media_items.dim() == 5 + batch_size, channels = media_items.shape[0:2] + + if channels != 3: + raise ValueError(f"Expects tensors with 3 channels, got {channels}.") + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + media_items = rearrange(media_items, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(media_items) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(media_items) // split_size + # latents = [vae.encode(image_batch).latent_dist.sample() for image_batch in media_items.split(encode_bs)] + latents = [] + if media_items.device.type == "xla": + xm.mark_step() + for image_batch in media_items.split(encode_bs): + latents.append(vae.encode(image_batch).latent_dist.sample()) + if media_items.device.type == "xla": + xm.mark_step() + latents = torch.cat(latents, dim=0) + else: + latents = vae.encode(media_items).latent_dist.sample() + + latents = normalize_latents(latents, vae, vae_per_channel_normalize) + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "(b n) c h w -> b c n h w", b=batch_size) + return latents + + +def vae_decode( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool = True, + split_size: int = 1, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + is_video_shaped = latents.dim() == 5 + batch_size = latents.shape[0] + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + latents = rearrange(latents, "b c n h w -> (b n) c h w") + if split_size > 1: + if len(latents) % split_size != 0: + raise ValueError( + "Error: The batch size must be divisible by 'train.vae_bs_split" + ) + encode_bs = len(latents) // split_size + image_batch = [ + _run_decoder( + latent_batch, vae, is_video, vae_per_channel_normalize, timestep + ) + for latent_batch in latents.split(encode_bs) + ] + images = torch.cat(image_batch, dim=0) + else: + images = _run_decoder( + latents, vae, is_video, vae_per_channel_normalize, timestep + ) + + if is_video_shaped and not isinstance( + vae, (VideoAutoencoder, CausalVideoAutoencoder) + ): + images = rearrange(images, "(b n) c h w -> b c n h w", b=batch_size) + return images + + +def _run_decoder( + latents: Tensor, + vae: AutoencoderKL, + is_video: bool, + vae_per_channel_normalize=False, + timestep=None, +) -> Tensor: + if isinstance(vae, (VideoAutoencoder, CausalVideoAutoencoder)): + *_, fl, hl, wl = latents.shape + temporal_scale, spatial_scale, _ = get_vae_size_scale_factor(vae) + latents = latents.to(vae.dtype) + vae_decode_kwargs = {} + if timestep is not None: + vae_decode_kwargs["timestep"] = timestep + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + target_shape=( + 1, + 3, + fl * temporal_scale if is_video else 1, + hl * spatial_scale, + wl * spatial_scale, + ), + **vae_decode_kwargs, + )[0] + else: + image = vae.decode( + un_normalize_latents(latents, vae, vae_per_channel_normalize), + return_dict=False, + )[0] + return image + + +def get_vae_size_scale_factor(vae: AutoencoderKL) -> float: + if isinstance(vae, CausalVideoAutoencoder): + spatial = vae.spatial_downscale_factor + temporal = vae.temporal_downscale_factor + else: + down_blocks = len( + [ + block + for block in vae.encoder.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + spatial = vae.config.patch_size * 2**down_blocks + temporal = ( + vae.config.patch_size_t * 2**down_blocks + if isinstance(vae, VideoAutoencoder) + else 1 + ) + + return (temporal, spatial, spatial) + + +def normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + (latents - vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1)) + / vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents * vae.config.scaling_factor + ) + + +def un_normalize_latents( + latents: Tensor, vae: AutoencoderKL, vae_per_channel_normalize: bool = False +) -> Tensor: + return ( + latents * vae.std_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + + vae.mean_of_means.to(latents.dtype).view(1, -1, 1, 1, 1) + if vae_per_channel_normalize + else latents / vae.config.scaling_factor + ) diff --git a/ltx_video/models/autoencoders/video_autoencoder.py b/ltx_video/models/autoencoders/video_autoencoder.py new file mode 100644 index 0000000000000000000000000000000000000000..3c7926c1d3afb8188221b2e569aaaf89f7271bce --- /dev/null +++ b/ltx_video/models/autoencoders/video_autoencoder.py @@ -0,0 +1,1045 @@ +import json +import os +from functools import partial +from types import SimpleNamespace +from typing import Any, Mapping, Optional, Tuple, Union + +import torch +from einops import rearrange +from torch import nn +from torch.nn import functional + +from diffusers.utils import logging + +from ltx_video.utils.torch_utils import Identity +from ltx_video.models.autoencoders.conv_nd_factory import make_conv_nd, make_linear_nd +from ltx_video.models.autoencoders.pixel_norm import PixelNorm +from ltx_video.models.autoencoders.vae import AutoencoderKLWrapper + +logger = logging.get_logger(__name__) + + +class VideoAutoencoder(AutoencoderKLWrapper): + @classmethod + def from_pretrained( + cls, + pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + config_local_path = pretrained_model_name_or_path / "config.json" + config = cls.load_config(config_local_path, **kwargs) + video_vae = cls.from_config(config) + video_vae.to(kwargs["torch_dtype"]) + + model_local_path = pretrained_model_name_or_path / "autoencoder.pth" + ckpt_state_dict = torch.load(model_local_path) + video_vae.load_state_dict(ckpt_state_dict) + + statistics_local_path = ( + pretrained_model_name_or_path / "per_channel_statistics.json" + ) + if statistics_local_path.exists(): + with open(statistics_local_path, "r") as file: + data = json.load(file) + transposed_data = list(zip(*data["data"])) + data_dict = { + col: torch.tensor(vals) + for col, vals in zip(data["columns"], transposed_data) + } + video_vae.register_buffer("std_of_means", data_dict["std-of-means"]) + video_vae.register_buffer( + "mean_of_means", + data_dict.get( + "mean-of-means", torch.zeros_like(data_dict["std-of-means"]) + ), + ) + + return video_vae + + @staticmethod + def from_config(config): + assert ( + config["_class_name"] == "VideoAutoencoder" + ), "config must have _class_name=VideoAutoencoder" + if isinstance(config["dims"], list): + config["dims"] = tuple(config["dims"]) + + assert config["dims"] in [2, 3, (2, 1)], "dims must be 2, 3 or (2, 1)" + + double_z = config.get("double_z", True) + latent_log_var = config.get( + "latent_log_var", "per_channel" if double_z else "none" + ) + use_quant_conv = config.get("use_quant_conv", True) + + if use_quant_conv and latent_log_var == "uniform": + raise ValueError("uniform latent_log_var requires use_quant_conv=False") + + encoder = Encoder( + dims=config["dims"], + in_channels=config.get("in_channels", 3), + out_channels=config["latent_channels"], + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + latent_log_var=latent_log_var, + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + decoder = Decoder( + dims=config["dims"], + in_channels=config["latent_channels"], + out_channels=config.get("out_channels", 3), + block_out_channels=config["block_out_channels"], + patch_size=config.get("patch_size", 1), + norm_layer=config.get("norm_layer", "group_norm"), + patch_size_t=config.get("patch_size_t", config.get("patch_size", 1)), + add_channel_padding=config.get("add_channel_padding", False), + ) + + dims = config["dims"] + return VideoAutoencoder( + encoder=encoder, + decoder=decoder, + latent_channels=config["latent_channels"], + dims=dims, + use_quant_conv=use_quant_conv, + ) + + @property + def config(self): + return SimpleNamespace( + _class_name="VideoAutoencoder", + dims=self.dims, + in_channels=self.encoder.conv_in.in_channels + // (self.encoder.patch_size_t * self.encoder.patch_size**2), + out_channels=self.decoder.conv_out.out_channels + // (self.decoder.patch_size_t * self.decoder.patch_size**2), + latent_channels=self.decoder.conv_in.in_channels, + block_out_channels=[ + self.encoder.down_blocks[i].res_blocks[-1].conv1.out_channels + for i in range(len(self.encoder.down_blocks)) + ], + scaling_factor=1.0, + norm_layer=self.encoder.norm_layer, + patch_size=self.encoder.patch_size, + latent_log_var=self.encoder.latent_log_var, + use_quant_conv=self.use_quant_conv, + patch_size_t=self.encoder.patch_size_t, + add_channel_padding=self.encoder.add_channel_padding, + ) + + @property + def is_video_supported(self): + """ + Check if the model supports video inputs of shape (B, C, F, H, W). Otherwise, the model only supports 2D images. + """ + return self.dims != 2 + + @property + def downscale_factor(self): + return self.encoder.downsample_factor + + def to_json_string(self) -> str: + import json + + return json.dumps(self.config.__dict__) + + def load_state_dict(self, state_dict: Mapping[str, Any], strict: bool = True): + model_keys = set(name for name, _ in self.named_parameters()) + + key_mapping = { + ".resnets.": ".res_blocks.", + "downsamplers.0": "downsample", + "upsamplers.0": "upsample", + } + + converted_state_dict = {} + for key, value in state_dict.items(): + for k, v in key_mapping.items(): + key = key.replace(k, v) + + if "norm" in key and key not in model_keys: + logger.info( + f"Removing key {key} from state_dict as it is not present in the model" + ) + continue + + converted_state_dict[key] = value + + super().load_state_dict(converted_state_dict, strict=strict) + + def last_layer(self): + if hasattr(self.decoder, "conv_out"): + if isinstance(self.decoder.conv_out, nn.Sequential): + last_layer = self.decoder.conv_out[-1] + else: + last_layer = self.decoder.conv_out + else: + last_layer = self.decoder.layers[-1] + return last_layer + + +class Encoder(nn.Module): + r""" + The `Encoder` layer of a variational autoencoder that encodes its input into a latent representation. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + latent_log_var (`str`, *optional*, defaults to `per_channel`): + The number of channels for the log variance. Can be either `per_channel`, `uniform`, or `none`. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]] = 3, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: Union[int, Tuple[int]] = 1, + norm_layer: str = "group_norm", # group_norm, pixel_norm + latent_log_var: str = "per_channel", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + self.norm_layer = norm_layer + self.latent_channels = out_channels + self.latent_log_var = latent_log_var + if add_channel_padding: + in_channels = in_channels * self.patch_size**3 + else: + in_channels = in_channels * self.patch_size_t * self.patch_size**2 + self.in_channels = in_channels + output_channel = block_out_channels[0] + + self.conv_in = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=output_channel, + kernel_size=3, + stride=1, + padding=1, + ) + + self.down_blocks = nn.ModuleList([]) + + for i in range(len(block_out_channels)): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = DownEncoderBlock3D( + dims=dims, + in_channels=input_channel, + out_channels=output_channel, + num_layers=self.layers_per_block, + add_downsample=not is_final_block and 2**i >= patch_size, + resnet_eps=1e-6, + downsample_padding=0, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.down_blocks.append(down_block) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + # out + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[-1], + num_groups=norm_num_groups, + eps=1e-6, + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + self.conv_act = nn.SiLU() + + conv_out_channels = out_channels + if latent_log_var == "per_channel": + conv_out_channels *= 2 + elif latent_log_var == "uniform": + conv_out_channels += 1 + elif latent_log_var != "none": + raise ValueError(f"Invalid latent_log_var: {latent_log_var}") + self.conv_out = make_conv_nd( + dims, block_out_channels[-1], conv_out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + @property + def downscale_factor(self): + return ( + 2 + ** len( + [ + block + for block in self.down_blocks + if isinstance(block.downsample, Downsample3D) + ] + ) + * self.patch_size + ) + + def forward( + self, sample: torch.FloatTensor, return_features=False + ) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + downsample_in_time = sample.shape[2] != 1 + + # patchify + patch_size_t = self.patch_size_t if downsample_in_time else 1 + sample = patchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + sample = self.conv_in(sample) + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + if return_features: + features = [] + for down_block in self.down_blocks: + sample = checkpoint_fn(down_block)( + sample, downsample_in_time=downsample_in_time + ) + if return_features: + features.append(sample) + + sample = checkpoint_fn(self.mid_block)(sample) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if self.latent_log_var == "uniform": + last_channel = sample[:, -1:, ...] + num_dims = sample.dim() + + if num_dims == 4: + # For shape (B, C, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + elif num_dims == 5: + # For shape (B, C, F, H, W) + repeated_last_channel = last_channel.repeat( + 1, sample.shape[1] - 2, 1, 1, 1 + ) + sample = torch.cat([sample, repeated_last_channel], dim=1) + else: + raise ValueError(f"Invalid input shape: {sample.shape}") + + if return_features: + features.append(sample[:, : self.latent_channels, ...]) + return sample, features + return sample + + +class Decoder(nn.Module): + r""" + The `Decoder` layer of a variational autoencoder that decodes its latent representation into an output sample. + + Args: + in_channels (`int`, *optional*, defaults to 3): + The number of input channels. + out_channels (`int`, *optional*, defaults to 3): + The number of output channels. + block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`): + The number of output channels for each block. + layers_per_block (`int`, *optional*, defaults to 2): + The number of layers per block. + norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups for normalization. + patch_size (`int`, *optional*, defaults to 1): + The patch size to use. Should be a power of 2. + norm_layer (`str`, *optional*, defaults to `group_norm`): + The normalization layer to use. Can be either `group_norm` or `pixel_norm`. + """ + + def __init__( + self, + dims, + in_channels: int = 3, + out_channels: int = 3, + block_out_channels: Tuple[int, ...] = (64,), + layers_per_block: int = 2, + norm_num_groups: int = 32, + patch_size: int = 1, + norm_layer: str = "group_norm", + patch_size_t: Optional[int] = None, + add_channel_padding: Optional[bool] = False, + ): + super().__init__() + self.patch_size = patch_size + self.patch_size_t = patch_size_t if patch_size_t is not None else patch_size + self.add_channel_padding = add_channel_padding + self.layers_per_block = layers_per_block + if add_channel_padding: + out_channels = out_channels * self.patch_size**3 + else: + out_channels = out_channels * self.patch_size_t * self.patch_size**2 + self.out_channels = out_channels + + self.conv_in = make_conv_nd( + dims, + in_channels, + block_out_channels[-1], + kernel_size=3, + stride=1, + padding=1, + ) + + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + self.mid_block = UNetMidBlock3D( + dims=dims, + in_channels=block_out_channels[-1], + num_layers=self.layers_per_block, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + + reversed_block_out_channels = list(reversed(block_out_channels)) + output_channel = reversed_block_out_channels[0] + for i in range(len(reversed_block_out_channels)): + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + + is_final_block = i == len(block_out_channels) - 1 + + up_block = UpDecoderBlock3D( + dims=dims, + num_layers=self.layers_per_block + 1, + in_channels=prev_output_channel, + out_channels=output_channel, + add_upsample=not is_final_block + and 2 ** (len(block_out_channels) - i - 1) > patch_size, + resnet_eps=1e-6, + resnet_groups=norm_num_groups, + norm_layer=norm_layer, + ) + self.up_blocks.append(up_block) + + if norm_layer == "group_norm": + self.conv_norm_out = nn.GroupNorm( + num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=1e-6 + ) + elif norm_layer == "pixel_norm": + self.conv_norm_out = PixelNorm() + + self.conv_act = nn.SiLU() + self.conv_out = make_conv_nd( + dims, block_out_channels[0], out_channels, 3, padding=1 + ) + + self.gradient_checkpointing = False + + def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + r"""The forward method of the `Decoder` class.""" + assert target_shape is not None, "target_shape must be provided" + upsample_in_time = sample.shape[2] < target_shape[2] + + sample = self.conv_in(sample) + + upscale_dtype = next(iter(self.up_blocks.parameters())).dtype + + checkpoint_fn = ( + partial(torch.utils.checkpoint.checkpoint, use_reentrant=False) + if self.gradient_checkpointing and self.training + else lambda x: x + ) + + sample = checkpoint_fn(self.mid_block)(sample) + sample = sample.to(upscale_dtype) + + for up_block in self.up_blocks: + sample = checkpoint_fn(up_block)(sample, upsample_in_time=upsample_in_time) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + # un-patchify + patch_size_t = self.patch_size_t if upsample_in_time else 1 + sample = unpatchify( + sample, + patch_size_hw=self.patch_size, + patch_size_t=patch_size_t, + add_channel_padding=self.add_channel_padding, + ) + + return sample + + +class DownEncoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_downsample: bool = True, + downsample_padding: int = 1, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_downsample: + self.downsample = Downsample3D( + dims, + out_channels, + out_channels=out_channels, + padding=downsample_padding, + ) + else: + self.downsample = Identity() + + def forward( + self, hidden_states: torch.FloatTensor, downsample_in_time + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.downsample( + hidden_states, downsample_in_time=downsample_in_time + ) + + return hidden_states + + +class UNetMidBlock3D(nn.Module): + """ + A 3D UNet mid-block [`UNetMidBlock3D`] with multiple residual blocks. + + Args: + in_channels (`int`): The number of input channels. + dropout (`float`, *optional*, defaults to 0.0): The dropout rate. + num_layers (`int`, *optional*, defaults to 1): The number of residual blocks. + resnet_eps (`float`, *optional*, 1e-6 ): The epsilon value for the resnet blocks. + resnet_groups (`int`, *optional*, defaults to 32): + The number of groups to use in the group normalization layers of the resnet blocks. + + Returns: + `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size, + in_channels, height, width)`. + + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + norm_layer: str = "group_norm", + ): + super().__init__() + resnet_groups = ( + resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + ) + + self.res_blocks = nn.ModuleList( + [ + ResnetBlock3D( + dims=dims, + in_channels=in_channels, + out_channels=in_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + for _ in range(num_layers) + ] + ) + + def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + return hidden_states + + +class UpDecoderBlock3D(nn.Module): + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: int, + resolution_idx: Optional[int] = None, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_groups: int = 32, + add_upsample: bool = True, + norm_layer: str = "group_norm", + ): + super().__init__() + res_blocks = [] + + for i in range(num_layers): + input_channels = in_channels if i == 0 else out_channels + + res_blocks.append( + ResnetBlock3D( + dims=dims, + in_channels=input_channels, + out_channels=out_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + norm_layer=norm_layer, + ) + ) + + self.res_blocks = nn.ModuleList(res_blocks) + + if add_upsample: + self.upsample = Upsample3D( + dims=dims, channels=out_channels, out_channels=out_channels + ) + else: + self.upsample = Identity() + + self.resolution_idx = resolution_idx + + def forward( + self, hidden_states: torch.FloatTensor, upsample_in_time=True + ) -> torch.FloatTensor: + for resnet in self.res_blocks: + hidden_states = resnet(hidden_states) + + hidden_states = self.upsample(hidden_states, upsample_in_time=upsample_in_time) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + r""" + A Resnet block. + + Parameters: + in_channels (`int`): The number of channels in the input. + out_channels (`int`, *optional*, default to be `None`): + The number of output channels for the first conv layer. If None, same as `in_channels`. + dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use. + groups (`int`, *optional*, default to `32`): The number of groups to use for the first normalization layer. + eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization. + """ + + def __init__( + self, + dims: Union[int, Tuple[int, int]], + in_channels: int, + out_channels: Optional[int] = None, + conv_shortcut: bool = False, + dropout: float = 0.0, + groups: int = 32, + eps: float = 1e-6, + norm_layer: str = "group_norm", + ): + super().__init__() + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + + if norm_layer == "group_norm": + self.norm1 = torch.nn.GroupNorm( + num_groups=groups, num_channels=in_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm1 = PixelNorm() + + self.non_linearity = nn.SiLU() + + self.conv1 = make_conv_nd( + dims, in_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + if norm_layer == "group_norm": + self.norm2 = torch.nn.GroupNorm( + num_groups=groups, num_channels=out_channels, eps=eps, affine=True + ) + elif norm_layer == "pixel_norm": + self.norm2 = PixelNorm() + + self.dropout = torch.nn.Dropout(dropout) + + self.conv2 = make_conv_nd( + dims, out_channels, out_channels, kernel_size=3, stride=1, padding=1 + ) + + self.conv_shortcut = ( + make_linear_nd( + dims=dims, in_channels=in_channels, out_channels=out_channels + ) + if in_channels != out_channels + else nn.Identity() + ) + + def forward( + self, + input_tensor: torch.FloatTensor, + ) -> torch.FloatTensor: + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + hidden_states = self.norm2(hidden_states) + + hidden_states = self.non_linearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + + hidden_states = self.conv2(hidden_states) + + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = input_tensor + hidden_states + + return output_tensor + + +class Downsample3D(nn.Module): + def __init__( + self, + dims, + in_channels: int, + out_channels: int, + kernel_size: int = 3, + padding: int = 1, + ): + super().__init__() + stride: int = 2 + self.padding = padding + self.in_channels = in_channels + self.dims = dims + self.conv = make_conv_nd( + dims=dims, + in_channels=in_channels, + out_channels=out_channels, + kernel_size=kernel_size, + stride=stride, + padding=padding, + ) + + def forward(self, x, downsample_in_time=True): + conv = self.conv + if self.padding == 0: + if self.dims == 2: + padding = (0, 1, 0, 1) + else: + padding = (0, 1, 0, 1, 0, 1 if downsample_in_time else 0) + + x = functional.pad(x, padding, mode="constant", value=0) + + if self.dims == (2, 1) and not downsample_in_time: + return conv(x, skip_time_conv=True) + + return conv(x) + + +class Upsample3D(nn.Module): + """ + An upsampling layer for 3D tensors of shape (B, C, D, H, W). + + :param channels: channels in the inputs and outputs. + """ + + def __init__(self, dims, channels, out_channels=None): + super().__init__() + self.dims = dims + self.channels = channels + self.out_channels = out_channels or channels + self.conv = make_conv_nd( + dims, channels, out_channels, kernel_size=3, padding=1, bias=True + ) + + def forward(self, x, upsample_in_time): + if self.dims == 2: + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + else: + time_scale_factor = 2 if upsample_in_time else 1 + # print("before:", x.shape) + b, c, d, h, w = x.shape + x = rearrange(x, "b c d h w -> (b d) c h w") + # height and width interpolate + x = functional.interpolate( + x, (x.shape[2] * 2, x.shape[3] * 2), mode="nearest" + ) + _, _, h, w = x.shape + + if not upsample_in_time and self.dims == (2, 1): + x = rearrange(x, "(b d) c h w -> b c d h w ", b=b, h=h, w=w) + return self.conv(x, skip_time_conv=True) + + # Second ** upsampling ** which is essentially treated as a 1D convolution across the 'd' dimension + x = rearrange(x, "(b d) c h w -> (b h w) c 1 d", b=b) + + # (b h w) c 1 d + new_d = x.shape[-1] * time_scale_factor + x = functional.interpolate(x, (1, new_d), mode="nearest") + # (b h w) c 1 new_d + x = rearrange( + x, "(b h w) c 1 new_d -> b c new_d h w", b=b, h=h, w=w, new_d=new_d + ) + # b c d h w + + # x = functional.interpolate( + # x, (x.shape[2] * time_scale_factor, x.shape[3] * 2, x.shape[4] * 2), mode="nearest" + # ) + # print("after:", x.shape) + + return self.conv(x) + + +def patchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + if x.dim() == 4: + x = rearrange( + x, "b c (h q) (w r) -> b (c r q) h w", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b c (f p) (h q) (w r) -> b (c p r q) f h w", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + else: + raise ValueError(f"Invalid input shape: {x.shape}") + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_pad = x.shape[1] * (patch_size_hw // patch_size_t) - x.shape[1] + padding_zeros = torch.zeros( + x.shape[0], + channels_to_pad, + x.shape[2], + x.shape[3], + x.shape[4], + device=x.device, + dtype=x.dtype, + ) + x = torch.cat([padding_zeros, x], dim=1) + + return x + + +def unpatchify(x, patch_size_hw, patch_size_t=1, add_channel_padding=False): + if patch_size_hw == 1 and patch_size_t == 1: + return x + + if ( + (x.dim() == 5) + and (patch_size_hw > patch_size_t) + and (patch_size_t > 1 or add_channel_padding) + ): + channels_to_keep = int(x.shape[1] * (patch_size_t / patch_size_hw)) + x = x[:, :channels_to_keep, :, :, :] + + if x.dim() == 4: + x = rearrange( + x, "b (c r q) h w -> b c (h q) (w r)", q=patch_size_hw, r=patch_size_hw + ) + elif x.dim() == 5: + x = rearrange( + x, + "b (c p r q) f h w -> b c (f p) (h q) (w r)", + p=patch_size_t, + q=patch_size_hw, + r=patch_size_hw, + ) + + return x + + +def create_video_autoencoder_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [ + 128, + 256, + 512, + 512, + ], # Number of output channels of each encoder / decoder inner block + "patch_size": 1, + } + + return config + + +def create_video_autoencoder_pathify4x4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": ( + 2, + 1, + ), # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "latent_log_var": "uniform", + } + + return config + + +def create_video_autoencoder_pathify4x4_config( + latent_channels: int = 4, +): + config = { + "_class_name": "VideoAutoencoder", + "dims": 2, # 2 for Conv2, 3 for Conv3d, (2, 1) for Conv2d followed by Conv1d + "in_channels": 3, # Number of input color channels (e.g., RGB) + "out_channels": 3, # Number of output color channels + "latent_channels": latent_channels, # Number of channels in the latent space representation + "block_out_channels": [512] + * 4, # Number of output channels of each encoder / decoder inner block + "patch_size": 4, + "norm_layer": "pixel_norm", + } + + return config + + +def test_vae_patchify_unpatchify(): + import torch + + x = torch.randn(2, 3, 8, 64, 64) + x_patched = patchify(x, patch_size_hw=4, patch_size_t=4) + x_unpatched = unpatchify(x_patched, patch_size_hw=4, patch_size_t=4) + assert torch.allclose(x, x_unpatched) + + +def demo_video_autoencoder_forward_backward(): + # Configuration for the VideoAutoencoder + config = create_video_autoencoder_pathify4x4x4_config() + + # Instantiate the VideoAutoencoder with the specified configuration + video_autoencoder = VideoAutoencoder.from_config(config) + + print(video_autoencoder) + + # Print the total number of parameters in the video autoencoder + total_params = sum(p.numel() for p in video_autoencoder.parameters()) + print(f"Total number of parameters in VideoAutoencoder: {total_params:,}") + + # Create a mock input tensor simulating a batch of videos + # Shape: (batch_size, channels, depth, height, width) + # E.g., 4 videos, each with 3 color channels, 16 frames, and 64x64 pixels per frame + input_videos = torch.randn(2, 3, 8, 64, 64) + + # Forward pass: encode and decode the input videos + latent = video_autoencoder.encode(input_videos).latent_dist.mode() + print(f"input shape={input_videos.shape}") + print(f"latent shape={latent.shape}") + reconstructed_videos = video_autoencoder.decode( + latent, target_shape=input_videos.shape + ).sample + + print(f"reconstructed shape={reconstructed_videos.shape}") + + # Calculate the loss (e.g., mean squared error) + loss = torch.nn.functional.mse_loss(input_videos, reconstructed_videos) + + # Perform backward pass + loss.backward() + + print(f"Demo completed with loss: {loss.item()}") + + +# Ensure to call the demo function to execute the forward and backward pass +if __name__ == "__main__": + demo_video_autoencoder_forward_backward() diff --git a/ltx_video/models/transformers/__init__.py b/ltx_video/models/transformers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/models/transformers/attention.py b/ltx_video/models/transformers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..96ed251478513b58742bebbf812d2176d8294f2c --- /dev/null +++ b/ltx_video/models/transformers/attention.py @@ -0,0 +1,1246 @@ +import inspect +from importlib import import_module +from typing import Any, Dict, Optional, Tuple + +import torch +import torch.nn.functional as F +from diffusers.models.activations import GEGLU, GELU, ApproximateGELU +from diffusers.models.attention import _chunked_feed_forward +from diffusers.models.attention_processor import ( + LoRAAttnAddedKVProcessor, + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + SpatialNorm, +) +from diffusers.models.lora import LoRACompatibleLinear +from diffusers.models.normalization import RMSNorm +from diffusers.utils import deprecate, logging +from diffusers.utils.torch_utils import maybe_allow_in_graph +from einops import rearrange +from torch import nn + +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +try: + from torch_xla.experimental.custom_kernel import flash_attention +except ImportError: + # workaround for automatic tests. Currently this function is manually patched + # to the torch_xla lib on setup of container + pass + +# code adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py + +logger = logging.get_logger(__name__) + + +@maybe_allow_in_graph +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + only_cross_attention (`bool`, *optional*): + Whether to use only cross-attention layers. In this case two cross attention layers are used. + double_self_attention (`bool`, *optional*): + Whether to use two self-attention layers. In this case no cross attention layers are used. + upcast_attention (`bool`, *optional*): + Whether to upcast the attention computation to float32. This is useful for mixed precision training. + norm_elementwise_affine (`bool`, *optional*, defaults to `True`): + Whether to use learnable elementwise affine parameters for normalization. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + adaptive_norm (`str`, *optional*, defaults to `"single_scale_shift"`): + The type of adaptive norm to use. Can be `"single_scale_shift"`, `"single_scale"` or "none". + standardization_norm (`str`, *optional*, defaults to `"layer_norm"`): + The type of pre-normalization to use. Can be `"layer_norm"` or `"rms_norm"`. + final_dropout (`bool` *optional*, defaults to False): + Whether to apply a final dropout after the last feed-forward layer. + attention_type (`str`, *optional*, defaults to `"default"`): + The type of attention to use. Can be `"default"` or `"gated"` or `"gated-text-image"`. + positional_embeddings (`str`, *optional*, defaults to `None`): + The type of positional embeddings to apply to. + num_positional_embeddings (`int`, *optional*, defaults to `None`): + The maximum number of positional embeddings to apply. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, # pylint: disable=unused-argument + attention_bias: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + norm_elementwise_affine: bool = True, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift', 'single_scale' or 'none' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_eps: float = 1e-5, + qk_norm: Optional[str] = None, + final_dropout: bool = False, + attention_type: str = "default", # pylint: disable=unused-argument + ff_inner_dim: Optional[int] = None, + ff_bias: bool = True, + attention_out_bias: bool = True, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_tpu_flash_attention = use_tpu_flash_attention + self.adaptive_norm = adaptive_norm + + assert standardization_norm in ["layer_norm", "rms_norm"] + assert adaptive_norm in ["single_scale_shift", "single_scale", "none"] + + make_norm_layer = ( + nn.LayerNorm if standardization_norm == "layer_norm" else RMSNorm + ) + + # Define 3 blocks. Each block has its own normalization layer. + # 1. Self-Attn + self.norm1 = make_norm_layer( + dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps + ) + + self.attn1 = Attention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) + + # 2. Cross-Attn + if cross_attention_dim is not None or double_self_attention: + self.attn2 = Attention( + query_dim=dim, + cross_attention_dim=( + cross_attention_dim if not double_self_attention else None + ), + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + out_bias=attention_out_bias, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=use_rope, + ) # is self-attn if encoder_hidden_states is none + + if adaptive_norm == "none": + self.attn2_norm = make_norm_layer( + dim, norm_eps, norm_elementwise_affine + ) + else: + self.attn2 = None + self.attn2_norm = None + + self.norm2 = make_norm_layer(dim, norm_eps, norm_elementwise_affine) + + # 3. Feed-forward + self.ff = FeedForward( + dim, + dropout=dropout, + activation_fn=activation_fn, + final_dropout=final_dropout, + inner_dim=ff_inner_dim, + bias=ff_bias, + ) + + # 5. Scale-shift for PixArt-Alpha. + if adaptive_norm != "none": + num_ada_params = 4 if adaptive_norm == "single_scale" else 6 + self.scale_shift_table = nn.Parameter( + torch.randn(num_ada_params, dim) / dim**0.5 + ) + + # let chunk size default to None + self._chunk_size = None + self._chunk_dim = 0 + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + self.use_tpu_flash_attention = True + self.attn1.set_use_tpu_flash_attention() + self.attn2.set_use_tpu_flash_attention() + + def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0): + # Sets chunk feed-forward + self._chunk_size = chunk_size + self._chunk_dim = dim + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + attention_mask: Optional[torch.FloatTensor] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + encoder_attention_mask: Optional[torch.FloatTensor] = None, + timestep: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + class_labels: Optional[torch.LongTensor] = None, + added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + ) -> torch.FloatTensor: + if cross_attention_kwargs is not None: + if cross_attention_kwargs.get("scale", None) is not None: + logger.warning( + "Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored." + ) + + # Notice that normalization is always applied before the real computation in the following blocks. + # 0. Self-Attention + batch_size = hidden_states.shape[0] + + norm_hidden_states = self.norm1(hidden_states) + + # Apply ada_norm_single + if self.adaptive_norm in ["single_scale_shift", "single_scale"]: + assert timestep.ndim == 3 # [batch, 1 or num_tokens, embedding_dim] + num_ada_params = self.scale_shift_table.shape[0] + ada_values = self.scale_shift_table[None, None] + timestep.reshape( + batch_size, timestep.shape[1], num_ada_params, -1 + ) + if self.adaptive_norm == "single_scale_shift": + shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = ( + ada_values.unbind(dim=2) + ) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + shift_msa + else: + scale_msa, gate_msa, scale_mlp, gate_mlp = ada_values.unbind(dim=2) + norm_hidden_states = norm_hidden_states * (1 + scale_msa) + elif self.adaptive_norm == "none": + scale_msa, gate_msa, scale_mlp, gate_mlp = None, None, None, None + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + norm_hidden_states = norm_hidden_states.squeeze( + 1 + ) # TODO: Check if this is needed + + # 1. Prepare GLIGEN inputs + cross_attention_kwargs = ( + cross_attention_kwargs.copy() if cross_attention_kwargs is not None else {} + ) + + attn_output = self.attn1( + norm_hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=( + encoder_hidden_states if self.only_cross_attention else None + ), + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + if gate_msa is not None: + attn_output = gate_msa * attn_output + + hidden_states = attn_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + # 3. Cross-Attention + if self.attn2 is not None: + if self.adaptive_norm == "none": + attn_input = self.attn2_norm(hidden_states) + else: + attn_input = hidden_states + attn_output = self.attn2( + attn_input, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=encoder_attention_mask, + **cross_attention_kwargs, + ) + hidden_states = attn_output + hidden_states + + # 4. Feed-forward + norm_hidden_states = self.norm2(hidden_states) + if self.adaptive_norm == "single_scale_shift": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + shift_mlp + elif self.adaptive_norm == "single_scale": + norm_hidden_states = norm_hidden_states * (1 + scale_mlp) + elif self.adaptive_norm == "none": + pass + else: + raise ValueError(f"Unknown adaptive norm type: {self.adaptive_norm}") + + if self._chunk_size is not None: + # "feed_forward_chunk_size" can be used to save memory + ff_output = _chunked_feed_forward( + self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size + ) + else: + ff_output = self.ff(norm_hidden_states) + if gate_mlp is not None: + ff_output = gate_mlp * ff_output + + hidden_states = ff_output + hidden_states + if hidden_states.ndim == 4: + hidden_states = hidden_states.squeeze(1) + + return hidden_states + + +@maybe_allow_in_graph +class Attention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): + The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): + The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): + The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): + The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + upcast_attention (`bool`, *optional*, defaults to False): + Set to `True` to upcast the attention computation to `float32`. + upcast_softmax (`bool`, *optional*, defaults to False): + Set to `True` to upcast the softmax computation to `float32`. + cross_attention_norm (`str`, *optional*, defaults to `None`): + The type of normalization to use for the cross attention. Can be `None`, `layer_norm`, or `group_norm`. + cross_attention_norm_num_groups (`int`, *optional*, defaults to 32): + The number of groups to use for the group norm in the cross attention. + added_kv_proj_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the added key and value projections. If `None`, no projection is used. + norm_num_groups (`int`, *optional*, defaults to `None`): + The number of groups to use for the group norm in the attention. + spatial_norm_dim (`int`, *optional*, defaults to `None`): + The number of channels to use for the spatial normalization. + out_bias (`bool`, *optional*, defaults to `True`): + Set to `True` to use a bias in the output linear layer. + scale_qk (`bool`, *optional*, defaults to `True`): + Set to `True` to scale the query and key by `1 / sqrt(dim_head)`. + qk_norm (`str`, *optional*, defaults to None): + Set to 'layer_norm' or `rms_norm` to perform query and key normalization. + only_cross_attention (`bool`, *optional*, defaults to `False`): + Set to `True` to only use cross attention and not added_kv_proj_dim. Can only be set to `True` if + `added_kv_proj_dim` is not `None`. + eps (`float`, *optional*, defaults to 1e-5): + An additional value added to the denominator in group normalization that is used for numerical stability. + rescale_output_factor (`float`, *optional*, defaults to 1.0): + A factor to rescale the output by dividing it with this value. + residual_connection (`bool`, *optional*, defaults to `False`): + Set to `True` to add the residual connection to the output. + _from_deprecated_attn_block (`bool`, *optional*, defaults to `False`): + Set to `True` if the attention block is loaded from a deprecated state dict. + processor (`AttnProcessor`, *optional*, defaults to `None`): + The attention processor to use. If `None`, defaults to `AttnProcessor2_0` if `torch 2.x` is used and + `AttnProcessor` otherwise. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias: bool = False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + cross_attention_norm: Optional[str] = None, + cross_attention_norm_num_groups: int = 32, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + spatial_norm_dim: Optional[int] = None, + out_bias: bool = True, + scale_qk: bool = True, + qk_norm: Optional[str] = None, + only_cross_attention: bool = False, + eps: float = 1e-5, + rescale_output_factor: float = 1.0, + residual_connection: bool = False, + _from_deprecated_attn_block: bool = False, + processor: Optional["AttnProcessor"] = None, + out_dim: int = None, + use_tpu_flash_attention: bool = False, + use_rope: bool = False, + ): + super().__init__() + self.inner_dim = out_dim if out_dim is not None else dim_head * heads + self.query_dim = query_dim + self.use_bias = bias + self.is_cross_attention = cross_attention_dim is not None + self.cross_attention_dim = ( + cross_attention_dim if cross_attention_dim is not None else query_dim + ) + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + self.rescale_output_factor = rescale_output_factor + self.residual_connection = residual_connection + self.dropout = dropout + self.fused_projections = False + self.out_dim = out_dim if out_dim is not None else query_dim + self.use_tpu_flash_attention = use_tpu_flash_attention + self.use_rope = use_rope + + # we make use of this private variable to know whether this class is loaded + # with an deprecated state dict so that we can convert it on the fly + self._from_deprecated_attn_block = _from_deprecated_attn_block + + self.scale_qk = scale_qk + self.scale = dim_head**-0.5 if self.scale_qk else 1.0 + + if qk_norm is None: + self.q_norm = nn.Identity() + self.k_norm = nn.Identity() + elif qk_norm == "rms_norm": + self.q_norm = RMSNorm(dim_head * heads, eps=1e-5) + self.k_norm = RMSNorm(dim_head * heads, eps=1e-5) + elif qk_norm == "layer_norm": + self.q_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + self.k_norm = nn.LayerNorm(dim_head * heads, eps=1e-5) + else: + raise ValueError(f"Unsupported qk_norm method: {qk_norm}") + + self.heads = out_dim // dim_head if out_dim is not None else heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + + self.added_kv_proj_dim = added_kv_proj_dim + self.only_cross_attention = only_cross_attention + + if self.added_kv_proj_dim is None and self.only_cross_attention: + raise ValueError( + "`only_cross_attention` can only be set to True if `added_kv_proj_dim` is not None. Make sure to set either `only_cross_attention=False` or define `added_kv_proj_dim`." + ) + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm( + num_channels=query_dim, num_groups=norm_num_groups, eps=eps, affine=True + ) + else: + self.group_norm = None + + if spatial_norm_dim is not None: + self.spatial_norm = SpatialNorm( + f_channels=query_dim, zq_channels=spatial_norm_dim + ) + else: + self.spatial_norm = None + + if cross_attention_norm is None: + self.norm_cross = None + elif cross_attention_norm == "layer_norm": + self.norm_cross = nn.LayerNorm(self.cross_attention_dim) + elif cross_attention_norm == "group_norm": + if self.added_kv_proj_dim is not None: + # The given `encoder_hidden_states` are initially of shape + # (batch_size, seq_len, added_kv_proj_dim) before being projected + # to (batch_size, seq_len, cross_attention_dim). The norm is applied + # before the projection, so we need to use `added_kv_proj_dim` as + # the number of channels for the group norm. + norm_cross_num_channels = added_kv_proj_dim + else: + norm_cross_num_channels = self.cross_attention_dim + + self.norm_cross = nn.GroupNorm( + num_channels=norm_cross_num_channels, + num_groups=cross_attention_norm_num_groups, + eps=1e-5, + affine=True, + ) + else: + raise ValueError( + f"unknown cross_attention_norm: {cross_attention_norm}. Should be None, 'layer_norm' or 'group_norm'" + ) + + linear_cls = nn.Linear + + self.linear_cls = linear_cls + self.to_q = linear_cls(query_dim, self.inner_dim, bias=bias) + + if not self.only_cross_attention: + # only relevant for the `AddedKVProcessor` classes + self.to_k = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + self.to_v = linear_cls(self.cross_attention_dim, self.inner_dim, bias=bias) + else: + self.to_k = None + self.to_v = None + + if self.added_kv_proj_dim is not None: + self.add_k_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + self.add_v_proj = linear_cls(added_kv_proj_dim, self.inner_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(linear_cls(self.inner_dim, self.out_dim, bias=out_bias)) + self.to_out.append(nn.Dropout(dropout)) + + # set attention processor + # We use the AttnProcessor2_0 by default when torch 2.x is used which uses + # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention + # but only if it has the default `scale` argument. TODO remove scale_qk check when we move to torch 2.1 + if processor is None: + processor = AttnProcessor2_0() + self.set_processor(processor) + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object. The flag will enforce the usage of TPU attention kernel. + """ + self.use_tpu_flash_attention = True + + def set_processor(self, processor: "AttnProcessor") -> None: + r""" + Set the attention processor to use. + + Args: + processor (`AttnProcessor`): + The attention processor to use. + """ + # if current processor is in `self._modules` and if passed `processor` is not, we need to + # pop `processor` from `self._modules` + if ( + hasattr(self, "processor") + and isinstance(self.processor, torch.nn.Module) + and not isinstance(processor, torch.nn.Module) + ): + logger.info( + f"You are removing possibly trained weights of {self.processor} with {processor}" + ) + self._modules.pop("processor") + + self.processor = processor + + def get_processor( + self, return_deprecated_lora: bool = False + ) -> "AttentionProcessor": # noqa: F821 + r""" + Get the attention processor in use. + + Args: + return_deprecated_lora (`bool`, *optional*, defaults to `False`): + Set to `True` to return the deprecated LoRA attention processor. + + Returns: + "AttentionProcessor": The attention processor in use. + """ + if not return_deprecated_lora: + return self.processor + + # TODO(Sayak, Patrick). The rest of the function is needed to ensure backwards compatible + # serialization format for LoRA Attention Processors. It should be deleted once the integration + # with PEFT is completed. + is_lora_activated = { + name: module.lora_layer is not None + for name, module in self.named_modules() + if hasattr(module, "lora_layer") + } + + # 1. if no layer has a LoRA activated we can return the processor as usual + if not any(is_lora_activated.values()): + return self.processor + + # If doesn't apply LoRA do `add_k_proj` or `add_v_proj` + is_lora_activated.pop("add_k_proj", None) + is_lora_activated.pop("add_v_proj", None) + # 2. else it is not posssible that only some layers have LoRA activated + if not all(is_lora_activated.values()): + raise ValueError( + f"Make sure that either all layers or no layers have LoRA activated, but have {is_lora_activated}" + ) + + # 3. And we need to merge the current LoRA layers into the corresponding LoRA attention processor + non_lora_processor_cls_name = self.processor.__class__.__name__ + lora_processor_cls = getattr( + import_module(__name__), "LoRA" + non_lora_processor_cls_name + ) + + hidden_size = self.inner_dim + + # now create a LoRA attention processor from the LoRA layers + if lora_processor_cls in [ + LoRAAttnProcessor, + LoRAAttnProcessor2_0, + LoRAXFormersAttnProcessor, + ]: + kwargs = { + "cross_attention_dim": self.cross_attention_dim, + "rank": self.to_q.lora_layer.rank, + "network_alpha": self.to_q.lora_layer.network_alpha, + "q_rank": self.to_q.lora_layer.rank, + "q_hidden_size": self.to_q.lora_layer.out_features, + "k_rank": self.to_k.lora_layer.rank, + "k_hidden_size": self.to_k.lora_layer.out_features, + "v_rank": self.to_v.lora_layer.rank, + "v_hidden_size": self.to_v.lora_layer.out_features, + "out_rank": self.to_out[0].lora_layer.rank, + "out_hidden_size": self.to_out[0].lora_layer.out_features, + } + + if hasattr(self.processor, "attention_op"): + kwargs["attention_op"] = self.processor.attention_op + + lora_processor = lora_processor_cls(hidden_size, **kwargs) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + elif lora_processor_cls == LoRAAttnAddedKVProcessor: + lora_processor = lora_processor_cls( + hidden_size, + cross_attention_dim=self.add_k_proj.weight.shape[0], + rank=self.to_q.lora_layer.rank, + network_alpha=self.to_q.lora_layer.network_alpha, + ) + lora_processor.to_q_lora.load_state_dict(self.to_q.lora_layer.state_dict()) + lora_processor.to_k_lora.load_state_dict(self.to_k.lora_layer.state_dict()) + lora_processor.to_v_lora.load_state_dict(self.to_v.lora_layer.state_dict()) + lora_processor.to_out_lora.load_state_dict( + self.to_out[0].lora_layer.state_dict() + ) + + # only save if used + if self.add_k_proj.lora_layer is not None: + lora_processor.add_k_proj_lora.load_state_dict( + self.add_k_proj.lora_layer.state_dict() + ) + lora_processor.add_v_proj_lora.load_state_dict( + self.add_v_proj.lora_layer.state_dict() + ) + else: + lora_processor.add_k_proj_lora = None + lora_processor.add_v_proj_lora = None + else: + raise ValueError(f"{lora_processor_cls} does not exist.") + + return lora_processor + + def forward( + self, + hidden_states: torch.FloatTensor, + freqs_cis: Optional[Tuple[torch.FloatTensor, torch.FloatTensor]] = None, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + **cross_attention_kwargs, + ) -> torch.Tensor: + r""" + The forward method of the `Attention` class. + + Args: + hidden_states (`torch.Tensor`): + The hidden states of the query. + encoder_hidden_states (`torch.Tensor`, *optional*): + The hidden states of the encoder. + attention_mask (`torch.Tensor`, *optional*): + The attention mask to use. If `None`, no mask is applied. + skip_layer_mask (`torch.Tensor`, *optional*): + The skip layer mask to use. If `None`, no mask is applied. + skip_layer_strategy (`SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers to skip for spatiotemporal guidance. + **cross_attention_kwargs: + Additional keyword arguments to pass along to the cross attention. + + Returns: + `torch.Tensor`: The output of the attention layer. + """ + # The `Attention` class can call different attention processors / attention functions + # here we simply pass along all tensors to the selected processor class + # For standard processors that are defined here, `**cross_attention_kwargs` is empty + + attn_parameters = set( + inspect.signature(self.processor.__call__).parameters.keys() + ) + unused_kwargs = [ + k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters + ] + if len(unused_kwargs) > 0: + logger.warning( + f"cross_attention_kwargs {unused_kwargs} are not expected by" + f" {self.processor.__class__.__name__} and will be ignored." + ) + cross_attention_kwargs = { + k: w for k, w in cross_attention_kwargs.items() if k in attn_parameters + } + + return self.processor( + self, + hidden_states, + freqs_cis=freqs_cis, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + **cross_attention_kwargs, + ) + + def batch_to_head_dim(self, tensor: torch.Tensor) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size // heads, seq_len, dim * heads]`. `heads` + is the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape( + batch_size // head_size, seq_len, dim * head_size + ) + return tensor + + def head_to_batch_dim(self, tensor: torch.Tensor, out_dim: int = 3) -> torch.Tensor: + r""" + Reshape the tensor from `[batch_size, seq_len, dim]` to `[batch_size, seq_len, heads, dim // heads]` `heads` is + the number of heads initialized while constructing the `Attention` class. + + Args: + tensor (`torch.Tensor`): The tensor to reshape. + out_dim (`int`, *optional*, defaults to `3`): The output dimension of the tensor. If `3`, the tensor is + reshaped to `[batch_size * heads, seq_len, dim // heads]`. + + Returns: + `torch.Tensor`: The reshaped tensor. + """ + + head_size = self.heads + if tensor.ndim == 3: + batch_size, seq_len, dim = tensor.shape + extra_dim = 1 + else: + batch_size, extra_dim, seq_len, dim = tensor.shape + tensor = tensor.reshape( + batch_size, seq_len * extra_dim, head_size, dim // head_size + ) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape( + batch_size * head_size, seq_len * extra_dim, dim // head_size + ) + + return tensor + + def get_attention_scores( + self, + query: torch.Tensor, + key: torch.Tensor, + attention_mask: torch.Tensor = None, + ) -> torch.Tensor: + r""" + Compute the attention scores. + + Args: + query (`torch.Tensor`): The query tensor. + key (`torch.Tensor`): The key tensor. + attention_mask (`torch.Tensor`, *optional*): The attention mask to use. If `None`, no mask is applied. + + Returns: + `torch.Tensor`: The attention probabilities/scores. + """ + dtype = query.dtype + if self.upcast_attention: + query = query.float() + key = key.float() + + if attention_mask is None: + baddbmm_input = torch.empty( + query.shape[0], + query.shape[1], + key.shape[1], + dtype=query.dtype, + device=query.device, + ) + beta = 0 + else: + baddbmm_input = attention_mask + beta = 1 + + attention_scores = torch.baddbmm( + baddbmm_input, + query, + key.transpose(-1, -2), + beta=beta, + alpha=self.scale, + ) + del baddbmm_input + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + del attention_scores + + attention_probs = attention_probs.to(dtype) + + return attention_probs + + def prepare_attention_mask( + self, + attention_mask: torch.Tensor, + target_length: int, + batch_size: int, + out_dim: int = 3, + ) -> torch.Tensor: + r""" + Prepare the attention mask for the attention computation. + + Args: + attention_mask (`torch.Tensor`): + The attention mask to prepare. + target_length (`int`): + The target length of the attention mask. This is the length of the attention mask after padding. + batch_size (`int`): + The batch size, which is used to repeat the attention mask. + out_dim (`int`, *optional*, defaults to `3`): + The output dimension of the attention mask. Can be either `3` or `4`. + + Returns: + `torch.Tensor`: The prepared attention mask. + """ + head_size = self.heads + if attention_mask is None: + return attention_mask + + current_length: int = attention_mask.shape[-1] + if current_length != target_length: + if attention_mask.device.type == "mps": + # HACK: MPS: Does not support padding by greater than dimension of input tensor. + # Instead, we can manually construct the padding tensor. + padding_shape = ( + attention_mask.shape[0], + attention_mask.shape[1], + target_length, + ) + padding = torch.zeros( + padding_shape, + dtype=attention_mask.dtype, + device=attention_mask.device, + ) + attention_mask = torch.cat([attention_mask, padding], dim=2) + else: + # TODO: for pipelines such as stable-diffusion, padding cross-attn mask: + # we want to instead pad by (0, remaining_length), where remaining_length is: + # remaining_length: int = target_length - current_length + # TODO: re-enable tests/models/test_models_unet_2d_condition.py#test_model_xattn_padding + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + + if out_dim == 3: + if attention_mask.shape[0] < batch_size * head_size: + attention_mask = attention_mask.repeat_interleave(head_size, dim=0) + elif out_dim == 4: + attention_mask = attention_mask.unsqueeze(1) + attention_mask = attention_mask.repeat_interleave(head_size, dim=1) + + return attention_mask + + def norm_encoder_hidden_states( + self, encoder_hidden_states: torch.Tensor + ) -> torch.Tensor: + r""" + Normalize the encoder hidden states. Requires `self.norm_cross` to be specified when constructing the + `Attention` class. + + Args: + encoder_hidden_states (`torch.Tensor`): Hidden states of the encoder. + + Returns: + `torch.Tensor`: The normalized encoder hidden states. + """ + assert ( + self.norm_cross is not None + ), "self.norm_cross must be defined to call self.norm_encoder_hidden_states" + + if isinstance(self.norm_cross, nn.LayerNorm): + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + elif isinstance(self.norm_cross, nn.GroupNorm): + # Group norm norms along the channels dimension and expects + # input to be in the shape of (N, C, *). In this case, we want + # to norm along the hidden dimension, so we need to move + # (batch_size, sequence_length, hidden_size) -> + # (batch_size, hidden_size, sequence_length) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + encoder_hidden_states = self.norm_cross(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.transpose(1, 2) + else: + assert False + + return encoder_hidden_states + + @staticmethod + def apply_rotary_emb( + input_tensor: torch.Tensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + ) -> Tuple[torch.Tensor, torch.Tensor]: + cos_freqs = freqs_cis[0] + sin_freqs = freqs_cis[1] + + t_dup = rearrange(input_tensor, "... (d r) -> ... d r", r=2) + t1, t2 = t_dup.unbind(dim=-1) + t_dup = torch.stack((-t2, t1), dim=-1) + input_tensor_rot = rearrange(t_dup, "... d r -> ... (d r)") + + out = input_tensor * cos_freqs + input_tensor_rot * sin_freqs + + return out + + +class AttnProcessor2_0: + r""" + Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0). + """ + + def __init__(self): + pass + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + freqs_cis: Tuple[torch.FloatTensor, torch.FloatTensor], + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + skip_layer_mask: Optional[torch.FloatTensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + *args, + **kwargs, + ) -> torch.FloatTensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + + if skip_layer_mask is not None: + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1) + + if (attention_mask is not None) and (not attn.use_tpu_flash_attention): + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + # scaled_dot_product_attention expects attention_mask shape to be + # (batch, heads, source_length, target_length) + attention_mask = attention_mask.view( + batch_size, attn.heads, -1, attention_mask.shape[-1] + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + query = attn.q_norm(query) + + if encoder_hidden_states is not None: + if attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + key = attn.to_k(encoder_hidden_states) + key = attn.k_norm(key) + else: # if no context provided do self-attention + encoder_hidden_states = hidden_states + key = attn.to_k(hidden_states) + key = attn.k_norm(key) + if attn.use_rope: + key = attn.apply_rotary_emb(key, freqs_cis) + query = attn.apply_rotary_emb(query, freqs_cis) + + value = attn.to_v(encoder_hidden_states) + + inner_dim = key.shape[-1] + head_dim = inner_dim // attn.heads + + query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2) + + # the output of sdp = (batch, num_heads, seq_len, head_dim) + + if attn.use_tpu_flash_attention: # use tpu attention offload 'flash attention' + q_segment_indexes = None + if ( + attention_mask is not None + ): # if mask is required need to tune both segmenIds fields + # attention_mask = torch.squeeze(attention_mask).to(torch.float32) + attention_mask = attention_mask.to(torch.float32) + q_segment_indexes = torch.ones( + batch_size, query.shape[2], device=query.device, dtype=torch.float32 + ) + assert ( + attention_mask.shape[1] == key.shape[2] + ), f"ERROR: KEY SHAPE must be same as attention mask [{key.shape[2]}, {attention_mask.shape[1]}]" + + assert ( + query.shape[2] % 128 == 0 + ), f"ERROR: QUERY SHAPE must be divisible by 128 (TPU limitation) [{query.shape[2]}]" + assert ( + key.shape[2] % 128 == 0 + ), f"ERROR: KEY SHAPE must be divisible by 128 (TPU limitation) [{key.shape[2]}]" + + # run the TPU kernel implemented in jax with pallas + hidden_states_a = flash_attention( + q=query, + k=key, + v=value, + q_segment_ids=q_segment_indexes, + kv_segment_ids=attention_mask, + sm_scale=attn.scale, + ) + else: + hidden_states_a = F.scaled_dot_product_attention( + query, + key, + value, + attn_mask=attention_mask, + dropout_p=0.0, + is_causal=False, + ) + + hidden_states_a = hidden_states_a.transpose(1, 2).reshape( + batch_size, -1, attn.heads * head_dim + ) + hidden_states_a = hidden_states_a.to(query.dtype) + + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Attention + ): + hidden_states = hidden_states_a * skip_layer_mask + hidden_states * ( + 1.0 - skip_layer_mask + ) + else: + hidden_states = hidden_states_a + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + skip_layer_mask = skip_layer_mask.reshape(batch_size, 1, 1, 1) + + if attn.residual_connection: + if ( + skip_layer_mask is not None + and skip_layer_strategy == SkipLayerStrategy.Residual + ): + hidden_states = hidden_states + residual * skip_layer_mask + else: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class AttnProcessor: + r""" + Default processor for performing attention-related computations. + """ + + def __call__( + self, + attn: Attention, + hidden_states: torch.FloatTensor, + encoder_hidden_states: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + temb: Optional[torch.FloatTensor] = None, + *args, + **kwargs, + ) -> torch.Tensor: + if len(args) > 0 or kwargs.get("scale", None) is not None: + deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`." + deprecate("scale", "1.0.0", deprecation_message) + + residual = hidden_states + + if attn.spatial_norm is not None: + hidden_states = attn.spatial_norm(hidden_states, temb) + + input_ndim = hidden_states.ndim + + if input_ndim == 4: + batch_size, channel, height, width = hidden_states.shape + hidden_states = hidden_states.view( + batch_size, channel, height * width + ).transpose(1, 2) + + batch_size, sequence_length, _ = ( + hidden_states.shape + if encoder_hidden_states is None + else encoder_hidden_states.shape + ) + attention_mask = attn.prepare_attention_mask( + attention_mask, sequence_length, batch_size + ) + + if attn.group_norm is not None: + hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose( + 1, 2 + ) + + query = attn.to_q(hidden_states) + + if encoder_hidden_states is None: + encoder_hidden_states = hidden_states + elif attn.norm_cross: + encoder_hidden_states = attn.norm_encoder_hidden_states( + encoder_hidden_states + ) + + key = attn.to_k(encoder_hidden_states) + value = attn.to_v(encoder_hidden_states) + + query = attn.head_to_batch_dim(query) + key = attn.head_to_batch_dim(key) + value = attn.head_to_batch_dim(value) + + query = attn.q_norm(query) + key = attn.k_norm(key) + + attention_probs = attn.get_attention_scores(query, key, attention_mask) + hidden_states = torch.bmm(attention_probs, value) + hidden_states = attn.batch_to_head_dim(hidden_states) + + # linear proj + hidden_states = attn.to_out[0](hidden_states) + # dropout + hidden_states = attn.to_out[1](hidden_states) + + if input_ndim == 4: + hidden_states = hidden_states.transpose(-1, -2).reshape( + batch_size, channel, height, width + ) + + if attn.residual_connection: + hidden_states = hidden_states + residual + + hidden_states = hidden_states / attn.rescale_output_factor + + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + final_dropout (`bool` *optional*, defaults to False): Apply a final dropout. + bias (`bool`, defaults to True): Whether to use a bias in the linear layer. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + final_dropout: bool = False, + inner_dim=None, + bias: bool = True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + linear_cls = nn.Linear + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + else: + raise ValueError(f"Unsupported activation function: {activation_fn}") + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(linear_cls(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor, scale: float = 1.0) -> torch.Tensor: + compatible_cls = (GEGLU, LoRACompatibleLinear) + for module in self.net: + if isinstance(module, compatible_cls): + hidden_states = module(hidden_states, scale) + else: + hidden_states = module(hidden_states) + return hidden_states diff --git a/ltx_video/models/transformers/embeddings.py b/ltx_video/models/transformers/embeddings.py new file mode 100644 index 0000000000000000000000000000000000000000..a30d6be16b4f3fe709cf24465e06eb798889ba66 --- /dev/null +++ b/ltx_video/models/transformers/embeddings.py @@ -0,0 +1,129 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/embeddings.py +import math + +import numpy as np +import torch +from einops import rearrange +from torch import nn + + +def get_timestep_embedding( + timesteps: torch.Tensor, + embedding_dim: int, + flip_sin_to_cos: bool = False, + downscale_freq_shift: float = 1, + scale: float = 1, + max_period: int = 10000, +): + """ + This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param embedding_dim: the dimension of the output. :param max_period: controls the minimum frequency of the + embeddings. :return: an [N x dim] Tensor of positional embeddings. + """ + assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array" + + half_dim = embedding_dim // 2 + exponent = -math.log(max_period) * torch.arange( + start=0, end=half_dim, dtype=torch.float32, device=timesteps.device + ) + exponent = exponent / (half_dim - downscale_freq_shift) + + emb = torch.exp(exponent) + emb = timesteps[:, None].float() * emb[None, :] + + # scale embeddings + emb = scale * emb + + # concat sine and cosine embeddings + emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1) + + # flip sine and cosine embeddings + if flip_sin_to_cos: + emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1) + + # zero pad + if embedding_dim % 2 == 1: + emb = torch.nn.functional.pad(emb, (0, 1, 0, 0)) + return emb + + +def get_3d_sincos_pos_embed(embed_dim, grid, w, h, f): + """ + grid_size: int of the grid height and width return: pos_embed: [grid_size*grid_size, embed_dim] or + [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token) + """ + grid = rearrange(grid, "c (f h w) -> c f h w", h=h, w=w) + grid = rearrange(grid, "c f h w -> c h w f", h=h, w=w) + grid = grid.reshape([3, 1, w, h, f]) + pos_embed = get_3d_sincos_pos_embed_from_grid(embed_dim, grid) + pos_embed = pos_embed.transpose(1, 0, 2, 3) + return rearrange(pos_embed, "h w f c -> (f h w) c") + + +def get_3d_sincos_pos_embed_from_grid(embed_dim, grid): + if embed_dim % 3 != 0: + raise ValueError("embed_dim must be divisible by 3") + + # use half of dimensions to encode grid_h + emb_f = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[0]) # (H*W*T, D/3) + emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[1]) # (H*W*T, D/3) + emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 3, grid[2]) # (H*W*T, D/3) + + emb = np.concatenate([emb_h, emb_w, emb_f], axis=-1) # (H*W*T, D) + return emb + + +def get_1d_sincos_pos_embed_from_grid(embed_dim, pos): + """ + embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D) + """ + if embed_dim % 2 != 0: + raise ValueError("embed_dim must be divisible by 2") + + omega = np.arange(embed_dim // 2, dtype=np.float64) + omega /= embed_dim / 2.0 + omega = 1.0 / 10000**omega # (D/2,) + + pos_shape = pos.shape + + pos = pos.reshape(-1) + out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product + out = out.reshape([*pos_shape, -1])[0] + + emb_sin = np.sin(out) # (M, D/2) + emb_cos = np.cos(out) # (M, D/2) + + emb = np.concatenate([emb_sin, emb_cos], axis=-1) # (M, D) + return emb + + +class SinusoidalPositionalEmbedding(nn.Module): + """Apply positional information to a sequence of embeddings. + + Takes in a sequence of embeddings with shape (batch_size, seq_length, embed_dim) and adds positional embeddings to + them + + Args: + embed_dim: (int): Dimension of the positional embedding. + max_seq_length: Maximum sequence length to apply positional embeddings + + """ + + def __init__(self, embed_dim: int, max_seq_length: int = 32): + super().__init__() + position = torch.arange(max_seq_length).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, embed_dim, 2) * (-math.log(10000.0) / embed_dim) + ) + pe = torch.zeros(1, max_seq_length, embed_dim) + pe[0, :, 0::2] = torch.sin(position * div_term) + pe[0, :, 1::2] = torch.cos(position * div_term) + self.register_buffer("pe", pe) + + def forward(self, x): + _, seq_length, _ = x.shape + x = x + self.pe[:, :seq_length] + return x diff --git a/ltx_video/models/transformers/symmetric_patchifier.py b/ltx_video/models/transformers/symmetric_patchifier.py new file mode 100644 index 0000000000000000000000000000000000000000..ba1bd6c6138bf06748ec4d957ec1e924af6cd988 --- /dev/null +++ b/ltx_video/models/transformers/symmetric_patchifier.py @@ -0,0 +1,96 @@ +from abc import ABC, abstractmethod +from typing import Tuple + +import torch +from diffusers.configuration_utils import ConfigMixin +from einops import rearrange +from torch import Tensor + +from ltx_video.utils.torch_utils import append_dims + + +class Patchifier(ConfigMixin, ABC): + def __init__(self, patch_size: int): + super().__init__() + self._patch_size = (1, patch_size, patch_size) + + @abstractmethod + def patchify( + self, latents: Tensor, frame_rates: Tensor, scale_grid: bool + ) -> Tuple[Tensor, Tensor]: + pass + + @abstractmethod + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + output_num_frames: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + pass + + @property + def patch_size(self): + return self._patch_size + + def get_grid( + self, orig_num_frames, orig_height, orig_width, batch_size, scale_grid, device + ): + f = orig_num_frames // self._patch_size[0] + h = orig_height // self._patch_size[1] + w = orig_width // self._patch_size[2] + grid_h = torch.arange(h, dtype=torch.float32, device=device) + grid_w = torch.arange(w, dtype=torch.float32, device=device) + grid_f = torch.arange(f, dtype=torch.float32, device=device) + grid = torch.meshgrid(grid_f, grid_h, grid_w) + grid = torch.stack(grid, dim=0) + grid = grid.unsqueeze(0).repeat(batch_size, 1, 1, 1, 1) + + if scale_grid is not None: + for i in range(3): + if isinstance(scale_grid[i], Tensor): + scale = append_dims(scale_grid[i], grid.ndim - 1) + else: + scale = scale_grid[i] + grid[:, i, ...] = grid[:, i, ...] * scale * self._patch_size[i] + + grid = rearrange(grid, "b c f h w -> b c (f h w)", b=batch_size) + return grid + + +class SymmetricPatchifier(Patchifier): + def patchify( + self, + latents: Tensor, + ) -> Tuple[Tensor, Tensor]: + latents = rearrange( + latents, + "b c (f p1) (h p2) (w p3) -> b (f h w) (c p1 p2 p3)", + p1=self._patch_size[0], + p2=self._patch_size[1], + p3=self._patch_size[2], + ) + return latents + + def unpatchify( + self, + latents: Tensor, + output_height: int, + output_width: int, + output_num_frames: int, + out_channels: int, + ) -> Tuple[Tensor, Tensor]: + output_height = output_height // self._patch_size[1] + output_width = output_width // self._patch_size[2] + latents = rearrange( + latents, + "b (f h w) (c p q) -> b c f (h p) (w q) ", + f=output_num_frames, + h=output_height, + w=output_width, + p=self._patch_size[1], + q=self._patch_size[2], + ) + return latents diff --git a/ltx_video/models/transformers/transformer3d.py b/ltx_video/models/transformers/transformer3d.py new file mode 100644 index 0000000000000000000000000000000000000000..afc98e7da70e403ce69040df4e2785601cc92e3d --- /dev/null +++ b/ltx_video/models/transformers/transformer3d.py @@ -0,0 +1,605 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/v0.26.3/src/diffusers/models/transformers/transformer_2d.py +import math +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Literal, Union +import os +import json +import glob +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import PixArtAlphaTextProjection +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.normalization import AdaLayerNormSingle +from diffusers.utils import BaseOutput, is_torch_version +from diffusers.utils import logging +from torch import nn +from safetensors import safe_open + + +from ltx_video.models.transformers.attention import BasicTransformerBlock +from ltx_video.models.transformers.embeddings import get_3d_sincos_pos_embed +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, + TRANSFORMER_KEYS_RENAME_DICT, +) + + +logger = logging.get_logger(__name__) + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + double_self_attention: bool = False, + upcast_attention: bool = False, + adaptive_norm: str = "single_scale_shift", # 'single_scale_shift' or 'single_scale' + standardization_norm: str = "layer_norm", # 'layer_norm' or 'rms_norm' + norm_elementwise_affine: bool = True, + norm_eps: float = 1e-5, + attention_type: str = "default", + caption_channels: int = None, + project_to_2d_pos: bool = False, + use_tpu_flash_attention: bool = False, # if True uses the TPU attention offload ('flash attention') + qk_norm: Optional[str] = None, + positional_embedding_type: str = "absolute", + positional_embedding_theta: Optional[float] = None, + positional_embedding_max_pos: Optional[List[int]] = None, + timestep_scale_multiplier: Optional[float] = None, + ): + super().__init__() + self.use_tpu_flash_attention = ( + use_tpu_flash_attention # FIXME: push config down to the attention modules + ) + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + self.inner_dim = inner_dim + + self.project_to_2d_pos = project_to_2d_pos + + self.patchify_proj = nn.Linear(in_channels, inner_dim, bias=True) + + self.positional_embedding_type = positional_embedding_type + self.positional_embedding_theta = positional_embedding_theta + self.positional_embedding_max_pos = positional_embedding_max_pos + self.use_rope = self.positional_embedding_type == "rope" + self.timestep_scale_multiplier = timestep_scale_multiplier + + if self.positional_embedding_type == "absolute": + embed_dim_3d = ( + math.ceil((inner_dim / 2) * 3) if project_to_2d_pos else inner_dim + ) + if self.project_to_2d_pos: + self.to_2d_proj = torch.nn.Linear(embed_dim_3d, inner_dim, bias=False) + self._init_to_2d_proj_weights(self.to_2d_proj) + elif self.positional_embedding_type == "rope": + if positional_embedding_theta is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_theta` must also be defined" + ) + if positional_embedding_max_pos is None: + raise ValueError( + "If `positional_embedding_type` type is rope, `positional_embedding_max_pos` must also be defined" + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + double_self_attention=double_self_attention, + upcast_attention=upcast_attention, + adaptive_norm=adaptive_norm, + standardization_norm=standardization_norm, + norm_elementwise_affine=norm_elementwise_affine, + norm_eps=norm_eps, + attention_type=attention_type, + use_tpu_flash_attention=use_tpu_flash_attention, + qk_norm=qk_norm, + use_rope=self.use_rope, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.scale_shift_table = nn.Parameter( + torch.randn(2, inner_dim) / inner_dim**0.5 + ) + self.proj_out = nn.Linear(inner_dim, self.out_channels) + + self.adaln_single = AdaLayerNormSingle( + inner_dim, use_additional_conditions=False + ) + if adaptive_norm == "single_scale": + self.adaln_single.linear = nn.Linear(inner_dim, 4 * inner_dim, bias=True) + + self.caption_projection = None + if caption_channels is not None: + self.caption_projection = PixArtAlphaTextProjection( + in_features=caption_channels, hidden_size=inner_dim + ) + + self.gradient_checkpointing = False + + def set_use_tpu_flash_attention(self): + r""" + Function sets the flag in this object and propagates down the children. The flag will enforce the usage of TPU + attention kernel. + """ + logger.info("ENABLE TPU FLASH ATTENTION -> TRUE") + self.use_tpu_flash_attention = True + # push config down to the attention modules + for block in self.transformer_blocks: + block.set_use_tpu_flash_attention() + + def create_skip_layer_mask( + self, + skip_block_list: List[int], + batch_size: int, + num_conds: int, + ptb_index: int, + ): + num_layers = len(self.transformer_blocks) + mask = torch.ones( + (num_layers, batch_size * num_conds), device=self.device, dtype=self.dtype + ) + for block_idx in skip_block_list: + mask[block_idx, ptb_index::num_conds] = 0 + return mask + + def initialize(self, embedding_std: float, mode: Literal["ltx_video", "legacy"]): + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + # Initialize timestep embedding MLP: + nn.init.normal_( + self.adaln_single.emb.timestep_embedder.linear_1.weight, std=embedding_std + ) + nn.init.normal_( + self.adaln_single.emb.timestep_embedder.linear_2.weight, std=embedding_std + ) + nn.init.normal_(self.adaln_single.linear.weight, std=embedding_std) + + if hasattr(self.adaln_single.emb, "resolution_embedder"): + nn.init.normal_( + self.adaln_single.emb.resolution_embedder.linear_1.weight, + std=embedding_std, + ) + nn.init.normal_( + self.adaln_single.emb.resolution_embedder.linear_2.weight, + std=embedding_std, + ) + if hasattr(self.adaln_single.emb, "aspect_ratio_embedder"): + nn.init.normal_( + self.adaln_single.emb.aspect_ratio_embedder.linear_1.weight, + std=embedding_std, + ) + nn.init.normal_( + self.adaln_single.emb.aspect_ratio_embedder.linear_2.weight, + std=embedding_std, + ) + + # Initialize caption embedding MLP: + nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std) + nn.init.normal_(self.caption_projection.linear_1.weight, std=embedding_std) + + for block in self.transformer_blocks: + if mode.lower() == "ltx_video": + nn.init.constant_(block.attn1.to_out[0].weight, 0) + nn.init.constant_(block.attn1.to_out[0].bias, 0) + + nn.init.constant_(block.attn2.to_out[0].weight, 0) + nn.init.constant_(block.attn2.to_out[0].bias, 0) + + if mode.lower() == "ltx_video": + nn.init.constant_(block.ff.net[2].weight, 0) + nn.init.constant_(block.ff.net[2].bias, 0) + + # Zero-out output layers: + nn.init.constant_(self.proj_out.weight, 0) + nn.init.constant_(self.proj_out.bias, 0) + + def _set_gradient_checkpointing(self, module, value=False): + if hasattr(module, "gradient_checkpointing"): + module.gradient_checkpointing = value + + @staticmethod + def _init_to_2d_proj_weights(linear_layer): + input_features = linear_layer.weight.data.size(1) + output_features = linear_layer.weight.data.size(0) + + # Start with a zero matrix + identity_like = torch.zeros((output_features, input_features)) + + # Fill the diagonal with 1's as much as possible + min_features = min(output_features, input_features) + identity_like[:min_features, :min_features] = torch.eye(min_features) + linear_layer.weight.data = identity_like.to(linear_layer.weight.data.device) + + def get_fractional_positions(self, indices_grid): + fractional_positions = torch.stack( + [ + indices_grid[:, i] / self.positional_embedding_max_pos[i] + for i in range(3) + ], + dim=-1, + ) + return fractional_positions + + def precompute_freqs_cis(self, indices_grid, spacing="exp"): + dtype = torch.float32 # We need full precision in the freqs_cis computation. + dim = self.inner_dim + theta = self.positional_embedding_theta + + fractional_positions = self.get_fractional_positions(indices_grid) + + start = 1 + end = theta + device = fractional_positions.device + if spacing == "exp": + indices = theta ** ( + torch.linspace( + math.log(start, theta), + math.log(end, theta), + dim // 6, + device=device, + dtype=dtype, + ) + ) + indices = indices.to(dtype=dtype) + elif spacing == "exp_2": + indices = 1.0 / theta ** (torch.arange(0, dim, 6, device=device) / dim) + indices = indices.to(dtype=dtype) + elif spacing == "linear": + indices = torch.linspace(start, end, dim // 6, device=device, dtype=dtype) + elif spacing == "sqrt": + indices = torch.linspace( + start**2, end**2, dim // 6, device=device, dtype=dtype + ).sqrt() + + indices = indices * math.pi / 2 + + if spacing == "exp_2": + freqs = ( + (indices * fractional_positions.unsqueeze(-1)) + .transpose(-1, -2) + .flatten(2) + ) + else: + freqs = ( + (indices * (fractional_positions.unsqueeze(-1) * 2 - 1)) + .transpose(-1, -2) + .flatten(2) + ) + + cos_freq = freqs.cos().repeat_interleave(2, dim=-1) + sin_freq = freqs.sin().repeat_interleave(2, dim=-1) + if dim % 6 != 0: + cos_padding = torch.ones_like(cos_freq[:, :, : dim % 6]) + sin_padding = torch.zeros_like(cos_freq[:, :, : dim % 6]) + cos_freq = torch.cat([cos_padding, cos_freq], dim=-1) + sin_freq = torch.cat([sin_padding, sin_freq], dim=-1) + return cos_freq.to(self.dtype), sin_freq.to(self.dtype) + + def load_state_dict( + self, + state_dict: Dict, + *args, + **kwargs, + ): + if any([key.startswith("model.diffusion_model.") for key in state_dict.keys()]): + state_dict = { + key.replace("model.diffusion_model.", ""): value + for key, value in state_dict.items() + if key.startswith("model.diffusion_model.") + } + super().load_state_dict(state_dict, **kwargs) + + @classmethod + def from_pretrained( + cls, + pretrained_model_path: Optional[Union[str, os.PathLike]], + *args, + **kwargs, + ): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_dir(): + config_path = pretrained_model_path / "transformer" / "config.json" + with open(config_path, "r") as f: + config = make_hashable_key(json.load(f)) + + assert config in diffusers_and_ours_config_mapping, ( + "Provided diffusers checkpoint config for transformer is not suppported. " + "We only support diffusers configs found in Lightricks/LTX-Video." + ) + + config = diffusers_and_ours_config_mapping[config] + state_dict = {} + ckpt_paths = ( + pretrained_model_path + / "transformer" + / "diffusion_pytorch_model*.safetensors" + ) + dict_list = glob.glob(str(ckpt_paths)) + for dict_path in dict_list: + part_dict = {} + with safe_open(dict_path, framework="pt", device="cpu") as f: + for k in f.keys(): + part_dict[k] = f.get_tensor(k) + state_dict.update(part_dict) + + for key in list(state_dict.keys()): + new_key = key + for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items(): + new_key = new_key.replace(replace_key, rename_key) + state_dict[new_key] = state_dict.pop(key) + + transformer = cls.from_config(config) + transformer.load_state_dict(state_dict, strict=True) + elif pretrained_model_path.is_file() and str(pretrained_model_path).endswith( + ".safetensors" + ): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + transformer_config = configs["transformer"] + transformer = Transformer3DModel.from_config(transformer_config) + transformer.load_state_dict(comfy_single_file_state_dict) + return transformer + + def forward( + self, + hidden_states: torch.Tensor, + indices_grid: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + skip_layer_mask: Optional[torch.Tensor] = None, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + return_dict: bool = True, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + indices_grid (`torch.LongTensor` of shape `(batch size, 3, num latent pixels)`): + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + cross_attention_kwargs ( `Dict[str, Any]`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under + `self.processor` in + [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py). + attention_mask ( `torch.Tensor`, *optional*): + An attention mask of shape `(batch, key_tokens)` is applied to `encoder_hidden_states`. If `1` the mask + is kept, otherwise if `0` it is discarded. Mask will be converted into a bias, which adds large + negative values to the attention scores corresponding to "discard" tokens. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + skip_layer_mask ( `torch.Tensor`, *optional*): + A mask of shape `(num_layers, batch)` that indicates which layers to skip. `0` at position + `layer, batch_idx` indicates that the layer should be skipped for the corresponding batch index. + skip_layer_strategy ( `SkipLayerStrategy`, *optional*, defaults to `None`): + Controls which layers are skipped when calculating a perturbed latent for spatiotemporal guidance. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # for tpu attention offload 2d token masks are used. No need to transform. + if not self.use_tpu_flash_attention: + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = ( + 1 - encoder_attention_mask.to(hidden_states.dtype) + ) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + hidden_states = self.patchify_proj(hidden_states) + + if self.timestep_scale_multiplier: + timestep = self.timestep_scale_multiplier * timestep + + if self.positional_embedding_type == "absolute": + pos_embed_3d = self.get_absolute_pos_embed(indices_grid).to( + hidden_states.device + ) + if self.project_to_2d_pos: + pos_embed = self.to_2d_proj(pos_embed_3d) + hidden_states = (hidden_states + pos_embed).to(hidden_states.dtype) + freqs_cis = None + elif self.positional_embedding_type == "rope": + freqs_cis = self.precompute_freqs_cis(indices_grid) + + batch_size = hidden_states.shape[0] + timestep, embedded_timestep = self.adaln_single( + timestep.flatten(), + {"resolution": None, "aspect_ratio": None}, + batch_size=batch_size, + hidden_dtype=hidden_states.dtype, + ) + # Second dimension is 1 or number of tokens (if timestep_per_token) + timestep = timestep.view(batch_size, -1, timestep.shape[-1]) + embedded_timestep = embedded_timestep.view( + batch_size, -1, embedded_timestep.shape[-1] + ) + + if skip_layer_mask is None: + skip_layer_mask = torch.ones( + len(self.transformer_blocks), batch_size, device=hidden_states.device + ) + + # 2. Blocks + if self.caption_projection is not None: + batch_size = hidden_states.shape[0] + encoder_hidden_states = self.caption_projection(encoder_hidden_states) + encoder_hidden_states = encoder_hidden_states.view( + batch_size, -1, hidden_states.shape[-1] + ) + + for block_idx, block in enumerate(self.transformer_blocks): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + ckpt_kwargs: Dict[str, Any] = ( + {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {} + ) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(block), + hidden_states, + freqs_cis, + attention_mask, + encoder_hidden_states, + encoder_attention_mask, + timestep, + cross_attention_kwargs, + class_labels, + skip_layer_mask[block_idx], + skip_layer_strategy, + **ckpt_kwargs, + ) + else: + hidden_states = block( + hidden_states, + freqs_cis=freqs_cis, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + skip_layer_mask=skip_layer_mask[block_idx], + skip_layer_strategy=skip_layer_strategy, + ) + + # 3. Output + scale_shift_values = ( + self.scale_shift_table[None, None] + embedded_timestep[:, :, None] + ) + shift, scale = scale_shift_values[:, :, 0], scale_shift_values[:, :, 1] + hidden_states = self.norm_out(hidden_states) + # Modulation + hidden_states = hidden_states * (1 + scale) + shift + hidden_states = self.proj_out(hidden_states) + if not return_dict: + return (hidden_states,) + + return Transformer3DModelOutput(sample=hidden_states) + + def get_absolute_pos_embed(self, grid): + grid_np = grid[0].cpu().numpy() + embed_dim_3d = ( + math.ceil((self.inner_dim / 2) * 3) + if self.project_to_2d_pos + else self.inner_dim + ) + pos_embed = get_3d_sincos_pos_embed( # (f h w) + embed_dim_3d, + grid_np, + h=int(max(grid_np[1]) + 1), + w=int(max(grid_np[2]) + 1), + f=int(max(grid_np[0] + 1)), + ) + return torch.from_numpy(pos_embed).float().unsqueeze(0) diff --git a/ltx_video/pipelines/__init__.py b/ltx_video/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/pipelines/pipeline_ltx_video.py b/ltx_video/pipelines/pipeline_ltx_video.py new file mode 100644 index 0000000000000000000000000000000000000000..19f191d00d9bf1d3b7cbc158e482ef80e059754b --- /dev/null +++ b/ltx_video/pipelines/pipeline_ltx_video.py @@ -0,0 +1,1274 @@ +# Adapted from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +import html +import inspect +import math +import re +import urllib.parse as ul +from typing import Callable, Dict, List, Optional, Tuple, Union + + +import torch +import torch.nn.functional as F +from contextlib import nullcontext +from diffusers.image_processor import VaeImageProcessor +from diffusers.models import AutoencoderKL +from diffusers.pipelines.pipeline_utils import DiffusionPipeline, ImagePipelineOutput +from diffusers.schedulers import DPMSolverMultistepScheduler +from diffusers.utils import ( + BACKENDS_MAPPING, + deprecate, + is_bs4_available, + is_ftfy_available, + logging, +) +from diffusers.utils.torch_utils import randn_tensor +from einops import rearrange +from transformers import T5EncoderModel, T5Tokenizer + +from ltx_video.models.transformers.transformer3d import Transformer3DModel +from ltx_video.models.transformers.symmetric_patchifier import Patchifier +from ltx_video.models.autoencoders.vae_encode import ( + get_vae_size_scale_factor, + vae_decode, + vae_encode, +) +from ltx_video.models.autoencoders.causal_video_autoencoder import ( + CausalVideoAutoencoder, +) +from ltx_video.schedulers.rf import TimestepShifter +from ltx_video.utils.conditioning_method import ConditioningMethod +from ltx_video.utils.skip_layer_strategy import SkipLayerStrategy + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +if is_bs4_available(): + from bs4 import BeautifulSoup + +if is_ftfy_available(): + import ftfy + +ASPECT_RATIO_1024_BIN = { + "0.25": [512.0, 2048.0], + "0.28": [512.0, 1856.0], + "0.32": [576.0, 1792.0], + "0.33": [576.0, 1728.0], + "0.35": [576.0, 1664.0], + "0.4": [640.0, 1600.0], + "0.42": [640.0, 1536.0], + "0.48": [704.0, 1472.0], + "0.5": [704.0, 1408.0], + "0.52": [704.0, 1344.0], + "0.57": [768.0, 1344.0], + "0.6": [768.0, 1280.0], + "0.68": [832.0, 1216.0], + "0.72": [832.0, 1152.0], + "0.78": [896.0, 1152.0], + "0.82": [896.0, 1088.0], + "0.88": [960.0, 1088.0], + "0.94": [960.0, 1024.0], + "1.0": [1024.0, 1024.0], + "1.07": [1024.0, 960.0], + "1.13": [1088.0, 960.0], + "1.21": [1088.0, 896.0], + "1.29": [1152.0, 896.0], + "1.38": [1152.0, 832.0], + "1.46": [1216.0, 832.0], + "1.67": [1280.0, 768.0], + "1.75": [1344.0, 768.0], + "2.0": [1408.0, 704.0], + "2.09": [1472.0, 704.0], + "2.4": [1536.0, 640.0], + "2.5": [1600.0, 640.0], + "3.0": [1728.0, 576.0], + "4.0": [2048.0, 512.0], +} + +ASPECT_RATIO_512_BIN = { + "0.25": [256.0, 1024.0], + "0.28": [256.0, 928.0], + "0.32": [288.0, 896.0], + "0.33": [288.0, 864.0], + "0.35": [288.0, 832.0], + "0.4": [320.0, 800.0], + "0.42": [320.0, 768.0], + "0.48": [352.0, 736.0], + "0.5": [352.0, 704.0], + "0.52": [352.0, 672.0], + "0.57": [384.0, 672.0], + "0.6": [384.0, 640.0], + "0.68": [416.0, 608.0], + "0.72": [416.0, 576.0], + "0.78": [448.0, 576.0], + "0.82": [448.0, 544.0], + "0.88": [480.0, 544.0], + "0.94": [480.0, 512.0], + "1.0": [512.0, 512.0], + "1.07": [512.0, 480.0], + "1.13": [544.0, 480.0], + "1.21": [544.0, 448.0], + "1.29": [576.0, 448.0], + "1.38": [576.0, 416.0], + "1.46": [608.0, 416.0], + "1.67": [640.0, 384.0], + "1.75": [672.0, 384.0], + "2.0": [704.0, 352.0], + "2.09": [736.0, 352.0], + "2.4": [768.0, 320.0], + "2.5": [800.0, 320.0], + "3.0": [864.0, 288.0], + "4.0": [1024.0, 256.0], +} + + +# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps +def retrieve_timesteps( + scheduler, + num_inference_steps: Optional[int] = None, + device: Optional[Union[str, torch.device]] = None, + timesteps: Optional[List[int]] = None, + **kwargs, +): + """ + Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles + custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. + + Args: + scheduler (`SchedulerMixin`): + The scheduler to get timesteps from. + num_inference_steps (`int`): + The number of diffusion steps used when generating samples with a pre-trained model. If used, + `timesteps` must be `None`. + device (`str` or `torch.device`, *optional*): + The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. + timesteps (`List[int]`, *optional*): + Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default + timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` + must be `None`. + + Returns: + `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the + second element is the number of inference steps. + """ + if timesteps is not None: + accepts_timesteps = "timesteps" in set( + inspect.signature(scheduler.set_timesteps).parameters.keys() + ) + if not accepts_timesteps: + raise ValueError( + f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" + f" timestep schedules. Please check whether you are using the correct scheduler." + ) + scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) + timesteps = scheduler.timesteps + num_inference_steps = len(timesteps) + else: + scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) + timesteps = scheduler.timesteps + return timesteps, num_inference_steps + + +class LTXVideoPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using LTX-Video. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`T5EncoderModel`]): + Frozen text-encoder. This uses + [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel), specifically the + [t5-v1_1-xxl](https://huggingface.co/PixArt-alpha/PixArt-alpha/tree/main/t5-v1_1-xxl) variant. + tokenizer (`T5Tokenizer`): + Tokenizer of class + [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). + transformer ([`Transformer2DModel`]): + A text conditioned `Transformer2DModel` to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `transformer` to denoise the encoded image latents. + """ + + bad_punct_regex = re.compile( + r"[" + + "#®•©™&@·º½¾¿¡§~" + + r"\)" + + r"\(" + + r"\]" + + r"\[" + + r"\}" + + r"\{" + + r"\|" + + "\\" + + r"\/" + + r"\*" + + r"]{1,}" + ) # noqa + + _optional_components = ["tokenizer", "text_encoder"] + model_cpu_offload_seq = "text_encoder->transformer->vae" + + def __init__( + self, + tokenizer: T5Tokenizer, + text_encoder: T5EncoderModel, + vae: AutoencoderKL, + transformer: Transformer3DModel, + scheduler: DPMSolverMultistepScheduler, + patchifier: Patchifier, + ): + super().__init__() + + self.register_modules( + tokenizer=tokenizer, + text_encoder=text_encoder, + vae=vae, + transformer=transformer, + scheduler=scheduler, + patchifier=patchifier, + ) + + self.video_scale_factor, self.vae_scale_factor, _ = get_vae_size_scale_factor( + self.vae + ) + self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) + + def mask_text_embeddings(self, emb, mask): + if emb.shape[0] == 1: + keep_index = mask.sum().item() + return emb[:, :, :keep_index, :], keep_index + else: + masked_feature = emb * mask[:, None, :, None] + return masked_feature, emb.shape[2] + + # Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt + def encode_prompt( + self, + prompt: Union[str, List[str]], + do_classifier_free_guidance: bool = True, + negative_prompt: str = "", + num_images_per_prompt: int = 1, + device: Optional[torch.device] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + clean_caption: bool = False, + **kwargs, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + negative_prompt (`str` or `List[str]`, *optional*): + The prompt not to guide the image generation. If not defined, one has to pass `negative_prompt_embeds` + instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). For + This should be "". + do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): + whether to use classifier free guidance or not + num_images_per_prompt (`int`, *optional*, defaults to 1): + number of images that should be generated per prompt + device: (`torch.device`, *optional*): + torch device to place the resulting embeddings on + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. + clean_caption (bool, defaults to `False`): + If `True`, the function will preprocess and clean the provided caption before encoding. + """ + + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + if device is None: + device = self._execution_device + + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + # See Section 3.1. of the paper. + # FIXME: to be configured in config not hardecoded. Fix in separate PR with rest of config + max_length = 128 # TPU supports only lengths multiple of 128 + text_enc_device = next(self.text_encoder.parameters()).device + if prompt_embeds is None: + prompt = self._text_preprocessing(prompt, clean_caption=clean_caption) + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=max_length, + truncation=True, + add_special_tokens=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer( + prompt, padding="longest", return_tensors="pt" + ).input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[ + -1 + ] and not torch.equal(text_input_ids, untruncated_ids): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {max_length} tokens: {removed_text}" + ) + + prompt_attention_mask = text_inputs.attention_mask + prompt_attention_mask = prompt_attention_mask.to(text_enc_device) + prompt_attention_mask = prompt_attention_mask.to(device) + + prompt_embeds = self.text_encoder( + text_input_ids.to(text_enc_device), attention_mask=prompt_attention_mask + ) + prompt_embeds = prompt_embeds[0] + + if self.text_encoder is not None: + dtype = self.text_encoder.dtype + elif self.transformer is not None: + dtype = self.transformer.dtype + else: + dtype = None + + prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view( + bs_embed * num_images_per_prompt, seq_len, -1 + ) + prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt) + prompt_attention_mask = prompt_attention_mask.view( + bs_embed * num_images_per_prompt, -1 + ) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens = [negative_prompt] * batch_size + uncond_tokens = self._text_preprocessing( + uncond_tokens, clean_caption=clean_caption + ) + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_attention_mask=True, + add_special_tokens=True, + return_tensors="pt", + ) + negative_prompt_attention_mask = uncond_input.attention_mask + negative_prompt_attention_mask = negative_prompt_attention_mask.to( + text_enc_device + ) + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(text_enc_device), + attention_mask=negative_prompt_attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to( + dtype=dtype, device=device + ) + + negative_prompt_embeds = negative_prompt_embeds.repeat( + 1, num_images_per_prompt, 1 + ) + negative_prompt_embeds = negative_prompt_embeds.view( + batch_size * num_images_per_prompt, seq_len, -1 + ) + + negative_prompt_attention_mask = negative_prompt_attention_mask.repeat( + 1, num_images_per_prompt + ) + negative_prompt_attention_mask = negative_prompt_attention_mask.view( + bs_embed * num_images_per_prompt, -1 + ) + else: + negative_prompt_embeds = None + negative_prompt_attention_mask = None + + return ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set( + inspect.signature(self.scheduler.step).parameters.keys() + ) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + negative_prompt, + prompt_embeds=None, + negative_prompt_embeds=None, + prompt_attention_mask=None, + negative_prompt_attention_mask=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError( + f"`height` and `width` have to be divisible by 8 but are {height} and {width}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and ( + not isinstance(prompt, str) and not isinstance(prompt, list) + ): + raise ValueError( + f"`prompt` has to be of type `str` or `list` but is {type(prompt)}" + ) + + if prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and prompt_attention_mask is None: + raise ValueError( + "Must provide `prompt_attention_mask` when specifying `prompt_embeds`." + ) + + if ( + negative_prompt_embeds is not None + and negative_prompt_attention_mask is None + ): + raise ValueError( + "Must provide `negative_prompt_attention_mask` when specifying `negative_prompt_embeds`." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + if prompt_attention_mask.shape != negative_prompt_attention_mask.shape: + raise ValueError( + "`prompt_attention_mask` and `negative_prompt_attention_mask` must have the same shape when passed directly, but" + f" got: `prompt_attention_mask` {prompt_attention_mask.shape} != `negative_prompt_attention_mask`" + f" {negative_prompt_attention_mask.shape}." + ) + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._text_preprocessing + def _text_preprocessing(self, text, clean_caption=False): + if clean_caption and not is_bs4_available(): + logger.warn( + BACKENDS_MAPPING["bs4"][-1].format("Setting `clean_caption=True`") + ) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if clean_caption and not is_ftfy_available(): + logger.warn( + BACKENDS_MAPPING["ftfy"][-1].format("Setting `clean_caption=True`") + ) + logger.warn("Setting `clean_caption` to False...") + clean_caption = False + + if not isinstance(text, (tuple, list)): + text = [text] + + def process(text: str): + if clean_caption: + text = self._clean_caption(text) + text = self._clean_caption(text) + else: + text = text.lower().strip() + return text + + return [process(t) for t in text] + + # Copied from diffusers.pipelines.deepfloyd_if.pipeline_if.IFPipeline._clean_caption + def _clean_caption(self, caption): + caption = str(caption) + caption = ul.unquote_plus(caption) + caption = caption.strip().lower() + caption = re.sub("", "person", caption) + # urls: + caption = re.sub( + r"\b((?:https?:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + caption = re.sub( + r"\b((?:www:(?:\/{1,3}|[a-zA-Z0-9%])|[a-zA-Z0-9.\-]+[.](?:com|co|ru|net|org|edu|gov|it)[\w/-]*\b\/?(?!@)))", # noqa + "", + caption, + ) # regex for urls + # html: + caption = BeautifulSoup(caption, features="html.parser").text + + # @ + caption = re.sub(r"@[\w\d]+\b", "", caption) + + # 31C0—31EF CJK Strokes + # 31F0—31FF Katakana Phonetic Extensions + # 3200—32FF Enclosed CJK Letters and Months + # 3300—33FF CJK Compatibility + # 3400—4DBF CJK Unified Ideographs Extension A + # 4DC0—4DFF Yijing Hexagram Symbols + # 4E00—9FFF CJK Unified Ideographs + caption = re.sub(r"[\u31c0-\u31ef]+", "", caption) + caption = re.sub(r"[\u31f0-\u31ff]+", "", caption) + caption = re.sub(r"[\u3200-\u32ff]+", "", caption) + caption = re.sub(r"[\u3300-\u33ff]+", "", caption) + caption = re.sub(r"[\u3400-\u4dbf]+", "", caption) + caption = re.sub(r"[\u4dc0-\u4dff]+", "", caption) + caption = re.sub(r"[\u4e00-\u9fff]+", "", caption) + ####################################################### + + # все виды тире / all types of dash --> "-" + caption = re.sub( + r"[\u002D\u058A\u05BE\u1400\u1806\u2010-\u2015\u2E17\u2E1A\u2E3A\u2E3B\u2E40\u301C\u3030\u30A0\uFE31\uFE32\uFE58\uFE63\uFF0D]+", # noqa + "-", + caption, + ) + + # кавычки к одному стандарту + caption = re.sub(r"[`´«»“”¨]", '"', caption) + caption = re.sub(r"[‘’]", "'", caption) + + # " + caption = re.sub(r""?", "", caption) + # & + caption = re.sub(r"&", "", caption) + + # ip adresses: + caption = re.sub(r"\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}", " ", caption) + + # article ids: + caption = re.sub(r"\d:\d\d\s+$", "", caption) + + # \n + caption = re.sub(r"\\n", " ", caption) + + # "#123" + caption = re.sub(r"#\d{1,3}\b", "", caption) + # "#12345.." + caption = re.sub(r"#\d{5,}\b", "", caption) + # "123456.." + caption = re.sub(r"\b\d{6,}\b", "", caption) + # filenames: + caption = re.sub( + r"[\S]+\.(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)", "", caption + ) + + # + caption = re.sub(r"[\"\']{2,}", r'"', caption) # """AUSVERKAUFT""" + caption = re.sub(r"[\.]{2,}", r" ", caption) # """AUSVERKAUFT""" + + caption = re.sub( + self.bad_punct_regex, r" ", caption + ) # ***AUSVERKAUFT***, #AUSVERKAUFT + caption = re.sub(r"\s+\.\s+", r" ", caption) # " . " + + # this-is-my-cute-cat / this_is_my_cute_cat + regex2 = re.compile(r"(?:\-|\_)") + if len(re.findall(regex2, caption)) > 3: + caption = re.sub(regex2, " ", caption) + + caption = ftfy.fix_text(caption) + caption = html.unescape(html.unescape(caption)) + + caption = re.sub(r"\b[a-zA-Z]{1,3}\d{3,15}\b", "", caption) # jc6640 + caption = re.sub(r"\b[a-zA-Z]+\d+[a-zA-Z]+\b", "", caption) # jc6640vc + caption = re.sub(r"\b\d+[a-zA-Z]+\d+\b", "", caption) # 6640vc231 + + caption = re.sub(r"(worldwide\s+)?(free\s+)?shipping", "", caption) + caption = re.sub(r"(free\s)?download(\sfree)?", "", caption) + caption = re.sub(r"\bclick\b\s(?:for|on)\s\w+", "", caption) + caption = re.sub( + r"\b(?:png|jpg|jpeg|bmp|webp|eps|pdf|apk|mp4)(\simage[s]?)?", "", caption + ) + caption = re.sub(r"\bpage\s+\d+\b", "", caption) + + caption = re.sub( + r"\b\d*[a-zA-Z]+\d+[a-zA-Z]+\d+[a-zA-Z\d]*\b", r" ", caption + ) # j2d1a2a... + + caption = re.sub(r"\b\d+\.?\d*[xх×]\d+\.?\d*\b", "", caption) + + caption = re.sub(r"\b\s+\:\s+", r": ", caption) + caption = re.sub(r"(\D[,\./])\b", r"\1 ", caption) + caption = re.sub(r"\s+", " ", caption) + + caption.strip() + + caption = re.sub(r"^[\"\']([\w\W]+)[\"\']$", r"\1", caption) + caption = re.sub(r"^[\'\_,\-\:;]", r"", caption) + caption = re.sub(r"[\'\_,\-\:\-\+]$", r"", caption) + caption = re.sub(r"^\.\S+$", "", caption) + + return caption.strip() + + def image_cond_noise_update( + self, + t, + init_latents, + latents, + noise_scale, + conditiong_mask, + generator, + ): + noise = randn_tensor( + latents.shape, + generator=generator, + device=latents.device, + dtype=latents.dtype, + ) + latents = (init_latents + noise_scale * noise * (t**2)) * conditiong_mask[ + ..., None + ] + latents * (1 - conditiong_mask[..., None]) + return latents + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents + def prepare_latents( + self, + batch_size, + num_latent_channels, + num_patches, + dtype, + device, + generator, + latents=None, + latents_mask=None, + ): + shape = ( + batch_size, + num_patches // math.prod(self.patchifier.patch_size), + num_latent_channels, + ) + + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor( + shape, generator=generator, device=generator.device, dtype=dtype + ) + elif latents_mask is not None: + noise = randn_tensor( + shape, generator=generator, device=generator.device, dtype=dtype + ) + latents = latents * latents_mask[..., None] + noise * ( + 1 - latents_mask[..., None] + ) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @staticmethod + def classify_height_width_bin( + height: int, width: int, ratios: dict + ) -> Tuple[int, int]: + """Returns binned height and width.""" + ar = float(height / width) + closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar)) + default_hw = ratios[closest_ratio] + return int(default_hw[0]), int(default_hw[1]) + + @staticmethod + def resize_and_crop_tensor( + samples: torch.Tensor, new_width: int, new_height: int + ) -> torch.Tensor: + n_frames, orig_height, orig_width = samples.shape[-3:] + + # Check if resizing is needed + if orig_height != new_height or orig_width != new_width: + ratio = max(new_height / orig_height, new_width / orig_width) + resized_width = int(orig_width * ratio) + resized_height = int(orig_height * ratio) + + # Resize + samples = rearrange(samples, "b c n h w -> (b n) c h w") + samples = F.interpolate( + samples, + size=(resized_height, resized_width), + mode="bilinear", + align_corners=False, + ) + samples = rearrange(samples, "(b n) c h w -> b c n h w", n=n_frames) + + # Center Crop + start_x = (resized_width - new_width) // 2 + end_x = start_x + new_width + start_y = (resized_height - new_height) // 2 + end_y = start_y + new_height + samples = samples[..., start_y:end_y, start_x:end_x] + + return samples + + @torch.no_grad() + def __call__( + self, + height: int, + width: int, + num_frames: int, + frame_rate: float, + prompt: Union[str, List[str]] = None, + negative_prompt: str = "", + num_inference_steps: int = 20, + timesteps: List[int] = None, + guidance_scale: float = 4.5, + skip_layer_strategy: Optional[SkipLayerStrategy] = None, + skip_block_list: List[int] = None, + stg_scale: float = 1.0, + do_rescaling: bool = True, + rescaling_scale: float = 0.7, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + prompt_attention_mask: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_attention_mask: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None, + clean_caption: bool = True, + media_items: Optional[torch.FloatTensor] = None, + decode_timestep: Union[List[float], float] = 0.0, + decode_noise_scale: Optional[List[float]] = None, + mixed_precision: bool = False, + offload_to_cpu: bool = False, + **kwargs, + ) -> Union[ImagePipelineOutput, Tuple]: + """ + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + num_inference_steps (`int`, *optional*, defaults to 100): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + timesteps (`List[int]`, *optional*): + Custom timesteps to use for the denoising process. If not defined, equal spaced `num_inference_steps` + timesteps are used. Must be in descending order. + guidance_scale (`float`, *optional*, defaults to 4.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + height (`int`, *optional*, defaults to self.unet.config.sample_size): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size): + The width in pixels of the generated image. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (η) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + prompt_attention_mask (`torch.FloatTensor`, *optional*): Pre-generated attention mask for text embeddings. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. This negative prompt should be "". If not + provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. + negative_prompt_attention_mask (`torch.FloatTensor`, *optional*): + Pre-generated attention mask for negative text embeddings. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.IFPipelineOutput`] instead of a plain tuple. + callback_on_step_end (`Callable`, *optional*): + A function that calls at the end of each denoising steps during the inference. The function is called + with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, + callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by + `callback_on_step_end_tensor_inputs`. + clean_caption (`bool`, *optional*, defaults to `True`): + Whether or not to clean the caption before creating embeddings. Requires `beautifulsoup4` and `ftfy` to + be installed. If the dependencies are not installed, the embeddings will be created from the raw + prompt. + use_resolution_binning (`bool` defaults to `True`): + If set to `True`, the requested height and width are first mapped to the closest resolutions using + `ASPECT_RATIO_1024_BIN`. After the produced latents are decoded into images, they are resized back to + the requested resolution. Useful for generating non-square images. + + Examples: + + Returns: + [`~pipelines.ImagePipelineOutput`] or `tuple`: + If `return_dict` is `True`, [`~pipelines.ImagePipelineOutput`] is returned, otherwise a `tuple` is + returned where the first element is a list with the generated images + """ + if "mask_feature" in kwargs: + deprecation_message = "The use of `mask_feature` is deprecated. It is no longer used in any computation and that doesn't affect the end results. It will be removed in a future version." + deprecate("mask_feature", "1.0.0", deprecation_message, standard_warn=False) + + is_video = kwargs.get("is_video", False) + self.check_inputs( + prompt, + height, + width, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + prompt_attention_mask, + negative_prompt_attention_mask, + ) + + # 2. Default height and width to transformer + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + do_spatio_temporal_guidance = stg_scale > 0.0 + + num_conds = 1 + if do_classifier_free_guidance: + num_conds += 1 + if do_spatio_temporal_guidance: + num_conds += 1 + + skip_layer_mask = None + if do_spatio_temporal_guidance: + skip_layer_mask = self.transformer.create_skip_layer_mask( + skip_block_list, batch_size, num_conds, 2 + ) + + # 3. Encode input prompt + self.text_encoder = self.text_encoder.to(self._execution_device) + + ( + prompt_embeds, + prompt_attention_mask, + negative_prompt_embeds, + negative_prompt_attention_mask, + ) = self.encode_prompt( + prompt, + do_classifier_free_guidance, + negative_prompt=negative_prompt, + num_images_per_prompt=num_images_per_prompt, + device=device, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + prompt_attention_mask=prompt_attention_mask, + negative_prompt_attention_mask=negative_prompt_attention_mask, + clean_caption=clean_caption, + ) + + if offload_to_cpu: + self.text_encoder = self.text_encoder.cpu() + + self.transformer = self.transformer.to(self._execution_device) + + prompt_embeds_batch = prompt_embeds + prompt_attention_mask_batch = prompt_attention_mask + if do_classifier_free_guidance: + prompt_embeds_batch = torch.cat( + [negative_prompt_embeds, prompt_embeds], dim=0 + ) + prompt_attention_mask_batch = torch.cat( + [negative_prompt_attention_mask, prompt_attention_mask], dim=0 + ) + if do_spatio_temporal_guidance: + prompt_embeds_batch = torch.cat([prompt_embeds_batch, prompt_embeds], dim=0) + prompt_attention_mask_batch = torch.cat( + [ + prompt_attention_mask_batch, + prompt_attention_mask, + ], + dim=0, + ) + + # 3b. Encode and prepare conditioning data + self.video_scale_factor = self.video_scale_factor if is_video else 1 + conditioning_method = kwargs.get("conditioning_method", None) + vae_per_channel_normalize = kwargs.get("vae_per_channel_normalize", False) + image_cond_noise_scale = kwargs.get("image_cond_noise_scale", 0.0) + init_latents, conditioning_mask = self.prepare_conditioning( + media_items, + num_frames, + height, + width, + conditioning_method, + vae_per_channel_normalize, + ) + + # 4. Prepare latents. + latent_height = height // self.vae_scale_factor + latent_width = width // self.vae_scale_factor + latent_num_frames = num_frames // self.video_scale_factor + if isinstance(self.vae, CausalVideoAutoencoder) and is_video: + latent_num_frames += 1 + latent_frame_rate = frame_rate / self.video_scale_factor + num_latent_patches = latent_height * latent_width * latent_num_frames + latents = self.prepare_latents( + batch_size=batch_size * num_images_per_prompt, + num_latent_channels=self.transformer.config.in_channels, + num_patches=num_latent_patches, + dtype=prompt_embeds_batch.dtype, + device=device, + generator=generator, + latents=init_latents, + latents_mask=conditioning_mask, + ) + orig_conditiong_mask = conditioning_mask + if conditioning_mask is not None and is_video: + assert num_images_per_prompt == 1 + conditioning_mask = ( + torch.cat([conditioning_mask] * num_conds) + if num_conds > 1 + else conditioning_mask + ) + + # 5. Prepare timesteps + retrieve_timesteps_kwargs = {} + if isinstance(self.scheduler, TimestepShifter): + retrieve_timesteps_kwargs["samples"] = latents + timesteps, num_inference_steps = retrieve_timesteps( + self.scheduler, + num_inference_steps, + device, + timesteps, + **retrieve_timesteps_kwargs, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = max( + len(timesteps) - num_inference_steps * self.scheduler.order, 0 + ) + + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + if conditioning_method == ConditioningMethod.FIRST_FRAME: + latents = self.image_cond_noise_update( + t, + init_latents, + latents, + image_cond_noise_scale, + orig_conditiong_mask, + generator, + ) + + latent_model_input = ( + torch.cat([latents] * num_conds) if num_conds > 1 else latents + ) + latent_model_input = self.scheduler.scale_model_input( + latent_model_input, t + ) + + latent_frame_rates = ( + torch.ones( + latent_model_input.shape[0], 1, device=latent_model_input.device + ) + * latent_frame_rate + ) + + current_timestep = t + if not torch.is_tensor(current_timestep): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = latent_model_input.device.type == "mps" + if isinstance(current_timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + current_timestep = torch.tensor( + [current_timestep], + dtype=dtype, + device=latent_model_input.device, + ) + elif len(current_timestep.shape) == 0: + current_timestep = current_timestep[None].to( + latent_model_input.device + ) + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + current_timestep = current_timestep.expand( + latent_model_input.shape[0] + ).unsqueeze(-1) + scale_grid = ( + ( + 1 / latent_frame_rates, + self.vae_scale_factor, + self.vae_scale_factor, + ) + if self.transformer.use_rope + else None + ) + indices_grid = self.patchifier.get_grid( + orig_num_frames=latent_num_frames, + orig_height=latent_height, + orig_width=latent_width, + batch_size=latent_model_input.shape[0], + scale_grid=scale_grid, + device=latents.device, + ) + + if conditioning_mask is not None: + current_timestep = current_timestep * (1 - conditioning_mask) + # Choose the appropriate context manager based on `mixed_precision` + if mixed_precision: + if "xla" in device.type: + raise NotImplementedError( + "Mixed precision is not supported yet on XLA devices." + ) + + context_manager = torch.autocast(device.type, dtype=torch.bfloat16) + else: + context_manager = nullcontext() # Dummy context manager + + # predict noise model_output + with context_manager: + noise_pred = self.transformer( + latent_model_input.to(self.transformer.dtype), + indices_grid, + encoder_hidden_states=prompt_embeds_batch.to( + self.transformer.dtype + ), + encoder_attention_mask=prompt_attention_mask_batch, + timestep=current_timestep, + skip_layer_mask=skip_layer_mask, + skip_layer_strategy=skip_layer_strategy, + return_dict=False, + )[0] + + # perform guidance + if do_spatio_temporal_guidance: + noise_pred_text_perturb = noise_pred[-1:] + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred[:2].chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * ( + noise_pred_text - noise_pred_uncond + ) + if do_spatio_temporal_guidance: + noise_pred = noise_pred + stg_scale * ( + noise_pred_text - noise_pred_text_perturb + ) + if do_rescaling: + factor = noise_pred_text.std() / noise_pred.std() + factor = rescaling_scale * factor + (1 - rescaling_scale) + noise_pred = noise_pred * factor + + current_timestep = current_timestep[:1] + # learned sigma + if ( + self.transformer.config.out_channels // 2 + == self.transformer.config.in_channels + ): + noise_pred = noise_pred.chunk(2, dim=1)[0] + + # compute previous image: x_t -> x_t-1 + latents = self.scheduler.step( + noise_pred, + t if current_timestep is None else current_timestep, + latents, + **extra_step_kwargs, + return_dict=False, + )[0] + + # call the callback, if provided + if i == len(timesteps) - 1 or ( + (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0 + ): + progress_bar.update() + + if callback_on_step_end is not None: + callback_on_step_end(self, i, t, {}) + + if offload_to_cpu: + self.transformer = self.transformer.cpu() + if self._execution_device == "cuda": + torch.cuda.empty_cache() + + latents = self.patchifier.unpatchify( + latents=latents, + output_height=latent_height, + output_width=latent_width, + output_num_frames=latent_num_frames, + out_channels=self.transformer.in_channels + // math.prod(self.patchifier.patch_size), + ) + if output_type != "latent": + if self.vae.decoder.timestep_conditioning: + noise = torch.randn_like(latents) + if not isinstance(decode_timestep, list): + decode_timestep = [decode_timestep] * latents.shape[0] + if decode_noise_scale is None: + decode_noise_scale = decode_timestep + elif not isinstance(decode_noise_scale, list): + decode_noise_scale = [decode_noise_scale] * latents.shape[0] + + decode_timestep = torch.tensor(decode_timestep).to(latents.device) + decode_noise_scale = torch.tensor(decode_noise_scale).to( + latents.device + )[:, None, None, None, None] + latents = ( + latents * (1 - decode_noise_scale) + noise * decode_noise_scale + ) + else: + decode_timestep = None + image = vae_decode( + latents, + self.vae, + is_video, + vae_per_channel_normalize=kwargs["vae_per_channel_normalize"], + timestep=decode_timestep, + ) + image = self.image_processor.postprocess(image, output_type=output_type) + + else: + image = latents + + # Offload all models + self.maybe_free_model_hooks() + + if not return_dict: + return (image,) + + return ImagePipelineOutput(images=image) + + def prepare_conditioning( + self, + media_items: torch.Tensor, + num_frames: int, + height: int, + width: int, + method: ConditioningMethod = ConditioningMethod.UNCONDITIONAL, + vae_per_channel_normalize: bool = False, + ) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Prepare the conditioning data for the video generation. If an input media item is provided, encode it + and set the conditioning_mask to indicate which tokens to condition on. Input media item should have + the same height and width as the generated video. + + Args: + media_items (torch.Tensor): media items to condition on (images or videos) + num_frames (int): number of frames to generate + height (int): height of the generated video + width (int): width of the generated video + method (ConditioningMethod, optional): conditioning method to use. Defaults to ConditioningMethod.UNCONDITIONAL. + vae_per_channel_normalize (bool, optional): whether to normalize the input to the VAE per channel. Defaults to False. + + Returns: + Tuple[torch.Tensor, torch.Tensor]: the conditioning latents and the conditioning mask + """ + if media_items is None or method == ConditioningMethod.UNCONDITIONAL: + return None, None + + assert media_items.ndim == 5 + assert height == media_items.shape[-2] and width == media_items.shape[-1] + + # Encode the input video and repeat to the required number of frame-tokens + init_latents = vae_encode( + media_items.to(dtype=self.vae.dtype, device=self.vae.device), + self.vae, + vae_per_channel_normalize=vae_per_channel_normalize, + ).float() + + init_len, target_len = ( + init_latents.shape[2], + num_frames // self.video_scale_factor, + ) + if isinstance(self.vae, CausalVideoAutoencoder): + target_len += 1 + init_latents = init_latents[:, :, :target_len] + if target_len > init_len: + repeat_factor = (target_len + init_len - 1) // init_len # Ceiling division + init_latents = init_latents.repeat(1, 1, repeat_factor, 1, 1)[ + :, :, :target_len + ] + + # Prepare the conditioning mask (1.0 = condition on this token) + b, n, f, h, w = init_latents.shape + conditioning_mask = torch.zeros([b, 1, f, h, w], device=init_latents.device) + if method == ConditioningMethod.FIRST_FRAME: + conditioning_mask[:, :, 0] = 1.0 + + # Patchify the init latents and the mask + conditioning_mask = self.patchifier.patchify(conditioning_mask).squeeze(-1) + init_latents = self.patchifier.patchify(latents=init_latents) + return init_latents, conditioning_mask diff --git a/ltx_video/schedulers/__init__.py b/ltx_video/schedulers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/schedulers/rf.py b/ltx_video/schedulers/rf.py new file mode 100644 index 0000000000000000000000000000000000000000..892266d05878e54e3953cf0b46a594616155d993 --- /dev/null +++ b/ltx_video/schedulers/rf.py @@ -0,0 +1,331 @@ +import math +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Callable, Optional, Tuple, Union +import json +import os +from pathlib import Path + +import torch +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.schedulers.scheduling_utils import SchedulerMixin +from diffusers.utils import BaseOutput +from torch import Tensor +from safetensors import safe_open + + +from ltx_video.utils.torch_utils import append_dims + +from ltx_video.utils.diffusers_config_mapping import ( + diffusers_and_ours_config_mapping, + make_hashable_key, +) + + +def simple_diffusion_resolution_dependent_timestep_shift( + samples: Tensor, + timesteps: Tensor, + n: int = 32 * 32, +) -> Tensor: + if len(samples.shape) == 3: + _, m, _ = samples.shape + elif len(samples.shape) in [4, 5]: + m = math.prod(samples.shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + snr = (timesteps / (1 - timesteps)) ** 2 + shift_snr = torch.log(snr) + 2 * math.log(m / n) + shifted_timesteps = torch.sigmoid(0.5 * shift_snr) + + return shifted_timesteps + + +def time_shift(mu: float, sigma: float, t: Tensor): + return math.exp(mu) / (math.exp(mu) + (1 / t - 1) ** sigma) + + +def get_normal_shift( + n_tokens: int, + min_tokens: int = 1024, + max_tokens: int = 4096, + min_shift: float = 0.95, + max_shift: float = 2.05, +) -> Callable[[float], float]: + m = (max_shift - min_shift) / (max_tokens - min_tokens) + b = min_shift - m * min_tokens + return m * n_tokens + b + + +def strech_shifts_to_terminal(shifts: Tensor, terminal=0.1): + """ + Stretch a function (given as sampled shifts) so that its final value matches the given terminal value + using the provided formula. + + Parameters: + - shifts (Tensor): The samples of the function to be stretched (PyTorch Tensor). + - terminal (float): The desired terminal value (value at the last sample). + + Returns: + - Tensor: The stretched shifts such that the final value equals `terminal`. + """ + if shifts.numel() == 0: + raise ValueError("The 'shifts' tensor must not be empty.") + + # Ensure terminal value is valid + if terminal <= 0 or terminal >= 1: + raise ValueError("The terminal value must be between 0 and 1 (exclusive).") + + # Transform the shifts using the given formula + one_minus_z = 1 - shifts + scale_factor = one_minus_z[-1] / (1 - terminal) + stretched_shifts = 1 - (one_minus_z / scale_factor) + + return stretched_shifts + + +def sd3_resolution_dependent_timestep_shift( + samples: Tensor, timesteps: Tensor, target_shift_terminal: Optional[float] = None +) -> Tensor: + """ + Shifts the timestep schedule as a function of the generated resolution. + + In the SD3 paper, the authors empirically how to shift the timesteps based on the resolution of the target images. + For more details: https://arxiv.org/pdf/2403.03206 + + In Flux they later propose a more dynamic resolution dependent timestep shift, see: + https://github.com/black-forest-labs/flux/blob/87f6fff727a377ea1c378af692afb41ae84cbe04/src/flux/sampling.py#L66 + + + Args: + samples (Tensor): A batch of samples with shape (batch_size, channels, height, width) or + (batch_size, channels, frame, height, width). + timesteps (Tensor): A batch of timesteps with shape (batch_size,). + target_shift_terminal (float): The target terminal value for the shifted timesteps. + + Returns: + Tensor: The shifted timesteps. + """ + if len(samples.shape) == 3: + _, m, _ = samples.shape + elif len(samples.shape) in [4, 5]: + m = math.prod(samples.shape[2:]) + else: + raise ValueError( + "Samples must have shape (b, t, c), (b, c, h, w) or (b, c, f, h, w)" + ) + + shift = get_normal_shift(m) + time_shifts = time_shift(shift, 1, timesteps) + if target_shift_terminal is not None: # Stretch the shifts to the target terminal + time_shifts = strech_shifts_to_terminal(time_shifts, target_shift_terminal) + return time_shifts + + +class TimestepShifter(ABC): + @abstractmethod + def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor: + pass + + +@dataclass +class RectifiedFlowSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +class RectifiedFlowScheduler(SchedulerMixin, ConfigMixin, TimestepShifter): + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps=1000, + shifting: Optional[str] = None, + base_resolution: int = 32**2, + target_shift_terminal: Optional[float] = None, + ): + super().__init__() + self.init_noise_sigma = 1.0 + self.num_inference_steps = None + self.timesteps = self.sigmas = torch.linspace( + 1, 1 / num_train_timesteps, num_train_timesteps + ) + self.delta_timesteps = self.timesteps - torch.cat( + [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])] + ) + self.shifting = shifting + self.base_resolution = base_resolution + self.target_shift_terminal = target_shift_terminal + + def shift_timesteps(self, samples: Tensor, timesteps: Tensor) -> Tensor: + if self.shifting == "SD3": + return sd3_resolution_dependent_timestep_shift( + samples, timesteps, self.target_shift_terminal + ) + elif self.shifting == "SimpleDiffusion": + return simple_diffusion_resolution_dependent_timestep_shift( + samples, timesteps, self.base_resolution + ) + return timesteps + + def set_timesteps( + self, + num_inference_steps: int, + samples: Tensor, + device: Union[str, torch.device] = None, + ): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): The number of diffusion steps used when generating samples. + samples (`Tensor`): A batch of samples with shape. + device (`Union[str, torch.device]`, *optional*): The device to which the timesteps tensor will be moved. + """ + num_inference_steps = min(self.config.num_train_timesteps, num_inference_steps) + timesteps = torch.linspace(1, 1 / num_inference_steps, num_inference_steps).to( + device + ) + self.timesteps = self.shift_timesteps(samples, timesteps) + self.delta_timesteps = self.timesteps - torch.cat( + [self.timesteps[1:], torch.zeros_like(self.timesteps[-1:])] + ) + self.num_inference_steps = num_inference_steps + self.sigmas = self.timesteps + + @staticmethod + def from_pretrained(pretrained_model_path: Union[str, os.PathLike]): + pretrained_model_path = Path(pretrained_model_path) + if pretrained_model_path.is_file(): + comfy_single_file_state_dict = {} + with safe_open(pretrained_model_path, framework="pt", device="cpu") as f: + metadata = f.metadata() + for k in f.keys(): + comfy_single_file_state_dict[k] = f.get_tensor(k) + configs = json.loads(metadata["config"]) + config = configs["scheduler"] + del comfy_single_file_state_dict + + elif pretrained_model_path.is_dir(): + diffusers_noise_scheduler_config_path = ( + pretrained_model_path / "scheduler" / "scheduler_config.json" + ) + + with open(diffusers_noise_scheduler_config_path, "r") as f: + scheduler_config = json.load(f) + hashable_config = make_hashable_key(scheduler_config) + if hashable_config in diffusers_and_ours_config_mapping: + config = diffusers_and_ours_config_mapping[hashable_config] + return RectifiedFlowScheduler.from_config(config) + + def scale_model_input( + self, sample: torch.FloatTensor, timestep: Optional[int] = None + ) -> torch.FloatTensor: + # pylint: disable=unused-argument + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def step( + self, + model_output: torch.FloatTensor, + timestep: torch.FloatTensor, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[RectifiedFlowSchedulerOutput, Tuple]: + # pylint: disable=unused-argument + """ + Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): + The direct output from learned diffusion model. + timestep (`float`): + The current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + A current instance of a sample created by the diffusion process. + eta (`float`): + The weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`, defaults to `False`): + If `True`, computes "corrected" `model_output` from the clipped predicted original sample. Necessary + because predicted original sample is clipped to [-1, 1] when `self.config.clip_sample` is `True`. If no + clipping has happened, "corrected" `model_output` would coincide with the one provided as input and + `use_clipped_model_output` has no effect. + generator (`torch.Generator`, *optional*): + A random number generator. + variance_noise (`torch.FloatTensor`): + Alternative to generating noise with `generator` by directly providing the noise for the variance + itself. Useful for methods such as [`CycleDiffusion`]. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~schedulers.scheduling_ddim.DDIMSchedulerOutput`] or `tuple`. + + Returns: + [`~schedulers.scheduling_utils.RectifiedFlowSchedulerOutput`] or `tuple`: + If return_dict is `True`, [`~schedulers.rf_scheduler.RectifiedFlowSchedulerOutput`] is returned, + otherwise a tuple is returned where the first element is the sample tensor. + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + if timestep.ndim == 0: + # Global timestep + current_index = (self.timesteps - timestep).abs().argmin() + dt = self.delta_timesteps.gather(0, current_index.unsqueeze(0)) + else: + # Timestep per token + assert timestep.ndim == 2 + current_index = ( + (self.timesteps[:, None, None] - timestep[None]).abs().argmin(dim=0) + ) + dt = self.delta_timesteps[current_index] + # Special treatment for zero timestep tokens - set dt to 0 so prev_sample = sample + dt = torch.where(timestep == 0.0, torch.zeros_like(dt), dt)[..., None] + + prev_sample = sample - dt * model_output + + if not return_dict: + return (prev_sample,) + + return RectifiedFlowSchedulerOutput(prev_sample=prev_sample) + + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.FloatTensor, + ) -> torch.FloatTensor: + sigmas = timesteps + sigmas = append_dims(sigmas, original_samples.ndim) + alphas = 1 - sigmas + noisy_samples = alphas * original_samples + sigmas * noise + return noisy_samples diff --git a/ltx_video/utils/__init__.py b/ltx_video/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/ltx_video/utils/conditioning_method.py b/ltx_video/utils/conditioning_method.py new file mode 100644 index 0000000000000000000000000000000000000000..20befcb747d51632dc3a9218aba9bf625e14d2e9 --- /dev/null +++ b/ltx_video/utils/conditioning_method.py @@ -0,0 +1,6 @@ +from enum import Enum + + +class ConditioningMethod(Enum): + UNCONDITIONAL = "unconditional" + FIRST_FRAME = "first_frame" diff --git a/ltx_video/utils/diffusers_config_mapping.py b/ltx_video/utils/diffusers_config_mapping.py new file mode 100644 index 0000000000000000000000000000000000000000..53c0082d182617f6f84eab9c849f7ef0224becb8 --- /dev/null +++ b/ltx_video/utils/diffusers_config_mapping.py @@ -0,0 +1,174 @@ +def make_hashable_key(dict_key): + def convert_value(value): + if isinstance(value, list): + return tuple(value) + elif isinstance(value, dict): + return tuple(sorted((k, convert_value(v)) for k, v in value.items())) + else: + return value + + return tuple(sorted((k, convert_value(v)) for k, v in dict_key.items())) + + +DIFFUSERS_SCHEDULER_CONFIG = { + "_class_name": "FlowMatchEulerDiscreteScheduler", + "_diffusers_version": "0.32.0.dev0", + "base_image_seq_len": 1024, + "base_shift": 0.95, + "invert_sigmas": False, + "max_image_seq_len": 4096, + "max_shift": 2.05, + "num_train_timesteps": 1000, + "shift": 1.0, + "shift_terminal": 0.1, + "use_beta_sigmas": False, + "use_dynamic_shifting": True, + "use_exponential_sigmas": False, + "use_karras_sigmas": False, +} +DIFFUSERS_TRANSFORMER_CONFIG = { + "_class_name": "LTXVideoTransformer3DModel", + "_diffusers_version": "0.32.0.dev0", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_out_bias": True, + "caption_channels": 4096, + "cross_attention_dim": 2048, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "num_attention_heads": 32, + "num_layers": 28, + "out_channels": 128, + "patch_size": 1, + "patch_size_t": 1, + "qk_norm": "rms_norm_across_heads", +} +DIFFUSERS_VAE_CONFIG = { + "_class_name": "AutoencoderKLLTXVideo", + "_diffusers_version": "0.32.0.dev0", + "block_out_channels": [128, 256, 512, 512], + "decoder_causal": False, + "encoder_causal": True, + "in_channels": 3, + "latent_channels": 128, + "layers_per_block": [4, 3, 3, 3, 4], + "out_channels": 3, + "patch_size": 4, + "patch_size_t": 1, + "resnet_norm_eps": 1e-06, + "scaling_factor": 1.0, + "spatio_temporal_scaling": [True, True, True, False], +} + +OURS_SCHEDULER_CONFIG = { + "_class_name": "RectifiedFlowScheduler", + "_diffusers_version": "0.25.1", + "num_train_timesteps": 1000, + "shifting": "SD3", + "base_resolution": None, + "target_shift_terminal": 0.1, +} + +OURS_TRANSFORMER_CONFIG = { + "_class_name": "Transformer3DModel", + "_diffusers_version": "0.25.1", + "_name_or_path": "PixArt-alpha/PixArt-XL-2-256x256", + "activation_fn": "gelu-approximate", + "attention_bias": True, + "attention_head_dim": 64, + "attention_type": "default", + "caption_channels": 4096, + "cross_attention_dim": 2048, + "double_self_attention": False, + "dropout": 0.0, + "in_channels": 128, + "norm_elementwise_affine": False, + "norm_eps": 1e-06, + "norm_num_groups": 32, + "num_attention_heads": 32, + "num_embeds_ada_norm": 1000, + "num_layers": 28, + "num_vector_embeds": None, + "only_cross_attention": False, + "out_channels": 128, + "project_to_2d_pos": True, + "upcast_attention": False, + "use_linear_projection": False, + "qk_norm": "rms_norm", + "standardization_norm": "rms_norm", + "positional_embedding_type": "rope", + "positional_embedding_theta": 10000.0, + "positional_embedding_max_pos": [20, 2048, 2048], + "timestep_scale_multiplier": 1000, +} +OURS_VAE_CONFIG = { + "_class_name": "CausalVideoAutoencoder", + "dims": 3, + "in_channels": 3, + "out_channels": 3, + "latent_channels": 128, + "blocks": [ + ["res_x", 4], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x_y", 1], + ["res_x", 3], + ["compress_all", 1], + ["res_x", 3], + ["res_x", 4], + ], + "scaling_factor": 1.0, + "norm_layer": "pixel_norm", + "patch_size": 4, + "latent_log_var": "uniform", + "use_quant_conv": False, + "causal_decoder": False, +} + + +diffusers_and_ours_config_mapping = { + make_hashable_key(DIFFUSERS_SCHEDULER_CONFIG): OURS_SCHEDULER_CONFIG, + make_hashable_key(DIFFUSERS_TRANSFORMER_CONFIG): OURS_TRANSFORMER_CONFIG, + make_hashable_key(DIFFUSERS_VAE_CONFIG): OURS_VAE_CONFIG, +} + + +TRANSFORMER_KEYS_RENAME_DICT = { + "proj_in": "patchify_proj", + "time_embed": "adaln_single", + "norm_q": "q_norm", + "norm_k": "k_norm", +} + + +VAE_KEYS_RENAME_DICT = { + "decoder.up_blocks.3.conv_in": "decoder.up_blocks.7", + "decoder.up_blocks.3.upsamplers.0": "decoder.up_blocks.8", + "decoder.up_blocks.3": "decoder.up_blocks.9", + "decoder.up_blocks.2.upsamplers.0": "decoder.up_blocks.5", + "decoder.up_blocks.2.conv_in": "decoder.up_blocks.4", + "decoder.up_blocks.2": "decoder.up_blocks.6", + "decoder.up_blocks.1.upsamplers.0": "decoder.up_blocks.2", + "decoder.up_blocks.1": "decoder.up_blocks.3", + "decoder.up_blocks.0": "decoder.up_blocks.1", + "decoder.mid_block": "decoder.up_blocks.0", + "encoder.down_blocks.3": "encoder.down_blocks.8", + "encoder.down_blocks.2.downsamplers.0": "encoder.down_blocks.7", + "encoder.down_blocks.2": "encoder.down_blocks.6", + "encoder.down_blocks.1.downsamplers.0": "encoder.down_blocks.4", + "encoder.down_blocks.1.conv_out": "encoder.down_blocks.5", + "encoder.down_blocks.1": "encoder.down_blocks.3", + "encoder.down_blocks.0.conv_out": "encoder.down_blocks.2", + "encoder.down_blocks.0.downsamplers.0": "encoder.down_blocks.1", + "encoder.down_blocks.0": "encoder.down_blocks.0", + "encoder.mid_block": "encoder.down_blocks.9", + "conv_shortcut.conv": "conv_shortcut", + "resnets": "res_blocks", + "norm3": "norm3.norm", + "latents_mean": "per_channel_statistics.mean-of-means", + "latents_std": "per_channel_statistics.std-of-means", +} diff --git a/ltx_video/utils/skip_layer_strategy.py b/ltx_video/utils/skip_layer_strategy.py new file mode 100644 index 0000000000000000000000000000000000000000..fd4c133053c590ee9761aea83e37764467110d21 --- /dev/null +++ b/ltx_video/utils/skip_layer_strategy.py @@ -0,0 +1,6 @@ +from enum import Enum, auto + + +class SkipLayerStrategy(Enum): + Attention = auto() + Residual = auto() diff --git a/ltx_video/utils/torch_utils.py b/ltx_video/utils/torch_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..991b07c36269ef4dafb88a85834f2596647ba816 --- /dev/null +++ b/ltx_video/utils/torch_utils.py @@ -0,0 +1,25 @@ +import torch +from torch import nn + + +def append_dims(x: torch.Tensor, target_dims: int) -> torch.Tensor: + """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" + dims_to_append = target_dims - x.ndim + if dims_to_append < 0: + raise ValueError( + f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" + ) + elif dims_to_append == 0: + return x + return x[(...,) + (None,) * dims_to_append] + + +class Identity(nn.Module): + """A placeholder identity operator that is argument-insensitive.""" + + def __init__(self, *args, **kwargs) -> None: # pylint: disable=unused-argument + super().__init__() + + # pylint: disable=unused-argument + def forward(self, x: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return x diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..589ec9c94dbf51aa6b4798a7dbc51875fa164203 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,33 @@ +[build-system] +requires = ["setuptools>=42", "wheel"] +build-backend = "setuptools.build_meta" + +[project] +name = "ltx-video" +version = "0.1.2" +description = "A package for LTX-Video model" +authors = [ + { name = "Sapir Weissbuch", email = "sapir@lightricks.com" } +] +requires-python = ">=3.10" +readme = "README.md" +classifiers = [ + "Programming Language :: Python :: 3", + "Operating System :: OS Independent" +] +dependencies = [ + "torch>=2.1.0", + "diffusers>=0.28.2", + "transformers>=4.44.2", + "sentencepiece>=0.1.96", + "huggingface-hub~=0.25.2", + "einops" +] + +[project.optional-dependencies] +# Instead of thinking of them as optional, think of them as specific modes +inference-script = [ + "accelerate", + "matplotlib", + "imageio[ffmpeg]" +] diff --git a/ttv.py b/ttv.py new file mode 100644 index 0000000000000000000000000000000000000000..e01b0d3ea4badf3722fd61034d01fb55a974fd39 --- /dev/null +++ b/ttv.py @@ -0,0 +1,21 @@ +import torch +from diffusers import LTXPipeline +from diffusers.utils import export_to_video +import os + +def generate_video(prompt, negative_prompt): + pipe = LTXPipeline.from_pretrained("Lightricks/LTX-Video", torch_dtype=torch.bfloat16) + pipe.to("cuda") + + video = pipe( + prompt=prompt, + negative_prompt=negative_prompt, + width=704, + height=480, + num_frames=161, + num_inference_steps=50, + ).frames[0] + + output_path = f"output_{len(os.listdir('.'))}.mp4" + export_to_video(video, output_path, fps=24) + return output_path