Luo-Yihang commited on
Commit
4c35d22
·
1 Parent(s): 7748d72

initial code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -1
  2. .gitignore +147 -0
  3. .gitmodules +3 -0
  4. LICENSE +13 -0
  5. README.md +4 -4
  6. app.py +588 -0
  7. assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png +0 -0
  8. assets/examples/i2mv/boy.jpg +0 -0
  9. assets/examples/i2mv/cake.png +0 -0
  10. assets/examples/i2mv/cup.png +0 -0
  11. assets/examples/i2mv/dragontoy.jpg +0 -0
  12. assets/examples/i2mv/gso_rabbit.jpg +0 -0
  13. assets/examples/i2mv/house2.png +0 -0
  14. assets/examples/i2mv/mannequin.png +0 -0
  15. assets/examples/i2mv/sea_turtle.png +0 -0
  16. assets/examples/i2mv/skull.png +0 -0
  17. assets/examples/mv_lq/cake/00.png +0 -0
  18. assets/examples/mv_lq/cake/01.png +0 -0
  19. assets/examples/mv_lq/cake/02.png +0 -0
  20. assets/examples/mv_lq/cake/03.png +0 -0
  21. assets/examples/mv_lq/fish/00.png +0 -0
  22. assets/examples/mv_lq/fish/01.png +0 -0
  23. assets/examples/mv_lq/fish/02.png +0 -0
  24. assets/examples/mv_lq/fish/03.png +0 -0
  25. assets/examples/mv_lq/gascan/00.png +0 -0
  26. assets/examples/mv_lq/gascan/01.png +0 -0
  27. assets/examples/mv_lq/gascan/02.png +0 -0
  28. assets/examples/mv_lq/gascan/03.png +0 -0
  29. assets/examples/mv_lq/house/00.png +0 -0
  30. assets/examples/mv_lq/house/01.png +0 -0
  31. assets/examples/mv_lq/house/02.png +0 -0
  32. assets/examples/mv_lq/house/03.png +0 -0
  33. assets/examples/mv_lq/lamp/00.png +0 -0
  34. assets/examples/mv_lq/lamp/01.png +0 -0
  35. assets/examples/mv_lq/lamp/02.png +0 -0
  36. assets/examples/mv_lq/lamp/03.png +0 -0
  37. assets/examples/mv_lq/mario/00.png +0 -0
  38. assets/examples/mv_lq/mario/01.png +0 -0
  39. assets/examples/mv_lq/mario/02.png +0 -0
  40. assets/examples/mv_lq/mario/03.png +0 -0
  41. assets/examples/mv_lq/oldman/00.png +0 -0
  42. assets/examples/mv_lq/oldman/01.png +0 -0
  43. assets/examples/mv_lq/oldman/02.png +0 -0
  44. assets/examples/mv_lq/oldman/03.png +0 -0
  45. assets/examples/mv_lq/tower/00.png +0 -0
  46. assets/examples/mv_lq/tower/01.png +0 -0
  47. assets/examples/mv_lq/tower/02.png +0 -0
  48. assets/examples/mv_lq/tower/03.png +0 -0
  49. assets/examples/mv_lq/truck/00.png +0 -0
  50. assets/examples/mv_lq/truck/01.png +0 -0
.gitattributes CHANGED
@@ -32,4 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
32
  *.xz filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ # Pipfile.lock
96
+
97
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
98
+ __pypackages__/
99
+
100
+ # Celery stuff
101
+ celerybeat-schedule
102
+ celerybeat.pid
103
+
104
+ # SageMath parsed files
105
+ *.sage.py
106
+
107
+ # Environments
108
+ .env
109
+ .venv
110
+ env/
111
+ venv/
112
+ ENV/
113
+ env.bak/
114
+ venv.bak/
115
+
116
+ # Spyder project settings
117
+ .spyderproject
118
+ .spyproject
119
+
120
+ # Rope project settings
121
+ .ropeproject
122
+
123
+ # mkdocs documentation
124
+ /site
125
+
126
+ # mypy
127
+ .mypy_cache/
128
+ .dmypy.json
129
+ dmypy.json
130
+
131
+ # Pyre type checker
132
+ .pyre/
133
+
134
+ # pytype static type analyzer
135
+ .pytype/
136
+
137
+ # Cython debug symbols
138
+ cython_debug
139
+ .idea/
140
+ cloud_tools/
141
+
142
+ output
143
+ pretrained_models
144
+ results
145
+ develop
146
+ gradio_results
147
+ demo
.gitmodules ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [submodule "extern/LGM"]
2
+ path = extern/LGM
3
+ url = https://github.com/3DTopia/LGM.git
LICENSE ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ S-Lab License 1.0
2
+ Copyright 2024 S-Lab
3
+
4
+ Redistribution and use for non-commercial purpose in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
7
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
8
+
9
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
10
+ IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
11
+ DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
12
+
13
+ In the event that redistribution and/or use for commercial purpose in source or binary forms, with or without modification is required, please contact the contributor(s) of the work.
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
  title: 3DEnhancer
