Commit
·
491eded
0
Parent(s):
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +44 -0
- .gitignore +6 -0
- LICENSE +22 -0
- NOTICE +15 -0
- README.md +13 -0
- app.py +184 -0
- app_utils.py +412 -0
- assets/example_data/Batman.png +3 -0
- assets/example_data/astronaut.png +3 -0
- assets/example_data/car.png +3 -0
- assets/example_data/crossbow.jpg +0 -0
- assets/example_data/knight.png +3 -0
- assets/example_data/robot.jpg +0 -0
- assets/example_data/robot1.jpeg +3 -0
- assets/example_data/robot_dog.jpg +0 -0
- assets/example_data/ship.jpg +0 -0
- assets/example_data/snake.png +3 -0
- assets/example_data/warhammer.png +3 -0
- configs/bbox_gen.yaml +34 -0
- modules/PartField/configs/final/correspondence_demo.yaml +44 -0
- modules/PartField/configs/final/demo.yaml +28 -0
- modules/PartField/partfield/config/__init__.py +26 -0
- modules/PartField/partfield/config/defaults.py +92 -0
- modules/PartField/partfield/model/PVCNN/conv_pointnet.py +251 -0
- modules/PartField/partfield/model/PVCNN/dnnlib_util.py +1074 -0
- modules/PartField/partfield/model/PVCNN/encoder_pc.py +243 -0
- modules/PartField/partfield/model/PVCNN/pc_encoder.py +90 -0
- modules/PartField/partfield/model/PVCNN/pv_module/__init__.py +2 -0
- modules/PartField/partfield/model/PVCNN/pv_module/ball_query.py +34 -0
- modules/PartField/partfield/model/PVCNN/pv_module/frustum.py +141 -0
- modules/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py +1 -0
- modules/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py +12 -0
- modules/PartField/partfield/model/PVCNN/pv_module/loss.py +10 -0
- modules/PartField/partfield/model/PVCNN/pv_module/pointnet.py +113 -0
- modules/PartField/partfield/model/PVCNN/pv_module/pvconv.py +38 -0
- modules/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py +35 -0
- modules/PartField/partfield/model/PVCNN/pv_module/voxelization.py +80 -0
- modules/PartField/partfield/model/PVCNN/unet_3daware.py +427 -0
- modules/PartField/partfield/model/UNet/buildingblocks.py +546 -0
- modules/PartField/partfield/model/UNet/model.py +170 -0
- modules/PartField/partfield/model/model_utils.py +54 -0
- modules/PartField/partfield/model/triplane.py +331 -0
- modules/PartField/partfield/model_trainer_pvcnn_only_demo.py +283 -0
- modules/PartField/partfield/partfield_encoder.py +103 -0
- modules/PartField/partfield/utils.py +5 -0
- modules/bbox_gen/config.py +57 -0
- modules/bbox_gen/models/autogressive_bbox_gen.py +305 -0
- modules/bbox_gen/models/bbox_gen_models.py +215 -0
- modules/bbox_gen/models/bboxopt.py +221 -0
- modules/bbox_gen/models/image_encoder.py +41 -0
.gitattributes
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/example_data/Batman.png filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/example_data/astronaut.png filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/example_data/car.png filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/example_data/knight.png filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/example_data/robot1.jpeg filter=lfs diff=lfs merge=lfs -text
|
41 |
+
assets/example_data/snake.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
assets/example_data/warhammer.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
modules/part_synthesis/representations/mesh/flexicubes/images/block_init.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
modules/part_synthesis/representations/mesh/flexicubes/images/teaser_top.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__/
|
2 |
+
output/
|
3 |
+
ckpt/
|
4 |
+
.DS_Store
|
5 |
+
tmp/
|
6 |
+
debug_images/
|
LICENSE
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# MIT License
|
2 |
+
|
3 |
+
# Copyright (c) 2025 VAST-AI-Research and contributors.
|
4 |
+
|
5 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
# of this software and associated documentation files (the "Software"), to deal
|
7 |
+
# in the Software without restriction, including without limitation the rights
|
8 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
# copies of the Software, and to permit persons to whom the Software is
|
10 |
+
# furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
# The above copyright notice and this permission notice shall be included in all
|
13 |
+
# copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
# SOFTWARE
|
22 |
+
|
NOTICE
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
OmniPart
|
2 |
+
Copyright (c) 2025 VAST-AI-Research and contributors
|
3 |
+
|
4 |
+
This project includes code from the following open source projects:
|
5 |
+
|
6 |
+
RMBG
|
7 |
+
Copyright (c) BRIA AI
|
8 |
+
License: bria-rmbg-2.0
|
9 |
+
Source: https://huggingface.co/briaai/RMBG-2.0
|
10 |
+
|
11 |
+
This software contains code derived from 🤗 Diffusers (https://github.com/huggingface/diffusers), available under the Apache License 2.0.
|
12 |
+
|
13 |
+
This software contains code derived from TRELLIS (https://github.com/Microsoft/TRELLIS), available under the MIT License.
|
14 |
+
|
15 |
+
This software contains code derived from PartPacker (https://github.com/NVlabs/PartPacker), available under the NVIDIA Source Code License.
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: OmniPart
|
3 |
+
emoji: 📚
|
4 |
+
colorFrom: green
|
5 |
+
colorTo: green
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 5.35.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
license: mit
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
app.py
ADDED
@@ -0,0 +1,184 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import os
|
4 |
+
import shutil
|
5 |
+
os.environ['SPCONV_ALGO'] = 'native'
|
6 |
+
from huggingface_hub import hf_hub_download
|
7 |
+
|
8 |
+
from app_utils import (
|
9 |
+
generate_parts,
|
10 |
+
prepare_models,
|
11 |
+
process_image,
|
12 |
+
apply_merge,
|
13 |
+
DEFAULT_SIZE_TH,
|
14 |
+
TMP_ROOT,
|
15 |
+
)
|
16 |
+
|
17 |
+
EXAMPLES = [
|
18 |
+
["assets/example_data/knight.png", 1800, "6,0,26,20,7;13,1,22,11,12,2,21,27,3,24,23;5,18;4,17;19,16,14,25,28", 42],
|
19 |
+
["assets/example_data/car.png", 2000, "12,10,2,11;1,7", 42],
|
20 |
+
["assets/example_data/warhammer.png", 1800, "7,1,0,8", 0],
|
21 |
+
["assets/example_data/snake.png", 3000, "2,3;0,1;4,5,6,7", 42],
|
22 |
+
["assets/example_data/Batman.png", 1800, "4,5", 42],
|
23 |
+
["assets/example_data/robot1.jpeg", 1600, "0,5;10,14,3;1,12,2;13,11,4;7,15", 42],
|
24 |
+
["assets/example_data/astronaut.png", 2000, "0,4,6;1,8,9,7;2,5", 42],
|
25 |
+
["assets/example_data/crossbow.jpg", 2000, "2,9;10,12,0,7,11,8,13;4,3", 42],
|
26 |
+
["assets/example_data/robot.jpg", 1600, "7,19;15,0;6,18", 42],
|
27 |
+
["assets/example_data/robot_dog.jpg", 1000, "21,9;2,12,10,15,17;11,7;1,0;13,19;4,16", 0],
|
28 |
+
["assets/example_data/crossbow.jpg", 1600, "9,2;10,15,13;7,14,8,11;0,12,16;5,3,1", 42],
|
29 |
+
["assets/example_data/robot.jpg", 1800, "1,2,3,5,4,16,17;11,7,19;10,14;18,6,0,15;13,9;12,8", 0],
|
30 |
+
["assets/example_data/robot_dog.jpg", 1000, "2,12,10,15,17,8,3,5,13,19,6,14;11,7;1,0,21,9,11;4,16", 0],
|
31 |
+
]
|
32 |
+
|
33 |
+
HEADER = """
|
34 |
+
|
35 |
+
# OmniPart: Part-Aware 3D Generation with Semantic Decoupling and Structural Cohesion
|
36 |
+
|
37 |
+
🔮 Generate **part-aware 3D content** from a single 2D image with **2D mask control**.
|
38 |
+
|
39 |
+
## How to Use
|
40 |
+
|
41 |
+
**🚀 Quick Start**: Select an example below and click **"▶️ Run Example"**
|
42 |
+
|
43 |
+
|
44 |
+
**📋 Custom Image Processing**:
|
45 |
+
1. **Upload Image** - Select your image file
|
46 |
+
2. **Click "Segment Image"** - Get initial 2D segmentation
|
47 |
+
3. **Merge Segments** - Enter merge groups like `0,1;3,4` and click **"Apply Merge"** (Recommend keeping **2-15 parts**)
|
48 |
+
4. **Click "Generate 3D Model"** - Create the final 3D results
|
49 |
+
"""
|
50 |
+
|
51 |
+
|
52 |
+
def start_session(req: gr.Request):
|
53 |
+
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
54 |
+
os.makedirs(user_dir, exist_ok=True)
|
55 |
+
|
56 |
+
|
57 |
+
def end_session(req: gr.Request):
|
58 |
+
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
59 |
+
shutil.rmtree(user_dir)
|
60 |
+
|
61 |
+
|
62 |
+
with gr.Blocks(title="OmniPart") as demo:
|
63 |
+
gr.Markdown(HEADER)
|
64 |
+
|
65 |
+
state = gr.State({})
|
66 |
+
|
67 |
+
with gr.Row():
|
68 |
+
with gr.Column(scale=1):
|
69 |
+
gr.Markdown("<div style='text-align: center'>\n\n## Input\n\n</div>")
|
70 |
+
|
71 |
+
input_image = gr.Image(label="Upload Image", type="filepath", height=250, width=250)
|
72 |
+
|
73 |
+
with gr.Row():
|
74 |
+
segment_btn = gr.Button("Segment Image", variant="primary", size="lg")
|
75 |
+
run_example_btn = gr.Button("▶️ Run Example", variant="secondary", size="lg")
|
76 |
+
|
77 |
+
size_threshold = gr.Slider(
|
78 |
+
minimum=600,
|
79 |
+
maximum=4000,
|
80 |
+
value=DEFAULT_SIZE_TH,
|
81 |
+
step=200,
|
82 |
+
label="Minimum Segment Size (pixels)",
|
83 |
+
info="Segments smaller than this will be ignored"
|
84 |
+
)
|
85 |
+
|
86 |
+
gr.Markdown("### Merge Controls")
|
87 |
+
merge_input = gr.Textbox(
|
88 |
+
label="Merge Groups",
|
89 |
+
placeholder="0,1;3,4",
|
90 |
+
lines=2,
|
91 |
+
info="Specify which segments to merge (e.g., '0,1;3,4' merges segments 0&1 together and 3&4 together)"
|
92 |
+
)
|
93 |
+
merge_btn = gr.Button("Apply Merge", variant="primary", size="lg")
|
94 |
+
|
95 |
+
gr.Markdown("### 3D Generation Controls")
|
96 |
+
|
97 |
+
seed_slider = gr.Slider(
|
98 |
+
minimum=0,
|
99 |
+
maximum=10000,
|
100 |
+
value=42,
|
101 |
+
step=1,
|
102 |
+
label="Generation Seed",
|
103 |
+
info="Random seed for 3D model generation"
|
104 |
+
)
|
105 |
+
|
106 |
+
cfg_slider = gr.Slider(
|
107 |
+
minimum=0.0,
|
108 |
+
maximum=15.0,
|
109 |
+
value=7.5,
|
110 |
+
step=0.5,
|
111 |
+
label="CFG Strength",
|
112 |
+
info="Classifier-Free Guidance strength"
|
113 |
+
)
|
114 |
+
|
115 |
+
generate_mesh_btn = gr.Button("Generate 3D Model", variant="secondary", size="lg")
|
116 |
+
|
117 |
+
with gr.Column(scale=2):
|
118 |
+
gr.Markdown("<div style='text-align: center'>\n\n## Results Display\n\n</div>")
|
119 |
+
|
120 |
+
with gr.Row():
|
121 |
+
initial_seg = gr.Image(label="Init Seg", height=220, width=220)
|
122 |
+
pre_merge_vis = gr.Image(label="Pre-merge", height=220, width=220)
|
123 |
+
merged_seg = gr.Image(label="Merged Seg", height=220, width=220)
|
124 |
+
|
125 |
+
with gr.Row():
|
126 |
+
bbox_mesh = gr.Model3D(label="Bounding Boxes", height=350)
|
127 |
+
whole_mesh = gr.Model3D(label="Combined Parts", height=350)
|
128 |
+
exploded_mesh = gr.Model3D(label="Exploded Parts", height=350)
|
129 |
+
|
130 |
+
with gr.Row():
|
131 |
+
combined_gs = gr.Model3D(label="Combined 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350)
|
132 |
+
exploded_gs = gr.Model3D(label="Exploded 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350)
|
133 |
+
|
134 |
+
with gr.Row():
|
135 |
+
examples = gr.Examples(
|
136 |
+
examples=EXAMPLES,
|
137 |
+
inputs=[input_image, size_threshold, merge_input, seed_slider],
|
138 |
+
cache_examples=False,
|
139 |
+
)
|
140 |
+
|
141 |
+
demo.load(start_session)
|
142 |
+
demo.unload(end_session)
|
143 |
+
|
144 |
+
segment_btn.click(
|
145 |
+
process_image,
|
146 |
+
inputs=[input_image, size_threshold],
|
147 |
+
outputs=[initial_seg, pre_merge_vis, state]
|
148 |
+
)
|
149 |
+
|
150 |
+
merge_btn.click(
|
151 |
+
apply_merge,
|
152 |
+
inputs=[merge_input, state],
|
153 |
+
outputs=[merged_seg, state]
|
154 |
+
)
|
155 |
+
|
156 |
+
generate_mesh_btn.click(
|
157 |
+
generate_parts,
|
158 |
+
inputs=[state, seed_slider, cfg_slider],
|
159 |
+
outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs]
|
160 |
+
)
|
161 |
+
|
162 |
+
run_example_btn.click(
|
163 |
+
fn=process_image,
|
164 |
+
inputs=[input_image, size_threshold],
|
165 |
+
outputs=[initial_seg, pre_merge_vis, state]
|
166 |
+
).then(
|
167 |
+
fn=apply_merge,
|
168 |
+
inputs=[merge_input, state],
|
169 |
+
outputs=[merged_seg, state]
|
170 |
+
).then(
|
171 |
+
fn=generate_parts,
|
172 |
+
inputs=[state, seed_slider, cfg_slider],
|
173 |
+
outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs]
|
174 |
+
)
|
175 |
+
|
176 |
+
if __name__ == "__main__":
|
177 |
+
os.makedirs("ckpt", exist_ok=True)
|
178 |
+
sam_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="sam_vit_h_4b8939.pth", local_dir="ckpt")
|
179 |
+
partfield_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="partfield_encoder.ckpt", local_dir="ckpt")
|
180 |
+
bbox_gen_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="bbox_gen.ckpt", local_dir="ckpt")
|
181 |
+
|
182 |
+
prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path)
|
183 |
+
|
184 |
+
demo.launch()
|
app_utils.py
ADDED
@@ -0,0 +1,412 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
import spaces
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import trimesh
|
6 |
+
import time
|
7 |
+
import traceback
|
8 |
+
import torch
|
9 |
+
from PIL import Image
|
10 |
+
import cv2
|
11 |
+
import shutil
|
12 |
+
from segment_anything import SamAutomaticMaskGenerator, build_sam
|
13 |
+
from omegaconf import OmegaConf
|
14 |
+
|
15 |
+
from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen
|
16 |
+
from modules.part_synthesis.process_utils import save_parts_outputs
|
17 |
+
from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts
|
18 |
+
from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline
|
19 |
+
from modules.label_2d_mask.visualizer import Visualizer
|
20 |
+
from transformers import AutoModelForImageSegmentation
|
21 |
+
|
22 |
+
from modules.label_2d_mask.label_parts import (
|
23 |
+
prepare_image,
|
24 |
+
get_sam_mask,
|
25 |
+
get_mask,
|
26 |
+
clean_segment_edges,
|
27 |
+
resize_and_pad_to_square,
|
28 |
+
size_th as DEFAULT_SIZE_TH
|
29 |
+
)
|
30 |
+
|
31 |
+
# Constants
|
32 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
33 |
+
DTYPE = torch.float16
|
34 |
+
MAX_SEED = np.iinfo(np.int32).max
|
35 |
+
TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
|
36 |
+
os.makedirs(TMP_ROOT, exist_ok=True)
|
37 |
+
|
38 |
+
sam_mask_generator = None
|
39 |
+
rmbg_model = None
|
40 |
+
bbox_gen_model = None
|
41 |
+
part_synthesis_pipeline = None
|
42 |
+
|
43 |
+
size_th = DEFAULT_SIZE_TH
|
44 |
+
|
45 |
+
|
46 |
+
def prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path):
|
47 |
+
global sam_mask_generator, rmbg_model, bbox_gen_model, part_synthesis_pipeline
|
48 |
+
if sam_mask_generator is None:
|
49 |
+
print("Loading SAM model...")
|
50 |
+
sam_model = build_sam(checkpoint=sam_ckpt_path).to(device=DEVICE)
|
51 |
+
sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
|
52 |
+
|
53 |
+
if rmbg_model is None:
|
54 |
+
print("Loading BriaRMBG 2.0 model...")
|
55 |
+
rmbg_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
|
56 |
+
rmbg_model.to(DEVICE)
|
57 |
+
rmbg_model.eval()
|
58 |
+
|
59 |
+
if part_synthesis_pipeline is None:
|
60 |
+
print("Loading PartSynthesis model...")
|
61 |
+
part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained('omnipart/OmniPart')
|
62 |
+
part_synthesis_pipeline.to(DEVICE)
|
63 |
+
|
64 |
+
if bbox_gen_model is None:
|
65 |
+
print("Loading BboxGen model...")
|
66 |
+
bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args
|
67 |
+
bbox_gen_config.partfield_encoder_path = partfield_ckpt_path
|
68 |
+
bbox_gen_model = BboxGen(bbox_gen_config)
|
69 |
+
bbox_gen_model.load_state_dict(torch.load(bbox_gen_ckpt_path), strict=False)
|
70 |
+
bbox_gen_model.to(DEVICE)
|
71 |
+
bbox_gen_model.eval().half()
|
72 |
+
|
73 |
+
print("Models ready")
|
74 |
+
|
75 |
+
|
76 |
+
@spaces.GPU
|
77 |
+
def process_image(image_path, threshold, req: gr.Request):
|
78 |
+
"""Process image and generate initial segmentation"""
|
79 |
+
global size_th
|
80 |
+
|
81 |
+
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
82 |
+
os.makedirs(user_dir, exist_ok=True)
|
83 |
+
|
84 |
+
img_name = os.path.basename(image_path).split(".")[0]
|
85 |
+
|
86 |
+
size_th = threshold
|
87 |
+
|
88 |
+
img = Image.open(image_path).convert("RGB")
|
89 |
+
processed_image = prepare_image(img, rmbg_net=rmbg_model.to(DEVICE))
|
90 |
+
|
91 |
+
processed_image = resize_and_pad_to_square(processed_image)
|
92 |
+
white_bg = Image.new("RGBA", processed_image.size, (255, 255, 255, 255))
|
93 |
+
white_bg_img = Image.alpha_composite(white_bg, processed_image.convert("RGBA"))
|
94 |
+
image = np.array(white_bg_img.convert('RGB'))
|
95 |
+
|
96 |
+
rgba_path = os.path.join(user_dir, f"{img_name}_processed.png")
|
97 |
+
processed_image.save(rgba_path)
|
98 |
+
|
99 |
+
print("Generating raw SAM masks without post-processing...")
|
100 |
+
raw_masks = sam_mask_generator.generate(image)
|
101 |
+
|
102 |
+
raw_sam_vis = np.copy(image)
|
103 |
+
raw_sam_vis = np.ones_like(image) * 255
|
104 |
+
|
105 |
+
sorted_masks = sorted(raw_masks, key=lambda x: x["area"], reverse=True)
|
106 |
+
|
107 |
+
for i, mask_data in enumerate(sorted_masks):
|
108 |
+
if mask_data["area"] < size_th:
|
109 |
+
continue
|
110 |
+
|
111 |
+
color_r = (i * 50 + 80) % 256
|
112 |
+
color_g = (i * 120 + 40) % 256
|
113 |
+
color_b = (i * 180 + 20) % 256
|
114 |
+
color = np.array([color_r, color_g, color_b])
|
115 |
+
|
116 |
+
mask = mask_data["segmentation"]
|
117 |
+
raw_sam_vis[mask] = color
|
118 |
+
|
119 |
+
visual = Visualizer(image)
|
120 |
+
|
121 |
+
group_ids, pre_merge_im = get_sam_mask(
|
122 |
+
image,
|
123 |
+
sam_mask_generator,
|
124 |
+
visual,
|
125 |
+
merge_groups=None,
|
126 |
+
rgba_image=processed_image,
|
127 |
+
img_name=img_name,
|
128 |
+
save_dir=user_dir,
|
129 |
+
size_threshold=size_th
|
130 |
+
)
|
131 |
+
|
132 |
+
pre_merge_path = os.path.join(user_dir, f"{img_name}_mask_pre_merge.png")
|
133 |
+
Image.fromarray(pre_merge_im).save(pre_merge_path)
|
134 |
+
pre_split_vis = np.ones_like(image) * 255
|
135 |
+
|
136 |
+
unique_ids = np.unique(group_ids)
|
137 |
+
unique_ids = unique_ids[unique_ids >= 0]
|
138 |
+
|
139 |
+
for i, unique_id in enumerate(unique_ids):
|
140 |
+
color_r = (i * 50 + 80) % 256
|
141 |
+
color_g = (i * 120 + 40) % 256
|
142 |
+
color_b = (i * 180 + 20) % 256
|
143 |
+
color = np.array([color_r, color_g, color_b])
|
144 |
+
|
145 |
+
mask = (group_ids == unique_id)
|
146 |
+
pre_split_vis[mask] = color
|
147 |
+
|
148 |
+
y_indices, x_indices = np.where(mask)
|
149 |
+
if len(y_indices) > 0:
|
150 |
+
center_y = int(np.mean(y_indices))
|
151 |
+
center_x = int(np.mean(x_indices))
|
152 |
+
cv2.putText(pre_split_vis, str(unique_id),
|
153 |
+
(center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX,
|
154 |
+
0.5, (0, 0, 0), 1, cv2.LINE_AA)
|
155 |
+
|
156 |
+
pre_split_path = os.path.join(user_dir, f"{img_name}_pre_split.png")
|
157 |
+
Image.fromarray(pre_split_vis).save(pre_split_path)
|
158 |
+
print(f"Pre-split segmentation (before disconnected parts handling) saved to {pre_split_path}")
|
159 |
+
|
160 |
+
get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=user_dir)
|
161 |
+
|
162 |
+
init_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_2.png")
|
163 |
+
|
164 |
+
seg_img = Image.open(init_seg_path)
|
165 |
+
if seg_img.mode == 'RGBA':
|
166 |
+
white_bg = Image.new('RGBA', seg_img.size, (255, 255, 255, 255))
|
167 |
+
seg_img = Image.alpha_composite(white_bg, seg_img)
|
168 |
+
seg_img.save(init_seg_path)
|
169 |
+
|
170 |
+
state = {
|
171 |
+
"image": image.tolist(),
|
172 |
+
"processed_image": rgba_path,
|
173 |
+
"group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids,
|
174 |
+
"original_group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids,
|
175 |
+
"img_name": img_name,
|
176 |
+
"pre_split_path": pre_split_path,
|
177 |
+
}
|
178 |
+
|
179 |
+
return init_seg_path, pre_merge_path, state
|
180 |
+
|
181 |
+
|
182 |
+
def apply_merge(merge_input, state, req: gr.Request):
|
183 |
+
"""Apply merge parameters and generate merged segmentation"""
|
184 |
+
global sam_mask_generator
|
185 |
+
|
186 |
+
if not state:
|
187 |
+
return None, None, state
|
188 |
+
|
189 |
+
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
190 |
+
|
191 |
+
# Convert back from list to numpy array
|
192 |
+
image = np.array(state["image"])
|
193 |
+
# Use original group IDs instead of the most recent ones
|
194 |
+
group_ids = np.array(state["original_group_ids"])
|
195 |
+
img_name = state["img_name"]
|
196 |
+
|
197 |
+
# Load processed image from path
|
198 |
+
processed_image = Image.open(state["processed_image"])
|
199 |
+
|
200 |
+
# Display the original IDs before merging, SORTED for easier reading
|
201 |
+
unique_ids = np.unique(group_ids)
|
202 |
+
unique_ids = unique_ids[unique_ids >= 0] # Exclude background
|
203 |
+
print(f"Original segment IDs (used for merging): {sorted(unique_ids.tolist())}")
|
204 |
+
|
205 |
+
# Parse merge groups
|
206 |
+
merge_groups = None
|
207 |
+
try:
|
208 |
+
if merge_input:
|
209 |
+
merge_groups = []
|
210 |
+
group_sets = merge_input.split(';')
|
211 |
+
for group_set in group_sets:
|
212 |
+
ids = [int(x) for x in group_set.split(',')]
|
213 |
+
if ids:
|
214 |
+
# Validate if these IDs exist in the segmentation
|
215 |
+
existing_ids = [id for id in ids if id in unique_ids]
|
216 |
+
missing_ids = [id for id in ids if id not in unique_ids]
|
217 |
+
|
218 |
+
if missing_ids:
|
219 |
+
print(f"Warning: These IDs don't exist in the segmentation: {missing_ids}")
|
220 |
+
|
221 |
+
# Only add group if it has valid IDs
|
222 |
+
if existing_ids:
|
223 |
+
merge_groups.append(ids)
|
224 |
+
print(f"Valid merge group: {ids} (missing: {missing_ids if missing_ids else 'none'})")
|
225 |
+
else:
|
226 |
+
print(f"Skipping merge group with no valid IDs: {ids}")
|
227 |
+
|
228 |
+
print(f"Using merge groups: {merge_groups}")
|
229 |
+
except Exception as e:
|
230 |
+
print(f"Error parsing merge groups: {e}")
|
231 |
+
return None, None, state
|
232 |
+
|
233 |
+
# Initialize visualizer
|
234 |
+
visual = Visualizer(image)
|
235 |
+
|
236 |
+
# Generate merged segmentation starting from original IDs
|
237 |
+
# Add skip_split=True to prevent splitting after merging
|
238 |
+
new_group_ids, merged_im = get_sam_mask(
|
239 |
+
image,
|
240 |
+
sam_mask_generator,
|
241 |
+
visual,
|
242 |
+
merge_groups=merge_groups,
|
243 |
+
existing_group_ids=group_ids,
|
244 |
+
rgba_image=processed_image,
|
245 |
+
skip_split=True,
|
246 |
+
img_name=img_name,
|
247 |
+
save_dir=user_dir,
|
248 |
+
size_threshold=size_th
|
249 |
+
)
|
250 |
+
|
251 |
+
# Display the new IDs after merging for future reference
|
252 |
+
new_unique_ids = np.unique(new_group_ids)
|
253 |
+
new_unique_ids = new_unique_ids[new_unique_ids >= 0] # Exclude background
|
254 |
+
print(f"New segment IDs (after merging): {new_unique_ids.tolist()}")
|
255 |
+
|
256 |
+
# Clean edges
|
257 |
+
new_group_ids = clean_segment_edges(new_group_ids)
|
258 |
+
|
259 |
+
# Save merged segmentation visualization
|
260 |
+
get_mask(new_group_ids, image, ids=3, img_name=img_name, save_dir=user_dir)
|
261 |
+
|
262 |
+
# Path to merged segmentation
|
263 |
+
merged_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_3.png")
|
264 |
+
|
265 |
+
save_mask = new_group_ids + 1
|
266 |
+
save_mask = save_mask.reshape(518, 518, 1).repeat(3, axis=-1)
|
267 |
+
cv2.imwrite(os.path.join(user_dir, f"{img_name}_mask.exr"), save_mask.astype(np.float32))
|
268 |
+
|
269 |
+
# Update state with the new group IDs but keep original IDs unchanged
|
270 |
+
state["group_ids"] = new_group_ids.tolist() if isinstance(new_group_ids, np.ndarray) else new_group_ids
|
271 |
+
state["save_mask_path"] = os.path.join(user_dir, f"{img_name}_mask.exr")
|
272 |
+
|
273 |
+
return merged_seg_path, state
|
274 |
+
|
275 |
+
|
276 |
+
def explode_mesh(mesh, explosion_scale=0.4):
|
277 |
+
|
278 |
+
if isinstance(mesh, trimesh.Scene):
|
279 |
+
scene = mesh
|
280 |
+
elif isinstance(mesh, trimesh.Trimesh):
|
281 |
+
print("Warning: Single mesh provided, can't create exploded view")
|
282 |
+
scene = trimesh.Scene(mesh)
|
283 |
+
return scene
|
284 |
+
else:
|
285 |
+
print(f"Warning: Unexpected mesh type: {type(mesh)}")
|
286 |
+
scene = mesh
|
287 |
+
|
288 |
+
if len(scene.geometry) <= 1:
|
289 |
+
print("Only one geometry found - nothing to explode")
|
290 |
+
return scene
|
291 |
+
|
292 |
+
print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}")
|
293 |
+
print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts")
|
294 |
+
|
295 |
+
exploded_scene = trimesh.Scene()
|
296 |
+
|
297 |
+
part_centers = []
|
298 |
+
geometry_names = []
|
299 |
+
|
300 |
+
for geometry_name, geometry in scene.geometry.items():
|
301 |
+
if hasattr(geometry, 'vertices'):
|
302 |
+
transform = scene.graph[geometry_name][0]
|
303 |
+
vertices_global = trimesh.transformations.transform_points(
|
304 |
+
geometry.vertices, transform)
|
305 |
+
center = np.mean(vertices_global, axis=0)
|
306 |
+
part_centers.append(center)
|
307 |
+
geometry_names.append(geometry_name)
|
308 |
+
print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}")
|
309 |
+
|
310 |
+
if not part_centers:
|
311 |
+
print("No valid geometries with vertices found")
|
312 |
+
return scene
|
313 |
+
|
314 |
+
part_centers = np.array(part_centers)
|
315 |
+
global_center = np.mean(part_centers, axis=0)
|
316 |
+
|
317 |
+
print(f"[EXPLODE_MESH] Global center: {global_center}")
|
318 |
+
|
319 |
+
for i, (geometry_name, geometry) in enumerate(scene.geometry.items()):
|
320 |
+
if hasattr(geometry, 'vertices'):
|
321 |
+
if i < len(part_centers):
|
322 |
+
part_center = part_centers[i]
|
323 |
+
direction = part_center - global_center
|
324 |
+
|
325 |
+
direction_norm = np.linalg.norm(direction)
|
326 |
+
if direction_norm > 1e-6:
|
327 |
+
direction = direction / direction_norm
|
328 |
+
else:
|
329 |
+
direction = np.random.randn(3)
|
330 |
+
direction = direction / np.linalg.norm(direction)
|
331 |
+
|
332 |
+
offset = direction * explosion_scale
|
333 |
+
else:
|
334 |
+
offset = np.zeros(3)
|
335 |
+
|
336 |
+
original_transform = scene.graph[geometry_name][0].copy()
|
337 |
+
|
338 |
+
new_transform = original_transform.copy()
|
339 |
+
new_transform[:3, 3] = new_transform[:3, 3] + offset
|
340 |
+
|
341 |
+
exploded_scene.add_geometry(
|
342 |
+
geometry,
|
343 |
+
transform=new_transform,
|
344 |
+
geom_name=geometry_name
|
345 |
+
)
|
346 |
+
|
347 |
+
print(f"[EXPLODE_MESH] Part {geometry_name}: moved by {np.linalg.norm(offset):.4f}")
|
348 |
+
|
349 |
+
print("[EXPLODE_MESH] Mesh explosion complete")
|
350 |
+
return exploded_scene
|
351 |
+
|
352 |
+
@spaces.GPU(duration=90)
|
353 |
+
def generate_parts(state, seed, cfg_strength, req: gr.Request):
|
354 |
+
explode_factor=0.3
|
355 |
+
img_path = state["processed_image"]
|
356 |
+
mask_path = state["save_mask_path"]
|
357 |
+
user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
|
358 |
+
img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(img_path, mask_path)
|
359 |
+
img_mask_vis.save(os.path.join(user_dir, "img_mask_vis.png"))
|
360 |
+
|
361 |
+
voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5})
|
362 |
+
voxel_coords = voxel_coords.cpu().numpy()
|
363 |
+
np.save(os.path.join(user_dir, "voxel_coords.npy"), voxel_coords)
|
364 |
+
voxel_coords_ply = vis_voxel_coords(voxel_coords)
|
365 |
+
voxel_coords_ply.export(os.path.join(user_dir, "voxel_coords_vis.ply"))
|
366 |
+
print("[INFO] Voxel coordinates saved")
|
367 |
+
|
368 |
+
bbox_gen_input = prepare_bbox_gen_input(os.path.join(user_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input)
|
369 |
+
bbox_gen_output = bbox_gen_model.generate(bbox_gen_input)
|
370 |
+
np.save(os.path.join(user_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0])
|
371 |
+
bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0])
|
372 |
+
bboxes_vis.export(os.path.join(user_dir, "bboxes_vis.glb"))
|
373 |
+
print("[INFO] BboxGen output saved")
|
374 |
+
|
375 |
+
|
376 |
+
part_synthesis_input = prepare_part_synthesis_input(os.path.join(user_dir, "voxel_coords.npy"), os.path.join(user_dir, "bboxes.npy"), ordered_mask_input)
|
377 |
+
|
378 |
+
torch.cuda.empty_cache()
|
379 |
+
|
380 |
+
part_synthesis_output = part_synthesis_pipeline.get_slat(
|
381 |
+
img_black_bg,
|
382 |
+
part_synthesis_input['coords'],
|
383 |
+
[part_synthesis_input['part_layouts']],
|
384 |
+
part_synthesis_input['masks'],
|
385 |
+
seed=seed,
|
386 |
+
slat_sampler_params={"steps": 25, "cfg_strength": cfg_strength},
|
387 |
+
formats=['mesh', 'gaussian'],
|
388 |
+
preprocess_image=False,
|
389 |
+
)
|
390 |
+
save_parts_outputs(
|
391 |
+
part_synthesis_output,
|
392 |
+
output_dir=user_dir,
|
393 |
+
simplify_ratio=0.0,
|
394 |
+
save_video=False,
|
395 |
+
save_glb=True,
|
396 |
+
textured=False,
|
397 |
+
)
|
398 |
+
merge_parts(user_dir)
|
399 |
+
print("[INFO] PartSynthesis output saved")
|
400 |
+
|
401 |
+
bbox_mesh_path = os.path.join(user_dir, "bboxes_vis.glb")
|
402 |
+
whole_mesh_path = os.path.join(user_dir, "mesh_segment.glb")
|
403 |
+
|
404 |
+
combined_mesh = trimesh.load(whole_mesh_path)
|
405 |
+
exploded_mesh_result = explode_mesh(combined_mesh, explosion_scale=explode_factor)
|
406 |
+
exploded_mesh_result.export(os.path.join(user_dir, "exploded_parts.glb"))
|
407 |
+
|
408 |
+
exploded_mesh_path = os.path.join(user_dir, "exploded_parts.glb")
|
409 |
+
combined_gs_path = os.path.join(user_dir, "merged_gs.ply")
|
410 |
+
exploded_gs_path = os.path.join(user_dir, "exploded_gs.ply")
|
411 |
+
|
412 |
+
return bbox_mesh_path, whole_mesh_path, exploded_mesh_path, combined_gs_path, exploded_gs_path
|
assets/example_data/Batman.png
ADDED
![]() |
Git LFS Details
|
assets/example_data/astronaut.png
ADDED
![]() |
Git LFS Details
|
assets/example_data/car.png
ADDED
![]() |
Git LFS Details
|
assets/example_data/crossbow.jpg
ADDED
![]() |
assets/example_data/knight.png
ADDED
![]() |
Git LFS Details
|
assets/example_data/robot.jpg
ADDED
![]() |
assets/example_data/robot1.jpeg
ADDED
![]() |
Git LFS Details
|
assets/example_data/robot_dog.jpg
ADDED
![]() |
assets/example_data/ship.jpg
ADDED
![]() |
assets/example_data/snake.png
ADDED
![]() |
Git LFS Details
|
assets/example_data/warhammer.png
ADDED
![]() |
Git LFS Details
|
configs/bbox_gen.yaml
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model:
|
2 |
+
name: bbox_gen
|
3 |
+
args:
|
4 |
+
encoder_dim_feat: 448
|
5 |
+
encoder_dim: 64
|
6 |
+
encoder_heads: 4
|
7 |
+
encoder_token_num: 2048
|
8 |
+
encoder_qkv_bias: false
|
9 |
+
encoder_use_ln_post: true
|
10 |
+
encoder_use_checkpoint: true
|
11 |
+
encoder_num_embed_freqs: 8
|
12 |
+
encoder_embed_include_pi: false
|
13 |
+
encoder_init_scale: 0.25
|
14 |
+
encoder_random_fps: true
|
15 |
+
encoder_learnable_query: true
|
16 |
+
encoder_layers: 8
|
17 |
+
|
18 |
+
max_group_size: 50
|
19 |
+
|
20 |
+
vocab_size: 67
|
21 |
+
decoder_hidden_size: 1024
|
22 |
+
decoder_num_hidden_layers: 24
|
23 |
+
decoder_ffn_dim: 4096
|
24 |
+
decoder_heads: 16
|
25 |
+
decoder_use_flash_attention: true
|
26 |
+
decoder_gradient_checkpointing: false
|
27 |
+
|
28 |
+
bins: 64
|
29 |
+
BOS_id: 64
|
30 |
+
EOS_id: 65
|
31 |
+
PAD_id: 66
|
32 |
+
max_length: 2187
|
33 |
+
voxel_token_length: 1886
|
34 |
+
voxel_token_placeholder: -1
|
modules/PartField/configs/final/correspondence_demo.yaml
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result_name: partfield_features/correspondence_demo
|
2 |
+
|
3 |
+
continue_ckpt: model/model.ckpt
|
4 |
+
|
5 |
+
triplane_channels_low: 128
|
6 |
+
triplane_channels_high: 512
|
7 |
+
triplane_resolution: 128
|
8 |
+
|
9 |
+
vertex_feature: True
|
10 |
+
n_point_per_face: 1000
|
11 |
+
n_sample_each: 10000
|
12 |
+
is_pc: True
|
13 |
+
remesh_demo: False
|
14 |
+
correspondence_demo: True
|
15 |
+
|
16 |
+
preprocess_mesh: True
|
17 |
+
|
18 |
+
dataset:
|
19 |
+
type: "Mix"
|
20 |
+
data_path: data/DenseCorr3D
|
21 |
+
train_batch_size: 1
|
22 |
+
val_batch_size: 1
|
23 |
+
train_num_workers: 8
|
24 |
+
all_files:
|
25 |
+
# pairs of example to run correspondence
|
26 |
+
- animals/071b8_toy_animals_017/simple_mesh.obj
|
27 |
+
- animals/bdfd0_toy_animals_016/simple_mesh.obj
|
28 |
+
- animals/2d6b3_toy_animals_009/simple_mesh.obj
|
29 |
+
- animals/96615_toy_animals_018/simple_mesh.obj
|
30 |
+
- chairs/063d1_chair_006/simple_mesh.obj
|
31 |
+
- chairs/bea57_chair_012/simple_mesh.obj
|
32 |
+
- chairs/fe0fe_chair_004/simple_mesh.obj
|
33 |
+
- chairs/288dc_chair_011/simple_mesh.obj
|
34 |
+
# consider decimating animals/../color_mesh.obj yourself for better mesh topology than the provided simple_mesh.obj
|
35 |
+
# (e.g. <50k vertices for functional map efficiency).
|
36 |
+
|
37 |
+
loss:
|
38 |
+
triplet: 1.0
|
39 |
+
|
40 |
+
use_2d_feat: False
|
41 |
+
pvcnn:
|
42 |
+
point_encoder_type: 'pvcnn'
|
43 |
+
z_triplane_channels: 256
|
44 |
+
z_triplane_resolution: 128
|
modules/PartField/configs/final/demo.yaml
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
result_name: demo_test
|
2 |
+
|
3 |
+
continue_ckpt: model/model.ckpt
|
4 |
+
|
5 |
+
triplane_channels_low: 128
|
6 |
+
triplane_channels_high: 512
|
7 |
+
triplane_resolution: 128
|
8 |
+
|
9 |
+
n_point_per_face: 1000
|
10 |
+
n_sample_each: 10000
|
11 |
+
is_pc : True
|
12 |
+
remesh_demo : False
|
13 |
+
|
14 |
+
dataset:
|
15 |
+
type: "Mix"
|
16 |
+
data_path: "objaverse_data"
|
17 |
+
train_batch_size: 1
|
18 |
+
val_batch_size: 1
|
19 |
+
train_num_workers: 8
|
20 |
+
|
21 |
+
loss:
|
22 |
+
triplet: 1.0
|
23 |
+
|
24 |
+
use_2d_feat: False
|
25 |
+
pvcnn:
|
26 |
+
point_encoder_type: 'pvcnn'
|
27 |
+
z_triplane_channels: 256
|
28 |
+
z_triplane_resolution: 128
|
modules/PartField/partfield/config/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os.path as osp
|
3 |
+
from datetime import datetime
|
4 |
+
import pytz
|
5 |
+
|
6 |
+
def default_argument_parser(add_help=True, default_config_file=""):
|
7 |
+
parser = argparse.ArgumentParser(add_help=add_help)
|
8 |
+
parser.add_argument("--config-file", '-c', default=default_config_file, metavar="FILE", help="path to config file")
|
9 |
+
parser.add_argument(
|
10 |
+
"--opts",
|
11 |
+
help="Modify config options using the command-line",
|
12 |
+
default=None,
|
13 |
+
nargs=argparse.REMAINDER,
|
14 |
+
)
|
15 |
+
return parser
|
16 |
+
|
17 |
+
def setup(args, freeze=True):
|
18 |
+
from .defaults import _C as cfg
|
19 |
+
cfg = cfg.clone()
|
20 |
+
cfg.merge_from_file(args.config_file)
|
21 |
+
cfg.merge_from_list(args.opts)
|
22 |
+
dt = datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%y%m%d-%H%M%S')
|
23 |
+
cfg.output_dir = osp.join(cfg.output_dir, cfg.name, dt)
|
24 |
+
if freeze:
|
25 |
+
cfg.freeze()
|
26 |
+
return cfg
|
modules/PartField/partfield/config/defaults.py
ADDED
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from yacs.config import CfgNode as CN
|
2 |
+
|
3 |
+
_C = CN()
|
4 |
+
_C.seed = 0
|
5 |
+
_C.output_dir = "results"
|
6 |
+
_C.result_name = "test_all"
|
7 |
+
|
8 |
+
_C.triplet_sampling = "random"
|
9 |
+
_C.load_original_mesh = False
|
10 |
+
|
11 |
+
_C.num_pos = 64
|
12 |
+
_C.num_neg_random = 256
|
13 |
+
_C.num_neg_hard_pc = 128
|
14 |
+
_C.num_neg_hard_emb = 128
|
15 |
+
|
16 |
+
_C.vertex_feature = False # if true, sample feature on vertices; if false, sample feature on faces
|
17 |
+
_C.n_point_per_face = 2000
|
18 |
+
_C.n_sample_each = 10000
|
19 |
+
_C.preprocess_mesh = False
|
20 |
+
|
21 |
+
_C.regress_2d_feat = False
|
22 |
+
|
23 |
+
_C.is_pc = False
|
24 |
+
|
25 |
+
_C.cut_manifold = False
|
26 |
+
_C.remesh_demo = False
|
27 |
+
_C.correspondence_demo = False
|
28 |
+
|
29 |
+
_C.save_every_epoch = 10
|
30 |
+
_C.training_epochs = 30
|
31 |
+
_C.continue_training = False
|
32 |
+
|
33 |
+
_C.continue_ckpt = None
|
34 |
+
_C.epoch_selected = "epoch=50.ckpt"
|
35 |
+
|
36 |
+
_C.triplane_resolution = 128
|
37 |
+
_C.triplane_channels_low = 128
|
38 |
+
_C.triplane_channels_high = 512
|
39 |
+
_C.lr = 1e-3
|
40 |
+
_C.train = True
|
41 |
+
_C.test = False
|
42 |
+
|
43 |
+
_C.inference_save_pred_sdf_to_mesh=True
|
44 |
+
_C.inference_save_feat_pca=True
|
45 |
+
_C.name = "test"
|
46 |
+
_C.test_subset = False
|
47 |
+
_C.test_corres = False
|
48 |
+
_C.test_partobjaversetiny = False
|
49 |
+
|
50 |
+
_C.dataset = CN()
|
51 |
+
_C.dataset.type = "Demo_Dataset"
|
52 |
+
_C.dataset.data_path = "objaverse_data/"
|
53 |
+
_C.dataset.train_num_workers = 64
|
54 |
+
_C.dataset.val_num_workers = 32
|
55 |
+
_C.dataset.train_batch_size = 2
|
56 |
+
_C.dataset.val_batch_size = 2
|
57 |
+
_C.dataset.all_files = [] # only used for correspondence demo
|
58 |
+
|
59 |
+
_C.voxel2triplane = CN()
|
60 |
+
_C.voxel2triplane.transformer_dim = 1024
|
61 |
+
_C.voxel2triplane.transformer_layers = 6
|
62 |
+
_C.voxel2triplane.transformer_heads = 8
|
63 |
+
_C.voxel2triplane.triplane_low_res = 32
|
64 |
+
_C.voxel2triplane.triplane_high_res = 256
|
65 |
+
_C.voxel2triplane.triplane_dim = 64
|
66 |
+
_C.voxel2triplane.normalize_vox_feat = False
|
67 |
+
|
68 |
+
|
69 |
+
_C.loss = CN()
|
70 |
+
_C.loss.triplet = 0.0
|
71 |
+
_C.loss.sdf = 1.0
|
72 |
+
_C.loss.feat = 10.0
|
73 |
+
_C.loss.l1 = 0.0
|
74 |
+
|
75 |
+
_C.use_pvcnn = False
|
76 |
+
_C.use_pvcnnonly = True
|
77 |
+
|
78 |
+
_C.pvcnn = CN()
|
79 |
+
_C.pvcnn.point_encoder_type = 'pvcnn'
|
80 |
+
_C.pvcnn.use_point_scatter = True
|
81 |
+
_C.pvcnn.z_triplane_channels = 64
|
82 |
+
_C.pvcnn.z_triplane_resolution = 256
|
83 |
+
_C.pvcnn.unet_cfg = CN()
|
84 |
+
_C.pvcnn.unet_cfg.depth = 3
|
85 |
+
_C.pvcnn.unet_cfg.enabled = True
|
86 |
+
_C.pvcnn.unet_cfg.rolled = True
|
87 |
+
_C.pvcnn.unet_cfg.use_3d_aware = True
|
88 |
+
_C.pvcnn.unet_cfg.start_hidden_channels = 32
|
89 |
+
_C.pvcnn.unet_cfg.use_initial_conv = False
|
90 |
+
|
91 |
+
_C.use_2d_feat = False
|
92 |
+
_C.inference_metrics_only = False
|
modules/PartField/partfield/model/PVCNN/conv_pointnet.py
ADDED
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
Taken from gensdf
|
3 |
+
https://github.com/princeton-computational-imaging/gensdf
|
4 |
+
"""
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
# from dnnlib.util import printarr
|
10 |
+
try:
|
11 |
+
from torch_scatter import scatter_mean, scatter_max
|
12 |
+
except:
|
13 |
+
pass
|
14 |
+
# from .unet import UNet
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.nn.functional as F
|
18 |
+
|
19 |
+
|
20 |
+
# Resnet Blocks
|
21 |
+
class ResnetBlockFC(nn.Module):
|
22 |
+
''' Fully connected ResNet Block class.
|
23 |
+
Args:
|
24 |
+
size_in (int): input dimension
|
25 |
+
size_out (int): output dimension
|
26 |
+
size_h (int): hidden dimension
|
27 |
+
'''
|
28 |
+
|
29 |
+
def __init__(self, size_in, size_out=None, size_h=None):
|
30 |
+
super().__init__()
|
31 |
+
# Attributes
|
32 |
+
if size_out is None:
|
33 |
+
size_out = size_in
|
34 |
+
|
35 |
+
if size_h is None:
|
36 |
+
size_h = min(size_in, size_out)
|
37 |
+
|
38 |
+
self.size_in = size_in
|
39 |
+
self.size_h = size_h
|
40 |
+
self.size_out = size_out
|
41 |
+
# Submodules
|
42 |
+
self.fc_0 = nn.Linear(size_in, size_h)
|
43 |
+
self.fc_1 = nn.Linear(size_h, size_out)
|
44 |
+
self.actvn = nn.ReLU()
|
45 |
+
|
46 |
+
if size_in == size_out:
|
47 |
+
self.shortcut = None
|
48 |
+
else:
|
49 |
+
self.shortcut = nn.Linear(size_in, size_out, bias=False)
|
50 |
+
# Initialization
|
51 |
+
nn.init.zeros_(self.fc_1.weight)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
net = self.fc_0(self.actvn(x))
|
55 |
+
dx = self.fc_1(self.actvn(net))
|
56 |
+
|
57 |
+
if self.shortcut is not None:
|
58 |
+
x_s = self.shortcut(x)
|
59 |
+
else:
|
60 |
+
x_s = x
|
61 |
+
|
62 |
+
return x_s + dx
|
63 |
+
|
64 |
+
|
65 |
+
class ConvPointnet(nn.Module):
|
66 |
+
''' PointNet-based encoder network with ResNet blocks for each point.
|
67 |
+
Number of input points are fixed.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
c_dim (int): dimension of latent code c
|
71 |
+
dim (int): input points dimension
|
72 |
+
hidden_dim (int): hidden dimension of the network
|
73 |
+
scatter_type (str): feature aggregation when doing local pooling
|
74 |
+
unet (bool): weather to use U-Net
|
75 |
+
unet_kwargs (str): U-Net parameters
|
76 |
+
plane_resolution (int): defined resolution for plane feature
|
77 |
+
plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
|
78 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
79 |
+
n_blocks (int): number of blocks ResNetBlockFC layers
|
80 |
+
'''
|
81 |
+
|
82 |
+
def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
|
83 |
+
# unet=False, unet_kwargs=None,
|
84 |
+
plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
|
85 |
+
super().__init__()
|
86 |
+
self.c_dim = c_dim
|
87 |
+
|
88 |
+
self.fc_pos = nn.Linear(dim, 2*hidden_dim)
|
89 |
+
self.blocks = nn.ModuleList([
|
90 |
+
ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
|
91 |
+
])
|
92 |
+
self.fc_c = nn.Linear(hidden_dim, c_dim)
|
93 |
+
|
94 |
+
self.actvn = nn.ReLU()
|
95 |
+
self.hidden_dim = hidden_dim
|
96 |
+
|
97 |
+
# if unet:
|
98 |
+
# self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
|
99 |
+
# else:
|
100 |
+
# self.unet = None
|
101 |
+
|
102 |
+
self.reso_plane = plane_resolution
|
103 |
+
self.plane_type = plane_type
|
104 |
+
self.padding = padding
|
105 |
+
|
106 |
+
if scatter_type == 'max':
|
107 |
+
self.scatter = scatter_max
|
108 |
+
elif scatter_type == 'mean':
|
109 |
+
self.scatter = scatter_mean
|
110 |
+
|
111 |
+
|
112 |
+
# takes in "p": point cloud and "query": sdf_xyz
|
113 |
+
# sample plane features for unlabeled_query as well
|
114 |
+
def forward(self, p):#, query2):
|
115 |
+
batch_size, T, D = p.size()
|
116 |
+
|
117 |
+
# acquire the index for each point
|
118 |
+
coord = {}
|
119 |
+
index = {}
|
120 |
+
if 'xz' in self.plane_type:
|
121 |
+
coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
|
122 |
+
index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
|
123 |
+
if 'xy' in self.plane_type:
|
124 |
+
coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
|
125 |
+
index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
|
126 |
+
if 'yz' in self.plane_type:
|
127 |
+
coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
|
128 |
+
index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
|
129 |
+
|
130 |
+
|
131 |
+
net = self.fc_pos(p)
|
132 |
+
|
133 |
+
net = self.blocks[0](net)
|
134 |
+
for block in self.blocks[1:]:
|
135 |
+
pooled = self.pool_local(coord, index, net)
|
136 |
+
net = torch.cat([net, pooled], dim=2)
|
137 |
+
net = block(net)
|
138 |
+
|
139 |
+
c = self.fc_c(net)
|
140 |
+
|
141 |
+
fea = {}
|
142 |
+
plane_feat_sum = 0
|
143 |
+
#second_sum = 0
|
144 |
+
if 'xz' in self.plane_type:
|
145 |
+
fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
|
146 |
+
# plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
|
147 |
+
#second_sum += self.sample_plane_feature(query2, fea['xz'], 'xz')
|
148 |
+
if 'xy' in self.plane_type:
|
149 |
+
fea['xy'] = self.generate_plane_features(p, c, plane='xy')
|
150 |
+
# plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
|
151 |
+
#second_sum += self.sample_plane_feature(query2, fea['xy'], 'xy')
|
152 |
+
if 'yz' in self.plane_type:
|
153 |
+
fea['yz'] = self.generate_plane_features(p, c, plane='yz')
|
154 |
+
# plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
|
155 |
+
#second_sum += self.sample_plane_feature(query2, fea['yz'], 'yz')
|
156 |
+
return fea
|
157 |
+
|
158 |
+
# return plane_feat_sum.transpose(2,1)#, second_sum.transpose(2,1)
|
159 |
+
|
160 |
+
|
161 |
+
def normalize_coordinate(self, p, padding=0.1, plane='xz'):
|
162 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
163 |
+
|
164 |
+
Args:
|
165 |
+
p (tensor): point
|
166 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
167 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
168 |
+
'''
|
169 |
+
if plane == 'xz':
|
170 |
+
xy = p[:, :, [0, 2]]
|
171 |
+
elif plane =='xy':
|
172 |
+
xy = p[:, :, [0, 1]]
|
173 |
+
else:
|
174 |
+
xy = p[:, :, [1, 2]]
|
175 |
+
|
176 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
177 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
178 |
+
|
179 |
+
# f there are outliers out of the range
|
180 |
+
if xy_new.max() >= 1:
|
181 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
182 |
+
if xy_new.min() < 0:
|
183 |
+
xy_new[xy_new < 0] = 0.0
|
184 |
+
return xy_new
|
185 |
+
|
186 |
+
|
187 |
+
def coordinate2index(self, x, reso):
|
188 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
189 |
+
Corresponds to our 3D model
|
190 |
+
|
191 |
+
Args:
|
192 |
+
x (tensor): coordinate
|
193 |
+
reso (int): defined resolution
|
194 |
+
coord_type (str): coordinate type
|
195 |
+
'''
|
196 |
+
x = (x * reso).long()
|
197 |
+
index = x[:, :, 0] + reso * x[:, :, 1]
|
198 |
+
index = index[:, None, :]
|
199 |
+
return index
|
200 |
+
|
201 |
+
|
202 |
+
# xy is the normalized coordinates of the point cloud of each plane
|
203 |
+
# I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
|
204 |
+
def pool_local(self, xy, index, c):
|
205 |
+
bs, fea_dim = c.size(0), c.size(2)
|
206 |
+
keys = xy.keys()
|
207 |
+
|
208 |
+
c_out = 0
|
209 |
+
for key in keys:
|
210 |
+
# scatter plane features from points
|
211 |
+
fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
|
212 |
+
if self.scatter == scatter_max:
|
213 |
+
fea = fea[0]
|
214 |
+
# gather feature back to points
|
215 |
+
fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
|
216 |
+
c_out += fea
|
217 |
+
return c_out.permute(0, 2, 1)
|
218 |
+
|
219 |
+
|
220 |
+
def generate_plane_features(self, p, c, plane='xz'):
|
221 |
+
# acquire indices of features in plane
|
222 |
+
xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
|
223 |
+
index = self.coordinate2index(xy, self.reso_plane)
|
224 |
+
|
225 |
+
# scatter plane features from points
|
226 |
+
fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
|
227 |
+
c = c.permute(0, 2, 1) # B x 512 x T
|
228 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
229 |
+
fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
|
230 |
+
|
231 |
+
# printarr(fea_plane, c, p, xy, index)
|
232 |
+
# import pdb; pdb.set_trace()
|
233 |
+
|
234 |
+
# process the plane features with UNet
|
235 |
+
# if self.unet is not None:
|
236 |
+
# fea_plane = self.unet(fea_plane)
|
237 |
+
|
238 |
+
return fea_plane
|
239 |
+
|
240 |
+
|
241 |
+
# sample_plane_feature function copied from /src/conv_onet/models/decoder.py
|
242 |
+
# uses values from plane_feature and pixel locations from vgrid to interpolate feature
|
243 |
+
def sample_plane_feature(self, query, plane_feature, plane):
|
244 |
+
xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
|
245 |
+
xy = xy[:, :, None].float()
|
246 |
+
vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
|
247 |
+
sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
|
248 |
+
return sampled_feat
|
249 |
+
|
250 |
+
|
251 |
+
|
modules/PartField/partfield/model/PVCNN/dnnlib_util.py
ADDED
@@ -0,0 +1,1074 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
"""Miscellaneous utility classes and functions."""
|
10 |
+
from collections import namedtuple
|
11 |
+
import time
|
12 |
+
import ctypes
|
13 |
+
import fnmatch
|
14 |
+
import importlib
|
15 |
+
import inspect
|
16 |
+
import numpy as np
|
17 |
+
import json
|
18 |
+
import os
|
19 |
+
import shutil
|
20 |
+
import sys
|
21 |
+
import types
|
22 |
+
import io
|
23 |
+
import pickle
|
24 |
+
import re
|
25 |
+
# import requests
|
26 |
+
import html
|
27 |
+
import hashlib
|
28 |
+
import glob
|
29 |
+
import tempfile
|
30 |
+
import urllib
|
31 |
+
import urllib.request
|
32 |
+
import uuid
|
33 |
+
import boto3
|
34 |
+
import threading
|
35 |
+
from contextlib import ContextDecorator
|
36 |
+
from contextlib import contextmanager, nullcontext
|
37 |
+
|
38 |
+
from distutils.util import strtobool
|
39 |
+
from typing import Any, List, Tuple, Union
|
40 |
+
import importlib
|
41 |
+
from loguru import logger
|
42 |
+
# import wandb
|
43 |
+
import torch
|
44 |
+
import psutil
|
45 |
+
import subprocess
|
46 |
+
|
47 |
+
import random
|
48 |
+
import string
|
49 |
+
import pdb
|
50 |
+
|
51 |
+
# Util classes
|
52 |
+
# ------------------------------------------------------------------------------------------
|
53 |
+
|
54 |
+
|
55 |
+
class EasyDict(dict):
|
56 |
+
"""Convenience class that behaves like a dict but allows access with the attribute syntax."""
|
57 |
+
|
58 |
+
def __getattr__(self, name: str) -> Any:
|
59 |
+
try:
|
60 |
+
return self[name]
|
61 |
+
except KeyError:
|
62 |
+
raise AttributeError(name)
|
63 |
+
|
64 |
+
def __setattr__(self, name: str, value: Any) -> None:
|
65 |
+
self[name] = value
|
66 |
+
|
67 |
+
def __delattr__(self, name: str) -> None:
|
68 |
+
del self[name]
|
69 |
+
|
70 |
+
|
71 |
+
class Logger(object):
|
72 |
+
"""Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
|
73 |
+
|
74 |
+
def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
|
75 |
+
self.file = None
|
76 |
+
|
77 |
+
if file_name is not None:
|
78 |
+
self.file = open(file_name, file_mode)
|
79 |
+
|
80 |
+
self.should_flush = should_flush
|
81 |
+
self.stdout = sys.stdout
|
82 |
+
self.stderr = sys.stderr
|
83 |
+
|
84 |
+
sys.stdout = self
|
85 |
+
sys.stderr = self
|
86 |
+
|
87 |
+
def __enter__(self) -> "Logger":
|
88 |
+
return self
|
89 |
+
|
90 |
+
def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
|
91 |
+
self.close()
|
92 |
+
|
93 |
+
def write(self, text: Union[str, bytes]) -> None:
|
94 |
+
"""Write text to stdout (and a file) and optionally flush."""
|
95 |
+
if isinstance(text, bytes):
|
96 |
+
text = text.decode()
|
97 |
+
if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
|
98 |
+
return
|
99 |
+
|
100 |
+
if self.file is not None:
|
101 |
+
self.file.write(text)
|
102 |
+
|
103 |
+
self.stdout.write(text)
|
104 |
+
|
105 |
+
if self.should_flush:
|
106 |
+
self.flush()
|
107 |
+
|
108 |
+
def flush(self) -> None:
|
109 |
+
"""Flush written text to both stdout and a file, if open."""
|
110 |
+
if self.file is not None:
|
111 |
+
self.file.flush()
|
112 |
+
|
113 |
+
self.stdout.flush()
|
114 |
+
|
115 |
+
def close(self) -> None:
|
116 |
+
"""Flush, close possible files, and remove stdout/stderr mirroring."""
|
117 |
+
self.flush()
|
118 |
+
|
119 |
+
# if using multiple loggers, prevent closing in wrong order
|
120 |
+
if sys.stdout is self:
|
121 |
+
sys.stdout = self.stdout
|
122 |
+
if sys.stderr is self:
|
123 |
+
sys.stderr = self.stderr
|
124 |
+
|
125 |
+
if self.file is not None:
|
126 |
+
self.file.close()
|
127 |
+
self.file = None
|
128 |
+
|
129 |
+
|
130 |
+
# Cache directories
|
131 |
+
# ------------------------------------------------------------------------------------------
|
132 |
+
|
133 |
+
_dnnlib_cache_dir = None
|
134 |
+
|
135 |
+
|
136 |
+
def set_cache_dir(path: str) -> None:
|
137 |
+
global _dnnlib_cache_dir
|
138 |
+
_dnnlib_cache_dir = path
|
139 |
+
|
140 |
+
|
141 |
+
def make_cache_dir_path(*paths: str) -> str:
|
142 |
+
if _dnnlib_cache_dir is not None:
|
143 |
+
return os.path.join(_dnnlib_cache_dir, *paths)
|
144 |
+
if 'DNNLIB_CACHE_DIR' in os.environ:
|
145 |
+
return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
|
146 |
+
if 'HOME' in os.environ:
|
147 |
+
return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
|
148 |
+
if 'USERPROFILE' in os.environ:
|
149 |
+
return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
|
150 |
+
return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
|
151 |
+
|
152 |
+
|
153 |
+
# Small util functions
|
154 |
+
# ------------------------------------------------------------------------------------------
|
155 |
+
|
156 |
+
|
157 |
+
def format_time(seconds: Union[int, float]) -> str:
|
158 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
159 |
+
s = int(np.rint(seconds))
|
160 |
+
|
161 |
+
if s < 60:
|
162 |
+
return "{0}s".format(s)
|
163 |
+
elif s < 60 * 60:
|
164 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
165 |
+
elif s < 24 * 60 * 60:
|
166 |
+
return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
|
167 |
+
else:
|
168 |
+
return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
|
169 |
+
|
170 |
+
|
171 |
+
def format_time_brief(seconds: Union[int, float]) -> str:
|
172 |
+
"""Convert the seconds to human readable string with days, hours, minutes and seconds."""
|
173 |
+
s = int(np.rint(seconds))
|
174 |
+
|
175 |
+
if s < 60:
|
176 |
+
return "{0}s".format(s)
|
177 |
+
elif s < 60 * 60:
|
178 |
+
return "{0}m {1:02}s".format(s // 60, s % 60)
|
179 |
+
elif s < 24 * 60 * 60:
|
180 |
+
return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
|
181 |
+
else:
|
182 |
+
return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
|
183 |
+
|
184 |
+
|
185 |
+
def ask_yes_no(question: str) -> bool:
|
186 |
+
"""Ask the user the question until the user inputs a valid answer."""
|
187 |
+
while True:
|
188 |
+
try:
|
189 |
+
print("{0} [y/n]".format(question))
|
190 |
+
return strtobool(input().lower())
|
191 |
+
except ValueError:
|
192 |
+
pass
|
193 |
+
|
194 |
+
|
195 |
+
def tuple_product(t: Tuple) -> Any:
|
196 |
+
"""Calculate the product of the tuple elements."""
|
197 |
+
result = 1
|
198 |
+
|
199 |
+
for v in t:
|
200 |
+
result *= v
|
201 |
+
|
202 |
+
return result
|
203 |
+
|
204 |
+
|
205 |
+
_str_to_ctype = {
|
206 |
+
"uint8": ctypes.c_ubyte,
|
207 |
+
"uint16": ctypes.c_uint16,
|
208 |
+
"uint32": ctypes.c_uint32,
|
209 |
+
"uint64": ctypes.c_uint64,
|
210 |
+
"int8": ctypes.c_byte,
|
211 |
+
"int16": ctypes.c_int16,
|
212 |
+
"int32": ctypes.c_int32,
|
213 |
+
"int64": ctypes.c_int64,
|
214 |
+
"float32": ctypes.c_float,
|
215 |
+
"float64": ctypes.c_double
|
216 |
+
}
|
217 |
+
|
218 |
+
|
219 |
+
def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
|
220 |
+
"""Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
|
221 |
+
type_str = None
|
222 |
+
|
223 |
+
if isinstance(type_obj, str):
|
224 |
+
type_str = type_obj
|
225 |
+
elif hasattr(type_obj, "__name__"):
|
226 |
+
type_str = type_obj.__name__
|
227 |
+
elif hasattr(type_obj, "name"):
|
228 |
+
type_str = type_obj.name
|
229 |
+
else:
|
230 |
+
raise RuntimeError("Cannot infer type name from input")
|
231 |
+
|
232 |
+
assert type_str in _str_to_ctype.keys()
|
233 |
+
|
234 |
+
my_dtype = np.dtype(type_str)
|
235 |
+
my_ctype = _str_to_ctype[type_str]
|
236 |
+
|
237 |
+
assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
|
238 |
+
|
239 |
+
return my_dtype, my_ctype
|
240 |
+
|
241 |
+
|
242 |
+
def is_pickleable(obj: Any) -> bool:
|
243 |
+
try:
|
244 |
+
with io.BytesIO() as stream:
|
245 |
+
pickle.dump(obj, stream)
|
246 |
+
return True
|
247 |
+
except:
|
248 |
+
return False
|
249 |
+
|
250 |
+
|
251 |
+
# Functionality to import modules/objects by name, and call functions by name
|
252 |
+
# ------------------------------------------------------------------------------------------
|
253 |
+
|
254 |
+
def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
|
255 |
+
"""Searches for the underlying module behind the name to some python object.
|
256 |
+
Returns the module and the object name (original name with module part removed)."""
|
257 |
+
|
258 |
+
# allow convenience shorthands, substitute them by full names
|
259 |
+
obj_name = re.sub("^np.", "numpy.", obj_name)
|
260 |
+
obj_name = re.sub("^tf.", "tensorflow.", obj_name)
|
261 |
+
|
262 |
+
# list alternatives for (module_name, local_obj_name)
|
263 |
+
parts = obj_name.split(".")
|
264 |
+
name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
|
265 |
+
|
266 |
+
# try each alternative in turn
|
267 |
+
for module_name, local_obj_name in name_pairs:
|
268 |
+
try:
|
269 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
270 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
271 |
+
return module, local_obj_name
|
272 |
+
except:
|
273 |
+
pass
|
274 |
+
|
275 |
+
# maybe some of the modules themselves contain errors?
|
276 |
+
for module_name, _local_obj_name in name_pairs:
|
277 |
+
try:
|
278 |
+
importlib.import_module(module_name) # may raise ImportError
|
279 |
+
except ImportError:
|
280 |
+
if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
|
281 |
+
raise
|
282 |
+
|
283 |
+
# maybe the requested attribute is missing?
|
284 |
+
for module_name, local_obj_name in name_pairs:
|
285 |
+
try:
|
286 |
+
module = importlib.import_module(module_name) # may raise ImportError
|
287 |
+
get_obj_from_module(module, local_obj_name) # may raise AttributeError
|
288 |
+
except ImportError:
|
289 |
+
pass
|
290 |
+
|
291 |
+
# we are out of luck, but we have no idea why
|
292 |
+
raise ImportError(obj_name)
|
293 |
+
|
294 |
+
|
295 |
+
def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
|
296 |
+
"""Traverses the object name and returns the last (rightmost) python object."""
|
297 |
+
if obj_name == '':
|
298 |
+
return module
|
299 |
+
obj = module
|
300 |
+
for part in obj_name.split("."):
|
301 |
+
obj = getattr(obj, part)
|
302 |
+
return obj
|
303 |
+
|
304 |
+
|
305 |
+
def get_obj_by_name(name: str) -> Any:
|
306 |
+
"""Finds the python object with the given name."""
|
307 |
+
module, obj_name = get_module_from_obj_name(name)
|
308 |
+
return get_obj_from_module(module, obj_name)
|
309 |
+
|
310 |
+
|
311 |
+
def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
|
312 |
+
"""Finds the python object with the given name and calls it as a function."""
|
313 |
+
assert func_name is not None
|
314 |
+
func_obj = get_obj_by_name(func_name)
|
315 |
+
assert callable(func_obj)
|
316 |
+
return func_obj(*args, **kwargs)
|
317 |
+
|
318 |
+
|
319 |
+
def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
|
320 |
+
"""Finds the python class with the given name and constructs it with the given arguments."""
|
321 |
+
return call_func_by_name(*args, func_name=class_name, **kwargs)
|
322 |
+
|
323 |
+
|
324 |
+
def get_module_dir_by_obj_name(obj_name: str) -> str:
|
325 |
+
"""Get the directory path of the module containing the given object name."""
|
326 |
+
module, _ = get_module_from_obj_name(obj_name)
|
327 |
+
return os.path.dirname(inspect.getfile(module))
|
328 |
+
|
329 |
+
|
330 |
+
def is_top_level_function(obj: Any) -> bool:
|
331 |
+
"""Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
|
332 |
+
return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
|
333 |
+
|
334 |
+
|
335 |
+
def get_top_level_function_name(obj: Any) -> str:
|
336 |
+
"""Return the fully-qualified name of a top-level function."""
|
337 |
+
assert is_top_level_function(obj)
|
338 |
+
module = obj.__module__
|
339 |
+
if module == '__main__':
|
340 |
+
module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
|
341 |
+
return module + "." + obj.__name__
|
342 |
+
|
343 |
+
|
344 |
+
# File system helpers
|
345 |
+
# ------------------------------------------------------------------------------------------
|
346 |
+
|
347 |
+
def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
|
348 |
+
"""List all files recursively in a given directory while ignoring given file and directory names.
|
349 |
+
Returns list of tuples containing both absolute and relative paths."""
|
350 |
+
assert os.path.isdir(dir_path)
|
351 |
+
base_name = os.path.basename(os.path.normpath(dir_path))
|
352 |
+
|
353 |
+
if ignores is None:
|
354 |
+
ignores = []
|
355 |
+
|
356 |
+
result = []
|
357 |
+
|
358 |
+
for root, dirs, files in os.walk(dir_path, topdown=True):
|
359 |
+
for ignore_ in ignores:
|
360 |
+
dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
|
361 |
+
|
362 |
+
# dirs need to be edited in-place
|
363 |
+
for d in dirs_to_remove:
|
364 |
+
dirs.remove(d)
|
365 |
+
|
366 |
+
files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
|
367 |
+
|
368 |
+
absolute_paths = [os.path.join(root, f) for f in files]
|
369 |
+
relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
|
370 |
+
|
371 |
+
if add_base_to_relative:
|
372 |
+
relative_paths = [os.path.join(base_name, p) for p in relative_paths]
|
373 |
+
|
374 |
+
assert len(absolute_paths) == len(relative_paths)
|
375 |
+
result += zip(absolute_paths, relative_paths)
|
376 |
+
|
377 |
+
return result
|
378 |
+
|
379 |
+
|
380 |
+
def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
|
381 |
+
"""Takes in a list of tuples of (src, dst) paths and copies files.
|
382 |
+
Will create all necessary directories."""
|
383 |
+
for file in files:
|
384 |
+
target_dir_name = os.path.dirname(file[1])
|
385 |
+
|
386 |
+
# will create all intermediate-level directories
|
387 |
+
if not os.path.exists(target_dir_name):
|
388 |
+
os.makedirs(target_dir_name)
|
389 |
+
|
390 |
+
shutil.copyfile(file[0], file[1])
|
391 |
+
|
392 |
+
|
393 |
+
# URL helpers
|
394 |
+
# ------------------------------------------------------------------------------------------
|
395 |
+
|
396 |
+
def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
|
397 |
+
"""Determine whether the given object is a valid URL string."""
|
398 |
+
if not isinstance(obj, str) or not "://" in obj:
|
399 |
+
return False
|
400 |
+
if allow_file_urls and obj.startswith('file://'):
|
401 |
+
return True
|
402 |
+
try:
|
403 |
+
res = requests.compat.urlparse(obj)
|
404 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
405 |
+
return False
|
406 |
+
res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
|
407 |
+
if not res.scheme or not res.netloc or not "." in res.netloc:
|
408 |
+
return False
|
409 |
+
except:
|
410 |
+
return False
|
411 |
+
return True
|
412 |
+
|
413 |
+
|
414 |
+
def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
|
415 |
+
"""Download the given URL and return a binary-mode file object to access the data."""
|
416 |
+
assert num_attempts >= 1
|
417 |
+
assert not (return_filename and (not cache))
|
418 |
+
|
419 |
+
# Doesn't look like an URL scheme so interpret it as a local filename.
|
420 |
+
if not re.match('^[a-z]+://', url):
|
421 |
+
return url if return_filename else open(url, "rb")
|
422 |
+
|
423 |
+
# Handle file URLs. This code handles unusual file:// patterns that
|
424 |
+
# arise on Windows:
|
425 |
+
#
|
426 |
+
# file:///c:/foo.txt
|
427 |
+
#
|
428 |
+
# which would translate to a local '/c:/foo.txt' filename that's
|
429 |
+
# invalid. Drop the forward slash for such pathnames.
|
430 |
+
#
|
431 |
+
# If you touch this code path, you should test it on both Linux and
|
432 |
+
# Windows.
|
433 |
+
#
|
434 |
+
# Some internet resources suggest using urllib.request.url2pathname() but
|
435 |
+
# but that converts forward slashes to backslashes and this causes
|
436 |
+
# its own set of problems.
|
437 |
+
if url.startswith('file://'):
|
438 |
+
filename = urllib.parse.urlparse(url).path
|
439 |
+
if re.match(r'^/[a-zA-Z]:', filename):
|
440 |
+
filename = filename[1:]
|
441 |
+
return filename if return_filename else open(filename, "rb")
|
442 |
+
|
443 |
+
assert is_url(url)
|
444 |
+
|
445 |
+
# Lookup from cache.
|
446 |
+
if cache_dir is None:
|
447 |
+
cache_dir = make_cache_dir_path('downloads')
|
448 |
+
|
449 |
+
url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
|
450 |
+
if cache:
|
451 |
+
cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
|
452 |
+
if len(cache_files) == 1:
|
453 |
+
filename = cache_files[0]
|
454 |
+
return filename if return_filename else open(filename, "rb")
|
455 |
+
|
456 |
+
# Download.
|
457 |
+
url_name = None
|
458 |
+
url_data = None
|
459 |
+
with requests.Session() as session:
|
460 |
+
if verbose:
|
461 |
+
print("Downloading %s ..." % url, end="", flush=True)
|
462 |
+
for attempts_left in reversed(range(num_attempts)):
|
463 |
+
try:
|
464 |
+
with session.get(url) as res:
|
465 |
+
res.raise_for_status()
|
466 |
+
if len(res.content) == 0:
|
467 |
+
raise IOError("No data received")
|
468 |
+
|
469 |
+
if len(res.content) < 8192:
|
470 |
+
content_str = res.content.decode("utf-8")
|
471 |
+
if "download_warning" in res.headers.get("Set-Cookie", ""):
|
472 |
+
links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
|
473 |
+
if len(links) == 1:
|
474 |
+
url = requests.compat.urljoin(url, links[0])
|
475 |
+
raise IOError("Google Drive virus checker nag")
|
476 |
+
if "Google Drive - Quota exceeded" in content_str:
|
477 |
+
raise IOError("Google Drive download quota exceeded -- please try again later")
|
478 |
+
|
479 |
+
match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
|
480 |
+
url_name = match[1] if match else url
|
481 |
+
url_data = res.content
|
482 |
+
if verbose:
|
483 |
+
print(" done")
|
484 |
+
break
|
485 |
+
except KeyboardInterrupt:
|
486 |
+
raise
|
487 |
+
except:
|
488 |
+
if not attempts_left:
|
489 |
+
if verbose:
|
490 |
+
print(" failed")
|
491 |
+
raise
|
492 |
+
if verbose:
|
493 |
+
print(".", end="", flush=True)
|
494 |
+
|
495 |
+
# Save to cache.
|
496 |
+
if cache:
|
497 |
+
safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
|
498 |
+
cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
|
499 |
+
temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
|
500 |
+
os.makedirs(cache_dir, exist_ok=True)
|
501 |
+
with open(temp_file, "wb") as f:
|
502 |
+
f.write(url_data)
|
503 |
+
os.replace(temp_file, cache_file) # atomic
|
504 |
+
if return_filename:
|
505 |
+
return cache_file
|
506 |
+
|
507 |
+
# Return data as file object.
|
508 |
+
assert not return_filename
|
509 |
+
return io.BytesIO(url_data)
|
510 |
+
|
511 |
+
# ------------------------------------------------------------------------------------------
|
512 |
+
# util function modified from https://github.com/nv-tlabs/LION/blob/0467d2199076e95a7e88bafd99dcd7d48a04b4a7/utils/model_helper.py
|
513 |
+
def import_class(model_str):
|
514 |
+
from torch_utils.dist_utils import is_rank0
|
515 |
+
if is_rank0():
|
516 |
+
logger.info('import: {}', model_str)
|
517 |
+
p, m = model_str.rsplit('.', 1)
|
518 |
+
mod = importlib.import_module(p)
|
519 |
+
Model = getattr(mod, m)
|
520 |
+
return Model
|
521 |
+
|
522 |
+
class ScopedTorchProfiler(ContextDecorator):
|
523 |
+
"""
|
524 |
+
Marks ranges for both nvtx profiling (with nsys) and torch autograd profiler
|
525 |
+
"""
|
526 |
+
__global_counts = {}
|
527 |
+
enabled=False
|
528 |
+
|
529 |
+
def __init__(self, unique_name: str):
|
530 |
+
"""
|
531 |
+
Names must be unique!
|
532 |
+
"""
|
533 |
+
ScopedTorchProfiler.__global_counts[unique_name] = 0
|
534 |
+
self._name = unique_name
|
535 |
+
self._autograd_scope = torch.profiler.record_function(unique_name)
|
536 |
+
|
537 |
+
def __enter__(self):
|
538 |
+
if ScopedTorchProfiler.enabled:
|
539 |
+
torch.cuda.nvtx.range_push(self._name)
|
540 |
+
self._autograd_scope.__enter__()
|
541 |
+
|
542 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
543 |
+
self._autograd_scope.__exit__(exc_type, exc_value, traceback)
|
544 |
+
if ScopedTorchProfiler.enabled:
|
545 |
+
torch.cuda.nvtx.range_pop()
|
546 |
+
|
547 |
+
class TimingsMonitor():
|
548 |
+
CUDATimer = namedtuple('CUDATimer', ['start', 'end'])
|
549 |
+
def __init__(self, device, enabled=True, timing_names:List[str]=[], cuda_timing_names:List[str]=[]):
|
550 |
+
"""
|
551 |
+
Usage:
|
552 |
+
tmonitor = TimingsMonitor(device)
|
553 |
+
for i in range(n_iter):
|
554 |
+
# Record arbitrary scopes
|
555 |
+
with tmonitor.timing_scope('regular_scope_name'):
|
556 |
+
...
|
557 |
+
with tmonitor.cuda_timing_scope('nested_scope_name'):
|
558 |
+
...
|
559 |
+
with tmonitor.cuda_timing_scope('cuda_scope_name'):
|
560 |
+
...
|
561 |
+
tmonitor.record_timing('duration_name', end_time - start_time)
|
562 |
+
|
563 |
+
# Gather timings
|
564 |
+
tmonitor.record_all_cuda_timings()
|
565 |
+
tmonitor.update_all_averages()
|
566 |
+
averages = tmonitor.get_average_timings()
|
567 |
+
all_timings = tmonitor.get_timings()
|
568 |
+
|
569 |
+
Two types of timers, standard report timing and cuda timings.
|
570 |
+
Cuda timing supports scoped context manager cuda_event_scope.
|
571 |
+
Args:
|
572 |
+
device: device to time on (needed for cuda timers)
|
573 |
+
# enabled: HACK to only report timings from rank 0, set enabled=(global_rank==0)
|
574 |
+
timing_names: timings to report optional (will auto add new names)
|
575 |
+
cuda_timing_names: cuda periods to time optional (will auto add new names)
|
576 |
+
"""
|
577 |
+
self.enabled=enabled
|
578 |
+
self.device = device
|
579 |
+
|
580 |
+
# Normal timing
|
581 |
+
# self.all_timings_dict = {k:None for k in timing_names + cuda_timing_names}
|
582 |
+
self.all_timings_dict = {}
|
583 |
+
self.avg_meter_dict = {}
|
584 |
+
|
585 |
+
# Cuda event timers to measure time spent on pushing data to gpu and on training step
|
586 |
+
self.cuda_event_timers = {}
|
587 |
+
|
588 |
+
for k in timing_names:
|
589 |
+
self.add_new_timing(k)
|
590 |
+
|
591 |
+
for k in cuda_timing_names:
|
592 |
+
self.add_new_cuda_timing(k)
|
593 |
+
|
594 |
+
# Running averages
|
595 |
+
# self.avg_meter_dict = {k:AverageMeter() for k in self.all_timings_dict}
|
596 |
+
|
597 |
+
def add_new_timing(self, name):
|
598 |
+
self.avg_meter_dict[name] = AverageMeter()
|
599 |
+
self.all_timings_dict[name] = None
|
600 |
+
|
601 |
+
def add_new_cuda_timing(self, name):
|
602 |
+
start_event = torch.cuda.Event(enable_timing=True)
|
603 |
+
end_event = torch.cuda.Event(enable_timing=True)
|
604 |
+
self.cuda_event_timers[name] = self.CUDATimer(start=start_event, end=end_event)
|
605 |
+
self.add_new_timing(name)
|
606 |
+
|
607 |
+
def clear_timings(self):
|
608 |
+
self.all_timings_dict = {k:None for k in self.all_timings_dict}
|
609 |
+
|
610 |
+
def get_timings(self):
|
611 |
+
return self.all_timings_dict
|
612 |
+
|
613 |
+
def get_average_timings(self):
|
614 |
+
return {k:v.avg for k,v in self.avg_meter_dict.items()}
|
615 |
+
|
616 |
+
def update_all_averages(self):
|
617 |
+
"""
|
618 |
+
Once per iter, when timings have been finished recording, one should
|
619 |
+
call update_average_iter to keep running average of timings.
|
620 |
+
"""
|
621 |
+
for k,v in self.all_timings_dict.items():
|
622 |
+
if v is None:
|
623 |
+
print("none_timing", k)
|
624 |
+
continue
|
625 |
+
self.avg_meter_dict[k].update(v)
|
626 |
+
|
627 |
+
def record_timing(self, name, value):
|
628 |
+
if name not in self.all_timings_dict: self.add_new_timing(name)
|
629 |
+
# assert name in self.all_timings_dict
|
630 |
+
self.all_timings_dict[name] = value
|
631 |
+
|
632 |
+
def _record_cuda_event_start(self, name):
|
633 |
+
if name in self.cuda_event_timers:
|
634 |
+
self.cuda_event_timers[name].start.record(
|
635 |
+
torch.cuda.current_stream(self.device))
|
636 |
+
|
637 |
+
def _record_cuda_event_end(self, name):
|
638 |
+
if name in self.cuda_event_timers:
|
639 |
+
self.cuda_event_timers[name].end.record(
|
640 |
+
torch.cuda.current_stream(self.device))
|
641 |
+
|
642 |
+
@contextmanager
|
643 |
+
def cuda_timing_scope(self, name, profile=True):
|
644 |
+
if name not in self.all_timings_dict: self.add_new_cuda_timing(name)
|
645 |
+
with ScopedTorchProfiler(name) if profile else nullcontext():
|
646 |
+
self._record_cuda_event_start(name)
|
647 |
+
try:
|
648 |
+
yield
|
649 |
+
finally:
|
650 |
+
self._record_cuda_event_end(name)
|
651 |
+
|
652 |
+
@contextmanager
|
653 |
+
def timing_scope(self, name, profile=True):
|
654 |
+
if name not in self.all_timings_dict: self.add_new_timing(name)
|
655 |
+
with ScopedTorchProfiler(name) if profile else nullcontext():
|
656 |
+
start_time = time.time()
|
657 |
+
try:
|
658 |
+
yield
|
659 |
+
finally:
|
660 |
+
self.record_timing(name, time.time()-start_time)
|
661 |
+
|
662 |
+
def record_all_cuda_timings(self):
|
663 |
+
""" After all the cuda events call this to synchronize and record down the cuda timings. """
|
664 |
+
for k, events in self.cuda_event_timers.items():
|
665 |
+
with torch.no_grad():
|
666 |
+
events.end.synchronize()
|
667 |
+
# Convert to seconds
|
668 |
+
time_elapsed = events.start.elapsed_time(events.end)/1000.
|
669 |
+
self.all_timings_dict[k] = time_elapsed
|
670 |
+
|
671 |
+
def init_s3(config_file):
|
672 |
+
config = json.load(open(config_file, 'r'))
|
673 |
+
s3_client = boto3.client("s3", **config)
|
674 |
+
return s3_client
|
675 |
+
|
676 |
+
def download_from_s3(file_path, target_path, cfg):
|
677 |
+
tic = time.time()
|
678 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
679 |
+
bucket_name = file_path.split('/')[2]
|
680 |
+
file_key = file_path.split(bucket_name+'/')[-1]
|
681 |
+
print(bucket_name, file_key)
|
682 |
+
s3_client.download_file(bucket_name, file_key, target_path)
|
683 |
+
logger.info(f'finish download from ! s3://{bucket_name}/{file_key} to {target_path} %.1f sec'%(
|
684 |
+
time.time() - tic))
|
685 |
+
|
686 |
+
def upload_to_s3(buffer, bucket_name, key, config_dict):
|
687 |
+
logger.info(f'start upload_to_s3! bucket_name={bucket_name}, key={key}')
|
688 |
+
tic = time.time()
|
689 |
+
s3 = boto3.client('s3', **config_dict)
|
690 |
+
s3.put_object(Bucket=bucket_name, Key=key, Body=buffer.getvalue())
|
691 |
+
logger.info(f'finish upload_to_s3! s3://{bucket_name}/{key} %.1f sec'%(time.time() - tic))
|
692 |
+
|
693 |
+
def write_ckpt_to_s3(cfg, all_model_dict, ckpt_name):
|
694 |
+
buffer = io.BytesIO()
|
695 |
+
tic = time.time()
|
696 |
+
torch.save(all_model_dict, buffer) # take ~0.25 sec
|
697 |
+
# logger.info('write ckpt to buffer: %.2f sec'%(time.time() - tic))
|
698 |
+
group, name = cfg.outdir.rstrip("/").split("/")[-2:]
|
699 |
+
key = f"checkpoints/{group}/{name}/ckpt/{ckpt_name}"
|
700 |
+
bucket_name = cfg.checkpoint.write_s3_bucket
|
701 |
+
|
702 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
703 |
+
|
704 |
+
config_dict = json.load(open(cfg.checkpoint.write_s3_config, 'r'))
|
705 |
+
upload_thread = threading.Thread(target=upload_to_s3, args=(buffer, bucket_name, key, config_dict))
|
706 |
+
upload_thread.start()
|
707 |
+
path = f"s3://{bucket_name}/{key}"
|
708 |
+
return path
|
709 |
+
|
710 |
+
def upload_file_to_s3(cfg, file_path, key_name=None):
|
711 |
+
# file_path is the local file path, can be a yaml file
|
712 |
+
# this function is used to upload the ckecpoint only
|
713 |
+
tic = time.time()
|
714 |
+
group, name = cfg.outdir.rstrip("/").split("/")[-2:]
|
715 |
+
if key_name is None:
|
716 |
+
key = os.path.basename(file_path)
|
717 |
+
key = f"checkpoints/{group}/{name}/{key}"
|
718 |
+
bucket_name = cfg.checkpoint.write_s3_bucket
|
719 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config)
|
720 |
+
# Upload the file
|
721 |
+
with open(file_path, 'rb') as f:
|
722 |
+
s3_client.upload_fileobj(f, bucket_name, key)
|
723 |
+
full_s3_path = f"s3://{bucket_name}/{key}"
|
724 |
+
logger.info(f'upload_to_s3: {file_path} {full_s3_path} | use time: {time.time()-tic}')
|
725 |
+
|
726 |
+
return full_s3_path
|
727 |
+
|
728 |
+
|
729 |
+
def load_from_s3(file_path, cfg, load_fn):
|
730 |
+
"""
|
731 |
+
ckpt_path example:
|
732 |
+
s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
|
733 |
+
"""
|
734 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
735 |
+
bucket_name = file_path.split("s3://")[-1].split('/')[0]
|
736 |
+
key = file_path.split(f'{bucket_name}/')[-1]
|
737 |
+
# logger.info(f"-> try to load s3://{bucket_name}/{key} ")
|
738 |
+
tic = time.time()
|
739 |
+
for attemp in range(10):
|
740 |
+
try:
|
741 |
+
# Download the state dict from S3 into memory (as a binary stream)
|
742 |
+
with io.BytesIO() as buffer:
|
743 |
+
s3_client.download_fileobj(bucket_name, key, buffer)
|
744 |
+
buffer.seek(0)
|
745 |
+
|
746 |
+
# Load the state dict into a PyTorch model
|
747 |
+
# out = torch.load(buffer, map_location=torch.device("cpu"))
|
748 |
+
out = load_fn(buffer)
|
749 |
+
break
|
750 |
+
except:
|
751 |
+
logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
|
752 |
+
from torch_utils.dist_utils import is_rank0
|
753 |
+
if is_rank0():
|
754 |
+
logger.info(f'loaded {file_path} | use time: {time.time()-tic:.1f} sec')
|
755 |
+
return out
|
756 |
+
|
757 |
+
def load_torch_dict_from_s3(ckpt_path, cfg):
|
758 |
+
"""
|
759 |
+
ckpt_path example:
|
760 |
+
s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
|
761 |
+
"""
|
762 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
|
763 |
+
bucket_name = ckpt_path.split("s3://")[-1].split('/')[0]
|
764 |
+
key = ckpt_path.split(f'{bucket_name}/')[-1]
|
765 |
+
for attemp in range(10):
|
766 |
+
try:
|
767 |
+
# Download the state dict from S3 into memory (as a binary stream)
|
768 |
+
with io.BytesIO() as buffer:
|
769 |
+
s3_client.download_fileobj(bucket_name, key, buffer)
|
770 |
+
buffer.seek(0)
|
771 |
+
|
772 |
+
# Load the state dict into a PyTorch model
|
773 |
+
out = torch.load(buffer, map_location=torch.device("cpu"))
|
774 |
+
break
|
775 |
+
except:
|
776 |
+
logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
|
777 |
+
return out
|
778 |
+
|
779 |
+
def count_parameters_in_M(model):
|
780 |
+
return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
|
781 |
+
|
782 |
+
def printarr(*arrs, float_width=6, **kwargs):
|
783 |
+
"""
|
784 |
+
Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.
|
785 |
+
|
786 |
+
Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.
|
787 |
+
|
788 |
+
Inputs can be:
|
789 |
+
- Numpy tensor arrays
|
790 |
+
- Pytorch tensor arrays
|
791 |
+
- Jax tensor arrays
|
792 |
+
- Python ints / floats
|
793 |
+
- None
|
794 |
+
|
795 |
+
It may also work with other array-like types, but they have not been tested.
|
796 |
+
|
797 |
+
Use the `float_width` option specify the precision to which floating point types are printed.
|
798 |
+
|
799 |
+
Author: Nicholas Sharp (nmwsharp.com)
|
800 |
+
Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
|
801 |
+
License: This snippet may be used under an MIT license, and it is also released into the public domain.
|
802 |
+
Please retain this docstring as a reference.
|
803 |
+
"""
|
804 |
+
|
805 |
+
frame = inspect.currentframe().f_back
|
806 |
+
default_name = "[temporary]"
|
807 |
+
|
808 |
+
## helpers to gather data about each array
|
809 |
+
def name_from_outer_scope(a):
|
810 |
+
if a is None:
|
811 |
+
return '[None]'
|
812 |
+
name = default_name
|
813 |
+
for k, v in frame.f_locals.items():
|
814 |
+
if v is a:
|
815 |
+
name = k
|
816 |
+
break
|
817 |
+
return name
|
818 |
+
|
819 |
+
def type_strip(type_str):
|
820 |
+
return type_str.lstrip('<class ').rstrip('>').replace('torch.', '').strip("'")
|
821 |
+
|
822 |
+
def dtype_str(a):
|
823 |
+
if a is None:
|
824 |
+
return 'None'
|
825 |
+
if isinstance(a, int):
|
826 |
+
return 'int'
|
827 |
+
if isinstance(a, float):
|
828 |
+
return 'float'
|
829 |
+
if isinstance(a, list) and len(a)>0:
|
830 |
+
return type_strip(str(type(a[0])))
|
831 |
+
if hasattr(a, 'dtype'):
|
832 |
+
return type_strip(str(a.dtype))
|
833 |
+
else:
|
834 |
+
return ''
|
835 |
+
def shape_str(a):
|
836 |
+
if a is None:
|
837 |
+
return 'N/A'
|
838 |
+
if isinstance(a, int):
|
839 |
+
return 'scalar'
|
840 |
+
if isinstance(a, float):
|
841 |
+
return 'scalar'
|
842 |
+
if isinstance(a, list):
|
843 |
+
return f"[{shape_str(a[0]) if len(a)>0 else '?'}]*{len(a)}"
|
844 |
+
if hasattr(a, 'shape'):
|
845 |
+
return str(tuple(a.shape))
|
846 |
+
else:
|
847 |
+
return ''
|
848 |
+
def type_str(a):
|
849 |
+
return type_strip(str(type(a))) # TODO this is is weird... what's the better way?
|
850 |
+
def device_str(a):
|
851 |
+
if hasattr(a, 'device'):
|
852 |
+
device_str = str(a.device)
|
853 |
+
if len(device_str) < 10:
|
854 |
+
# heuristic: jax returns some goofy long string we don't want, ignore it
|
855 |
+
return device_str
|
856 |
+
return ""
|
857 |
+
def format_float(x):
|
858 |
+
return f"{x:{float_width}g}"
|
859 |
+
def minmaxmean_str(a):
|
860 |
+
if a is None:
|
861 |
+
return ('N/A', 'N/A', 'N/A', 'N/A')
|
862 |
+
if isinstance(a, int) or isinstance(a, float):
|
863 |
+
return (format_float(a),)*4
|
864 |
+
|
865 |
+
# compute min/max/mean. if anything goes wrong, just print 'N/A'
|
866 |
+
min_str = "N/A"
|
867 |
+
try: min_str = format_float(a.min())
|
868 |
+
except: pass
|
869 |
+
max_str = "N/A"
|
870 |
+
try: max_str = format_float(a.max())
|
871 |
+
except: pass
|
872 |
+
mean_str = "N/A"
|
873 |
+
try: mean_str = format_float(a.mean())
|
874 |
+
except: pass
|
875 |
+
try: median_str = format_float(a.median())
|
876 |
+
except:
|
877 |
+
try: median_str = format_float(np.median(np.array(a)))
|
878 |
+
except: median_str = 'N/A'
|
879 |
+
return (min_str, max_str, mean_str, median_str)
|
880 |
+
|
881 |
+
def get_prop_dict(a,k=None):
|
882 |
+
minmaxmean = minmaxmean_str(a)
|
883 |
+
props = {
|
884 |
+
'name' : name_from_outer_scope(a) if k is None else k,
|
885 |
+
# 'type' : str(type(a)).replace('torch.',''),
|
886 |
+
'dtype' : dtype_str(a),
|
887 |
+
'shape' : shape_str(a),
|
888 |
+
'type' : type_str(a),
|
889 |
+
'device' : device_str(a),
|
890 |
+
'min' : minmaxmean[0],
|
891 |
+
'max' : minmaxmean[1],
|
892 |
+
'mean' : minmaxmean[2],
|
893 |
+
'median': minmaxmean[3]
|
894 |
+
}
|
895 |
+
return props
|
896 |
+
|
897 |
+
try:
|
898 |
+
|
899 |
+
props = ['name', 'type', 'dtype', 'shape', 'device', 'min', 'max', 'mean', 'median']
|
900 |
+
|
901 |
+
# precompute all of the properties for each input
|
902 |
+
str_props = []
|
903 |
+
for a in arrs:
|
904 |
+
str_props.append(get_prop_dict(a))
|
905 |
+
for k,a in kwargs.items():
|
906 |
+
str_props.append(get_prop_dict(a, k=k))
|
907 |
+
|
908 |
+
# for each property, compute its length
|
909 |
+
maxlen = {}
|
910 |
+
for p in props: maxlen[p] = 0
|
911 |
+
for sp in str_props:
|
912 |
+
for p in props:
|
913 |
+
maxlen[p] = max(maxlen[p], len(sp[p]))
|
914 |
+
|
915 |
+
# if any property got all empty strings, don't bother printing it, remove if from the list
|
916 |
+
props = [p for p in props if maxlen[p] > 0]
|
917 |
+
|
918 |
+
# print a header
|
919 |
+
header_str = ""
|
920 |
+
for p in props:
|
921 |
+
prefix = "" if p == 'name' else " | "
|
922 |
+
fmt_key = ">" if p == 'name' else "<"
|
923 |
+
header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
|
924 |
+
print(header_str)
|
925 |
+
print("-"*len(header_str))
|
926 |
+
|
927 |
+
# now print the acual arrays
|
928 |
+
for strp in str_props:
|
929 |
+
for p in props:
|
930 |
+
prefix = "" if p == 'name' else " | "
|
931 |
+
fmt_key = ">" if p == 'name' else "<"
|
932 |
+
print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='')
|
933 |
+
print("")
|
934 |
+
|
935 |
+
finally:
|
936 |
+
del frame
|
937 |
+
|
938 |
+
def debug_print_all_tensor_sizes(min_tot_size = 0):
|
939 |
+
import gc
|
940 |
+
print("---------------------------------------"*3)
|
941 |
+
for obj in gc.get_objects():
|
942 |
+
try:
|
943 |
+
if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
|
944 |
+
if np.prod(obj.size())>=min_tot_size:
|
945 |
+
print(type(obj), obj.size())
|
946 |
+
except:
|
947 |
+
pass
|
948 |
+
def print_cpu_usage():
|
949 |
+
|
950 |
+
# Get current CPU usage as a percentage
|
951 |
+
cpu_usage = psutil.cpu_percent()
|
952 |
+
|
953 |
+
# Get current memory usage
|
954 |
+
memory_usage = psutil.virtual_memory().used
|
955 |
+
|
956 |
+
# Convert memory usage to a human-readable format
|
957 |
+
memory_usage_str = psutil._common.bytes2human(memory_usage)
|
958 |
+
|
959 |
+
# Print CPU and memory usage
|
960 |
+
msg = f"Current CPU usage: {cpu_usage}% | "
|
961 |
+
msg += f"Current memory usage: {memory_usage_str}"
|
962 |
+
return msg
|
963 |
+
|
964 |
+
def calmsize(num_bytes):
|
965 |
+
if math.isnan(num_bytes):
|
966 |
+
return ''
|
967 |
+
for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
|
968 |
+
if abs(num_bytes) < 1024.0:
|
969 |
+
return "{:.1f}{}B".format(num_bytes, unit)
|
970 |
+
num_bytes /= 1024.0
|
971 |
+
return "{:.1f}{}B".format(num_bytes, 'Y')
|
972 |
+
|
973 |
+
def readable_size(num_bytes: int) -> str:
|
974 |
+
return calmsize(num_bytes) ## '' if math.isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes))
|
975 |
+
|
976 |
+
def get_gpu_memory():
|
977 |
+
"""
|
978 |
+
Get the current GPU memory usage for each device as a dictionary
|
979 |
+
"""
|
980 |
+
output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv"])
|
981 |
+
output = output.decode("utf-8")
|
982 |
+
gpu_memory_values = output.split("\n")[1:-1]
|
983 |
+
gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
|
984 |
+
gpu_memory = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
|
985 |
+
return gpu_memory
|
986 |
+
|
987 |
+
def get_gpu_util():
|
988 |
+
"""
|
989 |
+
Get the current GPU memory usage for each device as a dictionary
|
990 |
+
"""
|
991 |
+
output = subprocess.check_output(["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv"])
|
992 |
+
output = output.decode("utf-8")
|
993 |
+
gpu_memory_values = output.split("\n")[1:-1]
|
994 |
+
gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
|
995 |
+
gpu_util = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
|
996 |
+
return gpu_util
|
997 |
+
|
998 |
+
|
999 |
+
def print_gpu_usage():
|
1000 |
+
useage = get_gpu_memory()
|
1001 |
+
msg = f" | GPU usage: "
|
1002 |
+
for k, v in useage.items():
|
1003 |
+
msg += f"{k}: {v} MB "
|
1004 |
+
# utilization = get_gpu_util()
|
1005 |
+
# msg + ' | util '
|
1006 |
+
# for k, v in utilization.items():
|
1007 |
+
# msg += f"{k}: {v} % "
|
1008 |
+
return msg
|
1009 |
+
|
1010 |
+
class AverageMeter(object):
|
1011 |
+
|
1012 |
+
def __init__(self):
|
1013 |
+
self.reset()
|
1014 |
+
|
1015 |
+
def reset(self):
|
1016 |
+
self.avg = 0
|
1017 |
+
self.sum = 0
|
1018 |
+
self.cnt = 0
|
1019 |
+
|
1020 |
+
def update(self, val, n=1):
|
1021 |
+
self.sum += val * n
|
1022 |
+
self.cnt += n
|
1023 |
+
self.avg = self.sum / self.cnt
|
1024 |
+
|
1025 |
+
|
1026 |
+
def generate_random_string(length):
|
1027 |
+
# This script will generate a string of 10 random ASCII letters (both lowercase and uppercase).
|
1028 |
+
# You can adjust the length parameter to fit your needs.
|
1029 |
+
letters = string.ascii_letters
|
1030 |
+
return ''.join(random.choice(letters) for _ in range(length))
|
1031 |
+
|
1032 |
+
|
1033 |
+
class ForkedPdb(pdb.Pdb):
|
1034 |
+
"""
|
1035 |
+
PDB Subclass for debugging multi-processed code
|
1036 |
+
Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
|
1037 |
+
"""
|
1038 |
+
def interaction(self, *args, **kwargs):
|
1039 |
+
_stdin = sys.stdin
|
1040 |
+
try:
|
1041 |
+
sys.stdin = open('/dev/stdin')
|
1042 |
+
pdb.Pdb.interaction(self, *args, **kwargs)
|
1043 |
+
finally:
|
1044 |
+
sys.stdin = _stdin
|
1045 |
+
|
1046 |
+
def check_exist_in_s3(file_path, s3_config):
|
1047 |
+
s3 = init_s3(s3_config)
|
1048 |
+
bucket_name, object_name = s3path_to_bucket_key(file_path)
|
1049 |
+
|
1050 |
+
try:
|
1051 |
+
s3.head_object(Bucket=bucket_name, Key=object_name)
|
1052 |
+
return 1
|
1053 |
+
except:
|
1054 |
+
logger.info(f'file not found: s3://{bucket_name}/{object_name}')
|
1055 |
+
return 0
|
1056 |
+
|
1057 |
+
def s3path_to_bucket_key(file_path):
|
1058 |
+
bucket_name = file_path.split('/')[2]
|
1059 |
+
object_name = file_path.split(bucket_name + '/')[-1]
|
1060 |
+
return bucket_name, object_name
|
1061 |
+
|
1062 |
+
def copy_file_to_s3(cfg, file_path_local, file_path_s3):
|
1063 |
+
# work similar as upload_file_to_s3, but not trying to parse the file path
|
1064 |
+
# file_path_s3: s3://{bucket}/{key}
|
1065 |
+
bucket_name, key = s3path_to_bucket_key(file_path_s3)
|
1066 |
+
tic = time.time()
|
1067 |
+
s3_client = init_s3(cfg.checkpoint.write_s3_config)
|
1068 |
+
|
1069 |
+
# Upload the file
|
1070 |
+
with open(file_path_local, 'rb') as f:
|
1071 |
+
s3_client.upload_fileobj(f, bucket_name, key)
|
1072 |
+
full_s3_path = f"s3://{bucket_name}/{key}"
|
1073 |
+
logger.info(f'copy file: {file_path_local} {full_s3_path} | use time: {time.time()-tic}')
|
1074 |
+
return full_s3_path
|
modules/PartField/partfield/model/PVCNN/encoder_pc.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
2 |
+
#
|
3 |
+
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
4 |
+
# and proprietary rights in and to this software, related documentation
|
5 |
+
# and any modifications thereto. Any use, reproduction, disclosure or
|
6 |
+
# distribution of this software and related documentation without an express
|
7 |
+
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
8 |
+
|
9 |
+
from ast import Dict
|
10 |
+
import math
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch_scatter import scatter_mean #, scatter_max
|
17 |
+
|
18 |
+
from .unet_3daware import setup_unet #UNetTriplane3dAware
|
19 |
+
from .conv_pointnet import ConvPointnet
|
20 |
+
|
21 |
+
from .pc_encoder import PVCNNEncoder #PointNet
|
22 |
+
|
23 |
+
import einops
|
24 |
+
|
25 |
+
from .dnnlib_util import ScopedTorchProfiler, printarr
|
26 |
+
|
27 |
+
def generate_plane_features(p, c, resolution, plane='xz'):
|
28 |
+
"""
|
29 |
+
Args:
|
30 |
+
p: (B,3,n_p)
|
31 |
+
c: (B,C,n_p)
|
32 |
+
"""
|
33 |
+
padding = 0.
|
34 |
+
c_dim = c.size(1)
|
35 |
+
# acquire indices of features in plane
|
36 |
+
xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1)
|
37 |
+
index = coordinate2index(xy, resolution)
|
38 |
+
|
39 |
+
# scatter plane features from points
|
40 |
+
fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2)
|
41 |
+
fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
|
42 |
+
fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso)
|
43 |
+
return fea_plane
|
44 |
+
|
45 |
+
def normalize_coordinate(p, padding=0.1, plane='xz'):
|
46 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments
|
47 |
+
|
48 |
+
Args:
|
49 |
+
p (tensor): point
|
50 |
+
padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
|
51 |
+
plane (str): plane feature type, ['xz', 'xy', 'yz']
|
52 |
+
'''
|
53 |
+
if plane == 'xz':
|
54 |
+
xy = p[:, :, [0, 2]]
|
55 |
+
elif plane =='xy':
|
56 |
+
xy = p[:, :, [0, 1]]
|
57 |
+
else:
|
58 |
+
xy = p[:, :, [1, 2]]
|
59 |
+
|
60 |
+
xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
|
61 |
+
xy_new = xy_new + 0.5 # range (0, 1)
|
62 |
+
|
63 |
+
# if there are outliers out of the range
|
64 |
+
if xy_new.max() >= 1:
|
65 |
+
xy_new[xy_new >= 1] = 1 - 10e-6
|
66 |
+
if xy_new.min() < 0:
|
67 |
+
xy_new[xy_new < 0] = 0.0
|
68 |
+
return xy_new
|
69 |
+
|
70 |
+
|
71 |
+
def coordinate2index(x, resolution):
|
72 |
+
''' Normalize coordinate to [0, 1] for unit cube experiments.
|
73 |
+
Corresponds to our 3D model
|
74 |
+
|
75 |
+
Args:
|
76 |
+
x (tensor): coordinate
|
77 |
+
reso (int): defined resolution
|
78 |
+
coord_type (str): coordinate type
|
79 |
+
'''
|
80 |
+
x = (x * resolution).long()
|
81 |
+
index = x[:, :, 0] + resolution * x[:, :, 1]
|
82 |
+
index = index[:, None, :]
|
83 |
+
return index
|
84 |
+
|
85 |
+
def softclip(x, min, max, hardness=5):
|
86 |
+
# Soft clipping for the logsigma
|
87 |
+
x = min + F.softplus(hardness*(x - min))/hardness
|
88 |
+
x = max - F.softplus(-hardness*(x - max))/hardness
|
89 |
+
return x
|
90 |
+
|
91 |
+
|
92 |
+
def sample_triplane_feat(feature_triplane, normalized_pos):
|
93 |
+
'''
|
94 |
+
normalized_pos [-1, 1]
|
95 |
+
'''
|
96 |
+
tri_plane = torch.unbind(feature_triplane, dim=1)
|
97 |
+
|
98 |
+
x_feat = F.grid_sample(
|
99 |
+
tri_plane[0],
|
100 |
+
torch.cat(
|
101 |
+
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
|
102 |
+
dim=-1).unsqueeze(dim=1), padding_mode='border',
|
103 |
+
align_corners=True)
|
104 |
+
y_feat = F.grid_sample(
|
105 |
+
tri_plane[1],
|
106 |
+
torch.cat(
|
107 |
+
[normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
|
108 |
+
dim=-1).unsqueeze(dim=1), padding_mode='border',
|
109 |
+
align_corners=True)
|
110 |
+
|
111 |
+
z_feat = F.grid_sample(
|
112 |
+
tri_plane[2],
|
113 |
+
torch.cat(
|
114 |
+
[normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
|
115 |
+
dim=-1).unsqueeze(dim=1), padding_mode='border',
|
116 |
+
align_corners=True)
|
117 |
+
final_feat = (x_feat + y_feat + z_feat)
|
118 |
+
final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension
|
119 |
+
return final_feat
|
120 |
+
|
121 |
+
|
122 |
+
# @persistence.persistent_class
|
123 |
+
class TriPlanePC2Encoder(torch.nn.Module):
|
124 |
+
# Encoder that encode point cloud to triplane feature vector similar to ConvOccNet
|
125 |
+
def __init__(
|
126 |
+
self,
|
127 |
+
cfg,
|
128 |
+
device='cuda',
|
129 |
+
shape_min=-1.0,
|
130 |
+
shape_length=2.0,
|
131 |
+
use_2d_feat=False,
|
132 |
+
# point_encoder='pvcnn',
|
133 |
+
# use_point_scatter=False
|
134 |
+
):
|
135 |
+
"""
|
136 |
+
Outputs latent triplane from PC input
|
137 |
+
Configs:
|
138 |
+
max_logsigma: (float) Soft clip upper range for logsigm
|
139 |
+
min_logsigma: (float)
|
140 |
+
point_encoder_type: (str) one of ['pvcnn', 'pointnet']
|
141 |
+
pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel
|
142 |
+
features (instead of scattering point features)
|
143 |
+
unet_cfg: (dict)
|
144 |
+
z_triplane_channels: (int) output latent triplane
|
145 |
+
z_triplane_resolution: (int)
|
146 |
+
Args:
|
147 |
+
|
148 |
+
"""
|
149 |
+
# assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
|
150 |
+
super().__init__()
|
151 |
+
self.device = device
|
152 |
+
|
153 |
+
self.cfg = cfg
|
154 |
+
|
155 |
+
self.shape_min = shape_min
|
156 |
+
self.shape_length = shape_length
|
157 |
+
|
158 |
+
self.z_triplane_resolution = cfg.z_triplane_resolution
|
159 |
+
z_triplane_channels = cfg.z_triplane_channels
|
160 |
+
|
161 |
+
point_encoder_out_dim = z_triplane_channels #* 2
|
162 |
+
|
163 |
+
in_channels = 6
|
164 |
+
# self.resample_filter=[1, 3, 3, 1]
|
165 |
+
if cfg.point_encoder_type == 'pvcnn':
|
166 |
+
self.pc_encoder = PVCNNEncoder(point_encoder_out_dim,
|
167 |
+
device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector.
|
168 |
+
elif cfg.point_encoder_type == 'pointnet':
|
169 |
+
# TODO the pointnet was buggy, investigate
|
170 |
+
self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim,
|
171 |
+
dim=in_channels, hidden_dim=32,
|
172 |
+
plane_resolution=self.z_triplane_resolution,
|
173 |
+
padding=0)
|
174 |
+
else:
|
175 |
+
raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented")
|
176 |
+
|
177 |
+
if cfg.unet_cfg.enabled:
|
178 |
+
self.unet_encoder = setup_unet(
|
179 |
+
output_channels=point_encoder_out_dim,
|
180 |
+
input_channels=point_encoder_out_dim,
|
181 |
+
unet_cfg=cfg.unet_cfg)
|
182 |
+
else:
|
183 |
+
self.unet_encoder = None
|
184 |
+
|
185 |
+
# @ScopedTorchProfiler('encode')
|
186 |
+
def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict:
|
187 |
+
# output = AttrDict()
|
188 |
+
point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1]
|
189 |
+
point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5]
|
190 |
+
point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1)
|
191 |
+
|
192 |
+
if self.cfg.point_encoder_type == 'pvcnn':
|
193 |
+
if mv_feat is not None:
|
194 |
+
pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx)
|
195 |
+
else:
|
196 |
+
pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32
|
197 |
+
if self.cfg.use_point_scatter:
|
198 |
+
# Scattering from PVCNN point features
|
199 |
+
points_feat_ = points_feat[0]
|
200 |
+
# shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
|
201 |
+
pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_,
|
202 |
+
resolution=self.z_triplane_resolution, plane='xy')
|
203 |
+
pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_,
|
204 |
+
resolution=self.z_triplane_resolution, plane='yz')
|
205 |
+
pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_,
|
206 |
+
resolution=self.z_triplane_resolution, plane='xz')
|
207 |
+
pc_feat = pc_feat[0]
|
208 |
+
|
209 |
+
else:
|
210 |
+
pc_feat = pc_feat[0]
|
211 |
+
sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim
|
212 |
+
|
213 |
+
pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane
|
214 |
+
pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane
|
215 |
+
pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane
|
216 |
+
|
217 |
+
# nearest upsample
|
218 |
+
pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf)
|
219 |
+
pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
|
220 |
+
pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
|
221 |
+
elif self.cfg.point_encoder_type == 'pointnet':
|
222 |
+
assert self.cfg.use_point_scatter
|
223 |
+
# Run ConvPointnet
|
224 |
+
pc_feat = self.pc_encoder(point_cloud)
|
225 |
+
pc_feat_1 = pc_feat['xy'] #
|
226 |
+
pc_feat_2 = pc_feat['yz']
|
227 |
+
pc_feat_3 = pc_feat['xz']
|
228 |
+
else:
|
229 |
+
raise NotImplementedError()
|
230 |
+
|
231 |
+
if self.unet_encoder is not None:
|
232 |
+
# TODO eval adding a skip connection
|
233 |
+
# Unet expects B, 3, C, H, W
|
234 |
+
pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
|
235 |
+
# dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
|
236 |
+
# pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack
|
237 |
+
pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
|
238 |
+
pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1)
|
239 |
+
|
240 |
+
return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
|
241 |
+
|
242 |
+
def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None):
|
243 |
+
return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)
|
modules/PartField/partfield/model/PVCNN/pc_encoder.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import functools
|
5 |
+
|
6 |
+
from .pv_module import SharedMLP, PVConv
|
7 |
+
|
8 |
+
def create_pointnet_components(
|
9 |
+
blocks, in_channels, with_se=False, normalize=True, eps=0,
|
10 |
+
width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'):
|
11 |
+
r, vr = width_multiplier, voxel_resolution_multiplier
|
12 |
+
layers, concat_channels = [], 0
|
13 |
+
for out_channels, num_blocks, voxel_resolution in blocks:
|
14 |
+
out_channels = int(r * out_channels)
|
15 |
+
if voxel_resolution is None:
|
16 |
+
block = functools.partial(SharedMLP, device=device)
|
17 |
+
else:
|
18 |
+
block = functools.partial(
|
19 |
+
PVConv, kernel_size=3, resolution=int(vr * voxel_resolution),
|
20 |
+
with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device)
|
21 |
+
for _ in range(num_blocks):
|
22 |
+
layers.append(block(in_channels, out_channels))
|
23 |
+
in_channels = out_channels
|
24 |
+
concat_channels += out_channels
|
25 |
+
return layers, in_channels, concat_channels
|
26 |
+
|
27 |
+
class PCMerger(nn.Module):
|
28 |
+
# merge surface sampled PC and rendering backprojected PC (w/ 2D features):
|
29 |
+
def __init__(self, in_channels=204, device="cuda"):
|
30 |
+
super(PCMerger, self).__init__()
|
31 |
+
self.mlp_normal = SharedMLP(3, [128, 128], device=device)
|
32 |
+
self.mlp_rgb = SharedMLP(3, [128, 128], device=device)
|
33 |
+
self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device)
|
34 |
+
|
35 |
+
def forward(self, feat, mv_feat, pc2pc_idx):
|
36 |
+
mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :])
|
37 |
+
mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :])
|
38 |
+
mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :])
|
39 |
+
|
40 |
+
mv_feat_normal = mv_feat_normal.permute(0, 2, 1)
|
41 |
+
mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1)
|
42 |
+
mv_feat_sam = mv_feat_sam.permute(0, 2, 1)
|
43 |
+
feat = feat.permute(0, 2, 1)
|
44 |
+
|
45 |
+
for i in range(mv_feat.shape[0]):
|
46 |
+
mask = (pc2pc_idx[i] != -1).reshape(-1)
|
47 |
+
idx = pc2pc_idx[i][mask].reshape(-1)
|
48 |
+
feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx]
|
49 |
+
|
50 |
+
return feat.permute(0, 2, 1)
|
51 |
+
|
52 |
+
|
53 |
+
class PVCNNEncoder(nn.Module):
|
54 |
+
def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False):
|
55 |
+
super(PVCNNEncoder, self).__init__()
|
56 |
+
self.device = device
|
57 |
+
self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8))
|
58 |
+
self.use_2d_feat=use_2d_feat
|
59 |
+
if in_channels == 6:
|
60 |
+
self.append_channel = 2
|
61 |
+
elif in_channels == 3:
|
62 |
+
self.append_channel = 1
|
63 |
+
else:
|
64 |
+
raise NotImplementedError
|
65 |
+
layers, channels_point, concat_channels_point = create_pointnet_components(
|
66 |
+
blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False,
|
67 |
+
width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True,
|
68 |
+
device=device
|
69 |
+
)
|
70 |
+
self.encoder = nn.ModuleList(layers)#.to(self.device)
|
71 |
+
if self.use_2d_feat:
|
72 |
+
self.merger = PCMerger()
|
73 |
+
|
74 |
+
|
75 |
+
|
76 |
+
def forward(self, input_pc, mv_feat=None, pc2pc_idx=None):
|
77 |
+
features = input_pc.permute(0, 2, 1) * 2 # make point cloud [-1, 1]
|
78 |
+
coords = features[:, :3, :]
|
79 |
+
out_features_list = []
|
80 |
+
voxel_feature_list = []
|
81 |
+
zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=features.dtype)
|
82 |
+
features = torch.cat([features, zero_padding], dim=1)##################
|
83 |
+
|
84 |
+
for i in range(len(self.encoder)):
|
85 |
+
features, _, voxel_feature = self.encoder[i]((features, coords))
|
86 |
+
if i == 0 and mv_feat is not None:
|
87 |
+
features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx)
|
88 |
+
out_features_list.append(features)
|
89 |
+
voxel_feature_list.append(voxel_feature)
|
90 |
+
return voxel_feature_list, out_features_list
|
modules/PartField/partfield/model/PVCNN/pv_module/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .pvconv import PVConv
|
2 |
+
from .shared_mlp import SharedMLP
|
modules/PartField/partfield/model/PVCNN/pv_module/ball_query.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import functional as F
|
5 |
+
|
6 |
+
__all__ = ['BallQuery']
|
7 |
+
|
8 |
+
|
9 |
+
class BallQuery(nn.Module):
|
10 |
+
def __init__(self, radius, num_neighbors, include_coordinates=True):
|
11 |
+
super().__init__()
|
12 |
+
self.radius = radius
|
13 |
+
self.num_neighbors = num_neighbors
|
14 |
+
self.include_coordinates = include_coordinates
|
15 |
+
|
16 |
+
def forward(self, points_coords, centers_coords, points_features=None):
|
17 |
+
points_coords = points_coords.contiguous()
|
18 |
+
centers_coords = centers_coords.contiguous()
|
19 |
+
neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
|
20 |
+
neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
|
21 |
+
neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
|
22 |
+
|
23 |
+
if points_features is None:
|
24 |
+
assert self.include_coordinates, 'No Features For Grouping'
|
25 |
+
neighbor_features = neighbor_coordinates
|
26 |
+
else:
|
27 |
+
neighbor_features = F.grouping(points_features, neighbor_indices)
|
28 |
+
if self.include_coordinates:
|
29 |
+
neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
|
30 |
+
return neighbor_features
|
31 |
+
|
32 |
+
def extra_repr(self):
|
33 |
+
return 'radius={}, num_neighbors={}{}'.format(
|
34 |
+
self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
|
modules/PartField/partfield/model/PVCNN/pv_module/frustum.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from . import functional as PF
|
7 |
+
|
8 |
+
__all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
|
9 |
+
|
10 |
+
|
11 |
+
class FrustumPointNetLoss(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
|
14 |
+
corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
|
15 |
+
super().__init__()
|
16 |
+
self.box_loss_weight = box_loss_weight
|
17 |
+
self.corners_loss_weight = corners_loss_weight
|
18 |
+
self.heading_residual_loss_weight = heading_residual_loss_weight
|
19 |
+
self.size_residual_loss_weight = size_residual_loss_weight
|
20 |
+
|
21 |
+
self.num_heading_angle_bins = num_heading_angle_bins
|
22 |
+
self.num_size_templates = num_size_templates
|
23 |
+
self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
|
24 |
+
self.register_buffer(
|
25 |
+
'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
|
26 |
+
)
|
27 |
+
|
28 |
+
def forward(self, inputs, targets):
|
29 |
+
mask_logits = inputs['mask_logits'] # (B, 2, N)
|
30 |
+
center_reg = inputs['center_reg'] # (B, 3)
|
31 |
+
center = inputs['center'] # (B, 3)
|
32 |
+
heading_scores = inputs['heading_scores'] # (B, NH)
|
33 |
+
heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
|
34 |
+
heading_residuals = inputs['heading_residuals'] # (B, NH)
|
35 |
+
size_scores = inputs['size_scores'] # (B, NS)
|
36 |
+
size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
|
37 |
+
size_residuals = inputs['size_residuals'] # (B, NS, 3)
|
38 |
+
|
39 |
+
mask_logits_target = targets['mask_logits'] # (B, N)
|
40 |
+
center_target = targets['center'] # (B, 3)
|
41 |
+
heading_bin_id_target = targets['heading_bin_id'] # (B, )
|
42 |
+
heading_residual_target = targets['heading_residual'] # (B, )
|
43 |
+
size_template_id_target = targets['size_template_id'] # (B, )
|
44 |
+
size_residual_target = targets['size_residual'] # (B, 3)
|
45 |
+
|
46 |
+
batch_size = center.size(0)
|
47 |
+
batch_id = torch.arange(batch_size, device=center.device)
|
48 |
+
|
49 |
+
# Basic Classification and Regression losses
|
50 |
+
mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
|
51 |
+
heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
|
52 |
+
size_loss = F.cross_entropy(size_scores, size_template_id_target)
|
53 |
+
center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
|
54 |
+
center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
|
55 |
+
|
56 |
+
# Refinement losses for size/heading
|
57 |
+
heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
|
58 |
+
heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
|
59 |
+
heading_residual_normalized_loss = PF.huber_loss(
|
60 |
+
heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
|
61 |
+
)
|
62 |
+
size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
|
63 |
+
size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
|
64 |
+
size_residual_normalized_loss = PF.huber_loss(
|
65 |
+
torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
|
66 |
+
)
|
67 |
+
|
68 |
+
# Bounding box losses
|
69 |
+
heading = (heading_residuals[batch_id, heading_bin_id_target]
|
70 |
+
+ self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
|
71 |
+
# Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
|
72 |
+
size = (size_residuals[batch_id, size_template_id_target]
|
73 |
+
+ self.size_templates[size_template_id_target]) # (B, 3)
|
74 |
+
corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
|
75 |
+
heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
|
76 |
+
size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
|
77 |
+
corners_target, corners_target_flip = get_box_corners_3d(
|
78 |
+
centers=center_target, headings=heading_target,
|
79 |
+
sizes=size_target, with_flip=True) # (B, 3, 8)
|
80 |
+
corners_loss = PF.huber_loss(
|
81 |
+
torch.min(
|
82 |
+
torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
|
83 |
+
), delta=1.0)
|
84 |
+
# Summing up
|
85 |
+
loss = mask_loss + self.box_loss_weight * (
|
86 |
+
center_loss + center_reg_loss + heading_loss + size_loss
|
87 |
+
+ self.heading_residual_loss_weight * heading_residual_normalized_loss
|
88 |
+
+ self.size_residual_loss_weight * size_residual_normalized_loss
|
89 |
+
+ self.corners_loss_weight * corners_loss
|
90 |
+
)
|
91 |
+
|
92 |
+
return loss
|
93 |
+
|
94 |
+
|
95 |
+
def get_box_corners_3d(centers, headings, sizes, with_flip=False):
|
96 |
+
"""
|
97 |
+
:param centers: coords of box centers, FloatTensor[N, 3]
|
98 |
+
:param headings: heading angles, FloatTensor[N, ]
|
99 |
+
:param sizes: box sizes, FloatTensor[N, 3]
|
100 |
+
:param with_flip: bool, whether to return flipped box (headings + np.pi)
|
101 |
+
:return:
|
102 |
+
coords of box corners, FloatTensor[N, 3, 8]
|
103 |
+
NOTE: corner points are in counter clockwise order, e.g.,
|
104 |
+
2--1
|
105 |
+
3--0 5
|
106 |
+
7--4
|
107 |
+
"""
|
108 |
+
l = sizes[:, 0] # (N,)
|
109 |
+
w = sizes[:, 1] # (N,)
|
110 |
+
h = sizes[:, 2] # (N,)
|
111 |
+
x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
|
112 |
+
y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
|
113 |
+
z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
|
114 |
+
|
115 |
+
c = torch.cos(headings) # (N,)
|
116 |
+
s = torch.sin(headings) # (N,)
|
117 |
+
o = torch.ones_like(headings) # (N,)
|
118 |
+
z = torch.zeros_like(headings) # (N,)
|
119 |
+
|
120 |
+
centers = centers.unsqueeze(-1) # (B, 3, 1)
|
121 |
+
corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
122 |
+
R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
|
123 |
+
if with_flip:
|
124 |
+
R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
|
125 |
+
return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
|
126 |
+
else:
|
127 |
+
return torch.matmul(R, corners) + centers
|
128 |
+
|
129 |
+
# centers = centers.unsqueeze(1) # (B, 1, 3)
|
130 |
+
# corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
|
131 |
+
# RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
132 |
+
# if with_flip:
|
133 |
+
# RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
134 |
+
# return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
|
135 |
+
# else:
|
136 |
+
# return torch.matmul(corners, RT) + centers # (N, 8, 3)
|
137 |
+
|
138 |
+
# corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
|
139 |
+
# R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
|
140 |
+
# corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
|
141 |
+
# corners = corners.transpose(1, 2) # (N, 8, 3)
|
modules/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .devoxelization import trilinear_devoxelize
|
modules/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.autograd import Function
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
|
5 |
+
__all__ = ['trilinear_devoxelize']
|
6 |
+
|
7 |
+
def trilinear_devoxelize(c, coords, r, training=None):
|
8 |
+
coords = (coords * 2 + 1.0) / r - 1.0
|
9 |
+
coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3)
|
10 |
+
f = F.grid_sample(input=c, grid=coords, padding_mode='border', align_corners=False)
|
11 |
+
f = f.squeeze(dim=2).squeeze(dim=2)
|
12 |
+
return f
|
modules/PartField/partfield/model/PVCNN/pv_module/loss.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from . import functional as F
|
4 |
+
|
5 |
+
__all__ = ['KLLoss']
|
6 |
+
|
7 |
+
|
8 |
+
class KLLoss(nn.Module):
|
9 |
+
def forward(self, x, y):
|
10 |
+
return F.kl_loss(x, y)
|
modules/PartField/partfield/model/PVCNN/pv_module/pointnet.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import functional as F
|
5 |
+
from .ball_query import BallQuery
|
6 |
+
from .shared_mlp import SharedMLP
|
7 |
+
|
8 |
+
__all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
|
9 |
+
|
10 |
+
|
11 |
+
class PointNetAModule(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, include_coordinates=True):
|
13 |
+
super().__init__()
|
14 |
+
if not isinstance(out_channels, (list, tuple)):
|
15 |
+
out_channels = [[out_channels]]
|
16 |
+
elif not isinstance(out_channels[0], (list, tuple)):
|
17 |
+
out_channels = [out_channels]
|
18 |
+
|
19 |
+
mlps = []
|
20 |
+
total_out_channels = 0
|
21 |
+
for _out_channels in out_channels:
|
22 |
+
mlps.append(
|
23 |
+
SharedMLP(
|
24 |
+
in_channels=in_channels + (3 if include_coordinates else 0),
|
25 |
+
out_channels=_out_channels, dim=1)
|
26 |
+
)
|
27 |
+
total_out_channels += _out_channels[-1]
|
28 |
+
|
29 |
+
self.include_coordinates = include_coordinates
|
30 |
+
self.out_channels = total_out_channels
|
31 |
+
self.mlps = nn.ModuleList(mlps)
|
32 |
+
|
33 |
+
def forward(self, inputs):
|
34 |
+
features, coords = inputs
|
35 |
+
if self.include_coordinates:
|
36 |
+
features = torch.cat([features, coords], dim=1)
|
37 |
+
coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
|
38 |
+
if len(self.mlps) > 1:
|
39 |
+
features_list = []
|
40 |
+
for mlp in self.mlps:
|
41 |
+
features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
|
42 |
+
return torch.cat(features_list, dim=1), coords
|
43 |
+
else:
|
44 |
+
return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
|
45 |
+
|
46 |
+
def extra_repr(self):
|
47 |
+
return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
|
48 |
+
|
49 |
+
|
50 |
+
class PointNetSAModule(nn.Module):
|
51 |
+
def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
|
52 |
+
super().__init__()
|
53 |
+
if not isinstance(radius, (list, tuple)):
|
54 |
+
radius = [radius]
|
55 |
+
if not isinstance(num_neighbors, (list, tuple)):
|
56 |
+
num_neighbors = [num_neighbors] * len(radius)
|
57 |
+
assert len(radius) == len(num_neighbors)
|
58 |
+
if not isinstance(out_channels, (list, tuple)):
|
59 |
+
out_channels = [[out_channels]] * len(radius)
|
60 |
+
elif not isinstance(out_channels[0], (list, tuple)):
|
61 |
+
out_channels = [out_channels] * len(radius)
|
62 |
+
assert len(radius) == len(out_channels)
|
63 |
+
|
64 |
+
groupers, mlps = [], []
|
65 |
+
total_out_channels = 0
|
66 |
+
for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
|
67 |
+
groupers.append(
|
68 |
+
BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
|
69 |
+
)
|
70 |
+
mlps.append(
|
71 |
+
SharedMLP(
|
72 |
+
in_channels=in_channels + (3 if include_coordinates else 0),
|
73 |
+
out_channels=_out_channels, dim=2)
|
74 |
+
)
|
75 |
+
total_out_channels += _out_channels[-1]
|
76 |
+
|
77 |
+
self.num_centers = num_centers
|
78 |
+
self.out_channels = total_out_channels
|
79 |
+
self.groupers = nn.ModuleList(groupers)
|
80 |
+
self.mlps = nn.ModuleList(mlps)
|
81 |
+
|
82 |
+
def forward(self, inputs):
|
83 |
+
features, coords = inputs
|
84 |
+
centers_coords = F.furthest_point_sample(coords, self.num_centers)
|
85 |
+
features_list = []
|
86 |
+
for grouper, mlp in zip(self.groupers, self.mlps):
|
87 |
+
features_list.append(mlp(grouper(coords, centers_coords, features)).max(dim=-1).values)
|
88 |
+
if len(features_list) > 1:
|
89 |
+
return torch.cat(features_list, dim=1), centers_coords
|
90 |
+
else:
|
91 |
+
return features_list[0], centers_coords
|
92 |
+
|
93 |
+
def extra_repr(self):
|
94 |
+
return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
|
95 |
+
|
96 |
+
|
97 |
+
class PointNetFPModule(nn.Module):
|
98 |
+
def __init__(self, in_channels, out_channels):
|
99 |
+
super().__init__()
|
100 |
+
self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
|
101 |
+
|
102 |
+
def forward(self, inputs):
|
103 |
+
if len(inputs) == 3:
|
104 |
+
points_coords, centers_coords, centers_features = inputs
|
105 |
+
points_features = None
|
106 |
+
else:
|
107 |
+
points_coords, centers_coords, centers_features, points_features = inputs
|
108 |
+
interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
|
109 |
+
if points_features is not None:
|
110 |
+
interpolated_features = torch.cat(
|
111 |
+
[interpolated_features, points_features], dim=1
|
112 |
+
)
|
113 |
+
return self.mlp(interpolated_features), points_coords
|
modules/PartField/partfield/model/PVCNN/pv_module/pvconv.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
from . import functional as F
|
4 |
+
from .voxelization import Voxelization
|
5 |
+
from .shared_mlp import SharedMLP
|
6 |
+
import torch
|
7 |
+
|
8 |
+
__all__ = ['PVConv']
|
9 |
+
|
10 |
+
|
11 |
+
class PVConv(nn.Module):
|
12 |
+
def __init__(
|
13 |
+
self, in_channels, out_channels, kernel_size, resolution, with_se=False, normalize=True, eps=0, scale_pvcnn=False,
|
14 |
+
device='cuda'):
|
15 |
+
super().__init__()
|
16 |
+
self.in_channels = in_channels
|
17 |
+
self.out_channels = out_channels
|
18 |
+
self.kernel_size = kernel_size
|
19 |
+
self.resolution = resolution
|
20 |
+
self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn)
|
21 |
+
voxel_layers = [
|
22 |
+
nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
|
23 |
+
nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
|
24 |
+
nn.LeakyReLU(0.1, True),
|
25 |
+
nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
|
26 |
+
nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
|
27 |
+
nn.LeakyReLU(0.1, True),
|
28 |
+
]
|
29 |
+
self.voxel_layers = nn.Sequential(*voxel_layers)
|
30 |
+
self.point_features = SharedMLP(in_channels, out_channels, device=device)
|
31 |
+
|
32 |
+
def forward(self, inputs):
|
33 |
+
features, coords = inputs
|
34 |
+
voxel_features, voxel_coords = self.voxelization(features, coords)
|
35 |
+
voxel_features = self.voxel_layers(voxel_features)
|
36 |
+
devoxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
|
37 |
+
fused_features = devoxel_features + self.point_features(features)
|
38 |
+
return fused_features, coords, voxel_features
|
modules/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
|
3 |
+
__all__ = ['SharedMLP']
|
4 |
+
|
5 |
+
|
6 |
+
class SharedMLP(nn.Module):
|
7 |
+
def __init__(self, in_channels, out_channels, dim=1, device='cuda'):
|
8 |
+
super().__init__()
|
9 |
+
# print('==> SharedMLP device: ', device)
|
10 |
+
if dim == 1:
|
11 |
+
conv = nn.Conv1d
|
12 |
+
bn = nn.InstanceNorm1d
|
13 |
+
elif dim == 2:
|
14 |
+
conv = nn.Conv2d
|
15 |
+
bn = nn.InstanceNorm1d
|
16 |
+
else:
|
17 |
+
raise ValueError
|
18 |
+
if not isinstance(out_channels, (list, tuple)):
|
19 |
+
out_channels = [out_channels]
|
20 |
+
layers = []
|
21 |
+
for oc in out_channels:
|
22 |
+
layers.extend(
|
23 |
+
[
|
24 |
+
conv(in_channels, oc, 1, device=device),
|
25 |
+
bn(oc, device=device),
|
26 |
+
nn.ReLU(True),
|
27 |
+
])
|
28 |
+
in_channels = oc
|
29 |
+
self.layers = nn.Sequential(*layers)
|
30 |
+
|
31 |
+
def forward(self, inputs):
|
32 |
+
if isinstance(inputs, (list, tuple)):
|
33 |
+
return (self.layers(inputs[0]), *inputs[1:])
|
34 |
+
else:
|
35 |
+
return self.layers(inputs)
|
modules/PartField/partfield/model/PVCNN/pv_module/voxelization.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from . import functional as F
|
5 |
+
|
6 |
+
__all__ = ['Voxelization']
|
7 |
+
|
8 |
+
|
9 |
+
def my_voxelization(features, coords, resolution):
|
10 |
+
b, c, _ = features.shape
|
11 |
+
result = torch.zeros(b, c + 1, resolution * resolution * resolution, device=features.device, dtype=features.dtype)
|
12 |
+
r = resolution
|
13 |
+
r2 = resolution * resolution
|
14 |
+
coords = coords.long()
|
15 |
+
indices = coords[:, 0] * r2 + coords[:, 1] * r + coords[:, 2]
|
16 |
+
|
17 |
+
# print(r, r2, coords[:, 0].max(), coords[:, 1].max(), coords[:, 2].max())
|
18 |
+
|
19 |
+
# print(f"Resolution: {resolution}")
|
20 |
+
# print(f"Coords shape: {coords.shape}")
|
21 |
+
# print(f"Coords max per dim: x={coords[:, 0].max()}, y={coords[:, 1].max()}, z={coords[:, 2].max()}")
|
22 |
+
# print(f"Coords min per dim: x={coords[:, 0].min()}, y={coords[:, 1].min()}, z={coords[:, 2].min()}")
|
23 |
+
# print(f"Indices shape: {indices.shape}")
|
24 |
+
# print(f"Indices max: {indices.max()}, min: {indices.min()}")
|
25 |
+
# print(f"Expected max index: {resolution * resolution * resolution - 1}")
|
26 |
+
|
27 |
+
# # 检查是否有越界的索引
|
28 |
+
# max_valid_index = resolution * resolution * resolution - 1
|
29 |
+
# invalid_mask = (indices > max_valid_index) | (indices < 0)
|
30 |
+
# if invalid_mask.any():
|
31 |
+
# print(f"Found {invalid_mask.sum()} invalid indices!")
|
32 |
+
# print(f"Invalid indices: {indices[invalid_mask]}")
|
33 |
+
# # 找到对应的坐标
|
34 |
+
# invalid_coords = coords[:, :, invalid_mask.any(dim=0)]
|
35 |
+
# print(f"Invalid coords shape: {invalid_coords.shape}")
|
36 |
+
# if invalid_coords.numel() > 0:
|
37 |
+
# print(f"Sample invalid coords: {invalid_coords[:, :, :5]}") # 显示前5个无效坐标
|
38 |
+
|
39 |
+
indices = indices.unsqueeze(dim=1).expand(-1, result.shape[1], -1)
|
40 |
+
features = torch.cat([features, torch.ones(features.shape[0], 1, features.shape[2], device=features.device, dtype=features.dtype)], dim=1)
|
41 |
+
out_feature = result.scatter_(index=indices.long(), src=features, dim=2, reduce='add')
|
42 |
+
cnt = out_feature[:, -1:, :]
|
43 |
+
zero_mask = (cnt == 0).to(features.dtype)
|
44 |
+
cnt = cnt * (1 - zero_mask) + zero_mask * 1e-5
|
45 |
+
vox_feature = out_feature[:, :-1, :] / cnt
|
46 |
+
return vox_feature.view(b, c, resolution, resolution, resolution)
|
47 |
+
|
48 |
+
class Voxelization(nn.Module):
|
49 |
+
def __init__(self, resolution, normalize=True, eps=0, scale_pvcnn=False):
|
50 |
+
super().__init__()
|
51 |
+
self.r = int(resolution)
|
52 |
+
self.normalize = normalize
|
53 |
+
self.eps = eps
|
54 |
+
self.scale_pvcnn = scale_pvcnn
|
55 |
+
assert not normalize
|
56 |
+
|
57 |
+
def forward(self, features, coords):
|
58 |
+
# import pdb; pdb.set_trace()
|
59 |
+
with torch.no_grad():
|
60 |
+
coords = coords.detach()
|
61 |
+
|
62 |
+
if self.normalize:
|
63 |
+
norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
|
64 |
+
else:
|
65 |
+
if self.scale_pvcnn:
|
66 |
+
norm_coords = (coords + 1) / 2.0 # [0, 1]
|
67 |
+
# print(norm_coords.shape, norm_coords.max(), norm_coords.min())
|
68 |
+
else:
|
69 |
+
# norm_coords = (norm_coords + 1) / 2.0
|
70 |
+
norm_coords = (coords + 1) / 2.0
|
71 |
+
norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
|
72 |
+
# print(norm_coords.shape, norm_coords.max(), norm_coords.min())
|
73 |
+
vox_coords = torch.round(norm_coords)
|
74 |
+
# print(vox_coords.shape, vox_coords.max(), vox_coords.min())
|
75 |
+
# print(features.shape)
|
76 |
+
new_vox_feat = my_voxelization(features, vox_coords, self.r)
|
77 |
+
return new_vox_feat, norm_coords
|
78 |
+
|
79 |
+
def extra_repr(self):
|
80 |
+
return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
|
modules/PartField/partfield/model/PVCNN/unet_3daware.py
ADDED
@@ -0,0 +1,427 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch.nn import init
|
7 |
+
|
8 |
+
import einops
|
9 |
+
|
10 |
+
def conv3x3(in_channels, out_channels, stride=1,
|
11 |
+
padding=1, bias=True, groups=1):
|
12 |
+
return nn.Conv2d(
|
13 |
+
in_channels,
|
14 |
+
out_channels,
|
15 |
+
kernel_size=3,
|
16 |
+
stride=stride,
|
17 |
+
padding=padding,
|
18 |
+
bias=bias,
|
19 |
+
groups=groups)
|
20 |
+
|
21 |
+
def upconv2x2(in_channels, out_channels, mode='transpose'):
|
22 |
+
if mode == 'transpose':
|
23 |
+
return nn.ConvTranspose2d(
|
24 |
+
in_channels,
|
25 |
+
out_channels,
|
26 |
+
kernel_size=2,
|
27 |
+
stride=2)
|
28 |
+
else:
|
29 |
+
# out_channels is always going to be the same
|
30 |
+
# as in_channels
|
31 |
+
return nn.Sequential(
|
32 |
+
nn.Upsample(mode='bilinear', scale_factor=2),
|
33 |
+
conv1x1(in_channels, out_channels))
|
34 |
+
|
35 |
+
def conv1x1(in_channels, out_channels, groups=1):
|
36 |
+
return nn.Conv2d(
|
37 |
+
in_channels,
|
38 |
+
out_channels,
|
39 |
+
kernel_size=1,
|
40 |
+
groups=groups,
|
41 |
+
stride=1)
|
42 |
+
|
43 |
+
class ConvTriplane3dAware(nn.Module):
|
44 |
+
""" 3D aware triplane conv (as described in RODIN) """
|
45 |
+
def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'):
|
46 |
+
"""
|
47 |
+
Args:
|
48 |
+
internal_conv_f: function that should return a 2D convolution Module
|
49 |
+
given in and out channels
|
50 |
+
order: if triplane input is in 'xz' order
|
51 |
+
"""
|
52 |
+
super(ConvTriplane3dAware, self).__init__()
|
53 |
+
# Need 3 seperate convolutions
|
54 |
+
self.in_channels = in_channels
|
55 |
+
self.out_channels = out_channels
|
56 |
+
assert order in ['xz', 'zx']
|
57 |
+
self.order = order
|
58 |
+
# Going to stack from other planes
|
59 |
+
self.plane_convs = nn.ModuleList([
|
60 |
+
internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)])
|
61 |
+
|
62 |
+
def forward(self, triplanes_list):
|
63 |
+
"""
|
64 |
+
Args:
|
65 |
+
triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order
|
66 |
+
Returns:
|
67 |
+
out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order
|
68 |
+
"""
|
69 |
+
inps = list(triplanes_list)
|
70 |
+
xp = 1 #(yz)
|
71 |
+
yp = 2 #(zx)
|
72 |
+
zp = 0 #(xy)
|
73 |
+
|
74 |
+
if self.order == 'xz':
|
75 |
+
# get into zx order
|
76 |
+
inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x')
|
77 |
+
|
78 |
+
|
79 |
+
oplanes = [None]*3
|
80 |
+
# order shouldn't matter
|
81 |
+
for iplane in [zp, xp, yp]:
|
82 |
+
# i_plane -> (j,k)
|
83 |
+
|
84 |
+
# need to average out i and convert to (j,k)
|
85 |
+
# j_plane -> (k,i)
|
86 |
+
# k_plane -> (i,j)
|
87 |
+
jplane = (iplane+1)%3
|
88 |
+
kplane = (iplane+2)%3
|
89 |
+
|
90 |
+
ifeat = inps[iplane]
|
91 |
+
# need to average out nonshared dim
|
92 |
+
# Average pool across
|
93 |
+
|
94 |
+
# j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k)
|
95 |
+
# b c k i -> b c k 1
|
96 |
+
jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True)
|
97 |
+
jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k')
|
98 |
+
jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2))
|
99 |
+
|
100 |
+
# k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k)
|
101 |
+
# b c i j -> b c 1 j
|
102 |
+
kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True)
|
103 |
+
kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1')
|
104 |
+
kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3))
|
105 |
+
|
106 |
+
# b c h w
|
107 |
+
# jpool = jpool.expand_as(ifeat)
|
108 |
+
# kpool = kpool.expand_as(ifeat)
|
109 |
+
|
110 |
+
# concat and conv on feature dim
|
111 |
+
catfeat = torch.cat([ifeat, jpool, kpool], dim=1)
|
112 |
+
oplane = self.plane_convs[iplane](catfeat)
|
113 |
+
oplanes[iplane] = oplane
|
114 |
+
|
115 |
+
if self.order == 'xz':
|
116 |
+
# get back into xz order
|
117 |
+
oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z')
|
118 |
+
|
119 |
+
return oplanes
|
120 |
+
|
121 |
+
def roll_triplanes(triplanes_list):
|
122 |
+
# B, C, tri, h, w
|
123 |
+
tristack = torch.stack((triplanes_list),dim=2)
|
124 |
+
return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3)
|
125 |
+
|
126 |
+
def unroll_triplanes(rolled_triplane):
|
127 |
+
# B, C, tri*h, w
|
128 |
+
tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3)
|
129 |
+
return torch.unbind(tristack, dim=2)
|
130 |
+
|
131 |
+
def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs):
|
132 |
+
return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs),
|
133 |
+
in_channels, out_channels,order=order)
|
134 |
+
|
135 |
+
def Normalize(in_channels, num_groups=32):
|
136 |
+
num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32
|
137 |
+
return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
138 |
+
|
139 |
+
def nonlinearity(x):
|
140 |
+
# return F.relu(x)
|
141 |
+
# Swish
|
142 |
+
return x*torch.sigmoid(x)
|
143 |
+
|
144 |
+
class Upsample(nn.Module):
|
145 |
+
def __init__(self, in_channels, with_conv):
|
146 |
+
super().__init__()
|
147 |
+
self.with_conv = with_conv
|
148 |
+
if self.with_conv:
|
149 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
150 |
+
in_channels,
|
151 |
+
kernel_size=3,
|
152 |
+
stride=1,
|
153 |
+
padding=1)
|
154 |
+
|
155 |
+
def forward(self, x):
|
156 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
157 |
+
if self.with_conv:
|
158 |
+
x = self.conv(x)
|
159 |
+
return x
|
160 |
+
|
161 |
+
class Downsample(nn.Module):
|
162 |
+
def __init__(self, in_channels, with_conv):
|
163 |
+
super().__init__()
|
164 |
+
self.with_conv = with_conv
|
165 |
+
if self.with_conv:
|
166 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
167 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
168 |
+
in_channels,
|
169 |
+
kernel_size=3,
|
170 |
+
stride=2,
|
171 |
+
padding=0)
|
172 |
+
|
173 |
+
def forward(self, x):
|
174 |
+
if self.with_conv:
|
175 |
+
pad = (0,1,0,1)
|
176 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
177 |
+
x = self.conv(x)
|
178 |
+
else:
|
179 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
180 |
+
return x
|
181 |
+
|
182 |
+
class ResnetBlock3dAware(nn.Module):
|
183 |
+
def __init__(self, in_channels, out_channels=None):
|
184 |
+
#, conv_shortcut=False):
|
185 |
+
super().__init__()
|
186 |
+
self.in_channels = in_channels
|
187 |
+
out_channels = in_channels if out_channels is None else out_channels
|
188 |
+
self.out_channels = out_channels
|
189 |
+
# self.use_conv_shortcut = conv_shortcut
|
190 |
+
|
191 |
+
self.norm1 = Normalize(in_channels)
|
192 |
+
self.conv1 = conv3x3(self.in_channels, self.out_channels)
|
193 |
+
|
194 |
+
self.norm_mid = Normalize(out_channels)
|
195 |
+
self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels)
|
196 |
+
|
197 |
+
self.norm2 = Normalize(out_channels)
|
198 |
+
self.conv2 = conv3x3(self.out_channels, self.out_channels)
|
199 |
+
|
200 |
+
if self.in_channels != self.out_channels:
|
201 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
202 |
+
out_channels,
|
203 |
+
kernel_size=1,
|
204 |
+
stride=1,
|
205 |
+
padding=0)
|
206 |
+
|
207 |
+
def forward(self, x):
|
208 |
+
# 3x3 plane comm
|
209 |
+
h = x
|
210 |
+
h = self.norm1(h)
|
211 |
+
h = nonlinearity(h)
|
212 |
+
h = self.conv1(h)
|
213 |
+
|
214 |
+
# 1x1 3d aware, crossplane comm
|
215 |
+
h = self.norm_mid(h)
|
216 |
+
h = nonlinearity(h)
|
217 |
+
h = unroll_triplanes(h)
|
218 |
+
h = self.conv_3daware(h)
|
219 |
+
h = roll_triplanes(h)
|
220 |
+
|
221 |
+
# 3x3 plane comm
|
222 |
+
h = self.norm2(h)
|
223 |
+
h = nonlinearity(h)
|
224 |
+
h = self.conv2(h)
|
225 |
+
|
226 |
+
if self.in_channels != self.out_channels:
|
227 |
+
x = self.nin_shortcut(x)
|
228 |
+
|
229 |
+
return x+h
|
230 |
+
|
231 |
+
class DownConv3dAware(nn.Module):
|
232 |
+
"""
|
233 |
+
A helper Module that performs 2 convolutions and 1 MaxPool.
|
234 |
+
A ReLU activation follows each convolution.
|
235 |
+
"""
|
236 |
+
def __init__(self, in_channels, out_channels, downsample=True, with_conv=False):
|
237 |
+
super(DownConv3dAware, self).__init__()
|
238 |
+
|
239 |
+
self.in_channels = in_channels
|
240 |
+
self.out_channels = out_channels
|
241 |
+
|
242 |
+
self.block = ResnetBlock3dAware(in_channels=in_channels,
|
243 |
+
out_channels=out_channels)
|
244 |
+
|
245 |
+
self.do_downsample = downsample
|
246 |
+
self.downsample = Downsample(out_channels, with_conv=with_conv)
|
247 |
+
|
248 |
+
def forward(self, x):
|
249 |
+
"""
|
250 |
+
rolled input, rolled output
|
251 |
+
Args:
|
252 |
+
x: rolled (b c (tri*h) w)
|
253 |
+
"""
|
254 |
+
x = self.block(x)
|
255 |
+
before_pool = x
|
256 |
+
# if self.pooling:
|
257 |
+
# x = self.pool(x)
|
258 |
+
if self.do_downsample:
|
259 |
+
# unroll and cat channel-wise (to prevent pooling across triplane boundaries)
|
260 |
+
x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3)
|
261 |
+
x = self.downsample(x)
|
262 |
+
# undo
|
263 |
+
x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3)
|
264 |
+
return x, before_pool
|
265 |
+
|
266 |
+
class UpConv3dAware(nn.Module):
|
267 |
+
"""
|
268 |
+
A helper Module that performs 2 convolutions and 1 UpConvolution.
|
269 |
+
A ReLU activation follows each convolution.
|
270 |
+
"""
|
271 |
+
def __init__(self, in_channels, out_channels,
|
272 |
+
merge_mode='concat', with_conv=False): #up_mode='transpose', ):
|
273 |
+
super(UpConv3dAware, self).__init__()
|
274 |
+
|
275 |
+
self.in_channels = in_channels
|
276 |
+
self.out_channels = out_channels
|
277 |
+
self.merge_mode = merge_mode
|
278 |
+
|
279 |
+
self.upsample = Upsample(in_channels, with_conv)
|
280 |
+
|
281 |
+
if self.merge_mode == 'concat':
|
282 |
+
self.norm1 = Normalize(in_channels+out_channels)
|
283 |
+
self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels,
|
284 |
+
out_channels=out_channels)
|
285 |
+
else:
|
286 |
+
self.norm1 = Normalize(in_channels)
|
287 |
+
self.block = ResnetBlock3dAware(in_channels=in_channels,
|
288 |
+
out_channels=out_channels)
|
289 |
+
|
290 |
+
|
291 |
+
def forward(self, from_down, from_up):
|
292 |
+
""" Forward pass
|
293 |
+
rolled inputs, rolled output
|
294 |
+
rolled (b c (tri*h) w)
|
295 |
+
Arguments:
|
296 |
+
from_down: tensor from the encoder pathway
|
297 |
+
from_up: upconv'd tensor from the decoder pathway
|
298 |
+
"""
|
299 |
+
# from_up = self.upconv(from_up)
|
300 |
+
from_up = self.upsample(from_up)
|
301 |
+
if self.merge_mode == 'concat':
|
302 |
+
x = torch.cat((from_up, from_down), 1)
|
303 |
+
else:
|
304 |
+
x = from_up + from_down
|
305 |
+
|
306 |
+
x = self.norm1(x)
|
307 |
+
x = self.block(x)
|
308 |
+
return x
|
309 |
+
|
310 |
+
class UNetTriplane3dAware(nn.Module):
|
311 |
+
def __init__(self, out_channels, in_channels=3, depth=5,
|
312 |
+
start_filts=64,# up_mode='transpose',
|
313 |
+
use_initial_conv=False,
|
314 |
+
merge_mode='concat', **kwargs):
|
315 |
+
"""
|
316 |
+
Arguments:
|
317 |
+
in_channels: int, number of channels in the input tensor.
|
318 |
+
Default is 3 for RGB images.
|
319 |
+
depth: int, number of MaxPools in the U-Net.
|
320 |
+
start_filts: int, number of convolutional filters for the
|
321 |
+
first conv.
|
322 |
+
"""
|
323 |
+
super(UNetTriplane3dAware, self).__init__()
|
324 |
+
|
325 |
+
|
326 |
+
self.out_channels = out_channels
|
327 |
+
self.in_channels = in_channels
|
328 |
+
self.start_filts = start_filts
|
329 |
+
self.depth = depth
|
330 |
+
|
331 |
+
self.use_initial_conv = use_initial_conv
|
332 |
+
if use_initial_conv:
|
333 |
+
self.conv_initial = conv1x1(self.in_channels, self.start_filts)
|
334 |
+
|
335 |
+
self.down_convs = []
|
336 |
+
self.up_convs = []
|
337 |
+
|
338 |
+
# create the encoder pathway and add to a list
|
339 |
+
for i in range(depth):
|
340 |
+
if i == 0:
|
341 |
+
ins = self.start_filts if use_initial_conv else self.in_channels
|
342 |
+
else:
|
343 |
+
ins = outs
|
344 |
+
outs = self.start_filts*(2**i)
|
345 |
+
downsamp_it = True if i < depth-1 else False
|
346 |
+
|
347 |
+
down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it)
|
348 |
+
self.down_convs.append(down_conv)
|
349 |
+
|
350 |
+
for i in range(depth-1):
|
351 |
+
ins = outs
|
352 |
+
outs = ins // 2
|
353 |
+
up_conv = UpConv3dAware(ins, outs,
|
354 |
+
merge_mode=merge_mode)
|
355 |
+
self.up_convs.append(up_conv)
|
356 |
+
|
357 |
+
# add the list of modules to current module
|
358 |
+
self.down_convs = nn.ModuleList(self.down_convs)
|
359 |
+
self.up_convs = nn.ModuleList(self.up_convs)
|
360 |
+
|
361 |
+
self.norm_out = Normalize(outs)
|
362 |
+
self.conv_final = conv1x1(outs, self.out_channels)
|
363 |
+
|
364 |
+
self.reset_params()
|
365 |
+
|
366 |
+
@staticmethod
|
367 |
+
def weight_init(m):
|
368 |
+
if isinstance(m, nn.Conv2d):
|
369 |
+
# init.xavier_normal_(m.weight, gain=0.1)
|
370 |
+
init.xavier_normal_(m.weight)
|
371 |
+
init.constant_(m.bias, 0)
|
372 |
+
|
373 |
+
|
374 |
+
def reset_params(self):
|
375 |
+
for i, m in enumerate(self.modules()):
|
376 |
+
self.weight_init(m)
|
377 |
+
|
378 |
+
|
379 |
+
def forward(self, x):
|
380 |
+
"""
|
381 |
+
Args:
|
382 |
+
x: Stacked triplane expected to be in (B,3,C,H,W)
|
383 |
+
"""
|
384 |
+
# Roll
|
385 |
+
x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3)
|
386 |
+
|
387 |
+
if self.use_initial_conv:
|
388 |
+
x = self.conv_initial(x)
|
389 |
+
|
390 |
+
encoder_outs = []
|
391 |
+
# encoder pathway, save outputs for merging
|
392 |
+
for i, module in enumerate(self.down_convs):
|
393 |
+
x, before_pool = module(x)
|
394 |
+
encoder_outs.append(before_pool)
|
395 |
+
|
396 |
+
# Spend a block in the middle
|
397 |
+
# x = self.block_mid(x)
|
398 |
+
|
399 |
+
for i, module in enumerate(self.up_convs):
|
400 |
+
before_pool = encoder_outs[-(i+2)]
|
401 |
+
x = module(before_pool, x)
|
402 |
+
|
403 |
+
x = self.norm_out(x)
|
404 |
+
|
405 |
+
# No softmax is used. This means you need to use
|
406 |
+
# nn.CrossEntropyLoss is your training script,
|
407 |
+
# as this module includes a softmax already.
|
408 |
+
x = self.conv_final(nonlinearity(x))
|
409 |
+
|
410 |
+
# Unroll
|
411 |
+
x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3)
|
412 |
+
return x
|
413 |
+
|
414 |
+
|
415 |
+
def setup_unet(output_channels, input_channels, unet_cfg):
|
416 |
+
if unet_cfg['use_3d_aware']:
|
417 |
+
assert(unet_cfg['rolled'])
|
418 |
+
unet = UNetTriplane3dAware(
|
419 |
+
out_channels=output_channels,
|
420 |
+
in_channels=input_channels,
|
421 |
+
depth=unet_cfg['depth'],
|
422 |
+
use_initial_conv=unet_cfg['use_initial_conv'],
|
423 |
+
start_filts=unet_cfg['start_hidden_channels'],)
|
424 |
+
else:
|
425 |
+
raise NotImplementedError
|
426 |
+
return unet
|
427 |
+
|
modules/PartField/partfield/model/UNet/buildingblocks.py
ADDED
@@ -0,0 +1,546 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py
|
2 |
+
# MIT License
|
3 |
+
|
4 |
+
# Copyright (c) 2018 Adrian Wolny
|
5 |
+
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
|
24 |
+
from functools import partial
|
25 |
+
|
26 |
+
import torch
|
27 |
+
from torch import nn as nn
|
28 |
+
from torch.nn import functional as F
|
29 |
+
|
30 |
+
# from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D
|
31 |
+
|
32 |
+
|
33 |
+
def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding,
|
34 |
+
dropout_prob, is3d):
|
35 |
+
"""
|
36 |
+
Create a list of modules with together constitute a single conv layer with non-linearity
|
37 |
+
and optional batchnorm/groupnorm.
|
38 |
+
|
39 |
+
Args:
|
40 |
+
in_channels (int): number of input channels
|
41 |
+
out_channels (int): number of output channels
|
42 |
+
kernel_size(int or tuple): size of the convolving kernel
|
43 |
+
order (string): order of things, e.g.
|
44 |
+
'cr' -> conv + ReLU
|
45 |
+
'gcr' -> groupnorm + conv + ReLU
|
46 |
+
'cl' -> conv + LeakyReLU
|
47 |
+
'ce' -> conv + ELU
|
48 |
+
'bcr' -> batchnorm + conv + ReLU
|
49 |
+
'cbrd' -> conv + batchnorm + ReLU + dropout
|
50 |
+
'cbrD' -> conv + batchnorm + ReLU + dropout2d
|
51 |
+
num_groups (int): number of groups for the GroupNorm
|
52 |
+
padding (int or tuple): add zero-padding added to all three sides of the input
|
53 |
+
dropout_prob (float): dropout probability
|
54 |
+
is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d
|
55 |
+
Return:
|
56 |
+
list of tuple (name, module)
|
57 |
+
"""
|
58 |
+
assert 'c' in order, "Conv layer MUST be present"
|
59 |
+
assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
|
60 |
+
|
61 |
+
modules = []
|
62 |
+
for i, char in enumerate(order):
|
63 |
+
if char == 'r':
|
64 |
+
modules.append(('ReLU', nn.ReLU(inplace=True)))
|
65 |
+
elif char == 'l':
|
66 |
+
modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True)))
|
67 |
+
elif char == 'e':
|
68 |
+
modules.append(('ELU', nn.ELU(inplace=True)))
|
69 |
+
elif char == 'c':
|
70 |
+
# add learnable bias only in the absence of batchnorm/groupnorm
|
71 |
+
bias = not ('g' in order or 'b' in order)
|
72 |
+
if is3d:
|
73 |
+
conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
|
74 |
+
else:
|
75 |
+
conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
|
76 |
+
|
77 |
+
modules.append(('conv', conv))
|
78 |
+
elif char == 'g':
|
79 |
+
is_before_conv = i < order.index('c')
|
80 |
+
if is_before_conv:
|
81 |
+
num_channels = in_channels
|
82 |
+
else:
|
83 |
+
num_channels = out_channels
|
84 |
+
|
85 |
+
# use only one group if the given number of groups is greater than the number of channels
|
86 |
+
if num_channels < num_groups:
|
87 |
+
num_groups = 1
|
88 |
+
|
89 |
+
assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
|
90 |
+
modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
|
91 |
+
elif char == 'b':
|
92 |
+
is_before_conv = i < order.index('c')
|
93 |
+
if is3d:
|
94 |
+
bn = nn.BatchNorm3d
|
95 |
+
else:
|
96 |
+
bn = nn.BatchNorm2d
|
97 |
+
|
98 |
+
if is_before_conv:
|
99 |
+
modules.append(('batchnorm', bn(in_channels)))
|
100 |
+
else:
|
101 |
+
modules.append(('batchnorm', bn(out_channels)))
|
102 |
+
elif char == 'd':
|
103 |
+
modules.append(('dropout', nn.Dropout(p=dropout_prob)))
|
104 |
+
elif char == 'D':
|
105 |
+
modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob)))
|
106 |
+
else:
|
107 |
+
raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']")
|
108 |
+
|
109 |
+
return modules
|
110 |
+
|
111 |
+
|
112 |
+
class SingleConv(nn.Sequential):
|
113 |
+
"""
|
114 |
+
Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
|
115 |
+
of operations can be specified via the `order` parameter
|
116 |
+
|
117 |
+
Args:
|
118 |
+
in_channels (int): number of input channels
|
119 |
+
out_channels (int): number of output channels
|
120 |
+
kernel_size (int or tuple): size of the convolving kernel
|
121 |
+
order (string): determines the order of layers, e.g.
|
122 |
+
'cr' -> conv + ReLU
|
123 |
+
'crg' -> conv + ReLU + groupnorm
|
124 |
+
'cl' -> conv + LeakyReLU
|
125 |
+
'ce' -> conv + ELU
|
126 |
+
num_groups (int): number of groups for the GroupNorm
|
127 |
+
padding (int or tuple): add zero-padding
|
128 |
+
dropout_prob (float): dropout probability, default 0.1
|
129 |
+
is3d (bool): if True use Conv3d, otherwise use Conv2d
|
130 |
+
"""
|
131 |
+
|
132 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8,
|
133 |
+
padding=1, dropout_prob=0.1, is3d=True):
|
134 |
+
super(SingleConv, self).__init__()
|
135 |
+
|
136 |
+
for name, module in create_conv(in_channels, out_channels, kernel_size, order,
|
137 |
+
num_groups, padding, dropout_prob, is3d):
|
138 |
+
self.add_module(name, module)
|
139 |
+
|
140 |
+
|
141 |
+
class DoubleConv(nn.Sequential):
|
142 |
+
"""
|
143 |
+
A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
|
144 |
+
We use (Conv3d+ReLU+GroupNorm3d) by default.
|
145 |
+
This can be changed however by providing the 'order' argument, e.g. in order
|
146 |
+
to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
|
147 |
+
Use padded convolutions to make sure that the output (H_out, W_out) is the same
|
148 |
+
as (H_in, W_in), so that you don't have to crop in the decoder path.
|
149 |
+
|
150 |
+
Args:
|
151 |
+
in_channels (int): number of input channels
|
152 |
+
out_channels (int): number of output channels
|
153 |
+
encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
|
154 |
+
kernel_size (int or tuple): size of the convolving kernel
|
155 |
+
order (string): determines the order of layers, e.g.
|
156 |
+
'cr' -> conv + ReLU
|
157 |
+
'crg' -> conv + ReLU + groupnorm
|
158 |
+
'cl' -> conv + LeakyReLU
|
159 |
+
'ce' -> conv + ELU
|
160 |
+
num_groups (int): number of groups for the GroupNorm
|
161 |
+
padding (int or tuple): add zero-padding added to all three sides of the input
|
162 |
+
upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
|
163 |
+
dropout_prob (float or tuple): dropout probability for each convolution, default 0.1
|
164 |
+
is3d (bool): if True use Conv3d instead of Conv2d layers
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr',
|
168 |
+
num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
|
169 |
+
super(DoubleConv, self).__init__()
|
170 |
+
if encoder:
|
171 |
+
# we're in the encoder path
|
172 |
+
conv1_in_channels = in_channels
|
173 |
+
if upscale == 1:
|
174 |
+
conv1_out_channels = out_channels
|
175 |
+
else:
|
176 |
+
conv1_out_channels = out_channels // 2
|
177 |
+
if conv1_out_channels < in_channels:
|
178 |
+
conv1_out_channels = in_channels
|
179 |
+
conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
|
180 |
+
else:
|
181 |
+
# we're in the decoder path, decrease the number of channels in the 1st convolution
|
182 |
+
conv1_in_channels, conv1_out_channels = in_channels, out_channels
|
183 |
+
conv2_in_channels, conv2_out_channels = out_channels, out_channels
|
184 |
+
|
185 |
+
# check if dropout_prob is a tuple and if so
|
186 |
+
# split it for different dropout probabilities for each convolution.
|
187 |
+
if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple):
|
188 |
+
dropout_prob1 = dropout_prob[0]
|
189 |
+
dropout_prob2 = dropout_prob[1]
|
190 |
+
else:
|
191 |
+
dropout_prob1 = dropout_prob2 = dropout_prob
|
192 |
+
|
193 |
+
# conv1
|
194 |
+
self.add_module('SingleConv1',
|
195 |
+
SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups,
|
196 |
+
padding=padding, dropout_prob=dropout_prob1, is3d=is3d))
|
197 |
+
# conv2
|
198 |
+
self.add_module('SingleConv2',
|
199 |
+
SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups,
|
200 |
+
padding=padding, dropout_prob=dropout_prob2, is3d=is3d))
|
201 |
+
|
202 |
+
|
203 |
+
class ResNetBlock(nn.Module):
|
204 |
+
"""
|
205 |
+
Residual block that can be used instead of standard DoubleConv in the Encoder module.
|
206 |
+
Motivated by: https://arxiv.org/pdf/1706.00120.pdf
|
207 |
+
|
208 |
+
Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
|
209 |
+
"""
|
210 |
+
|
211 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs):
|
212 |
+
super(ResNetBlock, self).__init__()
|
213 |
+
|
214 |
+
if in_channels != out_channels:
|
215 |
+
# conv1x1 for increasing the number of channels
|
216 |
+
if is3d:
|
217 |
+
self.conv1 = nn.Conv3d(in_channels, out_channels, 1)
|
218 |
+
else:
|
219 |
+
self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
|
220 |
+
else:
|
221 |
+
self.conv1 = nn.Identity()
|
222 |
+
|
223 |
+
self.conv2 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups,
|
224 |
+
is3d=is3d)
|
225 |
+
# remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
|
226 |
+
n_order = order
|
227 |
+
for c in 'rel':
|
228 |
+
n_order = n_order.replace(c, '')
|
229 |
+
self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
|
230 |
+
num_groups=num_groups, is3d=is3d)
|
231 |
+
|
232 |
+
# create non-linearity separately
|
233 |
+
if 'l' in order:
|
234 |
+
self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
235 |
+
elif 'e' in order:
|
236 |
+
self.non_linearity = nn.ELU(inplace=True)
|
237 |
+
else:
|
238 |
+
self.non_linearity = nn.ReLU(inplace=True)
|
239 |
+
|
240 |
+
def forward(self, x):
|
241 |
+
# apply first convolution to bring the number of channels to out_channels
|
242 |
+
residual = self.conv1(x)
|
243 |
+
|
244 |
+
out = self.conv2(x)
|
245 |
+
out = self.conv3(out)
|
246 |
+
|
247 |
+
out += residual
|
248 |
+
out = self.non_linearity(out)
|
249 |
+
|
250 |
+
return out
|
251 |
+
|
252 |
+
class Encoder(nn.Module):
|
253 |
+
"""
|
254 |
+
A single module from the encoder path consisting of the optional max
|
255 |
+
pooling layer (one may specify the MaxPool kernel_size to be different
|
256 |
+
from the standard (2,2,2), e.g. if the volumetric data is anisotropic
|
257 |
+
(make sure to use complementary scale_factor in the decoder path) followed by
|
258 |
+
a basic module (DoubleConv or ResNetBlock).
|
259 |
+
|
260 |
+
Args:
|
261 |
+
in_channels (int): number of input channels
|
262 |
+
out_channels (int): number of output channels
|
263 |
+
conv_kernel_size (int or tuple): size of the convolving kernel
|
264 |
+
apply_pooling (bool): if True use MaxPool3d before DoubleConv
|
265 |
+
pool_kernel_size (int or tuple): the size of the window
|
266 |
+
pool_type (str): pooling layer: 'max' or 'avg'
|
267 |
+
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
268 |
+
conv_layer_order (string): determines the order of layers
|
269 |
+
in `DoubleConv` module. See `DoubleConv` for more info.
|
270 |
+
num_groups (int): number of groups for the GroupNorm
|
271 |
+
padding (int or tuple): add zero-padding added to all three sides of the input
|
272 |
+
upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
|
273 |
+
dropout_prob (float or tuple): dropout probability, default 0.1
|
274 |
+
is3d (bool): use 3d or 2d convolutions/pooling operation
|
275 |
+
"""
|
276 |
+
|
277 |
+
def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
|
278 |
+
pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
|
279 |
+
num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
|
280 |
+
super(Encoder, self).__init__()
|
281 |
+
assert pool_type in ['max', 'avg']
|
282 |
+
if apply_pooling:
|
283 |
+
if pool_type == 'max':
|
284 |
+
if is3d:
|
285 |
+
self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
|
286 |
+
else:
|
287 |
+
self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size)
|
288 |
+
else:
|
289 |
+
if is3d:
|
290 |
+
self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
|
291 |
+
else:
|
292 |
+
self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size)
|
293 |
+
else:
|
294 |
+
self.pooling = None
|
295 |
+
|
296 |
+
self.basic_module = basic_module(in_channels, out_channels,
|
297 |
+
encoder=True,
|
298 |
+
kernel_size=conv_kernel_size,
|
299 |
+
order=conv_layer_order,
|
300 |
+
num_groups=num_groups,
|
301 |
+
padding=padding,
|
302 |
+
upscale=upscale,
|
303 |
+
dropout_prob=dropout_prob,
|
304 |
+
is3d=is3d)
|
305 |
+
|
306 |
+
def forward(self, x):
|
307 |
+
if self.pooling is not None:
|
308 |
+
x = self.pooling(x)
|
309 |
+
x = self.basic_module(x)
|
310 |
+
return x
|
311 |
+
|
312 |
+
|
313 |
+
class Decoder(nn.Module):
|
314 |
+
"""
|
315 |
+
A single module for decoder path consisting of the upsampling layer
|
316 |
+
(either learned ConvTranspose3d or nearest neighbor interpolation)
|
317 |
+
followed by a basic module (DoubleConv or ResNetBlock).
|
318 |
+
|
319 |
+
Args:
|
320 |
+
in_channels (int): number of input channels
|
321 |
+
out_channels (int): number of output channels
|
322 |
+
conv_kernel_size (int or tuple): size of the convolving kernel
|
323 |
+
scale_factor (int or tuple): used as the multiplier for the image H/W/D in
|
324 |
+
case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
|
325 |
+
from the corresponding encoder
|
326 |
+
basic_module(nn.Module): either ResNetBlock or DoubleConv
|
327 |
+
conv_layer_order (string): determines the order of layers
|
328 |
+
in `DoubleConv` module. See `DoubleConv` for more info.
|
329 |
+
num_groups (int): number of groups for the GroupNorm
|
330 |
+
padding (int or tuple): add zero-padding added to all three sides of the input
|
331 |
+
upsample (str): algorithm used for upsampling:
|
332 |
+
InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
|
333 |
+
TransposeConvUpsampling: 'deconv'
|
334 |
+
No upsampling: None
|
335 |
+
Default: 'default' (chooses automatically)
|
336 |
+
dropout_prob (float or tuple): dropout probability, default 0.1
|
337 |
+
"""
|
338 |
+
|
339 |
+
def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv,
|
340 |
+
conv_layer_order='gcr', num_groups=8, padding=1, upsample='default',
|
341 |
+
dropout_prob=0.1, is3d=True):
|
342 |
+
super(Decoder, self).__init__()
|
343 |
+
|
344 |
+
# perform concat joining per default
|
345 |
+
concat = True
|
346 |
+
|
347 |
+
# don't adapt channels after join operation
|
348 |
+
adapt_channels = False
|
349 |
+
|
350 |
+
if upsample is not None and upsample != 'none':
|
351 |
+
if upsample == 'default':
|
352 |
+
if basic_module == DoubleConv:
|
353 |
+
upsample = 'nearest' # use nearest neighbor interpolation for upsampling
|
354 |
+
concat = True # use concat joining
|
355 |
+
adapt_channels = False # don't adapt channels
|
356 |
+
elif basic_module == ResNetBlock: #or basic_module == ResNetBlockSE:
|
357 |
+
upsample = 'deconv' # use deconvolution upsampling
|
358 |
+
concat = False # use summation joining
|
359 |
+
adapt_channels = True # adapt channels after joining
|
360 |
+
|
361 |
+
# perform deconvolution upsampling if mode is deconv
|
362 |
+
if upsample == 'deconv':
|
363 |
+
self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels,
|
364 |
+
kernel_size=conv_kernel_size, scale_factor=scale_factor,
|
365 |
+
is3d=is3d)
|
366 |
+
else:
|
367 |
+
self.upsampling = InterpolateUpsampling(mode=upsample)
|
368 |
+
else:
|
369 |
+
# no upsampling
|
370 |
+
self.upsampling = NoUpsampling()
|
371 |
+
# concat joining
|
372 |
+
self.joining = partial(self._joining, concat=True)
|
373 |
+
|
374 |
+
# perform joining operation
|
375 |
+
self.joining = partial(self._joining, concat=concat)
|
376 |
+
|
377 |
+
# adapt the number of in_channels for the ResNetBlock
|
378 |
+
if adapt_channels is True:
|
379 |
+
in_channels = out_channels
|
380 |
+
|
381 |
+
self.basic_module = basic_module(in_channels, out_channels,
|
382 |
+
encoder=False,
|
383 |
+
kernel_size=conv_kernel_size,
|
384 |
+
order=conv_layer_order,
|
385 |
+
num_groups=num_groups,
|
386 |
+
padding=padding,
|
387 |
+
dropout_prob=dropout_prob,
|
388 |
+
is3d=is3d)
|
389 |
+
|
390 |
+
def forward(self, encoder_features, x):
|
391 |
+
x = self.upsampling(encoder_features=encoder_features, x=x)
|
392 |
+
x = self.joining(encoder_features, x)
|
393 |
+
x = self.basic_module(x)
|
394 |
+
return x
|
395 |
+
|
396 |
+
@staticmethod
|
397 |
+
def _joining(encoder_features, x, concat):
|
398 |
+
if concat:
|
399 |
+
return torch.cat((encoder_features, x), dim=1)
|
400 |
+
else:
|
401 |
+
return encoder_features + x
|
402 |
+
|
403 |
+
|
404 |
+
def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding,
|
405 |
+
conv_upscale, dropout_prob,
|
406 |
+
layer_order, num_groups, pool_kernel_size, is3d):
|
407 |
+
# create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
|
408 |
+
encoders = []
|
409 |
+
for i, out_feature_num in enumerate(f_maps):
|
410 |
+
if i == 0:
|
411 |
+
# apply conv_coord only in the first encoder if any
|
412 |
+
encoder = Encoder(in_channels, out_feature_num,
|
413 |
+
apply_pooling=False, # skip pooling in the firs encoder
|
414 |
+
basic_module=basic_module,
|
415 |
+
conv_layer_order=layer_order,
|
416 |
+
conv_kernel_size=conv_kernel_size,
|
417 |
+
num_groups=num_groups,
|
418 |
+
padding=conv_padding,
|
419 |
+
upscale=conv_upscale,
|
420 |
+
dropout_prob=dropout_prob,
|
421 |
+
is3d=is3d)
|
422 |
+
else:
|
423 |
+
encoder = Encoder(f_maps[i - 1], out_feature_num,
|
424 |
+
basic_module=basic_module,
|
425 |
+
conv_layer_order=layer_order,
|
426 |
+
conv_kernel_size=conv_kernel_size,
|
427 |
+
num_groups=num_groups,
|
428 |
+
pool_kernel_size=pool_kernel_size,
|
429 |
+
padding=conv_padding,
|
430 |
+
upscale=conv_upscale,
|
431 |
+
dropout_prob=dropout_prob,
|
432 |
+
is3d=is3d)
|
433 |
+
|
434 |
+
encoders.append(encoder)
|
435 |
+
|
436 |
+
return nn.ModuleList(encoders)
|
437 |
+
|
438 |
+
|
439 |
+
def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order,
|
440 |
+
num_groups, upsample, dropout_prob, is3d):
|
441 |
+
# create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1`
|
442 |
+
decoders = []
|
443 |
+
reversed_f_maps = list(reversed(f_maps[1:]))
|
444 |
+
for i in range(len(reversed_f_maps) - 1):
|
445 |
+
if basic_module == DoubleConv and upsample != 'deconv':
|
446 |
+
in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
|
447 |
+
else:
|
448 |
+
in_feature_num = reversed_f_maps[i]
|
449 |
+
|
450 |
+
out_feature_num = reversed_f_maps[i + 1]
|
451 |
+
|
452 |
+
decoder = Decoder(in_feature_num, out_feature_num,
|
453 |
+
basic_module=basic_module,
|
454 |
+
conv_layer_order=layer_order,
|
455 |
+
conv_kernel_size=conv_kernel_size,
|
456 |
+
num_groups=num_groups,
|
457 |
+
padding=conv_padding,
|
458 |
+
upsample=upsample,
|
459 |
+
dropout_prob=dropout_prob,
|
460 |
+
is3d=is3d)
|
461 |
+
decoders.append(decoder)
|
462 |
+
return nn.ModuleList(decoders)
|
463 |
+
|
464 |
+
|
465 |
+
class AbstractUpsampling(nn.Module):
|
466 |
+
"""
|
467 |
+
Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either
|
468 |
+
interpolation or learned transposed convolution.
|
469 |
+
"""
|
470 |
+
|
471 |
+
def __init__(self, upsample):
|
472 |
+
super(AbstractUpsampling, self).__init__()
|
473 |
+
self.upsample = upsample
|
474 |
+
|
475 |
+
def forward(self, encoder_features, x):
|
476 |
+
# get the spatial dimensions of the output given the encoder_features
|
477 |
+
output_size = encoder_features.size()[2:]
|
478 |
+
# upsample the input and return
|
479 |
+
return self.upsample(x, output_size)
|
480 |
+
|
481 |
+
|
482 |
+
class InterpolateUpsampling(AbstractUpsampling):
|
483 |
+
"""
|
484 |
+
Args:
|
485 |
+
mode (str): algorithm used for upsampling:
|
486 |
+
'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
|
487 |
+
used only if transposed_conv is False
|
488 |
+
"""
|
489 |
+
|
490 |
+
def __init__(self, mode='nearest'):
|
491 |
+
upsample = partial(self._interpolate, mode=mode)
|
492 |
+
super().__init__(upsample)
|
493 |
+
|
494 |
+
@staticmethod
|
495 |
+
def _interpolate(x, size, mode):
|
496 |
+
return F.interpolate(x, size=size, mode=mode)
|
497 |
+
|
498 |
+
|
499 |
+
class TransposeConvUpsampling(AbstractUpsampling):
|
500 |
+
"""
|
501 |
+
Args:
|
502 |
+
in_channels (int): number of input channels for transposed conv
|
503 |
+
used only if transposed_conv is True
|
504 |
+
out_channels (int): number of output channels for transpose conv
|
505 |
+
used only if transposed_conv is True
|
506 |
+
kernel_size (int or tuple): size of the convolving kernel
|
507 |
+
used only if transposed_conv is True
|
508 |
+
scale_factor (int or tuple): stride of the convolution
|
509 |
+
used only if transposed_conv is True
|
510 |
+
is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d
|
511 |
+
"""
|
512 |
+
|
513 |
+
class Upsample(nn.Module):
|
514 |
+
"""
|
515 |
+
Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in
|
516 |
+
transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary.
|
517 |
+
"""
|
518 |
+
|
519 |
+
def __init__(self, conv_transposed, is3d):
|
520 |
+
super().__init__()
|
521 |
+
self.conv_transposed = conv_transposed
|
522 |
+
self.is3d = is3d
|
523 |
+
|
524 |
+
def forward(self, x, size):
|
525 |
+
x = self.conv_transposed(x)
|
526 |
+
return F.interpolate(x, size=size)
|
527 |
+
|
528 |
+
def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True):
|
529 |
+
# make sure that the output size reverses the MaxPool3d from the corresponding encoder
|
530 |
+
if is3d is True:
|
531 |
+
conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size,
|
532 |
+
stride=scale_factor, padding=1, bias=False)
|
533 |
+
else:
|
534 |
+
conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
|
535 |
+
stride=scale_factor, padding=1, bias=False)
|
536 |
+
upsample = self.Upsample(conv_transposed, is3d)
|
537 |
+
super().__init__(upsample)
|
538 |
+
|
539 |
+
|
540 |
+
class NoUpsampling(AbstractUpsampling):
|
541 |
+
def __init__(self):
|
542 |
+
super().__init__(self._no_upsampling)
|
543 |
+
|
544 |
+
@staticmethod
|
545 |
+
def _no_upsampling(x, size):
|
546 |
+
return x
|
modules/PartField/partfield/model/UNet/model.py
ADDED
@@ -0,0 +1,170 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py
|
2 |
+
# MIT License
|
3 |
+
|
4 |
+
# Copyright (c) 2018 Adrian Wolny
|
5 |
+
|
6 |
+
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
7 |
+
# of this software and associated documentation files (the "Software"), to deal
|
8 |
+
# in the Software without restriction, including without limitation the rights
|
9 |
+
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
10 |
+
# copies of the Software, and to permit persons to whom the Software is
|
11 |
+
# furnished to do so, subject to the following conditions:
|
12 |
+
|
13 |
+
# The above copyright notice and this permission notice shall be included in all
|
14 |
+
# copies or substantial portions of the Software.
|
15 |
+
|
16 |
+
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
17 |
+
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
18 |
+
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
19 |
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
20 |
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
21 |
+
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
22 |
+
# SOFTWARE.
|
23 |
+
|
24 |
+
import torch.nn as nn
|
25 |
+
|
26 |
+
from partfield.model.UNet.buildingblocks import DoubleConv, ResNetBlock, \
|
27 |
+
create_decoders, create_encoders
|
28 |
+
|
29 |
+
def number_of_features_per_level(init_channel_number, num_levels):
|
30 |
+
return [init_channel_number * 2 ** k for k in range(num_levels)]
|
31 |
+
|
32 |
+
class AbstractUNet(nn.Module):
|
33 |
+
"""
|
34 |
+
Base class for standard and residual UNet.
|
35 |
+
|
36 |
+
Args:
|
37 |
+
in_channels (int): number of input channels
|
38 |
+
out_channels (int): number of output segmentation masks;
|
39 |
+
Note that the of out_channels might correspond to either
|
40 |
+
different semantic classes or to different binary segmentation mask.
|
41 |
+
It's up to the user of the class to interpret the out_channels and
|
42 |
+
use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
|
43 |
+
or BCEWithLogitsLoss (two-class) respectively)
|
44 |
+
f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
|
45 |
+
of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
|
46 |
+
final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution,
|
47 |
+
otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing
|
48 |
+
basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....)
|
49 |
+
layer_order (string): determines the order of layers in `SingleConv` module.
|
50 |
+
E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info
|
51 |
+
num_groups (int): number of groups for the GroupNorm
|
52 |
+
num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
|
53 |
+
default: 4
|
54 |
+
is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied
|
55 |
+
after the final convolution; if False (regression problem) the normalization layer is skipped
|
56 |
+
conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
|
57 |
+
pool_kernel_size (int or tuple): the size of the window
|
58 |
+
conv_padding (int or tuple): add zero-padding added to all three sides of the input
|
59 |
+
conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
|
60 |
+
upsample (str): algorithm used for decoder upsampling:
|
61 |
+
InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
|
62 |
+
TransposeConvUpsampling: 'deconv'
|
63 |
+
No upsampling: None
|
64 |
+
Default: 'default' (chooses automatically)
|
65 |
+
dropout_prob (float or tuple): dropout probability, default: 0.1
|
66 |
+
is3d (bool): if True the model is 3D, otherwise 2D, default: True
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
|
70 |
+
num_groups=8, num_levels=4, is_segmentation=False, conv_kernel_size=3, pool_kernel_size=2,
|
71 |
+
conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True, encoder_only=False):
|
72 |
+
super(AbstractUNet, self).__init__()
|
73 |
+
|
74 |
+
if isinstance(f_maps, int):
|
75 |
+
f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
|
76 |
+
|
77 |
+
assert isinstance(f_maps, list) or isinstance(f_maps, tuple)
|
78 |
+
assert len(f_maps) > 1, "Required at least 2 levels in the U-Net"
|
79 |
+
if 'g' in layer_order:
|
80 |
+
assert num_groups is not None, "num_groups must be specified if GroupNorm is used"
|
81 |
+
|
82 |
+
# create encoder path
|
83 |
+
self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size,
|
84 |
+
conv_padding, conv_upscale, dropout_prob,
|
85 |
+
layer_order, num_groups, pool_kernel_size, is3d)
|
86 |
+
|
87 |
+
self.encoder_only = encoder_only
|
88 |
+
|
89 |
+
if encoder_only == False:
|
90 |
+
# create decoder path
|
91 |
+
self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding,
|
92 |
+
layer_order, num_groups, upsample, dropout_prob,
|
93 |
+
is3d)
|
94 |
+
|
95 |
+
# in the last layer a 1×1 convolution reduces the number of output channels to the number of labels
|
96 |
+
if is3d:
|
97 |
+
self.final_conv = nn.Conv3d(f_maps[1], out_channels, 1)
|
98 |
+
else:
|
99 |
+
self.final_conv = nn.Conv2d(f_maps[1], out_channels, 1)
|
100 |
+
|
101 |
+
if is_segmentation:
|
102 |
+
# semantic segmentation problem
|
103 |
+
if final_sigmoid:
|
104 |
+
self.final_activation = nn.Sigmoid()
|
105 |
+
else:
|
106 |
+
self.final_activation = nn.Softmax(dim=1)
|
107 |
+
else:
|
108 |
+
# regression problem
|
109 |
+
self.final_activation = None
|
110 |
+
|
111 |
+
def forward(self, x, return_bottleneck_feat=False):
|
112 |
+
# encoder part
|
113 |
+
encoders_features = []
|
114 |
+
for encoder in self.encoders:
|
115 |
+
x = encoder(x)
|
116 |
+
# reverse the encoder outputs to be aligned with the decoder
|
117 |
+
encoders_features.insert(0, x)
|
118 |
+
|
119 |
+
# remove the last encoder's output from the list
|
120 |
+
# !!remember: it's the 1st in the list
|
121 |
+
bottleneck_feat = encoders_features[0]
|
122 |
+
if self.encoder_only:
|
123 |
+
return bottleneck_feat
|
124 |
+
else:
|
125 |
+
encoders_features = encoders_features[1:]
|
126 |
+
|
127 |
+
# decoder part
|
128 |
+
for decoder, encoder_features in zip(self.decoders, encoders_features):
|
129 |
+
# pass the output from the corresponding encoder and the output
|
130 |
+
# of the previous decoder
|
131 |
+
x = decoder(encoder_features, x)
|
132 |
+
|
133 |
+
x = self.final_conv(x)
|
134 |
+
# During training the network outputs logits
|
135 |
+
if self.final_activation is not None:
|
136 |
+
x = self.final_activation(x)
|
137 |
+
|
138 |
+
if return_bottleneck_feat:
|
139 |
+
return x, bottleneck_feat
|
140 |
+
else:
|
141 |
+
return x
|
142 |
+
|
143 |
+
class ResidualUNet3D(AbstractUNet):
|
144 |
+
"""
|
145 |
+
Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
|
146 |
+
Uses ResNetBlock as a basic building block, summation joining instead
|
147 |
+
of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
|
148 |
+
Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
|
149 |
+
"""
|
150 |
+
|
151 |
+
def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=(8, 16, 64, 256, 1024), layer_order='gcr',
|
152 |
+
num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
|
153 |
+
conv_upscale=2, upsample='default', dropout_prob=0.1, encoder_only=False, **kwargs):
|
154 |
+
super(ResidualUNet3D, self).__init__(in_channels=in_channels,
|
155 |
+
out_channels=out_channels,
|
156 |
+
final_sigmoid=final_sigmoid,
|
157 |
+
basic_module=ResNetBlock,
|
158 |
+
f_maps=f_maps,
|
159 |
+
layer_order=layer_order,
|
160 |
+
num_groups=num_groups,
|
161 |
+
num_levels=num_levels,
|
162 |
+
is_segmentation=is_segmentation,
|
163 |
+
conv_padding=conv_padding,
|
164 |
+
conv_upscale=conv_upscale,
|
165 |
+
upsample=upsample,
|
166 |
+
dropout_prob=dropout_prob,
|
167 |
+
encoder_only=encoder_only,
|
168 |
+
is3d=True)
|
169 |
+
|
170 |
+
|
modules/PartField/partfield/model/model_utils.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class VanillaMLP(nn.Module):
|
5 |
+
def __init__(self, input_dim, output_dim, out_activation, n_hidden_layers=4, n_neurons=64, activation="ReLU"):
|
6 |
+
super().__init__()
|
7 |
+
self.n_neurons = n_neurons
|
8 |
+
self.n_hidden_layers = n_hidden_layers
|
9 |
+
self.activation = activation
|
10 |
+
self.out_activation = out_activation
|
11 |
+
layers = [
|
12 |
+
self.make_linear(input_dim, self.n_neurons, is_first=True, is_last=False),
|
13 |
+
self.make_activation(),
|
14 |
+
]
|
15 |
+
for i in range(self.n_hidden_layers - 1):
|
16 |
+
layers += [
|
17 |
+
self.make_linear(
|
18 |
+
self.n_neurons, self.n_neurons, is_first=False, is_last=False
|
19 |
+
),
|
20 |
+
self.make_activation(),
|
21 |
+
]
|
22 |
+
layers += [
|
23 |
+
self.make_linear(self.n_neurons, output_dim, is_first=False, is_last=True)
|
24 |
+
]
|
25 |
+
if self.out_activation == "sigmoid":
|
26 |
+
layers += [nn.Sigmoid()]
|
27 |
+
elif self.out_activation == "tanh":
|
28 |
+
layers += [nn.Tanh()]
|
29 |
+
elif self.out_activation == "hardtanh":
|
30 |
+
layers += [nn.Hardtanh()]
|
31 |
+
elif self.out_activation == "GELU":
|
32 |
+
layers += [nn.GELU()]
|
33 |
+
elif self.out_activation == "RELU":
|
34 |
+
layers += [nn.ReLU()]
|
35 |
+
else:
|
36 |
+
raise NotImplementedError
|
37 |
+
self.layers = nn.Sequential(*layers)
|
38 |
+
|
39 |
+
def forward(self, x, split_size=100000):
|
40 |
+
with torch.cuda.amp.autocast(enabled=False):
|
41 |
+
out = self.layers(x)
|
42 |
+
return out
|
43 |
+
|
44 |
+
def make_linear(self, dim_in, dim_out, is_first, is_last):
|
45 |
+
layer = nn.Linear(dim_in, dim_out, bias=False)
|
46 |
+
return layer
|
47 |
+
|
48 |
+
def make_activation(self):
|
49 |
+
if self.activation == "ReLU":
|
50 |
+
return nn.ReLU(inplace=True)
|
51 |
+
elif self.activation == "GELU":
|
52 |
+
return nn.GELU()
|
53 |
+
else:
|
54 |
+
raise NotImplementedError
|
modules/PartField/partfield/model/triplane.py
ADDED
@@ -0,0 +1,331 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#https://github.com/3DTopia/OpenLRM/blob/main/openlrm/models/modeling_lrm.py
|
2 |
+
# Copyright (c) 2023-2024, Zexin He
|
3 |
+
#
|
4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
5 |
+
# you may not use this file except in compliance with the License.
|
6 |
+
# You may obtain a copy of the License at
|
7 |
+
#
|
8 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
9 |
+
#
|
10 |
+
# Unless required by applicable law or agreed to in writing, software
|
11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13 |
+
# See the License for the specific language governing permissions and
|
14 |
+
# limitations under the License.
|
15 |
+
|
16 |
+
import torch
|
17 |
+
import torch.nn as nn
|
18 |
+
from functools import partial
|
19 |
+
|
20 |
+
def project_onto_planes(planes, coordinates):
|
21 |
+
"""
|
22 |
+
Does a projection of a 3D point onto a batch of 2D planes,
|
23 |
+
returning 2D plane coordinates.
|
24 |
+
|
25 |
+
Takes plane axes of shape n_planes, 3, 3
|
26 |
+
# Takes coordinates of shape N, M, 3
|
27 |
+
# returns projections of shape N*n_planes, M, 2
|
28 |
+
"""
|
29 |
+
N, M, C = coordinates.shape
|
30 |
+
n_planes, _, _ = planes.shape
|
31 |
+
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
|
32 |
+
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
|
33 |
+
projections = torch.bmm(coordinates, inv_planes)
|
34 |
+
return projections[..., :2]
|
35 |
+
|
36 |
+
def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
37 |
+
plane_axes = torch.tensor([[[1, 0, 0],
|
38 |
+
[0, 1, 0],
|
39 |
+
[0, 0, 1]],
|
40 |
+
[[1, 0, 0],
|
41 |
+
[0, 0, 1],
|
42 |
+
[0, 1, 0]],
|
43 |
+
[[0, 0, 1],
|
44 |
+
[0, 1, 0],
|
45 |
+
[1, 0, 0]]], dtype=torch.float32).cuda()
|
46 |
+
|
47 |
+
assert padding_mode == 'zeros'
|
48 |
+
N, n_planes, C, H, W = plane_features.shape
|
49 |
+
_, M, _ = coordinates.shape
|
50 |
+
plane_features = plane_features.view(N*n_planes, C, H, W)
|
51 |
+
|
52 |
+
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
|
53 |
+
output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
54 |
+
return output_features
|
55 |
+
|
56 |
+
def get_grid_coord(grid_size = 256, align_corners=False):
|
57 |
+
if align_corners == False:
|
58 |
+
coords = torch.linspace(-1 + 1/(grid_size), 1 - 1/(grid_size), steps=grid_size)
|
59 |
+
else:
|
60 |
+
coords = torch.linspace(-1, 1, steps=grid_size)
|
61 |
+
i, j, k = torch.meshgrid(coords, coords, coords, indexing='ij')
|
62 |
+
coordinates = torch.stack((i, j, k), dim=-1).reshape(-1, 3)
|
63 |
+
return coordinates
|
64 |
+
|
65 |
+
class BasicBlock(nn.Module):
|
66 |
+
"""
|
67 |
+
Transformer block that is in its simplest form.
|
68 |
+
Designed for PF-LRM architecture.
|
69 |
+
"""
|
70 |
+
# Block contains a self-attention layer and an MLP
|
71 |
+
def __init__(self, inner_dim: int, num_heads: int, eps: float,
|
72 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
73 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
74 |
+
super().__init__()
|
75 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
76 |
+
self.self_attn = nn.MultiheadAttention(
|
77 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
78 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
79 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
80 |
+
self.mlp = nn.Sequential(
|
81 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
82 |
+
nn.GELU(),
|
83 |
+
nn.Dropout(mlp_drop),
|
84 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
85 |
+
nn.Dropout(mlp_drop),
|
86 |
+
)
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
# x: [N, L, D]
|
90 |
+
before_sa = self.norm1(x)
|
91 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
92 |
+
x = x + self.mlp(self.norm2(x))
|
93 |
+
return x
|
94 |
+
|
95 |
+
class ConditionBlock(nn.Module):
|
96 |
+
"""
|
97 |
+
Transformer block that takes in a cross-attention condition.
|
98 |
+
Designed for SparseLRM architecture.
|
99 |
+
"""
|
100 |
+
# Block contains a cross-attention layer, a self-attention layer, and an MLP
|
101 |
+
def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
|
102 |
+
attn_drop: float = 0., attn_bias: bool = False,
|
103 |
+
mlp_ratio: float = 4., mlp_drop: float = 0.):
|
104 |
+
super().__init__()
|
105 |
+
self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
|
106 |
+
self.cross_attn = nn.MultiheadAttention(
|
107 |
+
embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
|
108 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
109 |
+
self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
|
110 |
+
self.self_attn = nn.MultiheadAttention(
|
111 |
+
embed_dim=inner_dim, num_heads=num_heads,
|
112 |
+
dropout=attn_drop, bias=attn_bias, batch_first=True)
|
113 |
+
self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
|
114 |
+
self.mlp = nn.Sequential(
|
115 |
+
nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
|
116 |
+
nn.GELU(),
|
117 |
+
nn.Dropout(mlp_drop),
|
118 |
+
nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
|
119 |
+
nn.Dropout(mlp_drop),
|
120 |
+
)
|
121 |
+
|
122 |
+
def forward(self, x, cond):
|
123 |
+
# x: [N, L, D]
|
124 |
+
# cond: [N, L_cond, D_cond]
|
125 |
+
x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
|
126 |
+
before_sa = self.norm2(x)
|
127 |
+
x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
|
128 |
+
x = x + self.mlp(self.norm3(x))
|
129 |
+
return x
|
130 |
+
|
131 |
+
class TransformerDecoder(nn.Module):
|
132 |
+
def __init__(self, block_type: str,
|
133 |
+
num_layers: int, num_heads: int,
|
134 |
+
inner_dim: int, cond_dim: int = None,
|
135 |
+
eps: float = 1e-6):
|
136 |
+
super().__init__()
|
137 |
+
self.block_type = block_type
|
138 |
+
self.layers = nn.ModuleList([
|
139 |
+
self._block_fn(inner_dim, cond_dim)(
|
140 |
+
num_heads=num_heads,
|
141 |
+
eps=eps,
|
142 |
+
)
|
143 |
+
for _ in range(num_layers)
|
144 |
+
])
|
145 |
+
self.norm = nn.LayerNorm(inner_dim, eps=eps)
|
146 |
+
|
147 |
+
@property
|
148 |
+
def block_type(self):
|
149 |
+
return self._block_type
|
150 |
+
|
151 |
+
@block_type.setter
|
152 |
+
def block_type(self, block_type):
|
153 |
+
assert block_type in ['cond', 'basic'], \
|
154 |
+
f"Unsupported block type: {block_type}"
|
155 |
+
self._block_type = block_type
|
156 |
+
|
157 |
+
def _block_fn(self, inner_dim, cond_dim):
|
158 |
+
assert inner_dim is not None, f"inner_dim must always be specified"
|
159 |
+
if self.block_type == 'basic':
|
160 |
+
return partial(BasicBlock, inner_dim=inner_dim)
|
161 |
+
elif self.block_type == 'cond':
|
162 |
+
assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock"
|
163 |
+
return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim)
|
164 |
+
else:
|
165 |
+
raise ValueError(f"Unsupported block type during runtime: {self.block_type}")
|
166 |
+
|
167 |
+
|
168 |
+
def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor,):
|
169 |
+
if self.block_type == 'basic':
|
170 |
+
return layer(x)
|
171 |
+
elif self.block_type == 'cond':
|
172 |
+
return layer(x, cond)
|
173 |
+
else:
|
174 |
+
raise NotImplementedError
|
175 |
+
|
176 |
+
def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
|
177 |
+
# x: [N, L, D]
|
178 |
+
# cond: [N, L_cond, D_cond] or None
|
179 |
+
for layer in self.layers:
|
180 |
+
x = self.forward_layer(layer, x, cond)
|
181 |
+
x = self.norm(x)
|
182 |
+
return x
|
183 |
+
|
184 |
+
class Voxel2Triplane(nn.Module):
|
185 |
+
"""
|
186 |
+
Full model of the basic single-view large reconstruction model.
|
187 |
+
"""
|
188 |
+
def __init__(self, transformer_dim: int, transformer_layers: int, transformer_heads: int,
|
189 |
+
triplane_low_res: int, triplane_high_res: int, triplane_dim: int, voxel_feat_dim: int, normalize_vox_feat=False, voxel_dim=16):
|
190 |
+
super().__init__()
|
191 |
+
|
192 |
+
# attributes
|
193 |
+
self.triplane_low_res = triplane_low_res
|
194 |
+
self.triplane_high_res = triplane_high_res
|
195 |
+
self.triplane_dim = triplane_dim
|
196 |
+
self.voxel_feat_dim = voxel_feat_dim
|
197 |
+
|
198 |
+
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
199 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5)
|
200 |
+
self.transformer = TransformerDecoder(
|
201 |
+
block_type='cond',
|
202 |
+
num_layers=transformer_layers, num_heads=transformer_heads,
|
203 |
+
inner_dim=transformer_dim, cond_dim=voxel_feat_dim
|
204 |
+
)
|
205 |
+
self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=8, stride=8, padding=0)
|
206 |
+
|
207 |
+
self.normalize_vox_feat = normalize_vox_feat
|
208 |
+
if normalize_vox_feat:
|
209 |
+
self.vox_norm = nn.LayerNorm(voxel_feat_dim, eps=1e-6)
|
210 |
+
self.vox_pos_embed = nn.Parameter(torch.randn(1, voxel_dim * voxel_dim * voxel_dim, voxel_feat_dim) * (1. / voxel_feat_dim) ** 0.5)
|
211 |
+
|
212 |
+
def forward_transformer(self, voxel_feats):
|
213 |
+
N = voxel_feats.shape[0]
|
214 |
+
x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
|
215 |
+
if self.normalize_vox_feat:
|
216 |
+
vox_pos_embed = self.vox_pos_embed.repeat(N, 1, 1) # [N, L, D]
|
217 |
+
voxel_feats = self.vox_norm(voxel_feats + vox_pos_embed)
|
218 |
+
x = self.transformer(
|
219 |
+
x,
|
220 |
+
cond=voxel_feats
|
221 |
+
)
|
222 |
+
return x
|
223 |
+
|
224 |
+
def reshape_upsample(self, tokens):
|
225 |
+
N = tokens.shape[0]
|
226 |
+
H = W = self.triplane_low_res
|
227 |
+
x = tokens.view(N, 3, H, W, -1)
|
228 |
+
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
229 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
230 |
+
x = self.upsampler(x) # [3*N, D', H', W']
|
231 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
232 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
233 |
+
x = x.contiguous()
|
234 |
+
return x
|
235 |
+
|
236 |
+
def forward(self, voxel_feats):
|
237 |
+
N = voxel_feats.shape[0]
|
238 |
+
|
239 |
+
# encode image
|
240 |
+
assert voxel_feats.shape[-1] == self.voxel_feat_dim, \
|
241 |
+
f"Feature dimension mismatch: {voxel_feats.shape[-1]} vs {self.voxel_feat_dim}"
|
242 |
+
|
243 |
+
# transformer generating planes
|
244 |
+
tokens = self.forward_transformer(voxel_feats)
|
245 |
+
planes = self.reshape_upsample(tokens)
|
246 |
+
assert planes.shape[0] == N, "Batch size mismatch for planes"
|
247 |
+
assert planes.shape[1] == 3, "Planes should have 3 channels"
|
248 |
+
|
249 |
+
return planes
|
250 |
+
|
251 |
+
|
252 |
+
class TriplaneTransformer(nn.Module):
|
253 |
+
"""
|
254 |
+
Full model of the basic single-view large reconstruction model.
|
255 |
+
"""
|
256 |
+
def __init__(self, input_dim: int, transformer_dim: int, transformer_layers: int, transformer_heads: int,
|
257 |
+
triplane_low_res: int, triplane_high_res: int, triplane_dim: int):
|
258 |
+
super().__init__()
|
259 |
+
|
260 |
+
# attributes
|
261 |
+
self.triplane_low_res = triplane_low_res
|
262 |
+
self.triplane_high_res = triplane_high_res
|
263 |
+
self.triplane_dim = triplane_dim
|
264 |
+
|
265 |
+
# initialize pos_embed with 1/sqrt(dim) * N(0, 1)
|
266 |
+
self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5)
|
267 |
+
self.transformer = TransformerDecoder(
|
268 |
+
block_type='basic',
|
269 |
+
num_layers=transformer_layers, num_heads=transformer_heads,
|
270 |
+
inner_dim=transformer_dim,
|
271 |
+
)
|
272 |
+
|
273 |
+
self.downsampler = nn.Sequential(
|
274 |
+
nn.Conv2d(input_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
|
275 |
+
nn.ReLU(),
|
276 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # Reduces size from 128x128 to 64x64
|
277 |
+
|
278 |
+
nn.Conv2d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
|
279 |
+
nn.ReLU(),
|
280 |
+
nn.MaxPool2d(kernel_size=2, stride=2), # Reduces size from 64x64 to 32x32
|
281 |
+
)
|
282 |
+
|
283 |
+
self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=4, stride=4, padding=0)
|
284 |
+
|
285 |
+
self.mlp = nn.Sequential(
|
286 |
+
nn.Linear(input_dim, triplane_dim),
|
287 |
+
nn.ReLU(),
|
288 |
+
nn.Linear(triplane_dim, triplane_dim)
|
289 |
+
)
|
290 |
+
|
291 |
+
def forward_transformer(self, triplanes):
|
292 |
+
N = triplanes.shape[0]
|
293 |
+
tokens = torch.einsum('nidhw->nihwd', triplanes).reshape(N, self.pos_embed.shape[1], -1) # [N, L, D]
|
294 |
+
x = self.pos_embed.repeat(N, 1, 1) + tokens # [N, L, D]
|
295 |
+
x = self.transformer(x)
|
296 |
+
return x
|
297 |
+
|
298 |
+
def reshape_downsample(self, triplanes):
|
299 |
+
N = triplanes.shape[0]
|
300 |
+
H = W = self.triplane_high_res
|
301 |
+
x = triplanes.view(N, 3, -1, H, W)
|
302 |
+
x = torch.einsum('nidhw->indhw', x) # [3, N, D, H, W]
|
303 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
304 |
+
x = self.downsampler(x) # [3*N, D', H', W']
|
305 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
306 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
307 |
+
x = x.contiguous()
|
308 |
+
return x
|
309 |
+
|
310 |
+
def reshape_upsample(self, tokens):
|
311 |
+
N = tokens.shape[0]
|
312 |
+
H = W = self.triplane_low_res
|
313 |
+
x = tokens.view(N, 3, H, W, -1)
|
314 |
+
x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
|
315 |
+
x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
|
316 |
+
x = self.upsampler(x) # [3*N, D', H', W']
|
317 |
+
x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
|
318 |
+
x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
|
319 |
+
x = x.contiguous()
|
320 |
+
return x
|
321 |
+
|
322 |
+
def forward(self, triplanes):
|
323 |
+
downsampled_triplanes = self.reshape_downsample(triplanes)
|
324 |
+
tokens = self.forward_transformer(downsampled_triplanes)
|
325 |
+
residual = self.reshape_upsample(tokens)
|
326 |
+
|
327 |
+
triplanes = triplanes.permute(0, 1, 3, 4, 2).contiguous()
|
328 |
+
triplanes = self.mlp(triplanes)
|
329 |
+
triplanes = triplanes.permute(0, 1, 4, 2, 3).contiguous()
|
330 |
+
planes = triplanes + residual
|
331 |
+
return planes
|
modules/PartField/partfield/model_trainer_pvcnn_only_demo.py
ADDED
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lightning.pytorch as pl
|
3 |
+
from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from partfield.model.UNet.model import ResidualUNet3D
|
6 |
+
from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
|
7 |
+
from partfield.model.model_utils import VanillaMLP
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.nn as nn
|
10 |
+
import os
|
11 |
+
import trimesh
|
12 |
+
import skimage
|
13 |
+
import numpy as np
|
14 |
+
import h5py
|
15 |
+
import torch.distributed as dist
|
16 |
+
from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
|
17 |
+
import json
|
18 |
+
import gc
|
19 |
+
import time
|
20 |
+
from plyfile import PlyData, PlyElement
|
21 |
+
|
22 |
+
|
23 |
+
class Model(pl.LightningModule):
|
24 |
+
def __init__(self, cfg):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.save_hyperparameters()
|
28 |
+
self.cfg = cfg
|
29 |
+
self.automatic_optimization = False
|
30 |
+
self.triplane_resolution = cfg.triplane_resolution
|
31 |
+
self.triplane_channels_low = cfg.triplane_channels_low
|
32 |
+
self.triplane_transformer = TriplaneTransformer(
|
33 |
+
input_dim=cfg.triplane_channels_low * 2,
|
34 |
+
transformer_dim=1024,
|
35 |
+
transformer_layers=6,
|
36 |
+
transformer_heads=8,
|
37 |
+
triplane_low_res=32,
|
38 |
+
triplane_high_res=128,
|
39 |
+
triplane_dim=cfg.triplane_channels_high,
|
40 |
+
)
|
41 |
+
self.sdf_decoder = VanillaMLP(input_dim=64,
|
42 |
+
output_dim=1,
|
43 |
+
out_activation="tanh",
|
44 |
+
n_neurons=64, #64
|
45 |
+
n_hidden_layers=6) #6
|
46 |
+
self.use_pvcnn = cfg.use_pvcnnonly
|
47 |
+
self.use_2d_feat = cfg.use_2d_feat
|
48 |
+
if self.use_pvcnn:
|
49 |
+
self.pvcnn = TriPlanePC2Encoder(
|
50 |
+
cfg.pvcnn,
|
51 |
+
device="cuda",
|
52 |
+
shape_min=-1,
|
53 |
+
shape_length=2,
|
54 |
+
use_2d_feat=self.use_2d_feat) #.cuda()
|
55 |
+
self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
|
56 |
+
self.grid_coord = get_grid_coord(256)
|
57 |
+
self.mse_loss = torch.nn.MSELoss()
|
58 |
+
self.l1_loss = torch.nn.L1Loss(reduction='none')
|
59 |
+
|
60 |
+
if cfg.regress_2d_feat:
|
61 |
+
self.feat_decoder = VanillaMLP(input_dim=64,
|
62 |
+
output_dim=192,
|
63 |
+
out_activation="GELU",
|
64 |
+
n_neurons=64, #64
|
65 |
+
n_hidden_layers=6) #6
|
66 |
+
|
67 |
+
def predict_dataloader(self):
|
68 |
+
if self.cfg.remesh_demo:
|
69 |
+
dataset = Demo_Remesh_Dataset(self.cfg)
|
70 |
+
elif self.cfg.correspondence_demo:
|
71 |
+
dataset = Correspondence_Demo_Dataset(self.cfg)
|
72 |
+
else:
|
73 |
+
dataset = Demo_Dataset(self.cfg)
|
74 |
+
|
75 |
+
dataloader = DataLoader(dataset,
|
76 |
+
num_workers=self.cfg.dataset.val_num_workers,
|
77 |
+
batch_size=self.cfg.dataset.val_batch_size,
|
78 |
+
shuffle=False,
|
79 |
+
pin_memory=True,
|
80 |
+
drop_last=False)
|
81 |
+
|
82 |
+
return dataloader
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def predict_step(self, batch, batch_idx):
|
87 |
+
save_dir = f"{self.cfg.result_name}"
|
88 |
+
os.makedirs(save_dir, exist_ok=True)
|
89 |
+
|
90 |
+
uid = batch['uid'][0]
|
91 |
+
view_id = 0
|
92 |
+
starttime = time.time()
|
93 |
+
|
94 |
+
if uid == "car" or uid == "complex_car":
|
95 |
+
# if uid == "complex_car":
|
96 |
+
print("Skipping this for now.")
|
97 |
+
print(uid)
|
98 |
+
return
|
99 |
+
|
100 |
+
### Skip if model already processed
|
101 |
+
if os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}.npy') or os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy'):
|
102 |
+
print("Already processed "+uid)
|
103 |
+
return
|
104 |
+
|
105 |
+
N = batch['pc'].shape[0]
|
106 |
+
assert N == 1
|
107 |
+
|
108 |
+
if self.use_2d_feat:
|
109 |
+
print("ERROR. Dataloader not implemented with input 2d feat.")
|
110 |
+
exit()
|
111 |
+
else:
|
112 |
+
pc_feat = self.pvcnn(batch['pc'], batch['pc'])
|
113 |
+
|
114 |
+
planes = pc_feat
|
115 |
+
planes = self.triplane_transformer(planes)
|
116 |
+
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
|
117 |
+
|
118 |
+
if self.cfg.is_pc:
|
119 |
+
tensor_vertices = batch['pc'].reshape(1, -1, 3).cuda().to(torch.float16)
|
120 |
+
point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
|
121 |
+
point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
|
122 |
+
|
123 |
+
np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat)
|
124 |
+
print(f"Exported part_feat_{uid}_{view_id}.npy")
|
125 |
+
|
126 |
+
###########
|
127 |
+
from sklearn.decomposition import PCA
|
128 |
+
data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
|
129 |
+
|
130 |
+
pca = PCA(n_components=3)
|
131 |
+
|
132 |
+
data_reduced = pca.fit_transform(data_scaled)
|
133 |
+
data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min())
|
134 |
+
colors_255 = (data_reduced * 255).astype(np.uint8)
|
135 |
+
|
136 |
+
points = batch['pc'].squeeze().detach().cpu().numpy()
|
137 |
+
|
138 |
+
if colors_255 is None:
|
139 |
+
colors_255 = np.full_like(points, 255) # Default to white color (255,255,255)
|
140 |
+
else:
|
141 |
+
assert colors_255.shape == points.shape, "Colors must have the same shape as points"
|
142 |
+
|
143 |
+
# Convert to structured array for PLY format
|
144 |
+
vertex_data = np.array(
|
145 |
+
[(*point, *color) for point, color in zip(points, colors_255)],
|
146 |
+
dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")]
|
147 |
+
)
|
148 |
+
|
149 |
+
# Create PLY element
|
150 |
+
el = PlyElement.describe(vertex_data, "vertex")
|
151 |
+
# Write to file
|
152 |
+
filename = f'{save_dir}/feat_pca_{uid}_{view_id}.ply'
|
153 |
+
PlyData([el], text=True).write(filename)
|
154 |
+
print(f"Saved PLY file: {filename}")
|
155 |
+
############
|
156 |
+
|
157 |
+
else:
|
158 |
+
use_cuda_version = True
|
159 |
+
if use_cuda_version:
|
160 |
+
|
161 |
+
def sample_points(vertices, faces, n_point_per_face):
|
162 |
+
# Generate random barycentric coordinates
|
163 |
+
# borrowed from Kaolin https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/mesh/trianglemesh.py#L43
|
164 |
+
n_f = faces.shape[0]
|
165 |
+
u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1),
|
166 |
+
device=vertices.device,
|
167 |
+
dtype=vertices.dtype))
|
168 |
+
v = torch.rand((n_f, n_point_per_face, 1),
|
169 |
+
device=vertices.device,
|
170 |
+
dtype=vertices.dtype)
|
171 |
+
w0 = 1 - u
|
172 |
+
w1 = u * (1 - v)
|
173 |
+
w2 = u * v
|
174 |
+
|
175 |
+
face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1))
|
176 |
+
face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1))
|
177 |
+
face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1))
|
178 |
+
points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1)
|
179 |
+
return points
|
180 |
+
|
181 |
+
def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face):
|
182 |
+
n_sample_each = self.cfg.n_sample_each # we iterate over this to avoid OOM
|
183 |
+
n_v = tensor_vertices.shape[1]
|
184 |
+
n_sample = n_v // n_sample_each + 1
|
185 |
+
all_sample = []
|
186 |
+
for i_sample in range(n_sample):
|
187 |
+
sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,])
|
188 |
+
assert sampled_feature.shape[1] % n_point_per_face == 0
|
189 |
+
sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1])
|
190 |
+
sampled_feature = torch.mean(sampled_feature, axis=-2)
|
191 |
+
all_sample.append(sampled_feature)
|
192 |
+
return torch.cat(all_sample, dim=1)
|
193 |
+
|
194 |
+
if self.cfg.vertex_feature:
|
195 |
+
tensor_vertices = batch['vertices'][0].reshape(1, -1, 3).to(torch.float32)
|
196 |
+
point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, 1)
|
197 |
+
else:
|
198 |
+
n_point_per_face = self.cfg.n_point_per_face
|
199 |
+
tensor_vertices = sample_points(batch['vertices'][0], batch['faces'][0], n_point_per_face)
|
200 |
+
tensor_vertices = tensor_vertices.reshape(1, -1, 3).to(torch.float32)
|
201 |
+
point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face) # N, M, C
|
202 |
+
|
203 |
+
#### Take mean feature in the triangle
|
204 |
+
print("Time elapsed for feature prediction: " + str(time.time() - starttime))
|
205 |
+
point_feat = point_feat.reshape(-1, 448).cpu().numpy()
|
206 |
+
np.save(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy', point_feat)
|
207 |
+
print(f"Exported part_feat_{uid}_{view_id}.npy")
|
208 |
+
|
209 |
+
###########
|
210 |
+
from sklearn.decomposition import PCA
|
211 |
+
data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
|
212 |
+
|
213 |
+
pca = PCA(n_components=3)
|
214 |
+
|
215 |
+
data_reduced = pca.fit_transform(data_scaled)
|
216 |
+
data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min())
|
217 |
+
colors_255 = (data_reduced * 255).astype(np.uint8)
|
218 |
+
V = batch['vertices'][0].cpu().numpy()
|
219 |
+
F = batch['faces'][0].cpu().numpy()
|
220 |
+
if self.cfg.vertex_feature:
|
221 |
+
colored_mesh = trimesh.Trimesh(vertices=V, faces=F, vertex_colors=colors_255, process=False)
|
222 |
+
else:
|
223 |
+
colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False)
|
224 |
+
colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply')
|
225 |
+
############
|
226 |
+
torch.cuda.empty_cache()
|
227 |
+
|
228 |
+
else:
|
229 |
+
### Mesh input (obj file)
|
230 |
+
V = batch['vertices'][0].cpu().numpy()
|
231 |
+
F = batch['faces'][0].cpu().numpy()
|
232 |
+
|
233 |
+
##### Loop through faces #####
|
234 |
+
num_samples_per_face = self.cfg.n_point_per_face
|
235 |
+
|
236 |
+
all_point_feats = []
|
237 |
+
for face in F:
|
238 |
+
# Get the vertices of the current face
|
239 |
+
v0, v1, v2 = V[face]
|
240 |
+
|
241 |
+
# Generate random barycentric coordinates
|
242 |
+
u = np.random.rand(num_samples_per_face, 1)
|
243 |
+
v = np.random.rand(num_samples_per_face, 1)
|
244 |
+
is_prob = (u+v) >1
|
245 |
+
u[is_prob] = 1 - u[is_prob]
|
246 |
+
v[is_prob] = 1 - v[is_prob]
|
247 |
+
w = 1 - u - v
|
248 |
+
|
249 |
+
# Calculate points in Cartesian coordinates
|
250 |
+
points = u * v0 + v * v1 + w * v2
|
251 |
+
|
252 |
+
tensor_vertices = torch.from_numpy(points.copy()).reshape(1, -1, 3).cuda().to(torch.float32)
|
253 |
+
point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
|
254 |
+
|
255 |
+
#### Take mean feature in the triangle
|
256 |
+
point_feat = torch.mean(point_feat, axis=1).cpu().detach().numpy()
|
257 |
+
all_point_feats.append(point_feat)
|
258 |
+
##############################
|
259 |
+
|
260 |
+
all_point_feats = np.array(all_point_feats).reshape(-1, 448)
|
261 |
+
|
262 |
+
point_feat = all_point_feats
|
263 |
+
|
264 |
+
np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat)
|
265 |
+
print(f"Exported part_feat_{uid}_{view_id}.npy")
|
266 |
+
|
267 |
+
###########
|
268 |
+
from sklearn.decomposition import PCA
|
269 |
+
data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
|
270 |
+
|
271 |
+
pca = PCA(n_components=3)
|
272 |
+
|
273 |
+
data_reduced = pca.fit_transform(data_scaled)
|
274 |
+
data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min())
|
275 |
+
colors_255 = (data_reduced * 255).astype(np.uint8)
|
276 |
+
|
277 |
+
colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False)
|
278 |
+
colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply')
|
279 |
+
############
|
280 |
+
|
281 |
+
print("Time elapsed: " + str(time.time()-starttime))
|
282 |
+
|
283 |
+
return
|
modules/PartField/partfield/partfield_encoder.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import lightning.pytorch as pl
|
3 |
+
# from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset
|
4 |
+
from torch.utils.data import DataLoader
|
5 |
+
from partfield.model.UNet.model import ResidualUNet3D
|
6 |
+
from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
|
7 |
+
from partfield.model.model_utils import VanillaMLP
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.nn as nn
|
10 |
+
import os
|
11 |
+
import trimesh
|
12 |
+
import skimage
|
13 |
+
import numpy as np
|
14 |
+
import h5py
|
15 |
+
import torch.distributed as dist
|
16 |
+
from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
|
17 |
+
import json
|
18 |
+
import gc
|
19 |
+
import time
|
20 |
+
from plyfile import PlyData, PlyElement
|
21 |
+
|
22 |
+
|
23 |
+
class Model(pl.LightningModule):
|
24 |
+
def __init__(self, cfg):
|
25 |
+
super().__init__()
|
26 |
+
|
27 |
+
self.save_hyperparameters()
|
28 |
+
self.cfg = cfg
|
29 |
+
self.automatic_optimization = False
|
30 |
+
self.triplane_resolution = cfg.triplane_resolution
|
31 |
+
self.triplane_channels_low = cfg.triplane_channels_low
|
32 |
+
self.triplane_transformer = TriplaneTransformer(
|
33 |
+
input_dim=cfg.triplane_channels_low * 2,
|
34 |
+
transformer_dim=1024,
|
35 |
+
transformer_layers=6,
|
36 |
+
transformer_heads=8,
|
37 |
+
triplane_low_res=32,
|
38 |
+
triplane_high_res=128,
|
39 |
+
triplane_dim=cfg.triplane_channels_high,
|
40 |
+
)
|
41 |
+
self.sdf_decoder = VanillaMLP(input_dim=64,
|
42 |
+
output_dim=1,
|
43 |
+
out_activation="tanh",
|
44 |
+
n_neurons=64, #64
|
45 |
+
n_hidden_layers=6) #6
|
46 |
+
self.use_pvcnn = cfg.use_pvcnnonly
|
47 |
+
self.use_2d_feat = cfg.use_2d_feat
|
48 |
+
if self.use_pvcnn:
|
49 |
+
self.pvcnn = TriPlanePC2Encoder(
|
50 |
+
cfg.pvcnn,
|
51 |
+
device="cuda",
|
52 |
+
shape_min=-1,
|
53 |
+
shape_length=2,
|
54 |
+
use_2d_feat=self.use_2d_feat) #.cuda()
|
55 |
+
self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
|
56 |
+
self.grid_coord = get_grid_coord(256)
|
57 |
+
self.mse_loss = torch.nn.MSELoss()
|
58 |
+
self.l1_loss = torch.nn.L1Loss(reduction='none')
|
59 |
+
|
60 |
+
if cfg.regress_2d_feat:
|
61 |
+
self.feat_decoder = VanillaMLP(input_dim=64,
|
62 |
+
output_dim=192,
|
63 |
+
out_activation="GELU",
|
64 |
+
n_neurons=64, #64
|
65 |
+
n_hidden_layers=6) #6
|
66 |
+
|
67 |
+
# def predict_dataloader(self):
|
68 |
+
# if self.cfg.remesh_demo:
|
69 |
+
# dataset = Demo_Remesh_Dataset(self.cfg)
|
70 |
+
# elif self.cfg.correspondence_demo:
|
71 |
+
# dataset = Correspondence_Demo_Dataset(self.cfg)
|
72 |
+
# else:
|
73 |
+
# dataset = Demo_Dataset(self.cfg)
|
74 |
+
|
75 |
+
# dataloader = DataLoader(dataset,
|
76 |
+
# num_workers=self.cfg.dataset.val_num_workers,
|
77 |
+
# batch_size=self.cfg.dataset.val_batch_size,
|
78 |
+
# shuffle=False,
|
79 |
+
# pin_memory=True,
|
80 |
+
# drop_last=False)
|
81 |
+
|
82 |
+
# return dataloader
|
83 |
+
|
84 |
+
|
85 |
+
@torch.no_grad()
|
86 |
+
def encode(self, points):
|
87 |
+
|
88 |
+
N = points.shape[0]
|
89 |
+
# assert N == 1
|
90 |
+
pcd = points[..., :3]
|
91 |
+
|
92 |
+
pc_feat = self.pvcnn(pcd, pcd)
|
93 |
+
|
94 |
+
planes = pc_feat
|
95 |
+
planes = self.triplane_transformer(planes)
|
96 |
+
sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
|
97 |
+
|
98 |
+
tensor_vertices = pcd.reshape(N, -1, 3).cuda().to(pcd.dtype)
|
99 |
+
point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
|
100 |
+
# point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
|
101 |
+
point_feat = point_feat.reshape(N, -1, 448)
|
102 |
+
|
103 |
+
return point_feat
|
modules/PartField/partfield/utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trimesh
|
2 |
+
|
3 |
+
def load_mesh_util(input_fname):
|
4 |
+
mesh = trimesh.load(input_fname, force='mesh', process=False)
|
5 |
+
return mesh
|
modules/bbox_gen/config.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from omegaconf import OmegaConf, DictConfig
|
3 |
+
from dataclasses import dataclass, field
|
4 |
+
from typing import Any, Dict, List, Optional, Union
|
5 |
+
from datetime import datetime
|
6 |
+
|
7 |
+
@dataclass
|
8 |
+
class ExperimentConfig:
|
9 |
+
name: str = "default"
|
10 |
+
tag: str = ""
|
11 |
+
use_timestamp: bool = False
|
12 |
+
timestamp: Optional[str] = None
|
13 |
+
exp_root_dir: str = "outputs"
|
14 |
+
|
15 |
+
### these shouldn't be set manually
|
16 |
+
exp_dir: str = "outputs/default"
|
17 |
+
trial_name: str = "exp"
|
18 |
+
trial_dir: str = "outputs/default/exp"
|
19 |
+
###
|
20 |
+
|
21 |
+
resume: Optional[str] = None
|
22 |
+
ckpt_path: Optional[str] = None
|
23 |
+
|
24 |
+
data: dict = field(default_factory=dict)
|
25 |
+
model_pl: dict = field(default_factory=dict)
|
26 |
+
|
27 |
+
trainer: dict = field(default_factory=dict)
|
28 |
+
checkpoint: dict = field(default_factory=dict)
|
29 |
+
checkpoint_epoch: Optional[dict] = None
|
30 |
+
wandb: dict = field(default_factory=dict)
|
31 |
+
|
32 |
+
|
33 |
+
def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any:
|
34 |
+
if from_string:
|
35 |
+
yaml_confs = [OmegaConf.create(s) for s in yamls]
|
36 |
+
else:
|
37 |
+
yaml_confs = [OmegaConf.load(f) for f in yamls]
|
38 |
+
cli_conf = OmegaConf.from_cli(cli_args)
|
39 |
+
cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
|
40 |
+
OmegaConf.resolve(cfg)
|
41 |
+
assert isinstance(cfg, DictConfig)
|
42 |
+
scfg = parse_structured(ExperimentConfig, cfg)
|
43 |
+
return scfg
|
44 |
+
|
45 |
+
|
46 |
+
def config_to_primitive(config, resolve: bool = True) -> Any:
|
47 |
+
return OmegaConf.to_container(config, resolve=resolve)
|
48 |
+
|
49 |
+
|
50 |
+
def dump_config(path: str, config) -> None:
|
51 |
+
with open(path, "w") as fp:
|
52 |
+
OmegaConf.save(config=config, f=fp)
|
53 |
+
|
54 |
+
|
55 |
+
def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
|
56 |
+
scfg = OmegaConf.structured(fields(**cfg))
|
57 |
+
return scfg
|
modules/bbox_gen/models/autogressive_bbox_gen.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
import os
|
3 |
+
import sys
|
4 |
+
import torch
|
5 |
+
import trimesh
|
6 |
+
from torch import nn
|
7 |
+
from transformers import AutoModelForCausalLM
|
8 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
9 |
+
from einops import rearrange
|
10 |
+
|
11 |
+
from modules.bbox_gen.models.image_encoder import DINOv2ImageEncoder
|
12 |
+
from modules.bbox_gen.config import parse_structured
|
13 |
+
from modules.bbox_gen.models.bboxopt import BBoxOPT, BBoxOPTConfig
|
14 |
+
from modules.bbox_gen.utils.bbox_tokenizer import BoundsTokenizerDiag
|
15 |
+
from modules.bbox_gen.models.bbox_gen_models import GroupEmbedding, MultiModalProjector, MeshDecodeLogitsProcessor, SparseStructureEncoder
|
16 |
+
|
17 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
18 |
+
modules_dir = os.path.dirname(os.path.dirname(current_dir))
|
19 |
+
partfield_dir = os.path.join(modules_dir, 'PartField')
|
20 |
+
if partfield_dir not in sys.path:
|
21 |
+
sys.path.insert(0, partfield_dir)
|
22 |
+
import importlib.util
|
23 |
+
from partfield.config import default_argument_parser, setup
|
24 |
+
|
25 |
+
|
26 |
+
class BboxGen(nn.Module):
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class Config:
|
30 |
+
# encoder config
|
31 |
+
encoder_dim_feat: int = 3
|
32 |
+
encoder_dim: int = 64
|
33 |
+
encoder_heads: int = 4
|
34 |
+
encoder_token_num: int = 256
|
35 |
+
encoder_qkv_bias: bool = False
|
36 |
+
encoder_use_ln_post: bool = True
|
37 |
+
encoder_use_checkpoint: bool = False
|
38 |
+
encoder_num_embed_freqs: int = 8
|
39 |
+
encoder_embed_include_pi: bool = False
|
40 |
+
encoder_init_scale: float = 0.25
|
41 |
+
encoder_random_fps: bool = True
|
42 |
+
encoder_learnable_query: bool = False
|
43 |
+
encoder_layers: int = 4
|
44 |
+
group_embedding_dim: int = 64
|
45 |
+
|
46 |
+
# decoder config
|
47 |
+
vocab_size: int = 518
|
48 |
+
decoder_hidden_size: int = 1536
|
49 |
+
decoder_num_hidden_layers: int = 24
|
50 |
+
decoder_ffn_dim: int = 6144
|
51 |
+
decoder_heads: int = 16
|
52 |
+
decoder_use_flash_attention: bool = True
|
53 |
+
decoder_gradient_checkpointing: bool = True
|
54 |
+
|
55 |
+
# data config
|
56 |
+
bins: int = 64
|
57 |
+
BOS_id: int = 64
|
58 |
+
EOS_id: int = 65
|
59 |
+
PAD_id: int = 66
|
60 |
+
max_length: int = 2187 # bos + 50x2x3 + 1374 + 512
|
61 |
+
voxel_token_length: int = 1886
|
62 |
+
voxel_token_placeholder: int = -1
|
63 |
+
|
64 |
+
# tokenizer config
|
65 |
+
max_group_size: int = 50
|
66 |
+
|
67 |
+
# voxel encoder
|
68 |
+
partfield_encoder_path: str = ""
|
69 |
+
|
70 |
+
cfg: Config
|
71 |
+
|
72 |
+
def __init__(self, cfg):
|
73 |
+
super().__init__()
|
74 |
+
self.cfg = parse_structured(self.Config, cfg)
|
75 |
+
|
76 |
+
self.image_encoder = DINOv2ImageEncoder(
|
77 |
+
model_name="facebook/dinov2-with-registers-large",
|
78 |
+
)
|
79 |
+
|
80 |
+
self.image_projector = MultiModalProjector(
|
81 |
+
in_features=(1024 + self.cfg.group_embedding_dim),
|
82 |
+
out_features=self.cfg.decoder_hidden_size,
|
83 |
+
)
|
84 |
+
|
85 |
+
self.group_embedding = GroupEmbedding(
|
86 |
+
max_group_size=self.cfg.max_group_size,
|
87 |
+
hidden_size=self.cfg.group_embedding_dim,
|
88 |
+
)
|
89 |
+
|
90 |
+
self.decoder_config = BBoxOPTConfig(
|
91 |
+
vocab_size=self.cfg.vocab_size,
|
92 |
+
hidden_size=self.cfg.decoder_hidden_size,
|
93 |
+
num_hidden_layers=self.cfg.decoder_num_hidden_layers,
|
94 |
+
ffn_dim=self.cfg.decoder_ffn_dim,
|
95 |
+
max_position_embeddings=self.cfg.max_length,
|
96 |
+
num_attention_heads=self.cfg.decoder_heads,
|
97 |
+
pad_token_id=self.cfg.PAD_id,
|
98 |
+
bos_token_id=self.cfg.BOS_id,
|
99 |
+
eos_token_id=self.cfg.EOS_id,
|
100 |
+
use_cache=True,
|
101 |
+
init_std=0.02,
|
102 |
+
)
|
103 |
+
|
104 |
+
if self.cfg.decoder_use_flash_attention:
|
105 |
+
self.decoder: BBoxOPT = AutoModelForCausalLM.from_config(
|
106 |
+
self.decoder_config,
|
107 |
+
torch_dtype=torch.bfloat16,
|
108 |
+
attn_implementation="flash_attention_2"
|
109 |
+
)
|
110 |
+
else:
|
111 |
+
self.decoder: BBoxOPT = AutoModelForCausalLM.from_config(
|
112 |
+
self.decoder_config,
|
113 |
+
)
|
114 |
+
if self.cfg.decoder_gradient_checkpointing:
|
115 |
+
self.decoder.gradient_checkpointing_enable()
|
116 |
+
|
117 |
+
self.logits_processor = LogitsProcessorList()
|
118 |
+
|
119 |
+
self.logits_processor.append(MeshDecodeLogitsProcessor(
|
120 |
+
bins=self.cfg.bins,
|
121 |
+
BOS_id=self.cfg.BOS_id,
|
122 |
+
EOS_id=self.cfg.EOS_id,
|
123 |
+
PAD_id=self.cfg.PAD_id,
|
124 |
+
vertices_num=2,
|
125 |
+
))
|
126 |
+
self.tokenizer = BoundsTokenizerDiag(
|
127 |
+
bins=self.cfg.bins,
|
128 |
+
BOS_id=self.cfg.BOS_id,
|
129 |
+
EOS_id=self.cfg.EOS_id,
|
130 |
+
PAD_id=self.cfg.PAD_id,
|
131 |
+
)
|
132 |
+
|
133 |
+
self._load_partfield_encoder()
|
134 |
+
|
135 |
+
self.partfield_voxel_encoder = SparseStructureEncoder(
|
136 |
+
in_channels=451,
|
137 |
+
channels=[448, 448, 448, 1024],
|
138 |
+
latent_channels=448,
|
139 |
+
num_res_blocks=1,
|
140 |
+
num_res_blocks_middle=1,
|
141 |
+
norm_type="layer",
|
142 |
+
)
|
143 |
+
|
144 |
+
|
145 |
+
def _load_partfield_encoder(self):
|
146 |
+
# Load PartField encoder
|
147 |
+
model_spec = importlib.util.spec_from_file_location(
|
148 |
+
"partfield.partfield_encoder",
|
149 |
+
os.path.join(partfield_dir, "partfield", "partfield_encoder.py")
|
150 |
+
)
|
151 |
+
model_module = importlib.util.module_from_spec(model_spec)
|
152 |
+
model_spec.loader.exec_module(model_module)
|
153 |
+
Model = model_module.Model
|
154 |
+
parser = default_argument_parser()
|
155 |
+
args = []
|
156 |
+
args.extend(["-c", os.path.join(partfield_dir, "configs/final/demo.yaml")])
|
157 |
+
args.append("--opts")
|
158 |
+
args.extend(["continue_ckpt", self.cfg.partfield_encoder_path])
|
159 |
+
parsed_args = parser.parse_args(args)
|
160 |
+
cfg = setup(parsed_args, freeze=False)
|
161 |
+
self.partfield_encoder = Model(cfg)
|
162 |
+
self.partfield_encoder.eval()
|
163 |
+
weights = torch.load(self.cfg.partfield_encoder_path)["state_dict"]
|
164 |
+
self.partfield_encoder.load_state_dict(weights)
|
165 |
+
for param in self.partfield_encoder.parameters():
|
166 |
+
param.requires_grad = False
|
167 |
+
print("PartField encoder loaded")
|
168 |
+
|
169 |
+
def _prepare_lm_inputs(self, voxel_token, input_ids):
|
170 |
+
inputs_embeds = torch.zeros(input_ids.shape[0], input_ids.shape[1], self.cfg.decoder_hidden_size, device=input_ids.device, dtype=voxel_token.dtype)
|
171 |
+
voxel_token_mask = (input_ids == self.cfg.voxel_token_placeholder)
|
172 |
+
inputs_embeds[voxel_token_mask] = voxel_token.view(-1, self.cfg.decoder_hidden_size)
|
173 |
+
|
174 |
+
inputs_embeds[~voxel_token_mask] = self.decoder.get_input_embeddings()(input_ids[~voxel_token_mask]).to(dtype=inputs_embeds.dtype)
|
175 |
+
|
176 |
+
attention_mask = (input_ids != self.cfg.PAD_id)
|
177 |
+
return inputs_embeds, attention_mask.long()
|
178 |
+
|
179 |
+
def forward(self, batch):
|
180 |
+
|
181 |
+
image_latents = self.image_encoder(batch['images'])
|
182 |
+
masks = batch['masks']
|
183 |
+
masks_emb = self.group_embedding(masks)
|
184 |
+
masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') # B x Q x C
|
185 |
+
group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype)
|
186 |
+
group_emb[:, :masks_emb.shape[1], :] = masks_emb
|
187 |
+
image_latents = torch.cat([image_latents, group_emb], dim=-1)
|
188 |
+
image_latents = self.image_projector(image_latents)
|
189 |
+
|
190 |
+
points = batch['points'][..., :3]
|
191 |
+
rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype)
|
192 |
+
rot_points = torch.matmul(points, rot_matrix)
|
193 |
+
rot_points = rot_points * (2 * 0.9) # from (-0.5, 0.5) to (-1, 1)
|
194 |
+
|
195 |
+
partfield_feat = self.partfield_encoder.encode(rot_points)
|
196 |
+
feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype)
|
197 |
+
whole_voxel_index = batch['whole_voxel_index'] # (b, m, 3)
|
198 |
+
|
199 |
+
batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1]
|
200 |
+
batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) # (b, m)
|
201 |
+
batch_flat = batch_indices.flatten() # (b*m,)
|
202 |
+
x_flat = whole_voxel_index[..., 0].flatten() # (b*m,)
|
203 |
+
y_flat = whole_voxel_index[..., 1].flatten() # (b*m,)
|
204 |
+
z_flat = whole_voxel_index[..., 2].flatten() # (b*m,)
|
205 |
+
partfield_feat_flat = partfield_feat.reshape(-1, 448) # (b*m, 448)
|
206 |
+
feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat
|
207 |
+
|
208 |
+
xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype)
|
209 |
+
xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3)
|
210 |
+
feat_volume = torch.cat([feat_volume, xyz_volume], dim=1)
|
211 |
+
|
212 |
+
feat_volume = self.partfield_voxel_encoder(feat_volume)
|
213 |
+
feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c')
|
214 |
+
|
215 |
+
voxel_token = torch.cat([image_latents, feat_volume], dim=1) # B x N x D
|
216 |
+
|
217 |
+
input_ids = batch['input_ids']
|
218 |
+
inputs_embeds, attention_mask = self._prepare_lm_inputs(voxel_token, input_ids)
|
219 |
+
output = self.decoder(
|
220 |
+
attention_mask=attention_mask,
|
221 |
+
inputs_embeds=inputs_embeds,
|
222 |
+
return_dict=True,
|
223 |
+
)
|
224 |
+
return {
|
225 |
+
"logits": output.logits,
|
226 |
+
}
|
227 |
+
|
228 |
+
def gen_mesh_from_bounds(self, bounds, random_color):
|
229 |
+
bboxes = []
|
230 |
+
for j in range(bounds.shape[0]):
|
231 |
+
bbox = trimesh.primitives.Box(bounds=bounds[j])
|
232 |
+
color = random_color[j]
|
233 |
+
bbox.visual.vertex_colors = color
|
234 |
+
bboxes.append(bbox)
|
235 |
+
mesh = trimesh.Scene(bboxes)
|
236 |
+
return mesh
|
237 |
+
|
238 |
+
def generate(self, batch):
|
239 |
+
|
240 |
+
image_latents = self.image_encoder(batch['images'])
|
241 |
+
masks = batch['masks']
|
242 |
+
masks_emb = self.group_embedding(masks)
|
243 |
+
masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') # B x Q x C
|
244 |
+
group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype)
|
245 |
+
group_emb[:, :masks_emb.shape[1], :] = masks_emb
|
246 |
+
image_latents = torch.cat([image_latents, group_emb], dim=-1)
|
247 |
+
image_latents = self.image_projector(image_latents)
|
248 |
+
|
249 |
+
points = batch['points'][..., :3]
|
250 |
+
rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype)
|
251 |
+
rot_points = torch.matmul(points, rot_matrix)
|
252 |
+
rot_points = rot_points * (2 * 0.9) # from (-0.5, 0.5) to (-1, 1)
|
253 |
+
|
254 |
+
partfield_feat = self.partfield_encoder.encode(rot_points)
|
255 |
+
feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype)
|
256 |
+
whole_voxel_index = batch['whole_voxel_index'] # (b, m, 3)
|
257 |
+
|
258 |
+
batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1]
|
259 |
+
batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) # (b, m)
|
260 |
+
batch_flat = batch_indices.flatten() # (b*m,)
|
261 |
+
x_flat = whole_voxel_index[..., 0].flatten() # (b*m,)
|
262 |
+
y_flat = whole_voxel_index[..., 1].flatten() # (b*m,)
|
263 |
+
z_flat = whole_voxel_index[..., 2].flatten() # (b*m,)
|
264 |
+
partfield_feat_flat = partfield_feat.reshape(-1, 448) # (b*m, 448)
|
265 |
+
feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat
|
266 |
+
|
267 |
+
xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype)
|
268 |
+
xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3)
|
269 |
+
feat_volume = torch.cat([feat_volume, xyz_volume], dim=1)
|
270 |
+
|
271 |
+
feat_volume = self.partfield_voxel_encoder(feat_volume)
|
272 |
+
feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c')
|
273 |
+
|
274 |
+
voxel_token = torch.cat([image_latents, feat_volume], dim=1) # B x N x D
|
275 |
+
|
276 |
+
meshes = []
|
277 |
+
mesh_names = []
|
278 |
+
bboxes = []
|
279 |
+
|
280 |
+
output = self.decoder.generate(
|
281 |
+
inputs_embeds=voxel_token,
|
282 |
+
max_new_tokens=self.cfg.max_length - voxel_token.shape[1],
|
283 |
+
logits_processor=self.logits_processor,
|
284 |
+
do_sample=True,
|
285 |
+
top_k=5,
|
286 |
+
top_p=0.95,
|
287 |
+
temperature=0.5,
|
288 |
+
use_cache=True,
|
289 |
+
)
|
290 |
+
|
291 |
+
for i in range(output.shape[0]):
|
292 |
+
bounds = self.tokenizer.decode(output[i].detach().cpu().numpy(), coord_rg=(-0.5, 0.5))
|
293 |
+
# mesh = self.gen_mesh_from_bounds(bounds, batch['random_color'][i])
|
294 |
+
# meshes.append(mesh)
|
295 |
+
mesh_names.append("topk=5")
|
296 |
+
bboxes.append(bounds)
|
297 |
+
|
298 |
+
return {
|
299 |
+
# 'meshes': meshes,
|
300 |
+
'mesh_names': mesh_names,
|
301 |
+
'bboxes': bboxes,
|
302 |
+
}
|
303 |
+
|
304 |
+
|
305 |
+
|
modules/bbox_gen/models/bbox_gen_models.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from diffusers.models.normalization import FP32LayerNorm
|
5 |
+
from diffusers.models.attention import FeedForward
|
6 |
+
from transformers.generation.logits_process import LogitsProcessor
|
7 |
+
from typing import List, Literal, Optional
|
8 |
+
|
9 |
+
from modules.bbox_gen.modules.norm import GroupNorm32, ChannelLayerNorm32
|
10 |
+
|
11 |
+
|
12 |
+
class GroupEmbedding(nn.Module):
|
13 |
+
def __init__(self, max_group_size, hidden_size=64):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
self.group_embedding = nn.Embedding(max_group_size + 1, hidden_size) # +1 for background
|
17 |
+
self.group_embedding.weight.data.normal_(mean=0.0, std=0.02)
|
18 |
+
|
19 |
+
def forward(self, masks):
|
20 |
+
batch_size, height, width = masks.shape
|
21 |
+
masks_flat = masks.reshape(batch_size, -1)
|
22 |
+
embeddings = self.group_embedding(masks_flat)
|
23 |
+
embeddings = embeddings.reshape(batch_size, height, width, -1)
|
24 |
+
embeddings = embeddings.permute(0, 3, 1, 2)
|
25 |
+
return embeddings
|
26 |
+
|
27 |
+
|
28 |
+
class MultiModalProjector(torch.nn.Module):
|
29 |
+
def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
|
30 |
+
super().__init__()
|
31 |
+
|
32 |
+
self.norm1 = FP32LayerNorm(in_features)
|
33 |
+
self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
|
34 |
+
self.norm2 = FP32LayerNorm(out_features)
|
35 |
+
if pos_embed_seq_len is not None:
|
36 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
|
37 |
+
else:
|
38 |
+
self.pos_embed = None
|
39 |
+
|
40 |
+
def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
|
41 |
+
if self.pos_embed is not None:
|
42 |
+
batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
|
43 |
+
encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
|
44 |
+
encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
|
45 |
+
|
46 |
+
hidden_states = self.norm1(encoder_hidden_states_image)
|
47 |
+
hidden_states = self.ff(hidden_states)
|
48 |
+
hidden_states = self.norm2(hidden_states)
|
49 |
+
return hidden_states
|
50 |
+
|
51 |
+
|
52 |
+
class MeshDecodeLogitsProcessor(LogitsProcessor):
|
53 |
+
def __init__(self, bins, BOS_id, EOS_id, PAD_id, vertices_num=8):
|
54 |
+
super().__init__()
|
55 |
+
self.bins = bins
|
56 |
+
self.BOS_id = BOS_id
|
57 |
+
self.EOS_id = EOS_id
|
58 |
+
self.PAD_id = PAD_id
|
59 |
+
self.filter_value = -float('inf')
|
60 |
+
self.vertices_num = vertices_num
|
61 |
+
|
62 |
+
def force_token(self, scores, token_id):
|
63 |
+
mask = torch.ones_like(scores, dtype=torch.bool)
|
64 |
+
mask[:, token_id] = False
|
65 |
+
scores[mask] = self.filter_value
|
66 |
+
|
67 |
+
def __call__(self, input_ids, scores):
|
68 |
+
# # all rules:
|
69 |
+
# # 1. first token: BOS
|
70 |
+
current_len = input_ids.shape[-1]
|
71 |
+
if current_len == 0:
|
72 |
+
# force bos
|
73 |
+
self.force_token(scores, self.BOS_id)
|
74 |
+
elif current_len <= self.vertices_num * 3 + 1:
|
75 |
+
scores[:, self.bins:] = self.filter_value
|
76 |
+
else:
|
77 |
+
scores[:, self.BOS_id] = self.filter_value
|
78 |
+
scores[:, self.PAD_id] = self.filter_value
|
79 |
+
|
80 |
+
effective_tokens = current_len - 1
|
81 |
+
complete_boxes = effective_tokens % (self.vertices_num * 3) == 0
|
82 |
+
# print(effective_tokens, complete_boxes)
|
83 |
+
if not complete_boxes:
|
84 |
+
scores[:, self.EOS_id] = self.filter_value
|
85 |
+
|
86 |
+
return scores
|
87 |
+
|
88 |
+
|
89 |
+
def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
|
90 |
+
"""
|
91 |
+
Return a normalization layer.
|
92 |
+
"""
|
93 |
+
if norm_type == "group":
|
94 |
+
return GroupNorm32(32, *args, **kwargs)
|
95 |
+
elif norm_type == "layer":
|
96 |
+
return ChannelLayerNorm32(*args, **kwargs)
|
97 |
+
else:
|
98 |
+
raise ValueError(f"Invalid norm type {norm_type}")
|
99 |
+
|
100 |
+
|
101 |
+
class ResBlock3d(nn.Module):
|
102 |
+
def __init__(
|
103 |
+
self,
|
104 |
+
channels: int,
|
105 |
+
out_channels: Optional[int] = None,
|
106 |
+
norm_type: Literal["group", "layer"] = "layer",
|
107 |
+
):
|
108 |
+
super().__init__()
|
109 |
+
self.channels = channels
|
110 |
+
self.out_channels = out_channels or channels
|
111 |
+
|
112 |
+
self.norm1 = norm_layer(norm_type, channels)
|
113 |
+
self.norm2 = norm_layer(norm_type, self.out_channels)
|
114 |
+
self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
|
115 |
+
self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
|
116 |
+
self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
|
117 |
+
|
118 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
119 |
+
h = self.norm1(x)
|
120 |
+
h = F.silu(h)
|
121 |
+
h = self.conv1(h)
|
122 |
+
h = self.norm2(h)
|
123 |
+
h = F.silu(h)
|
124 |
+
h = self.conv2(h)
|
125 |
+
h = h + self.skip_connection(x)
|
126 |
+
return h
|
127 |
+
|
128 |
+
|
129 |
+
class DownsampleBlock3d(nn.Module):
|
130 |
+
def __init__(
|
131 |
+
self,
|
132 |
+
in_channels: int,
|
133 |
+
out_channels: int,
|
134 |
+
mode: Literal["conv", "avgpool"] = "conv",
|
135 |
+
):
|
136 |
+
assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
|
137 |
+
|
138 |
+
super().__init__()
|
139 |
+
self.in_channels = in_channels
|
140 |
+
self.out_channels = out_channels
|
141 |
+
|
142 |
+
if mode == "conv":
|
143 |
+
self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
|
144 |
+
elif mode == "avgpool":
|
145 |
+
assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
|
146 |
+
|
147 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
148 |
+
if hasattr(self, "conv"):
|
149 |
+
return self.conv(x)
|
150 |
+
else:
|
151 |
+
return F.avg_pool3d(x, 2)
|
152 |
+
|
153 |
+
|
154 |
+
def zero_module(module):
|
155 |
+
"""
|
156 |
+
Zero out the parameters of a module and return it.
|
157 |
+
"""
|
158 |
+
for p in module.parameters():
|
159 |
+
p.detach().zero_()
|
160 |
+
return module
|
161 |
+
|
162 |
+
|
163 |
+
class SparseStructureEncoder(nn.Module):
|
164 |
+
def __init__(
|
165 |
+
self,
|
166 |
+
in_channels: int,
|
167 |
+
latent_channels: int,
|
168 |
+
num_res_blocks: int,
|
169 |
+
channels: List[int],
|
170 |
+
num_res_blocks_middle: int = 2,
|
171 |
+
norm_type: Literal["group", "layer"] = "layer",
|
172 |
+
):
|
173 |
+
super().__init__()
|
174 |
+
self.in_channels = in_channels
|
175 |
+
self.latent_channels = latent_channels
|
176 |
+
self.num_res_blocks = num_res_blocks
|
177 |
+
self.channels = channels
|
178 |
+
self.num_res_blocks_middle = num_res_blocks_middle
|
179 |
+
self.norm_type = norm_type
|
180 |
+
self.dtype = torch.float16
|
181 |
+
self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
|
182 |
+
|
183 |
+
self.blocks = nn.ModuleList([])
|
184 |
+
for i, ch in enumerate(channels):
|
185 |
+
self.blocks.extend([
|
186 |
+
ResBlock3d(ch, ch)
|
187 |
+
for _ in range(num_res_blocks)
|
188 |
+
])
|
189 |
+
if i < len(channels) - 1:
|
190 |
+
self.blocks.append(
|
191 |
+
DownsampleBlock3d(ch, channels[i+1])
|
192 |
+
)
|
193 |
+
|
194 |
+
self.middle_block = nn.Sequential(*[
|
195 |
+
ResBlock3d(channels[-1], channels[-1])
|
196 |
+
for _ in range(num_res_blocks_middle)
|
197 |
+
])
|
198 |
+
|
199 |
+
@property
|
200 |
+
def device(self) -> torch.device:
|
201 |
+
"""
|
202 |
+
Return the device of the model.
|
203 |
+
"""
|
204 |
+
return next(self.parameters()).device
|
205 |
+
|
206 |
+
def forward(self, x: torch.Tensor):
|
207 |
+
h = self.input_layer(x)
|
208 |
+
h = h.type(self.dtype)
|
209 |
+
|
210 |
+
for block in self.blocks:
|
211 |
+
h = block(h)
|
212 |
+
h = self.middle_block(h)
|
213 |
+
|
214 |
+
h = h.type(x.dtype)
|
215 |
+
return h
|
modules/bbox_gen/models/bboxopt.py
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.utils.checkpoint
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
from transformers import AutoModelForCausalLM, AutoConfig
|
6 |
+
from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTModel, OPTDecoder, OPTConfig
|
7 |
+
|
8 |
+
from transformers.utils import logging
|
9 |
+
from typing import Optional, Union
|
10 |
+
|
11 |
+
from transformers.generation.logits_process import LogitsProcessorList
|
12 |
+
from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
|
13 |
+
from transformers.generation.stopping_criteria import StoppingCriteriaList
|
14 |
+
from transformers.generation.configuration_utils import GenerationConfig
|
15 |
+
from transformers.generation.streamers import BaseStreamer
|
16 |
+
|
17 |
+
logger = logging.get_logger(__name__)
|
18 |
+
|
19 |
+
class BBoxOPTConfig(OPTConfig):
|
20 |
+
model_type = "mesh_opt"
|
21 |
+
|
22 |
+
class BBoxOPTDecoder(OPTDecoder):
|
23 |
+
config_class = BBoxOPTConfig
|
24 |
+
|
25 |
+
class BBoxOPTModel(OPTModel):
|
26 |
+
config_class = BBoxOPTConfig
|
27 |
+
def __init__(self, config: BBoxOPTConfig):
|
28 |
+
super(OPTModel, self).__init__(config)
|
29 |
+
self.decoder = BBoxOPTDecoder(config)
|
30 |
+
# Initialize weights and apply final processing
|
31 |
+
self.post_init()
|
32 |
+
|
33 |
+
class BBoxOPT(OPTForCausalLM):
|
34 |
+
config_class = BBoxOPTConfig
|
35 |
+
|
36 |
+
def __init__(self, config: BBoxOPTConfig):
|
37 |
+
super(OPTForCausalLM, self).__init__(config)
|
38 |
+
self.model = BBoxOPTModel(config)
|
39 |
+
|
40 |
+
# the lm_head weight is automatically tied to the embed tokens weight
|
41 |
+
self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
|
42 |
+
|
43 |
+
# Initialize weights and apply final processing
|
44 |
+
self.post_init()
|
45 |
+
|
46 |
+
def _sample(
|
47 |
+
self,
|
48 |
+
input_ids: torch.LongTensor,
|
49 |
+
logits_processor: LogitsProcessorList,
|
50 |
+
stopping_criteria: StoppingCriteriaList,
|
51 |
+
generation_config: GenerationConfig,
|
52 |
+
synced_gpus: bool,
|
53 |
+
streamer: Optional["BaseStreamer"],
|
54 |
+
**model_kwargs,
|
55 |
+
) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
|
56 |
+
r"""
|
57 |
+
Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
|
58 |
+
can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
|
59 |
+
|
60 |
+
Parameters:
|
61 |
+
input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
|
62 |
+
The sequence used as a prompt for the generation.
|
63 |
+
logits_processor (`LogitsProcessorList`):
|
64 |
+
An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
|
65 |
+
used to modify the prediction scores of the language modeling head applied at each generation step.
|
66 |
+
stopping_criteria (`StoppingCriteriaList`):
|
67 |
+
An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
|
68 |
+
used to tell if the generation loop should stop.
|
69 |
+
generation_config ([`~generation.GenerationConfig`]):
|
70 |
+
The generation configuration to be used as parametrization of the decoding method.
|
71 |
+
synced_gpus (`bool`):
|
72 |
+
Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
|
73 |
+
streamer (`BaseStreamer`, *optional*):
|
74 |
+
Streamer object that will be used to stream the generated sequences. Generated tokens are passed
|
75 |
+
through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
|
76 |
+
model_kwargs:
|
77 |
+
Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
|
78 |
+
an encoder-decoder model the kwargs should include `encoder_outputs`.
|
79 |
+
|
80 |
+
Return:
|
81 |
+
[`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
|
82 |
+
A `torch.LongTensor` containing the generated tokens (default behaviour) or a
|
83 |
+
[`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
|
84 |
+
`return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
|
85 |
+
`model.config.is_encoder_decoder=True`.
|
86 |
+
"""
|
87 |
+
# init values
|
88 |
+
pad_token_id = generation_config._pad_token_tensor
|
89 |
+
output_attentions = generation_config.output_attentions
|
90 |
+
output_hidden_states = generation_config.output_hidden_states
|
91 |
+
output_scores = generation_config.output_scores
|
92 |
+
output_logits = generation_config.output_logits
|
93 |
+
return_dict_in_generate = generation_config.return_dict_in_generate
|
94 |
+
max_length = generation_config.max_length
|
95 |
+
has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
|
96 |
+
do_sample = generation_config.do_sample
|
97 |
+
|
98 |
+
# init attention / hidden states / scores tuples
|
99 |
+
scores = () if (return_dict_in_generate and output_scores) else None
|
100 |
+
raw_logits = () if (return_dict_in_generate and output_logits) else None
|
101 |
+
decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
|
102 |
+
cross_attentions = () if (return_dict_in_generate and output_attentions) else None
|
103 |
+
decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
|
104 |
+
|
105 |
+
# if model is an encoder-decoder, retrieve encoder attention weights and hidden states
|
106 |
+
if return_dict_in_generate and self.config.is_encoder_decoder:
|
107 |
+
encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
|
108 |
+
encoder_hidden_states = (
|
109 |
+
model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
|
110 |
+
)
|
111 |
+
|
112 |
+
# keep track of which sequences are already finished
|
113 |
+
batch_size, cur_len = input_ids.shape
|
114 |
+
this_peer_finished = False
|
115 |
+
unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
|
116 |
+
model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
|
117 |
+
|
118 |
+
while self._has_unfinished_sequences(
|
119 |
+
this_peer_finished, synced_gpus, device=input_ids.device
|
120 |
+
) and cur_len < max_length:
|
121 |
+
# prepare model inputs
|
122 |
+
model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
|
123 |
+
|
124 |
+
# prepare variable output controls (note: some models won't accept all output controls)
|
125 |
+
model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
|
126 |
+
model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
|
127 |
+
|
128 |
+
# forward pass to get next token
|
129 |
+
outputs = self(**model_inputs, return_dict=True)
|
130 |
+
|
131 |
+
if synced_gpus and this_peer_finished:
|
132 |
+
continue # don't waste resources running the code we don't need
|
133 |
+
|
134 |
+
# Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
|
135 |
+
# (the clone itself is always small)
|
136 |
+
next_token_logits = outputs.logits.clone()[:, -1, :].float()
|
137 |
+
|
138 |
+
# pre-process distribution
|
139 |
+
next_token_scores = logits_processor(input_ids, next_token_logits)
|
140 |
+
|
141 |
+
# Store scores, attentions and hidden_states when required
|
142 |
+
if return_dict_in_generate:
|
143 |
+
if output_scores:
|
144 |
+
scores += (next_token_scores,)
|
145 |
+
if output_logits:
|
146 |
+
raw_logits += (next_token_logits,)
|
147 |
+
if output_attentions:
|
148 |
+
decoder_attentions += (
|
149 |
+
(outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
|
150 |
+
)
|
151 |
+
if self.config.is_encoder_decoder:
|
152 |
+
cross_attentions += (outputs.cross_attentions,)
|
153 |
+
|
154 |
+
if output_hidden_states:
|
155 |
+
decoder_hidden_states += (
|
156 |
+
(outputs.decoder_hidden_states,)
|
157 |
+
if self.config.is_encoder_decoder
|
158 |
+
else (outputs.hidden_states,)
|
159 |
+
)
|
160 |
+
|
161 |
+
# token selection
|
162 |
+
if do_sample:
|
163 |
+
probs = nn.functional.softmax(next_token_scores, dim=-1)
|
164 |
+
# TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
|
165 |
+
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
166 |
+
else:
|
167 |
+
next_tokens = torch.argmax(next_token_scores, dim=-1)
|
168 |
+
|
169 |
+
# finished sentences should have their next token be a padding token
|
170 |
+
if has_eos_stopping_criteria:
|
171 |
+
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
172 |
+
|
173 |
+
# update generated ids, model inputs, and length for next step
|
174 |
+
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
175 |
+
if streamer is not None:
|
176 |
+
streamer.put(next_tokens.cpu())
|
177 |
+
model_kwargs = self._update_model_kwargs_for_generation(
|
178 |
+
outputs,
|
179 |
+
model_kwargs,
|
180 |
+
is_encoder_decoder=self.config.is_encoder_decoder,
|
181 |
+
)
|
182 |
+
|
183 |
+
unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
|
184 |
+
this_peer_finished = unfinished_sequences.max() == 0
|
185 |
+
cur_len += 1
|
186 |
+
|
187 |
+
# This is needed to properly delete outputs.logits which may be very large for first iteration
|
188 |
+
# Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
|
189 |
+
del outputs
|
190 |
+
|
191 |
+
if streamer is not None:
|
192 |
+
streamer.end()
|
193 |
+
|
194 |
+
if return_dict_in_generate:
|
195 |
+
if self.config.is_encoder_decoder:
|
196 |
+
return GenerateEncoderDecoderOutput(
|
197 |
+
sequences=input_ids,
|
198 |
+
scores=scores,
|
199 |
+
logits=raw_logits,
|
200 |
+
encoder_attentions=encoder_attentions,
|
201 |
+
encoder_hidden_states=encoder_hidden_states,
|
202 |
+
decoder_attentions=decoder_attentions,
|
203 |
+
cross_attentions=cross_attentions,
|
204 |
+
decoder_hidden_states=decoder_hidden_states,
|
205 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
206 |
+
)
|
207 |
+
else:
|
208 |
+
return GenerateDecoderOnlyOutput(
|
209 |
+
sequences=input_ids,
|
210 |
+
scores=scores,
|
211 |
+
logits=raw_logits,
|
212 |
+
attentions=decoder_attentions,
|
213 |
+
hidden_states=decoder_hidden_states,
|
214 |
+
past_key_values=model_kwargs.get("past_key_values"),
|
215 |
+
)
|
216 |
+
else:
|
217 |
+
return input_ids
|
218 |
+
|
219 |
+
|
220 |
+
AutoConfig.register("mesh_opt", BBoxOPTConfig)
|
221 |
+
AutoModelForCausalLM.register(BBoxOPTConfig, BBoxOPT)
|
modules/bbox_gen/models/image_encoder.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Literal
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
from transformers import AutoModel
|
9 |
+
|
10 |
+
|
11 |
+
class DINOv2ImageEncoder(nn.Module):
|
12 |
+
def __init__(self, model_name: Literal[
|
13 |
+
"facebook/dinov2-with-registers-large",
|
14 |
+
"facebook/dinov2-large"
|
15 |
+
]):
|
16 |
+
super().__init__()
|
17 |
+
self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
|
18 |
+
self.model.requires_grad_(False)
|
19 |
+
self.model.eval()
|
20 |
+
|
21 |
+
DINOv2_INPUT_MEAN = torch.as_tensor([0.485, 0.456, 0.406], dtype=torch.float32)[
|
22 |
+
None, :, None, None
|
23 |
+
]
|
24 |
+
DINOv2_INPUT_STD = torch.as_tensor([0.229, 0.224, 0.225], dtype=torch.float32)[
|
25 |
+
None, :, None, None
|
26 |
+
]
|
27 |
+
self.register_buffer("DINOv2_INPUT_MEAN", DINOv2_INPUT_MEAN, persistent=False)
|
28 |
+
self.register_buffer("DINOv2_INPUT_STD", DINOv2_INPUT_STD, persistent=False)
|
29 |
+
self.max_size = 518
|
30 |
+
self.hidden_size = self.model.config.hidden_size
|
31 |
+
|
32 |
+
def preprocess(self, image: torch.Tensor):
|
33 |
+
B, C, H, W = image.shape
|
34 |
+
assert C == 3 and H <= self.max_size and W <= self.max_size
|
35 |
+
image = (image - self.DINOv2_INPUT_MEAN.to(image)) / self.DINOv2_INPUT_STD.to(image)
|
36 |
+
return image
|
37 |
+
|
38 |
+
def forward(self, image: torch.Tensor):
|
39 |
+
image = self.preprocess(image)
|
40 |
+
features = self.model(image).last_hidden_state
|
41 |
+
return features
|