3
- emoji: 📚
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.20.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
 
1
  ---
2
  title: 3DEnhancer
3
+ emoji: 🔆
4
+ colorFrom: red
5
+ colorTo: green
6
  sdk: gradio
7
+ sdk_version: 4.44.1
8
  app_file: app.py
9
  pinned: false
10
  license: other
app.py ADDED
@@ -0,0 +1,588 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ warnings.filterwarnings('ignore')
3
+
4
+ import spaces
5
+
6
+ import os
7
+ import tyro
8
+ import imageio
9
+ import numpy as np
10
+ import tqdm
11
+ import cv2
12
+ import torch
13
+ import torch.nn.functional as F
14
+ from torchvision import transforms as T
15
+ import torchvision.transforms.functional as TF
16
+ from safetensors.torch import load_file
17
+ import kiui
18
+ from kiui.op import recenter
19
+ from kiui.cam import orbit_camera
20
+ import rembg
21
+ import gradio as gr
22
+ from gradio_imageslider import ImageSlider
23
+
24
+ import sys
25
+ sys.path.insert(0, "src")
26
+ from src.enhancer import Enhancer
27
+ from src.utils.camera import get_c2ws
28
+
29
+ # import LGM
30
+ sys.path.insert(0, "extern/LGM")
31
+ from core.options import AllConfigs
32
+ from core.models import LGM
33
+ from mvdream.pipeline_mvdream import MVDreamPipeline
34
+
35
+
36
+ # download checkpoints
37
+ from huggingface_hub import hf_hub_download
38
+ hf_hub_download(repo_id="ashawkey/LGM", filename="model_fp16_fixrot.safetensors", local_dir='pretrained_models/LGM')
39
+ hf_hub_download(repo_id="Luo-Yihang/3DEnhancer", filename="model.safetensors", local_dir='pretrained_models/3DEnhancer')
40
+
41
+
42
+ ### Title and Description ###
43
+ #### Description ####
44
+ title = r"""<h1 align="center">3DEnhancer: Consistent Multi-View Diffusion for 3D Enhancement</h1>"""
45
+
46
+ important_link = r"""
47
+ <div align='center'>
48
+ <a href='https://arxiv.org/abs/2412.18565'>[arxiv]</a>
49
+ &ensp; <a href='https://Luo-Yihang.github.io/projects/3DEnhancer'>[Project Page]</a>
50
+ &ensp; <a href='https://github.com/Luo-Yihang/3DEnhancer'>[Code]</a>
51
+ </div>
52
+ """
53
+
54
+ authors = r"""
55
+ <div align='center'>
56
+ <a href='https://github.com/Luo-Yihang'>Yihang Luo</a>
57
+ &ensp; <a href='https://shangchenzhou.com/'>Shangchen Zhou</a>
58
+ &ensp; <a href='https://nirvanalan.github.io/'>Yushi Lan</a>
59
+ &ensp; <a href='https://xingangpan.github.io/'>Xingang Pan</a>
60
+ &ensp; <a href='https://www.mmlab-ntu.com/person/ccloy/index.html'>Chen Change Loy</a>
61
+ </div>
62
+ """
63
+
64
+ affiliation = r"""
65
+ <div align='center'>
66
+ <a href='https://www.mmlab-ntu.com/'>S-Lab, NTU Singapore</a>
67
+ </div>
68
+ """
69
+
70
+ description = r"""
71
+ <b>Official Gradio demo</b> for <a href='https://yihangluo.com/projects/3DEnhancer' target='_blank'><b>3DEnhancer: Consistent Multi-View Diffusion for 3D Enhancement</b></a>.<br>
72
+ 🔥 3DEnhancer employs a multi-view diffusion model to enhance multi-view images, thus improving 3D models. Our contributions include a robust data augmentation pipeline, and the view-consistent blocks that integrate multi-view row attention and near-view epipolar aggregation modules to promote view consistency. <br>
73
+ """
74
+
75
+ article = r"""
76
+ <br>If 3DEnhancer is helpful, please help to ⭐ the <a href='https://github.com/Luo-Yihang/3DEnhancer' target='_blank'>Github Repo</a>. Thanks!
77
+ [![GitHub Stars](https://img.shields.io/github/stars/Luo-Yihang/3DEnhancer)](https://github.com/Luo-Yihang/3DEnhancer)
78
+ ---
79
+ 📝 **License**
80
+ <br>
81
+ This project is licensed under <a href="https://github.com/Luo-Yihang/3DEnhancer/blob/main/LICENSE">S-Lab License 1.0</a>,
82
+ Redistribution and use for non-commercial purposes should follow this license.
83
+ <br>
84
+ 📝 **Citation**
85
+ <br>
86
+ If our work is useful for your research, please consider citing:
87
+ ```bibtex
88
+ @article{luo20243denhancer,
89
+ title={3DEnhancer: Consistent Multi-View Diffusion for 3D Enhancement},
90
+ author={Yihang Luo and Shangchen Zhou and Yushi Lan and Xingang Pan and Chen Change Loy},
91
+ booktitle={arXiv preprint arXiv:2412.18565}
92
+ year={2024},
93
+ }
94
+ ```
95
+ 📧 **Contact**
96
+ <br>
97
+ If you have any questions, please feel free to reach me out at <b>luo_yihang@outlook.com</b>.
98
+ """
99
+
100
+
101
+ IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
102
+ IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
103
+ BASE_SAVE_PATH = 'gradio_results'
104
+ GRADIO_VIDEO_PATH = f'{BASE_SAVE_PATH}/gradio_output.mp4'
105
+ GRADIO_PLY_PATH = f'{BASE_SAVE_PATH}/gradio_output.ply'
106
+ GRADIO_ENHANCED_VIDEO_PATH = f'{BASE_SAVE_PATH}/gradio_enhanced_output.mp4'
107
+ GRADIO_ENHANCED_PLY_PATH = f'{BASE_SAVE_PATH}/gradio_enhanced_output.ply'
108
+ DEFAULT_NEG_PROMPT = "ugly, blurry, pixelated obscure, unnatural colors, poor lighting, dull, unclear, cropped, lowres, low quality, artifacts, duplicate"
109
+ DEFAULT_SEED = 0
110
+ os.makedirs(BASE_SAVE_PATH, exist_ok=True)
111
+
112
+
113
+ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
114
+
115
+ # load dreams
116
+ pipe_text = MVDreamPipeline.from_pretrained(
117
+ 'ashawkey/mvdream-sd2.1-diffusers', # remote weights
118
+ torch_dtype=torch.float16,
119
+ trust_remote_code=True
120
+ )
121
+ pipe_text = pipe_text.to(device)
122
+
123
+ pipe_image = MVDreamPipeline.from_pretrained(
124
+ "ashawkey/imagedream-ipmv-diffusers", # remote weights
125
+ torch_dtype=torch.float16,
126
+ trust_remote_code=True
127
+ )
128
+ pipe_image = pipe_image.to(device)
129
+
130
+ # load lgm
131
+ lgm_opt = tyro.cli(AllConfigs, args=["big"])
132
+
133
+ tan_half_fov = np.tan(0.5 * np.deg2rad(lgm_opt.fovy))
134
+ proj_matrix = torch.zeros(4, 4, dtype=torch.float32, device=device)
135
+ proj_matrix[0, 0] = 1 / tan_half_fov
136
+ proj_matrix[1, 1] = 1 / tan_half_fov
137
+ proj_matrix[2, 2] = (lgm_opt.zfar + lgm_opt.znear) / (lgm_opt.zfar - lgm_opt.znear)
138
+ proj_matrix[3, 2] = - (lgm_opt.zfar * lgm_opt.znear) / (lgm_opt.zfar - lgm_opt.znear)
139
+ proj_matrix[2, 3] = 1
140
+
141
+ lgm_model = LGM(lgm_opt)
142
+ lgm_model = lgm_model.half().to(device)
143
+ ckpt = load_file("pretrained_models/LGM/model_fp16_fixrot.safetensors", device='cpu')
144
+ lgm_model.load_state_dict(ckpt, strict=False)
145
+ lgm_model.eval()
146
+
147
+ # load 3denhancer
148
+ enhancer = Enhancer(
149
+ model_path = "pretrained_models/3DEnhancer/model.safetensors",
150
+ config_path = "src/configs/config.py",
151
+ )
152
+
153
+ # load rembg
154
+ bg_remover = rembg.new_session()
155
+
156
+ @torch.no_grad()
157
+ @spaces.GPU
158
+ def gen_mv(ref_image, ref_text):
159
+ kiui.seed_everything(DEFAULT_SEED)
160
+
161
+ # text-conditioned
162
+ if ref_image is None:
163
+ mv_image_uint8 = pipe_text(ref_text, negative_prompt=DEFAULT_NEG_PROMPT, num_inference_steps=30, guidance_scale=7.5, elevation=0)
164
+ mv_image_uint8 = (mv_image_uint8 * 255).astype(np.uint8)
165
+ # bg removal
166
+ mv_image = []
167
+ for i in range(4):
168
+ image = rembg.remove(mv_image_uint8[i], session=bg_remover) # [H, W, 4]
169
+ # to white bg
170
+ image = image.astype(np.float32) / 255
171
+ image = recenter(image, image[..., 0] > 0, border_ratio=0.2)
172
+ image = image[..., :3] * image[..., -1:] + (1 - image[..., -1:])
173
+ mv_image.append(image)
174
+ # image-conditioned (may also input text, but no text usually works too)
175
+ else:
176
+ ref_image = np.array(ref_image) # uint8
177
+ # bg removal
178
+ carved_image = rembg.remove(ref_image, session=bg_remover) # [H, W, 4]
179
+ mask = carved_image[..., -1] > 0
180
+ image = recenter(carved_image, mask, border_ratio=0.2)
181
+ image = image.astype(np.float32) / 255.0
182
+ image = image[..., :3] * image[..., 3:4] + (1 - image[..., 3:4])
183
+ mv_image = pipe_image(ref_text, image, negative_prompt=DEFAULT_NEG_PROMPT, num_inference_steps=30, guidance_scale=5.0, elevation=0)
184
+
185
+ # mv_image, a list of 4 np_arrays in shape (256, 256, 3) in range (0.0, 1.0)
186
+ mv_image_512 = []
187
+ for i in range(len(mv_image)):
188
+ mv_image_512.append(cv2.resize(mv_image[i], (512, 512), interpolation=cv2.INTER_LINEAR))
189
+
190
+ return mv_image_512[0], mv_image_512[1], mv_image_512[2], mv_image_512[3], ref_text, 120
191
+
192
+
193
+ @torch.no_grad()
194
+ @spaces.GPU
195
+ def gen_3d(image_0, image_1, image_2, image_3, elevation, output_video_path, output_ply_path):
196
+ kiui.seed_everything(DEFAULT_SEED)
197
+
198
+ mv_image = [image_0, image_1, image_2, image_3]
199
+ for i in range(len(mv_image)):
200
+ if type(mv_image[i]) is tuple:
201
+ mv_image[i] = mv_image[i][1]
202
+ mv_image[i] = np.array(mv_image[i]).astype(np.float32) / 255.0
203
+ mv_image[i] = cv2.resize(mv_image[i], (256, 256), interpolation=cv2.INTER_AREA)
204
+
205
+ # generate gaussians
206
+ input_image = np.stack([mv_image[1], mv_image[2], mv_image[3], mv_image[0]], axis=0) # [4, 256, 256, 3], float32
207
+ input_image = torch.from_numpy(input_image).permute(0, 3, 1, 2).float().to(device) # [4, 3, 256, 256]
208
+ input_image = F.interpolate(input_image, size=(lgm_opt.input_size, lgm_opt.input_size), mode='bilinear', align_corners=False)
209
+ input_image = TF.normalize(input_image, IMAGENET_DEFAULT_MEAN, IMAGENET_DEFAULT_STD)
210
+
211
+ rays_embeddings = lgm_model.prepare_default_rays(device, elevation=elevation)
212
+ input_image = torch.cat([input_image, rays_embeddings], dim=1).unsqueeze(0) # [1, 4, 9, H, W]
213
+
214
+ with torch.no_grad():
215
+ with torch.autocast(device_type='cuda', dtype=torch.float16):
216
+ # generate gaussians
217
+ gaussians = lgm_model.forward_gaussians(input_image)
218
+ lgm_model.gs.save_ply(gaussians, output_ply_path)
219
+
220
+ # render 360 video
221
+ images = []
222
+ elevation = 0
223
+ if lgm_opt.fancy_video:
224
+ azimuth = np.arange(0, 720, 4, dtype=np.int32)
225
+ for azi in tqdm.tqdm(azimuth):
226
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=lgm_opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
227
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
228
+
229
+ # cameras needed by gaussian rasterizer
230
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
231
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
232
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
233
+
234
+ scale = min(azi / 360, 1)
235
+
236
+ image = lgm_model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=scale)['image']
237
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
238
+ else:
239
+ azimuth = np.arange(0, 360, 2, dtype=np.int32)
240
+ for azi in tqdm.tqdm(azimuth):
241
+ cam_poses = torch.from_numpy(orbit_camera(elevation, azi, radius=lgm_opt.cam_radius, opengl=True)).unsqueeze(0).to(device)
242
+ cam_poses[:, :3, 1:3] *= -1 # invert up & forward direction
243
+
244
+ # cameras needed by gaussian rasterizer
245
+ cam_view = torch.inverse(cam_poses).transpose(1, 2) # [V, 4, 4]
246
+ cam_view_proj = cam_view @ proj_matrix # [V, 4, 4]
247
+ cam_pos = - cam_poses[:, :3, 3] # [V, 3]
248
+
249
+ image = lgm_model.gs.render(gaussians, cam_view.unsqueeze(0), cam_view_proj.unsqueeze(0), cam_pos.unsqueeze(0), scale_modifier=1)['image']
250
+ images.append((image.squeeze(1).permute(0,2,3,1).contiguous().float().cpu().numpy() * 255).astype(np.uint8))
251
+
252
+ images = np.concatenate(images, axis=0)
253
+ imageio.mimwrite(output_video_path, images, fps=30)
254
+
255
+ return output_video_path, output_ply_path
256
+
257
+
258
+ @torch.no_grad()
259
+ @spaces.GPU
260
+ def enhance(image_0, image_1, image_2, image_3, prompt, elevation, noise_level, cfg_scale, steps, seed, color_shift):
261
+ kiui.seed_everything(seed)
262
+
263
+ mv_image = [image_0, image_1, image_2, image_3]
264
+ img_tensor_list = []
265
+ for image in mv_image:
266
+ img_tensor_list.append(T.ToTensor()(image))
267
+
268
+ img_tensors = torch.stack(img_tensor_list)
269
+
270
+ color_shift = None if color_shift=="disabled" else color_shift
271
+ output_img_tensors = enhancer.inference(
272
+ mv_imgs=img_tensors,
273
+ c2ws=get_c2ws(elevations=[elevation]*4, amuziths=[0,90,180,270]),
274
+ prompt=prompt,
275
+ noise_level=noise_level,
276
+ cfg_scale=cfg_scale,
277
+ sample_steps=steps,
278
+ color_shift=color_shift,
279
+ )
280
+
281
+ mv_image_512 = output_img_tensors.permute(0,2,3,1).cpu().numpy()
282
+
283
+ # return to the image slider component
284
+ return (image_0, mv_image_512[0]), (image_1, mv_image_512[1]), (image_2, mv_image_512[2]), (image_3, mv_image_512[3])
285
+
286
+
287
+ def check_video(input_video):
288
+ if input_video:
289
+ return gr.update(interactive=True)
290
+ return gr.update(interactive=False)
291
+
292
+
293
+ i2mv_examples = [
294
+ ["assets/examples/i2mv/cake.png", "cake"],
295
+ ["assets/examples/i2mv/skull.png", "skull"],
296
+ ["assets/examples/i2mv/sea_turtle.png", "sea turtle"],
297
+ ["assets/examples/i2mv/house2.png", "house"],
298
+ ["assets/examples/i2mv/cup.png", "cup"],
299
+ ["assets/examples/i2mv/mannequin.png", "mannequin"],
300
+ ["assets/examples/i2mv/boy.jpg", "boy"],
301
+ ["assets/examples/i2mv/dragontoy.jpg", "dragon toy"],
302
+ ["assets/examples/i2mv/gso_rabbit.jpg", "rabbit car"],
303
+ ["assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png", "standing Mario"],
304
+ ]
305
+
306
+ t2mv_examples = [
307
+ "teddy bear",
308
+ "hamburger",
309
+ "oldman's head sculpture",
310
+ "headphone",
311
+ "mech suit",
312
+ "wooden barrel",
313
+ "scary zombie"
314
+ ]
315
+
316
+ mv_examples = [
317
+ [
318
+ "assets/examples/mv_lq_prerendered/vase.mp4",
319
+ "assets/examples/mv_lq/vase/00.png",
320
+ "assets/examples/mv_lq/vase/01.png",
321
+ "assets/examples/mv_lq/vase/02.png",
322
+ "assets/examples/mv_lq/vase/03.png",
323
+ "vase",
324
+ 0
325
+ ],
326
+ [
327
+ "assets/examples/mv_lq_prerendered/tower.mp4",
328
+ "assets/examples/mv_lq/tower/00.png",
329
+ "assets/examples/mv_lq/tower/01.png",
330
+ "assets/examples/mv_lq/tower/02.png",
331
+ "assets/examples/mv_lq/tower/03.png",
332
+ "brick tower",
333
+ 0
334
+ ],
335
+ [
336
+ "assets/examples/mv_lq_prerendered/truck.mp4",
337
+ "assets/examples/mv_lq/truck/00.png",
338
+ "assets/examples/mv_lq/truck/01.png",
339
+ "assets/examples/mv_lq/truck/02.png",
340
+ "assets/examples/mv_lq/truck/03.png",
341
+ "truck",
342
+ 0
343
+ ],
344
+ [
345
+ "assets/examples/mv_lq_prerendered/gascan.mp4",
346
+ "assets/examples/mv_lq/gascan/00.png",
347
+ "assets/examples/mv_lq/gascan/01.png",
348
+ "assets/examples/mv_lq/gascan/02.png",
349
+ "assets/examples/mv_lq/gascan/03.png",
350
+ "gas can",
351
+ 0
352
+ ],
353
+ [
354
+ "assets/examples/mv_lq_prerendered/fish.mp4",
355
+ "assets/examples/mv_lq/fish/00.png",
356
+ "assets/examples/mv_lq/fish/01.png",
357
+ "assets/examples/mv_lq/fish/02.png",
358
+ "assets/examples/mv_lq/fish/03.png",
359
+ "sea fish with eyes",
360
+ 0
361
+ ],
362
+ [
363
+ "assets/examples/mv_lq_prerendered/tshirt.mp4",
364
+ "assets/examples/mv_lq/tshirt/00.png",
365
+ "assets/examples/mv_lq/tshirt/01.png",
366
+ "assets/examples/mv_lq/tshirt/02.png",
367
+ "assets/examples/mv_lq/tshirt/03.png",
368
+ "t-shirt",
369
+ 0
370
+ ],
371
+ [
372
+ "assets/examples/mv_lq_prerendered/turtle.mp4",
373
+ "assets/examples/mv_lq/turtle/00.png",
374
+ "assets/examples/mv_lq/turtle/01.png",
375
+ "assets/examples/mv_lq/turtle/02.png",
376
+ "assets/examples/mv_lq/turtle/03.png",
377
+ "sea turtle",
378
+ 200
379
+ ],
380
+ [
381
+ "assets/examples/mv_lq_prerendered/cake.mp4",
382
+ "assets/examples/mv_lq/cake/00.png",
383
+ "assets/examples/mv_lq/cake/01.png",
384
+ "assets/examples/mv_lq/cake/02.png",
385
+ "assets/examples/mv_lq/cake/03.png",
386
+ "cake",
387
+ 120
388
+ ],
389
+ [
390
+ "assets/examples/mv_lq_prerendered/lamp.mp4",
391
+ "assets/examples/mv_lq/lamp/00.png",
392
+ "assets/examples/mv_lq/lamp/01.png",
393
+ "assets/examples/mv_lq/lamp/02.png",
394
+ "assets/examples/mv_lq/lamp/03.png",
395
+ "lamp",
396
+ 0
397
+ ],
398
+ [
399
+ "assets/examples/mv_lq_prerendered/oldman.mp4",
400
+ "assets/examples/mv_lq/oldman/00.png",
401
+ "assets/examples/mv_lq/oldman/00.png",
402
+ "assets/examples/mv_lq/oldman/00.png",
403
+ "assets/examples/mv_lq/oldman/00.png",
404
+ "old man sculpture",
405
+ 120
406
+ ],
407
+ [
408
+ "assets/examples/mv_lq_prerendered/mario.mp4",
409
+ "assets/examples/mv_lq/mario/00.png",
410
+ "assets/examples/mv_lq/mario/01.png",
411
+ "assets/examples/mv_lq/mario/02.png",
412
+ "assets/examples/mv_lq/mario/03.png",
413
+ "standing mario",
414
+ 120
415
+ ],
416
+ [
417
+ "assets/examples/mv_lq_prerendered/house.mp4",
418
+ "assets/examples/mv_lq/house/00.png",
419
+ "assets/examples/mv_lq/house/01.png",
420
+ "assets/examples/mv_lq/house/02.png",
421
+ "assets/examples/mv_lq/house/03.png",
422
+ "house",
423
+ 120
424
+ ],
425
+ ]
426
+
427
+
428
+ # gradio UI
429
+ demo = gr.Blocks().queue()
430
+ with demo:
431
+ gr.Markdown(title)
432
+ gr.Markdown(authors)
433
+ gr.Markdown(affiliation)
434
+ gr.Markdown(important_link)
435
+ gr.Markdown(description)
436
+
437
+ original_video_path = gr.State(GRADIO_VIDEO_PATH)
438
+ original_ply_path = gr.State(GRADIO_PLY_PATH)
439
+ enhanced_video_path = gr.State(GRADIO_ENHANCED_VIDEO_PATH)
440
+ enhanced_ply_path = gr.State(GRADIO_ENHANCED_PLY_PATH)
441
+
442
+ with gr.Column(variant='panel'):
443
+ with gr.Accordion("Generate Multi Views (LGM)", open=False):
444
+ gr.Markdown("*Don't have multi-view images on hand? Generate them here using a single image, text, or a combination of both.*")
445
+ with gr.Row():
446
+ with gr.Column():
447
+ ref_image = gr.Image(label="Reference Image", type='pil', height=400, interactive=True)
448
+ ref_text = gr.Textbox(label="Prompt", value="", interactive=True)
449
+ with gr.Column():
450
+ gr.Examples(
451
+ examples=i2mv_examples,
452
+ inputs=[ref_image, ref_text],
453
+ examples_per_page=3,
454
+ label='Image-to-Multiviews Examples',
455
+ )
456
+
457
+ gr.Examples(
458
+ examples=t2mv_examples,
459
+ inputs=[ref_text],
460
+ outputs=[ref_image, ref_text],
461
+ cache_examples=False,
462
+ run_on_click=True,
463
+ fn=lambda x: (None, x),
464
+ label='Text-to-Multiviews Examples',
465
+ )
466
+
467
+ with gr.Row():
468
+ gr.Column() # Empty column for spacing
469
+ button_gen_mv = gr.Button("Generate Multi Views", scale=1)
470
+ gr.Column() # Empty column for spacing
471
+
472
+ with gr.Column():
473
+ gr.Markdown("Let's enhance!")
474
+ with gr.Row():
475
+ with gr.Column(scale=2):
476
+ with gr.Tab("Multi Views"):
477
+ gr.Markdown("*Upload your multi-view images and enhance them with 3DEnhancer. You can also generate 3D model using LGM.*")
478
+ with gr.Row():
479
+ input_image_0 = gr.Image(label="[Input] view-0", type='pil', height=320)
480
+ input_image_1 = gr.Image(label="[Input] view-1", type='pil', height=320)
481
+ input_image_2 = gr.Image(label="[Input] view-2", type='pil', height=320)
482
+ input_image_3 = gr.Image(label="[Input] view-3", type='pil', height=320)
483
+ gr.Markdown("---")
484
+ gr.Markdown("Enhanced Output")
485
+ with gr.Row():
486
+ enhanced_image_0 = ImageSlider(label="[Enhanced] view-0", type='pil', height=350, interactive=False)
487
+ enhanced_image_1 = ImageSlider(label="[Enhanced] view-1", type='pil', height=350, interactive=False)
488
+ enhanced_image_2 = ImageSlider(label="[Enhanced] view-2", type='pil', height=350, interactive=False)
489
+ enhanced_image_3 = ImageSlider(label="[Enhanced] view-3", type='pil', height=350, interactive=False)
490
+ with gr.Tab("Generated 3D"):
491
+ gr.Markdown("Coarse Input")
492
+ with gr.Column():
493
+ with gr.Row():
494
+ gr.Column() # Empty column for spacing
495
+ with gr.Column():
496
+ input_3d_video = gr.Video(label="[Input] Rendered Video", height=300, scale=1, interactive=False)
497
+ with gr.Row():
498
+ button_gen_3d = gr.Button("Render 3D")
499
+ button_download_3d = gr.DownloadButton("Download Ply", interactive=False)
500
+ # button_download_3d = gr.File(label="Download Ply", interactive=False, height=50)
501
+ gr.Column() # Empty column for spacing
502
+ gr.Markdown("---")
503
+ gr.Markdown("Enhanced Output")
504
+ with gr.Row():
505
+ gr.Column() # Empty column for spacing
506
+ with gr.Column():
507
+ enhanced_3d_video = gr.Video(label="[Enhanced] Rendered Video", height=300, scale=1, interactive=False)
508
+ with gr.Row():
509
+ enhanced_button_gen_3d = gr.Button("Render 3D")
510
+ enhanced_button_download_3d = gr.DownloadButton("Download Ply", interactive=False)
511
+ gr.Column() # Empty column for spacing
512
+
513
+ with gr.Column():
514
+ with gr.Row():
515
+ enhancer_text = gr.Textbox(label="Prompt", value="", scale=1)
516
+ enhancer_noise_level = gr.Slider(label="enhancer noise level", minimum=0, maximum=300, step=1, value=0, interactive=True)
517
+ with gr.Accordion("Addvanced Setting", open=False):
518
+ with gr.Column():
519
+ with gr.Row():
520
+ with gr.Column():
521
+ elevation = gr.Slider(label="elevation", minimum=-90, maximum=90, step=1, value=0)
522
+ cfg_scale = gr.Slider(label="cfg scale", minimum=0, maximum=10, step=0.1, value=4.5)
523
+ with gr.Column():
524
+ seed = gr.Slider(label="random seed", minimum=0, maximum=100000, step=1, value=0)
525
+ steps = gr.Slider(label="inference steps", minimum=1, maximum=100, step=1, value=20)
526
+ with gr.Row():
527
+ color_shift = gr.Radio(label="color shift", value="disabled", choices=["disabled", "adain", "wavelet"])
528
+ with gr.Row():
529
+ gr.Column() # Empty column for spacing
530
+ button_enhance = gr.Button("Enhance", scale=1, variant="primary")
531
+ gr.Column() # Empty column for spacing
532
+
533
+ gr.Examples(
534
+ examples=mv_examples,
535
+ inputs=[input_3d_video, input_image_0, input_image_1, input_image_2, input_image_3, enhancer_text, enhancer_noise_level],
536
+ examples_per_page=3,
537
+ label='Multiviews Examples',
538
+ )
539
+
540
+ gr.Markdown("*Don't have multi-view images on hand but want to generate your own multi-viwes? Generate them in the `Generate Multi Views (LGM)` secction above.*")
541
+
542
+ gr.Markdown(article)
543
+
544
+ button_gen_mv.click(
545
+ gen_mv,
546
+ inputs=[ref_image, ref_text],
547
+ outputs=[input_image_0, input_image_1, input_image_2, input_image_3, enhancer_text, enhancer_noise_level]
548
+ )
549
+
550
+ button_gen_3d.click(
551
+ gen_3d,
552
+ inputs=[input_image_0, input_image_1, input_image_2, input_image_3, elevation, original_video_path, original_ply_path],
553
+ outputs=[input_3d_video, button_download_3d]
554
+ ).success(
555
+ lambda: gr.Button(interactive=True),
556
+ outputs=[button_download_3d],
557
+ )
558
+
559
+ enhanced_button_gen_3d.click(
560
+ gen_3d,
561
+ inputs=[enhanced_image_0, enhanced_image_1, enhanced_image_2, enhanced_image_3, elevation, original_video_path, original_ply_path],
562
+ outputs=[enhanced_3d_video, enhanced_button_download_3d]
563
+ ).success(
564
+ lambda: gr.Button(interactive=True),
565
+ outputs=[enhanced_button_download_3d],
566
+ )
567
+
568
+ button_enhance.click(
569
+ enhance,
570
+ inputs=[input_image_0, input_image_1, input_image_2, input_image_3, enhancer_text, elevation, enhancer_noise_level, cfg_scale, steps, seed, color_shift],
571
+ outputs=[enhanced_image_0, enhanced_image_1, enhanced_image_2, enhanced_image_3]
572
+ ).success(
573
+ gen_3d,
574
+ inputs=[input_image_0, input_image_1, input_image_2, input_image_3, elevation, original_video_path, original_ply_path],
575
+ outputs=[input_3d_video, button_download_3d]
576
+ ).success(
577
+ lambda: gr.Button(interactive=True),
578
+ outputs=[button_download_3d],
579
+ ).success(
580
+ gen_3d,
581
+ inputs=[enhanced_image_0, enhanced_image_1, enhanced_image_2, enhanced_image_3, elevation, enhanced_video_path, enhanced_ply_path],
582
+ outputs=[enhanced_3d_video, enhanced_button_download_3d]
583
+ ).success(
584
+ lambda: gr.Button(interactive=True),
585
+ outputs=[enhanced_button_download_3d],
586
+ )
587
+
588
+ demo.launch()
assets/examples/i2mv/Mario_New_Super_Mario_Bros_U_Deluxe.png ADDED
assets/examples/i2mv/boy.jpg ADDED
assets/examples/i2mv/cake.png ADDED
assets/examples/i2mv/cup.png ADDED
assets/examples/i2mv/dragontoy.jpg ADDED
assets/examples/i2mv/gso_rabbit.jpg ADDED
assets/examples/i2mv/house2.png ADDED
assets/examples/i2mv/mannequin.png ADDED
assets/examples/i2mv/sea_turtle.png ADDED
assets/examples/i2mv/skull.png ADDED
assets/examples/mv_lq/cake/00.png ADDED
assets/examples/mv_lq/cake/01.png ADDED
assets/examples/mv_lq/cake/02.png ADDED
assets/examples/mv_lq/cake/03.png ADDED
assets/examples/mv_lq/fish/00.png ADDED
assets/examples/mv_lq/fish/01.png ADDED
assets/examples/mv_lq/fish/02.png ADDED
assets/examples/mv_lq/fish/03.png ADDED
assets/examples/mv_lq/gascan/00.png ADDED
assets/examples/mv_lq/gascan/01.png ADDED
assets/examples/mv_lq/gascan/02.png ADDED
assets/examples/mv_lq/gascan/03.png ADDED
assets/examples/mv_lq/house/00.png ADDED
assets/examples/mv_lq/house/01.png ADDED
assets/examples/mv_lq/house/02.png ADDED
assets/examples/mv_lq/house/03.png ADDED
assets/examples/mv_lq/lamp/00.png ADDED
assets/examples/mv_lq/lamp/01.png ADDED
assets/examples/mv_lq/lamp/02.png ADDED
assets/examples/mv_lq/lamp/03.png ADDED
assets/examples/mv_lq/mario/00.png ADDED
assets/examples/mv_lq/mario/01.png ADDED
assets/examples/mv_lq/mario/02.png ADDED
assets/examples/mv_lq/mario/03.png ADDED
assets/examples/mv_lq/oldman/00.png ADDED
assets/examples/mv_lq/oldman/01.png ADDED
assets/examples/mv_lq/oldman/02.png ADDED
assets/examples/mv_lq/oldman/03.png ADDED
assets/examples/mv_lq/tower/00.png ADDED
assets/examples/mv_lq/tower/01.png ADDED
assets/examples/mv_lq/tower/02.png ADDED
assets/examples/mv_lq/tower/03.png ADDED
assets/examples/mv_lq/truck/00.png ADDED
assets/examples/mv_lq/truck/01.png ADDED