omnipart commited on
Commit
491eded
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +44 -0
  2. .gitignore +6 -0
  3. LICENSE +22 -0
  4. NOTICE +15 -0
  5. README.md +13 -0
  6. app.py +184 -0
  7. app_utils.py +412 -0
  8. assets/example_data/Batman.png +3 -0
  9. assets/example_data/astronaut.png +3 -0
  10. assets/example_data/car.png +3 -0
  11. assets/example_data/crossbow.jpg +0 -0
  12. assets/example_data/knight.png +3 -0
  13. assets/example_data/robot.jpg +0 -0
  14. assets/example_data/robot1.jpeg +3 -0
  15. assets/example_data/robot_dog.jpg +0 -0
  16. assets/example_data/ship.jpg +0 -0
  17. assets/example_data/snake.png +3 -0
  18. assets/example_data/warhammer.png +3 -0
  19. configs/bbox_gen.yaml +34 -0
  20. modules/PartField/configs/final/correspondence_demo.yaml +44 -0
  21. modules/PartField/configs/final/demo.yaml +28 -0
  22. modules/PartField/partfield/config/__init__.py +26 -0
  23. modules/PartField/partfield/config/defaults.py +92 -0
  24. modules/PartField/partfield/model/PVCNN/conv_pointnet.py +251 -0
  25. modules/PartField/partfield/model/PVCNN/dnnlib_util.py +1074 -0
  26. modules/PartField/partfield/model/PVCNN/encoder_pc.py +243 -0
  27. modules/PartField/partfield/model/PVCNN/pc_encoder.py +90 -0
  28. modules/PartField/partfield/model/PVCNN/pv_module/__init__.py +2 -0
  29. modules/PartField/partfield/model/PVCNN/pv_module/ball_query.py +34 -0
  30. modules/PartField/partfield/model/PVCNN/pv_module/frustum.py +141 -0
  31. modules/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py +1 -0
  32. modules/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py +12 -0
  33. modules/PartField/partfield/model/PVCNN/pv_module/loss.py +10 -0
  34. modules/PartField/partfield/model/PVCNN/pv_module/pointnet.py +113 -0
  35. modules/PartField/partfield/model/PVCNN/pv_module/pvconv.py +38 -0
  36. modules/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py +35 -0
  37. modules/PartField/partfield/model/PVCNN/pv_module/voxelization.py +80 -0
  38. modules/PartField/partfield/model/PVCNN/unet_3daware.py +427 -0
  39. modules/PartField/partfield/model/UNet/buildingblocks.py +546 -0
  40. modules/PartField/partfield/model/UNet/model.py +170 -0
  41. modules/PartField/partfield/model/model_utils.py +54 -0
  42. modules/PartField/partfield/model/triplane.py +331 -0
  43. modules/PartField/partfield/model_trainer_pvcnn_only_demo.py +283 -0
  44. modules/PartField/partfield/partfield_encoder.py +103 -0
  45. modules/PartField/partfield/utils.py +5 -0
  46. modules/bbox_gen/config.py +57 -0
  47. modules/bbox_gen/models/autogressive_bbox_gen.py +305 -0
  48. modules/bbox_gen/models/bbox_gen_models.py +215 -0
  49. modules/bbox_gen/models/bboxopt.py +221 -0
  50. modules/bbox_gen/models/image_encoder.py +41 -0
.gitattributes ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/example_data/Batman.png filter=lfs diff=lfs merge=lfs -text
37
+ assets/example_data/astronaut.png filter=lfs diff=lfs merge=lfs -text
38
+ assets/example_data/car.png filter=lfs diff=lfs merge=lfs -text
39
+ assets/example_data/knight.png filter=lfs diff=lfs merge=lfs -text
40
+ assets/example_data/robot1.jpeg filter=lfs diff=lfs merge=lfs -text
41
+ assets/example_data/snake.png filter=lfs diff=lfs merge=lfs -text
42
+ assets/example_data/warhammer.png filter=lfs diff=lfs merge=lfs -text
43
+ modules/part_synthesis/representations/mesh/flexicubes/images/block_init.png filter=lfs diff=lfs merge=lfs -text
44
+ modules/part_synthesis/representations/mesh/flexicubes/images/teaser_top.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ __pycache__/
2
+ output/
3
+ ckpt/
4
+ .DS_Store
5
+ tmp/
6
+ debug_images/
LICENSE ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # MIT License
2
+
3
+ # Copyright (c) 2025 VAST-AI-Research and contributors.
4
+
5
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ # of this software and associated documentation files (the "Software"), to deal
7
+ # in the Software without restriction, including without limitation the rights
8
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ # copies of the Software, and to permit persons to whom the Software is
10
+ # furnished to do so, subject to the following conditions:
11
+
12
+ # The above copyright notice and this permission notice shall be included in all
13
+ # copies or substantial portions of the Software.
14
+
15
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ # SOFTWARE
22
+
NOTICE ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ OmniPart
2
+ Copyright (c) 2025 VAST-AI-Research and contributors
3
+
4
+ This project includes code from the following open source projects:
5
+
6
+ RMBG
7
+ Copyright (c) BRIA AI
8
+ License: bria-rmbg-2.0
9
+ Source: https://huggingface.co/briaai/RMBG-2.0
10
+
11
+ This software contains code derived from 🤗 Diffusers (https://github.com/huggingface/diffusers), available under the Apache License 2.0.
12
+
13
+ This software contains code derived from TRELLIS (https://github.com/Microsoft/TRELLIS), available under the MIT License.
14
+
15
+ This software contains code derived from PartPacker (https://github.com/NVlabs/PartPacker), available under the NVIDIA Source Code License.
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: OmniPart
3
+ emoji: 📚
4
+ colorFrom: green
5
+ colorTo: green
6
+ sdk: gradio
7
+ sdk_version: 5.35.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os
4
+ import shutil
5
+ os.environ['SPCONV_ALGO'] = 'native'
6
+ from huggingface_hub import hf_hub_download
7
+
8
+ from app_utils import (
9
+ generate_parts,
10
+ prepare_models,
11
+ process_image,
12
+ apply_merge,
13
+ DEFAULT_SIZE_TH,
14
+ TMP_ROOT,
15
+ )
16
+
17
+ EXAMPLES = [
18
+ ["assets/example_data/knight.png", 1800, "6,0,26,20,7;13,1,22,11,12,2,21,27,3,24,23;5,18;4,17;19,16,14,25,28", 42],
19
+ ["assets/example_data/car.png", 2000, "12,10,2,11;1,7", 42],
20
+ ["assets/example_data/warhammer.png", 1800, "7,1,0,8", 0],
21
+ ["assets/example_data/snake.png", 3000, "2,3;0,1;4,5,6,7", 42],
22
+ ["assets/example_data/Batman.png", 1800, "4,5", 42],
23
+ ["assets/example_data/robot1.jpeg", 1600, "0,5;10,14,3;1,12,2;13,11,4;7,15", 42],
24
+ ["assets/example_data/astronaut.png", 2000, "0,4,6;1,8,9,7;2,5", 42],
25
+ ["assets/example_data/crossbow.jpg", 2000, "2,9;10,12,0,7,11,8,13;4,3", 42],
26
+ ["assets/example_data/robot.jpg", 1600, "7,19;15,0;6,18", 42],
27
+ ["assets/example_data/robot_dog.jpg", 1000, "21,9;2,12,10,15,17;11,7;1,0;13,19;4,16", 0],
28
+ ["assets/example_data/crossbow.jpg", 1600, "9,2;10,15,13;7,14,8,11;0,12,16;5,3,1", 42],
29
+ ["assets/example_data/robot.jpg", 1800, "1,2,3,5,4,16,17;11,7,19;10,14;18,6,0,15;13,9;12,8", 0],
30
+ ["assets/example_data/robot_dog.jpg", 1000, "2,12,10,15,17,8,3,5,13,19,6,14;11,7;1,0,21,9,11;4,16", 0],
31
+ ]
32
+
33
+ HEADER = """
34
+
35
+ # OmniPart: Part-Aware 3D Generation with Semantic Decoupling and Structural Cohesion
36
+
37
+ 🔮 Generate **part-aware 3D content** from a single 2D image with **2D mask control**.
38
+
39
+ ## How to Use
40
+
41
+ **🚀 Quick Start**: Select an example below and click **"▶️ Run Example"**
42
+
43
+
44
+ **📋 Custom Image Processing**:
45
+ 1. **Upload Image** - Select your image file
46
+ 2. **Click "Segment Image"** - Get initial 2D segmentation
47
+ 3. **Merge Segments** - Enter merge groups like `0,1;3,4` and click **"Apply Merge"** (Recommend keeping **2-15 parts**)
48
+ 4. **Click "Generate 3D Model"** - Create the final 3D results
49
+ """
50
+
51
+
52
+ def start_session(req: gr.Request):
53
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
54
+ os.makedirs(user_dir, exist_ok=True)
55
+
56
+
57
+ def end_session(req: gr.Request):
58
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
59
+ shutil.rmtree(user_dir)
60
+
61
+
62
+ with gr.Blocks(title="OmniPart") as demo:
63
+ gr.Markdown(HEADER)
64
+
65
+ state = gr.State({})
66
+
67
+ with gr.Row():
68
+ with gr.Column(scale=1):
69
+ gr.Markdown("<div style='text-align: center'>\n\n## Input\n\n</div>")
70
+
71
+ input_image = gr.Image(label="Upload Image", type="filepath", height=250, width=250)
72
+
73
+ with gr.Row():
74
+ segment_btn = gr.Button("Segment Image", variant="primary", size="lg")
75
+ run_example_btn = gr.Button("▶️ Run Example", variant="secondary", size="lg")
76
+
77
+ size_threshold = gr.Slider(
78
+ minimum=600,
79
+ maximum=4000,
80
+ value=DEFAULT_SIZE_TH,
81
+ step=200,
82
+ label="Minimum Segment Size (pixels)",
83
+ info="Segments smaller than this will be ignored"
84
+ )
85
+
86
+ gr.Markdown("### Merge Controls")
87
+ merge_input = gr.Textbox(
88
+ label="Merge Groups",
89
+ placeholder="0,1;3,4",
90
+ lines=2,
91
+ info="Specify which segments to merge (e.g., '0,1;3,4' merges segments 0&1 together and 3&4 together)"
92
+ )
93
+ merge_btn = gr.Button("Apply Merge", variant="primary", size="lg")
94
+
95
+ gr.Markdown("### 3D Generation Controls")
96
+
97
+ seed_slider = gr.Slider(
98
+ minimum=0,
99
+ maximum=10000,
100
+ value=42,
101
+ step=1,
102
+ label="Generation Seed",
103
+ info="Random seed for 3D model generation"
104
+ )
105
+
106
+ cfg_slider = gr.Slider(
107
+ minimum=0.0,
108
+ maximum=15.0,
109
+ value=7.5,
110
+ step=0.5,
111
+ label="CFG Strength",
112
+ info="Classifier-Free Guidance strength"
113
+ )
114
+
115
+ generate_mesh_btn = gr.Button("Generate 3D Model", variant="secondary", size="lg")
116
+
117
+ with gr.Column(scale=2):
118
+ gr.Markdown("<div style='text-align: center'>\n\n## Results Display\n\n</div>")
119
+
120
+ with gr.Row():
121
+ initial_seg = gr.Image(label="Init Seg", height=220, width=220)
122
+ pre_merge_vis = gr.Image(label="Pre-merge", height=220, width=220)
123
+ merged_seg = gr.Image(label="Merged Seg", height=220, width=220)
124
+
125
+ with gr.Row():
126
+ bbox_mesh = gr.Model3D(label="Bounding Boxes", height=350)
127
+ whole_mesh = gr.Model3D(label="Combined Parts", height=350)
128
+ exploded_mesh = gr.Model3D(label="Exploded Parts", height=350)
129
+
130
+ with gr.Row():
131
+ combined_gs = gr.Model3D(label="Combined 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350)
132
+ exploded_gs = gr.Model3D(label="Exploded 3D Gaussians", clear_color=(0.0, 0.0, 0.0, 0.0), height=350)
133
+
134
+ with gr.Row():
135
+ examples = gr.Examples(
136
+ examples=EXAMPLES,
137
+ inputs=[input_image, size_threshold, merge_input, seed_slider],
138
+ cache_examples=False,
139
+ )
140
+
141
+ demo.load(start_session)
142
+ demo.unload(end_session)
143
+
144
+ segment_btn.click(
145
+ process_image,
146
+ inputs=[input_image, size_threshold],
147
+ outputs=[initial_seg, pre_merge_vis, state]
148
+ )
149
+
150
+ merge_btn.click(
151
+ apply_merge,
152
+ inputs=[merge_input, state],
153
+ outputs=[merged_seg, state]
154
+ )
155
+
156
+ generate_mesh_btn.click(
157
+ generate_parts,
158
+ inputs=[state, seed_slider, cfg_slider],
159
+ outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs]
160
+ )
161
+
162
+ run_example_btn.click(
163
+ fn=process_image,
164
+ inputs=[input_image, size_threshold],
165
+ outputs=[initial_seg, pre_merge_vis, state]
166
+ ).then(
167
+ fn=apply_merge,
168
+ inputs=[merge_input, state],
169
+ outputs=[merged_seg, state]
170
+ ).then(
171
+ fn=generate_parts,
172
+ inputs=[state, seed_slider, cfg_slider],
173
+ outputs=[bbox_mesh, whole_mesh, exploded_mesh, combined_gs, exploded_gs]
174
+ )
175
+
176
+ if __name__ == "__main__":
177
+ os.makedirs("ckpt", exist_ok=True)
178
+ sam_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="sam_vit_h_4b8939.pth", local_dir="ckpt")
179
+ partfield_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="partfield_encoder.ckpt", local_dir="ckpt")
180
+ bbox_gen_ckpt_path = hf_hub_download(repo_id="omnipart/OmniPart_modules", filename="bbox_gen.ckpt", local_dir="ckpt")
181
+
182
+ prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path)
183
+
184
+ demo.launch()
app_utils.py ADDED
@@ -0,0 +1,412 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import spaces
3
+ import os
4
+ import numpy as np
5
+ import trimesh
6
+ import time
7
+ import traceback
8
+ import torch
9
+ from PIL import Image
10
+ import cv2
11
+ import shutil
12
+ from segment_anything import SamAutomaticMaskGenerator, build_sam
13
+ from omegaconf import OmegaConf
14
+
15
+ from modules.bbox_gen.models.autogressive_bbox_gen import BboxGen
16
+ from modules.part_synthesis.process_utils import save_parts_outputs
17
+ from modules.inference_utils import load_img_mask, prepare_bbox_gen_input, prepare_part_synthesis_input, gen_mesh_from_bounds, vis_voxel_coords, merge_parts
18
+ from modules.part_synthesis.pipelines import OmniPartImageTo3DPipeline
19
+ from modules.label_2d_mask.visualizer import Visualizer
20
+ from transformers import AutoModelForImageSegmentation
21
+
22
+ from modules.label_2d_mask.label_parts import (
23
+ prepare_image,
24
+ get_sam_mask,
25
+ get_mask,
26
+ clean_segment_edges,
27
+ resize_and_pad_to_square,
28
+ size_th as DEFAULT_SIZE_TH
29
+ )
30
+
31
+ # Constants
32
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
33
+ DTYPE = torch.float16
34
+ MAX_SEED = np.iinfo(np.int32).max
35
+ TMP_ROOT = os.path.join(os.path.dirname(os.path.abspath(__file__)), "tmp")
36
+ os.makedirs(TMP_ROOT, exist_ok=True)
37
+
38
+ sam_mask_generator = None
39
+ rmbg_model = None
40
+ bbox_gen_model = None
41
+ part_synthesis_pipeline = None
42
+
43
+ size_th = DEFAULT_SIZE_TH
44
+
45
+
46
+ def prepare_models(sam_ckpt_path, partfield_ckpt_path, bbox_gen_ckpt_path):
47
+ global sam_mask_generator, rmbg_model, bbox_gen_model, part_synthesis_pipeline
48
+ if sam_mask_generator is None:
49
+ print("Loading SAM model...")
50
+ sam_model = build_sam(checkpoint=sam_ckpt_path).to(device=DEVICE)
51
+ sam_mask_generator = SamAutomaticMaskGenerator(sam_model)
52
+
53
+ if rmbg_model is None:
54
+ print("Loading BriaRMBG 2.0 model...")
55
+ rmbg_model = AutoModelForImageSegmentation.from_pretrained('briaai/RMBG-2.0', trust_remote_code=True)
56
+ rmbg_model.to(DEVICE)
57
+ rmbg_model.eval()
58
+
59
+ if part_synthesis_pipeline is None:
60
+ print("Loading PartSynthesis model...")
61
+ part_synthesis_pipeline = OmniPartImageTo3DPipeline.from_pretrained('omnipart/OmniPart')
62
+ part_synthesis_pipeline.to(DEVICE)
63
+
64
+ if bbox_gen_model is None:
65
+ print("Loading BboxGen model...")
66
+ bbox_gen_config = OmegaConf.load("configs/bbox_gen.yaml").model.args
67
+ bbox_gen_config.partfield_encoder_path = partfield_ckpt_path
68
+ bbox_gen_model = BboxGen(bbox_gen_config)
69
+ bbox_gen_model.load_state_dict(torch.load(bbox_gen_ckpt_path), strict=False)
70
+ bbox_gen_model.to(DEVICE)
71
+ bbox_gen_model.eval().half()
72
+
73
+ print("Models ready")
74
+
75
+
76
+ @spaces.GPU
77
+ def process_image(image_path, threshold, req: gr.Request):
78
+ """Process image and generate initial segmentation"""
79
+ global size_th
80
+
81
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
82
+ os.makedirs(user_dir, exist_ok=True)
83
+
84
+ img_name = os.path.basename(image_path).split(".")[0]
85
+
86
+ size_th = threshold
87
+
88
+ img = Image.open(image_path).convert("RGB")
89
+ processed_image = prepare_image(img, rmbg_net=rmbg_model.to(DEVICE))
90
+
91
+ processed_image = resize_and_pad_to_square(processed_image)
92
+ white_bg = Image.new("RGBA", processed_image.size, (255, 255, 255, 255))
93
+ white_bg_img = Image.alpha_composite(white_bg, processed_image.convert("RGBA"))
94
+ image = np.array(white_bg_img.convert('RGB'))
95
+
96
+ rgba_path = os.path.join(user_dir, f"{img_name}_processed.png")
97
+ processed_image.save(rgba_path)
98
+
99
+ print("Generating raw SAM masks without post-processing...")
100
+ raw_masks = sam_mask_generator.generate(image)
101
+
102
+ raw_sam_vis = np.copy(image)
103
+ raw_sam_vis = np.ones_like(image) * 255
104
+
105
+ sorted_masks = sorted(raw_masks, key=lambda x: x["area"], reverse=True)
106
+
107
+ for i, mask_data in enumerate(sorted_masks):
108
+ if mask_data["area"] < size_th:
109
+ continue
110
+
111
+ color_r = (i * 50 + 80) % 256
112
+ color_g = (i * 120 + 40) % 256
113
+ color_b = (i * 180 + 20) % 256
114
+ color = np.array([color_r, color_g, color_b])
115
+
116
+ mask = mask_data["segmentation"]
117
+ raw_sam_vis[mask] = color
118
+
119
+ visual = Visualizer(image)
120
+
121
+ group_ids, pre_merge_im = get_sam_mask(
122
+ image,
123
+ sam_mask_generator,
124
+ visual,
125
+ merge_groups=None,
126
+ rgba_image=processed_image,
127
+ img_name=img_name,
128
+ save_dir=user_dir,
129
+ size_threshold=size_th
130
+ )
131
+
132
+ pre_merge_path = os.path.join(user_dir, f"{img_name}_mask_pre_merge.png")
133
+ Image.fromarray(pre_merge_im).save(pre_merge_path)
134
+ pre_split_vis = np.ones_like(image) * 255
135
+
136
+ unique_ids = np.unique(group_ids)
137
+ unique_ids = unique_ids[unique_ids >= 0]
138
+
139
+ for i, unique_id in enumerate(unique_ids):
140
+ color_r = (i * 50 + 80) % 256
141
+ color_g = (i * 120 + 40) % 256
142
+ color_b = (i * 180 + 20) % 256
143
+ color = np.array([color_r, color_g, color_b])
144
+
145
+ mask = (group_ids == unique_id)
146
+ pre_split_vis[mask] = color
147
+
148
+ y_indices, x_indices = np.where(mask)
149
+ if len(y_indices) > 0:
150
+ center_y = int(np.mean(y_indices))
151
+ center_x = int(np.mean(x_indices))
152
+ cv2.putText(pre_split_vis, str(unique_id),
153
+ (center_x, center_y), cv2.FONT_HERSHEY_SIMPLEX,
154
+ 0.5, (0, 0, 0), 1, cv2.LINE_AA)
155
+
156
+ pre_split_path = os.path.join(user_dir, f"{img_name}_pre_split.png")
157
+ Image.fromarray(pre_split_vis).save(pre_split_path)
158
+ print(f"Pre-split segmentation (before disconnected parts handling) saved to {pre_split_path}")
159
+
160
+ get_mask(group_ids, image, ids=2, img_name=img_name, save_dir=user_dir)
161
+
162
+ init_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_2.png")
163
+
164
+ seg_img = Image.open(init_seg_path)
165
+ if seg_img.mode == 'RGBA':
166
+ white_bg = Image.new('RGBA', seg_img.size, (255, 255, 255, 255))
167
+ seg_img = Image.alpha_composite(white_bg, seg_img)
168
+ seg_img.save(init_seg_path)
169
+
170
+ state = {
171
+ "image": image.tolist(),
172
+ "processed_image": rgba_path,
173
+ "group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids,
174
+ "original_group_ids": group_ids.tolist() if isinstance(group_ids, np.ndarray) else group_ids,
175
+ "img_name": img_name,
176
+ "pre_split_path": pre_split_path,
177
+ }
178
+
179
+ return init_seg_path, pre_merge_path, state
180
+
181
+
182
+ def apply_merge(merge_input, state, req: gr.Request):
183
+ """Apply merge parameters and generate merged segmentation"""
184
+ global sam_mask_generator
185
+
186
+ if not state:
187
+ return None, None, state
188
+
189
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
190
+
191
+ # Convert back from list to numpy array
192
+ image = np.array(state["image"])
193
+ # Use original group IDs instead of the most recent ones
194
+ group_ids = np.array(state["original_group_ids"])
195
+ img_name = state["img_name"]
196
+
197
+ # Load processed image from path
198
+ processed_image = Image.open(state["processed_image"])
199
+
200
+ # Display the original IDs before merging, SORTED for easier reading
201
+ unique_ids = np.unique(group_ids)
202
+ unique_ids = unique_ids[unique_ids >= 0] # Exclude background
203
+ print(f"Original segment IDs (used for merging): {sorted(unique_ids.tolist())}")
204
+
205
+ # Parse merge groups
206
+ merge_groups = None
207
+ try:
208
+ if merge_input:
209
+ merge_groups = []
210
+ group_sets = merge_input.split(';')
211
+ for group_set in group_sets:
212
+ ids = [int(x) for x in group_set.split(',')]
213
+ if ids:
214
+ # Validate if these IDs exist in the segmentation
215
+ existing_ids = [id for id in ids if id in unique_ids]
216
+ missing_ids = [id for id in ids if id not in unique_ids]
217
+
218
+ if missing_ids:
219
+ print(f"Warning: These IDs don't exist in the segmentation: {missing_ids}")
220
+
221
+ # Only add group if it has valid IDs
222
+ if existing_ids:
223
+ merge_groups.append(ids)
224
+ print(f"Valid merge group: {ids} (missing: {missing_ids if missing_ids else 'none'})")
225
+ else:
226
+ print(f"Skipping merge group with no valid IDs: {ids}")
227
+
228
+ print(f"Using merge groups: {merge_groups}")
229
+ except Exception as e:
230
+ print(f"Error parsing merge groups: {e}")
231
+ return None, None, state
232
+
233
+ # Initialize visualizer
234
+ visual = Visualizer(image)
235
+
236
+ # Generate merged segmentation starting from original IDs
237
+ # Add skip_split=True to prevent splitting after merging
238
+ new_group_ids, merged_im = get_sam_mask(
239
+ image,
240
+ sam_mask_generator,
241
+ visual,
242
+ merge_groups=merge_groups,
243
+ existing_group_ids=group_ids,
244
+ rgba_image=processed_image,
245
+ skip_split=True,
246
+ img_name=img_name,
247
+ save_dir=user_dir,
248
+ size_threshold=size_th
249
+ )
250
+
251
+ # Display the new IDs after merging for future reference
252
+ new_unique_ids = np.unique(new_group_ids)
253
+ new_unique_ids = new_unique_ids[new_unique_ids >= 0] # Exclude background
254
+ print(f"New segment IDs (after merging): {new_unique_ids.tolist()}")
255
+
256
+ # Clean edges
257
+ new_group_ids = clean_segment_edges(new_group_ids)
258
+
259
+ # Save merged segmentation visualization
260
+ get_mask(new_group_ids, image, ids=3, img_name=img_name, save_dir=user_dir)
261
+
262
+ # Path to merged segmentation
263
+ merged_seg_path = os.path.join(user_dir, f"{img_name}_mask_segments_3.png")
264
+
265
+ save_mask = new_group_ids + 1
266
+ save_mask = save_mask.reshape(518, 518, 1).repeat(3, axis=-1)
267
+ cv2.imwrite(os.path.join(user_dir, f"{img_name}_mask.exr"), save_mask.astype(np.float32))
268
+
269
+ # Update state with the new group IDs but keep original IDs unchanged
270
+ state["group_ids"] = new_group_ids.tolist() if isinstance(new_group_ids, np.ndarray) else new_group_ids
271
+ state["save_mask_path"] = os.path.join(user_dir, f"{img_name}_mask.exr")
272
+
273
+ return merged_seg_path, state
274
+
275
+
276
+ def explode_mesh(mesh, explosion_scale=0.4):
277
+
278
+ if isinstance(mesh, trimesh.Scene):
279
+ scene = mesh
280
+ elif isinstance(mesh, trimesh.Trimesh):
281
+ print("Warning: Single mesh provided, can't create exploded view")
282
+ scene = trimesh.Scene(mesh)
283
+ return scene
284
+ else:
285
+ print(f"Warning: Unexpected mesh type: {type(mesh)}")
286
+ scene = mesh
287
+
288
+ if len(scene.geometry) <= 1:
289
+ print("Only one geometry found - nothing to explode")
290
+ return scene
291
+
292
+ print(f"[EXPLODE_MESH] Starting mesh explosion with scale {explosion_scale}")
293
+ print(f"[EXPLODE_MESH] Processing {len(scene.geometry)} parts")
294
+
295
+ exploded_scene = trimesh.Scene()
296
+
297
+ part_centers = []
298
+ geometry_names = []
299
+
300
+ for geometry_name, geometry in scene.geometry.items():
301
+ if hasattr(geometry, 'vertices'):
302
+ transform = scene.graph[geometry_name][0]
303
+ vertices_global = trimesh.transformations.transform_points(
304
+ geometry.vertices, transform)
305
+ center = np.mean(vertices_global, axis=0)
306
+ part_centers.append(center)
307
+ geometry_names.append(geometry_name)
308
+ print(f"[EXPLODE_MESH] Part {geometry_name}: center = {center}")
309
+
310
+ if not part_centers:
311
+ print("No valid geometries with vertices found")
312
+ return scene
313
+
314
+ part_centers = np.array(part_centers)
315
+ global_center = np.mean(part_centers, axis=0)
316
+
317
+ print(f"[EXPLODE_MESH] Global center: {global_center}")
318
+
319
+ for i, (geometry_name, geometry) in enumerate(scene.geometry.items()):
320
+ if hasattr(geometry, 'vertices'):
321
+ if i < len(part_centers):
322
+ part_center = part_centers[i]
323
+ direction = part_center - global_center
324
+
325
+ direction_norm = np.linalg.norm(direction)
326
+ if direction_norm > 1e-6:
327
+ direction = direction / direction_norm
328
+ else:
329
+ direction = np.random.randn(3)
330
+ direction = direction / np.linalg.norm(direction)
331
+
332
+ offset = direction * explosion_scale
333
+ else:
334
+ offset = np.zeros(3)
335
+
336
+ original_transform = scene.graph[geometry_name][0].copy()
337
+
338
+ new_transform = original_transform.copy()
339
+ new_transform[:3, 3] = new_transform[:3, 3] + offset
340
+
341
+ exploded_scene.add_geometry(
342
+ geometry,
343
+ transform=new_transform,
344
+ geom_name=geometry_name
345
+ )
346
+
347
+ print(f"[EXPLODE_MESH] Part {geometry_name}: moved by {np.linalg.norm(offset):.4f}")
348
+
349
+ print("[EXPLODE_MESH] Mesh explosion complete")
350
+ return exploded_scene
351
+
352
+ @spaces.GPU(duration=90)
353
+ def generate_parts(state, seed, cfg_strength, req: gr.Request):
354
+ explode_factor=0.3
355
+ img_path = state["processed_image"]
356
+ mask_path = state["save_mask_path"]
357
+ user_dir = os.path.join(TMP_ROOT, str(req.session_hash))
358
+ img_white_bg, img_black_bg, ordered_mask_input, img_mask_vis = load_img_mask(img_path, mask_path)
359
+ img_mask_vis.save(os.path.join(user_dir, "img_mask_vis.png"))
360
+
361
+ voxel_coords = part_synthesis_pipeline.get_coords(img_black_bg, num_samples=1, seed=seed, sparse_structure_sampler_params={"steps": 25, "cfg_strength": 7.5})
362
+ voxel_coords = voxel_coords.cpu().numpy()
363
+ np.save(os.path.join(user_dir, "voxel_coords.npy"), voxel_coords)
364
+ voxel_coords_ply = vis_voxel_coords(voxel_coords)
365
+ voxel_coords_ply.export(os.path.join(user_dir, "voxel_coords_vis.ply"))
366
+ print("[INFO] Voxel coordinates saved")
367
+
368
+ bbox_gen_input = prepare_bbox_gen_input(os.path.join(user_dir, "voxel_coords.npy"), img_white_bg, ordered_mask_input)
369
+ bbox_gen_output = bbox_gen_model.generate(bbox_gen_input)
370
+ np.save(os.path.join(user_dir, "bboxes.npy"), bbox_gen_output['bboxes'][0])
371
+ bboxes_vis = gen_mesh_from_bounds(bbox_gen_output['bboxes'][0])
372
+ bboxes_vis.export(os.path.join(user_dir, "bboxes_vis.glb"))
373
+ print("[INFO] BboxGen output saved")
374
+
375
+
376
+ part_synthesis_input = prepare_part_synthesis_input(os.path.join(user_dir, "voxel_coords.npy"), os.path.join(user_dir, "bboxes.npy"), ordered_mask_input)
377
+
378
+ torch.cuda.empty_cache()
379
+
380
+ part_synthesis_output = part_synthesis_pipeline.get_slat(
381
+ img_black_bg,
382
+ part_synthesis_input['coords'],
383
+ [part_synthesis_input['part_layouts']],
384
+ part_synthesis_input['masks'],
385
+ seed=seed,
386
+ slat_sampler_params={"steps": 25, "cfg_strength": cfg_strength},
387
+ formats=['mesh', 'gaussian'],
388
+ preprocess_image=False,
389
+ )
390
+ save_parts_outputs(
391
+ part_synthesis_output,
392
+ output_dir=user_dir,
393
+ simplify_ratio=0.0,
394
+ save_video=False,
395
+ save_glb=True,
396
+ textured=False,
397
+ )
398
+ merge_parts(user_dir)
399
+ print("[INFO] PartSynthesis output saved")
400
+
401
+ bbox_mesh_path = os.path.join(user_dir, "bboxes_vis.glb")
402
+ whole_mesh_path = os.path.join(user_dir, "mesh_segment.glb")
403
+
404
+ combined_mesh = trimesh.load(whole_mesh_path)
405
+ exploded_mesh_result = explode_mesh(combined_mesh, explosion_scale=explode_factor)
406
+ exploded_mesh_result.export(os.path.join(user_dir, "exploded_parts.glb"))
407
+
408
+ exploded_mesh_path = os.path.join(user_dir, "exploded_parts.glb")
409
+ combined_gs_path = os.path.join(user_dir, "merged_gs.ply")
410
+ exploded_gs_path = os.path.join(user_dir, "exploded_gs.ply")
411
+
412
+ return bbox_mesh_path, whole_mesh_path, exploded_mesh_path, combined_gs_path, exploded_gs_path
assets/example_data/Batman.png ADDED

Git LFS Details

  • SHA256: 9a9a80321c27ee38899bbc2bb4f346d449422898f3dc3214dba4dcd6e5cf6397
  • Pointer size: 131 Bytes
  • Size of remote file: 510 kB
assets/example_data/astronaut.png ADDED

Git LFS Details

  • SHA256: 49712b3a29aa24862e8a4d3c1c69326459585ef9f7aa15ff8c2b2d90101f3784
  • Pointer size: 131 Bytes
  • Size of remote file: 252 kB
assets/example_data/car.png ADDED

Git LFS Details

  • SHA256: 82239d215901c12d12ddaa5fbb5b6c3f928e3a6fec19bc1a2b26a4aa084d482d
  • Pointer size: 133 Bytes
  • Size of remote file: 10.1 MB
assets/example_data/crossbow.jpg ADDED
assets/example_data/knight.png ADDED

Git LFS Details

  • SHA256: 291db3fca9c1d63b91609d28352b3f6fbc1e9f143f7783b70dc9ec35a911d77c
  • Pointer size: 131 Bytes
  • Size of remote file: 604 kB
assets/example_data/robot.jpg ADDED
assets/example_data/robot1.jpeg ADDED

Git LFS Details

  • SHA256: 7131acb0e194caf8bac6bee72d668def184a18df14848ce731380b96486e996b
  • Pointer size: 131 Bytes
  • Size of remote file: 133 kB
assets/example_data/robot_dog.jpg ADDED
assets/example_data/ship.jpg ADDED
assets/example_data/snake.png ADDED

Git LFS Details

  • SHA256: fa4ec58625fed4dd0e5b65e323333c89ecace02fbe4161327d4e06e0ea4b678a
  • Pointer size: 132 Bytes
  • Size of remote file: 1.41 MB
assets/example_data/warhammer.png ADDED

Git LFS Details

  • SHA256: bc63bda34774288092d069808b7cb28c9544dd253cfbcfb33a98b22c9ec19537
  • Pointer size: 131 Bytes
  • Size of remote file: 163 kB
configs/bbox_gen.yaml ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ name: bbox_gen
3
+ args:
4
+ encoder_dim_feat: 448
5
+ encoder_dim: 64
6
+ encoder_heads: 4
7
+ encoder_token_num: 2048
8
+ encoder_qkv_bias: false
9
+ encoder_use_ln_post: true
10
+ encoder_use_checkpoint: true
11
+ encoder_num_embed_freqs: 8
12
+ encoder_embed_include_pi: false
13
+ encoder_init_scale: 0.25
14
+ encoder_random_fps: true
15
+ encoder_learnable_query: true
16
+ encoder_layers: 8
17
+
18
+ max_group_size: 50
19
+
20
+ vocab_size: 67
21
+ decoder_hidden_size: 1024
22
+ decoder_num_hidden_layers: 24
23
+ decoder_ffn_dim: 4096
24
+ decoder_heads: 16
25
+ decoder_use_flash_attention: true
26
+ decoder_gradient_checkpointing: false
27
+
28
+ bins: 64
29
+ BOS_id: 64
30
+ EOS_id: 65
31
+ PAD_id: 66
32
+ max_length: 2187
33
+ voxel_token_length: 1886
34
+ voxel_token_placeholder: -1
modules/PartField/configs/final/correspondence_demo.yaml ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result_name: partfield_features/correspondence_demo
2
+
3
+ continue_ckpt: model/model.ckpt
4
+
5
+ triplane_channels_low: 128
6
+ triplane_channels_high: 512
7
+ triplane_resolution: 128
8
+
9
+ vertex_feature: True
10
+ n_point_per_face: 1000
11
+ n_sample_each: 10000
12
+ is_pc: True
13
+ remesh_demo: False
14
+ correspondence_demo: True
15
+
16
+ preprocess_mesh: True
17
+
18
+ dataset:
19
+ type: "Mix"
20
+ data_path: data/DenseCorr3D
21
+ train_batch_size: 1
22
+ val_batch_size: 1
23
+ train_num_workers: 8
24
+ all_files:
25
+ # pairs of example to run correspondence
26
+ - animals/071b8_toy_animals_017/simple_mesh.obj
27
+ - animals/bdfd0_toy_animals_016/simple_mesh.obj
28
+ - animals/2d6b3_toy_animals_009/simple_mesh.obj
29
+ - animals/96615_toy_animals_018/simple_mesh.obj
30
+ - chairs/063d1_chair_006/simple_mesh.obj
31
+ - chairs/bea57_chair_012/simple_mesh.obj
32
+ - chairs/fe0fe_chair_004/simple_mesh.obj
33
+ - chairs/288dc_chair_011/simple_mesh.obj
34
+ # consider decimating animals/../color_mesh.obj yourself for better mesh topology than the provided simple_mesh.obj
35
+ # (e.g. <50k vertices for functional map efficiency).
36
+
37
+ loss:
38
+ triplet: 1.0
39
+
40
+ use_2d_feat: False
41
+ pvcnn:
42
+ point_encoder_type: 'pvcnn'
43
+ z_triplane_channels: 256
44
+ z_triplane_resolution: 128
modules/PartField/configs/final/demo.yaml ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ result_name: demo_test
2
+
3
+ continue_ckpt: model/model.ckpt
4
+
5
+ triplane_channels_low: 128
6
+ triplane_channels_high: 512
7
+ triplane_resolution: 128
8
+
9
+ n_point_per_face: 1000
10
+ n_sample_each: 10000
11
+ is_pc : True
12
+ remesh_demo : False
13
+
14
+ dataset:
15
+ type: "Mix"
16
+ data_path: "objaverse_data"
17
+ train_batch_size: 1
18
+ val_batch_size: 1
19
+ train_num_workers: 8
20
+
21
+ loss:
22
+ triplet: 1.0
23
+
24
+ use_2d_feat: False
25
+ pvcnn:
26
+ point_encoder_type: 'pvcnn'
27
+ z_triplane_channels: 256
28
+ z_triplane_resolution: 128
modules/PartField/partfield/config/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import os.path as osp
3
+ from datetime import datetime
4
+ import pytz
5
+
6
+ def default_argument_parser(add_help=True, default_config_file=""):
7
+ parser = argparse.ArgumentParser(add_help=add_help)
8
+ parser.add_argument("--config-file", '-c', default=default_config_file, metavar="FILE", help="path to config file")
9
+ parser.add_argument(
10
+ "--opts",
11
+ help="Modify config options using the command-line",
12
+ default=None,
13
+ nargs=argparse.REMAINDER,
14
+ )
15
+ return parser
16
+
17
+ def setup(args, freeze=True):
18
+ from .defaults import _C as cfg
19
+ cfg = cfg.clone()
20
+ cfg.merge_from_file(args.config_file)
21
+ cfg.merge_from_list(args.opts)
22
+ dt = datetime.now(pytz.timezone('America/Los_Angeles')).strftime('%y%m%d-%H%M%S')
23
+ cfg.output_dir = osp.join(cfg.output_dir, cfg.name, dt)
24
+ if freeze:
25
+ cfg.freeze()
26
+ return cfg
modules/PartField/partfield/config/defaults.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from yacs.config import CfgNode as CN
2
+
3
+ _C = CN()
4
+ _C.seed = 0
5
+ _C.output_dir = "results"
6
+ _C.result_name = "test_all"
7
+
8
+ _C.triplet_sampling = "random"
9
+ _C.load_original_mesh = False
10
+
11
+ _C.num_pos = 64
12
+ _C.num_neg_random = 256
13
+ _C.num_neg_hard_pc = 128
14
+ _C.num_neg_hard_emb = 128
15
+
16
+ _C.vertex_feature = False # if true, sample feature on vertices; if false, sample feature on faces
17
+ _C.n_point_per_face = 2000
18
+ _C.n_sample_each = 10000
19
+ _C.preprocess_mesh = False
20
+
21
+ _C.regress_2d_feat = False
22
+
23
+ _C.is_pc = False
24
+
25
+ _C.cut_manifold = False
26
+ _C.remesh_demo = False
27
+ _C.correspondence_demo = False
28
+
29
+ _C.save_every_epoch = 10
30
+ _C.training_epochs = 30
31
+ _C.continue_training = False
32
+
33
+ _C.continue_ckpt = None
34
+ _C.epoch_selected = "epoch=50.ckpt"
35
+
36
+ _C.triplane_resolution = 128
37
+ _C.triplane_channels_low = 128
38
+ _C.triplane_channels_high = 512
39
+ _C.lr = 1e-3
40
+ _C.train = True
41
+ _C.test = False
42
+
43
+ _C.inference_save_pred_sdf_to_mesh=True
44
+ _C.inference_save_feat_pca=True
45
+ _C.name = "test"
46
+ _C.test_subset = False
47
+ _C.test_corres = False
48
+ _C.test_partobjaversetiny = False
49
+
50
+ _C.dataset = CN()
51
+ _C.dataset.type = "Demo_Dataset"
52
+ _C.dataset.data_path = "objaverse_data/"
53
+ _C.dataset.train_num_workers = 64
54
+ _C.dataset.val_num_workers = 32
55
+ _C.dataset.train_batch_size = 2
56
+ _C.dataset.val_batch_size = 2
57
+ _C.dataset.all_files = [] # only used for correspondence demo
58
+
59
+ _C.voxel2triplane = CN()
60
+ _C.voxel2triplane.transformer_dim = 1024
61
+ _C.voxel2triplane.transformer_layers = 6
62
+ _C.voxel2triplane.transformer_heads = 8
63
+ _C.voxel2triplane.triplane_low_res = 32
64
+ _C.voxel2triplane.triplane_high_res = 256
65
+ _C.voxel2triplane.triplane_dim = 64
66
+ _C.voxel2triplane.normalize_vox_feat = False
67
+
68
+
69
+ _C.loss = CN()
70
+ _C.loss.triplet = 0.0
71
+ _C.loss.sdf = 1.0
72
+ _C.loss.feat = 10.0
73
+ _C.loss.l1 = 0.0
74
+
75
+ _C.use_pvcnn = False
76
+ _C.use_pvcnnonly = True
77
+
78
+ _C.pvcnn = CN()
79
+ _C.pvcnn.point_encoder_type = 'pvcnn'
80
+ _C.pvcnn.use_point_scatter = True
81
+ _C.pvcnn.z_triplane_channels = 64
82
+ _C.pvcnn.z_triplane_resolution = 256
83
+ _C.pvcnn.unet_cfg = CN()
84
+ _C.pvcnn.unet_cfg.depth = 3
85
+ _C.pvcnn.unet_cfg.enabled = True
86
+ _C.pvcnn.unet_cfg.rolled = True
87
+ _C.pvcnn.unet_cfg.use_3d_aware = True
88
+ _C.pvcnn.unet_cfg.start_hidden_channels = 32
89
+ _C.pvcnn.unet_cfg.use_initial_conv = False
90
+
91
+ _C.use_2d_feat = False
92
+ _C.inference_metrics_only = False
modules/PartField/partfield/model/PVCNN/conv_pointnet.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Taken from gensdf
3
+ https://github.com/princeton-computational-imaging/gensdf
4
+ """
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ # from dnnlib.util import printarr
10
+ try:
11
+ from torch_scatter import scatter_mean, scatter_max
12
+ except:
13
+ pass
14
+ # from .unet import UNet
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
+
19
+
20
+ # Resnet Blocks
21
+ class ResnetBlockFC(nn.Module):
22
+ ''' Fully connected ResNet Block class.
23
+ Args:
24
+ size_in (int): input dimension
25
+ size_out (int): output dimension
26
+ size_h (int): hidden dimension
27
+ '''
28
+
29
+ def __init__(self, size_in, size_out=None, size_h=None):
30
+ super().__init__()
31
+ # Attributes
32
+ if size_out is None:
33
+ size_out = size_in
34
+
35
+ if size_h is None:
36
+ size_h = min(size_in, size_out)
37
+
38
+ self.size_in = size_in
39
+ self.size_h = size_h
40
+ self.size_out = size_out
41
+ # Submodules
42
+ self.fc_0 = nn.Linear(size_in, size_h)
43
+ self.fc_1 = nn.Linear(size_h, size_out)
44
+ self.actvn = nn.ReLU()
45
+
46
+ if size_in == size_out:
47
+ self.shortcut = None
48
+ else:
49
+ self.shortcut = nn.Linear(size_in, size_out, bias=False)
50
+ # Initialization
51
+ nn.init.zeros_(self.fc_1.weight)
52
+
53
+ def forward(self, x):
54
+ net = self.fc_0(self.actvn(x))
55
+ dx = self.fc_1(self.actvn(net))
56
+
57
+ if self.shortcut is not None:
58
+ x_s = self.shortcut(x)
59
+ else:
60
+ x_s = x
61
+
62
+ return x_s + dx
63
+
64
+
65
+ class ConvPointnet(nn.Module):
66
+ ''' PointNet-based encoder network with ResNet blocks for each point.
67
+ Number of input points are fixed.
68
+
69
+ Args:
70
+ c_dim (int): dimension of latent code c
71
+ dim (int): input points dimension
72
+ hidden_dim (int): hidden dimension of the network
73
+ scatter_type (str): feature aggregation when doing local pooling
74
+ unet (bool): weather to use U-Net
75
+ unet_kwargs (str): U-Net parameters
76
+ plane_resolution (int): defined resolution for plane feature
77
+ plane_type (str): feature type, 'xz' - 1-plane, ['xz', 'xy', 'yz'] - 3-plane, ['grid'] - 3D grid volume
78
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
79
+ n_blocks (int): number of blocks ResNetBlockFC layers
80
+ '''
81
+
82
+ def __init__(self, c_dim=128, dim=3, hidden_dim=128, scatter_type='max',
83
+ # unet=False, unet_kwargs=None,
84
+ plane_resolution=None, plane_type=['xz', 'xy', 'yz'], padding=0.1, n_blocks=5):
85
+ super().__init__()
86
+ self.c_dim = c_dim
87
+
88
+ self.fc_pos = nn.Linear(dim, 2*hidden_dim)
89
+ self.blocks = nn.ModuleList([
90
+ ResnetBlockFC(2*hidden_dim, hidden_dim) for i in range(n_blocks)
91
+ ])
92
+ self.fc_c = nn.Linear(hidden_dim, c_dim)
93
+
94
+ self.actvn = nn.ReLU()
95
+ self.hidden_dim = hidden_dim
96
+
97
+ # if unet:
98
+ # self.unet = UNet(c_dim, in_channels=c_dim, **unet_kwargs)
99
+ # else:
100
+ # self.unet = None
101
+
102
+ self.reso_plane = plane_resolution
103
+ self.plane_type = plane_type
104
+ self.padding = padding
105
+
106
+ if scatter_type == 'max':
107
+ self.scatter = scatter_max
108
+ elif scatter_type == 'mean':
109
+ self.scatter = scatter_mean
110
+
111
+
112
+ # takes in "p": point cloud and "query": sdf_xyz
113
+ # sample plane features for unlabeled_query as well
114
+ def forward(self, p):#, query2):
115
+ batch_size, T, D = p.size()
116
+
117
+ # acquire the index for each point
118
+ coord = {}
119
+ index = {}
120
+ if 'xz' in self.plane_type:
121
+ coord['xz'] = self.normalize_coordinate(p.clone(), plane='xz', padding=self.padding)
122
+ index['xz'] = self.coordinate2index(coord['xz'], self.reso_plane)
123
+ if 'xy' in self.plane_type:
124
+ coord['xy'] = self.normalize_coordinate(p.clone(), plane='xy', padding=self.padding)
125
+ index['xy'] = self.coordinate2index(coord['xy'], self.reso_plane)
126
+ if 'yz' in self.plane_type:
127
+ coord['yz'] = self.normalize_coordinate(p.clone(), plane='yz', padding=self.padding)
128
+ index['yz'] = self.coordinate2index(coord['yz'], self.reso_plane)
129
+
130
+
131
+ net = self.fc_pos(p)
132
+
133
+ net = self.blocks[0](net)
134
+ for block in self.blocks[1:]:
135
+ pooled = self.pool_local(coord, index, net)
136
+ net = torch.cat([net, pooled], dim=2)
137
+ net = block(net)
138
+
139
+ c = self.fc_c(net)
140
+
141
+ fea = {}
142
+ plane_feat_sum = 0
143
+ #second_sum = 0
144
+ if 'xz' in self.plane_type:
145
+ fea['xz'] = self.generate_plane_features(p, c, plane='xz') # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
146
+ # plane_feat_sum += self.sample_plane_feature(query, fea['xz'], 'xz')
147
+ #second_sum += self.sample_plane_feature(query2, fea['xz'], 'xz')
148
+ if 'xy' in self.plane_type:
149
+ fea['xy'] = self.generate_plane_features(p, c, plane='xy')
150
+ # plane_feat_sum += self.sample_plane_feature(query, fea['xy'], 'xy')
151
+ #second_sum += self.sample_plane_feature(query2, fea['xy'], 'xy')
152
+ if 'yz' in self.plane_type:
153
+ fea['yz'] = self.generate_plane_features(p, c, plane='yz')
154
+ # plane_feat_sum += self.sample_plane_feature(query, fea['yz'], 'yz')
155
+ #second_sum += self.sample_plane_feature(query2, fea['yz'], 'yz')
156
+ return fea
157
+
158
+ # return plane_feat_sum.transpose(2,1)#, second_sum.transpose(2,1)
159
+
160
+
161
+ def normalize_coordinate(self, p, padding=0.1, plane='xz'):
162
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
163
+
164
+ Args:
165
+ p (tensor): point
166
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
167
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
168
+ '''
169
+ if plane == 'xz':
170
+ xy = p[:, :, [0, 2]]
171
+ elif plane =='xy':
172
+ xy = p[:, :, [0, 1]]
173
+ else:
174
+ xy = p[:, :, [1, 2]]
175
+
176
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
177
+ xy_new = xy_new + 0.5 # range (0, 1)
178
+
179
+ # f there are outliers out of the range
180
+ if xy_new.max() >= 1:
181
+ xy_new[xy_new >= 1] = 1 - 10e-6
182
+ if xy_new.min() < 0:
183
+ xy_new[xy_new < 0] = 0.0
184
+ return xy_new
185
+
186
+
187
+ def coordinate2index(self, x, reso):
188
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
189
+ Corresponds to our 3D model
190
+
191
+ Args:
192
+ x (tensor): coordinate
193
+ reso (int): defined resolution
194
+ coord_type (str): coordinate type
195
+ '''
196
+ x = (x * reso).long()
197
+ index = x[:, :, 0] + reso * x[:, :, 1]
198
+ index = index[:, None, :]
199
+ return index
200
+
201
+
202
+ # xy is the normalized coordinates of the point cloud of each plane
203
+ # I'm pretty sure the keys of xy are the same as those of index, so xy isn't needed here as input
204
+ def pool_local(self, xy, index, c):
205
+ bs, fea_dim = c.size(0), c.size(2)
206
+ keys = xy.keys()
207
+
208
+ c_out = 0
209
+ for key in keys:
210
+ # scatter plane features from points
211
+ fea = self.scatter(c.permute(0, 2, 1), index[key], dim_size=self.reso_plane**2)
212
+ if self.scatter == scatter_max:
213
+ fea = fea[0]
214
+ # gather feature back to points
215
+ fea = fea.gather(dim=2, index=index[key].expand(-1, fea_dim, -1))
216
+ c_out += fea
217
+ return c_out.permute(0, 2, 1)
218
+
219
+
220
+ def generate_plane_features(self, p, c, plane='xz'):
221
+ # acquire indices of features in plane
222
+ xy = self.normalize_coordinate(p.clone(), plane=plane, padding=self.padding) # normalize to the range of (0, 1)
223
+ index = self.coordinate2index(xy, self.reso_plane)
224
+
225
+ # scatter plane features from points
226
+ fea_plane = c.new_zeros(p.size(0), self.c_dim, self.reso_plane**2)
227
+ c = c.permute(0, 2, 1) # B x 512 x T
228
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
229
+ fea_plane = fea_plane.reshape(p.size(0), self.c_dim, self.reso_plane, self.reso_plane) # sparce matrix (B x 512 x reso x reso)
230
+
231
+ # printarr(fea_plane, c, p, xy, index)
232
+ # import pdb; pdb.set_trace()
233
+
234
+ # process the plane features with UNet
235
+ # if self.unet is not None:
236
+ # fea_plane = self.unet(fea_plane)
237
+
238
+ return fea_plane
239
+
240
+
241
+ # sample_plane_feature function copied from /src/conv_onet/models/decoder.py
242
+ # uses values from plane_feature and pixel locations from vgrid to interpolate feature
243
+ def sample_plane_feature(self, query, plane_feature, plane):
244
+ xy = self.normalize_coordinate(query.clone(), plane=plane, padding=self.padding)
245
+ xy = xy[:, :, None].float()
246
+ vgrid = 2.0 * xy - 1.0 # normalize to (-1, 1)
247
+ sampled_feat = F.grid_sample(plane_feature, vgrid, padding_mode='border', align_corners=True, mode='bilinear').squeeze(-1)
248
+ return sampled_feat
249
+
250
+
251
+
modules/PartField/partfield/model/PVCNN/dnnlib_util.py ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ """Miscellaneous utility classes and functions."""
10
+ from collections import namedtuple
11
+ import time
12
+ import ctypes
13
+ import fnmatch
14
+ import importlib
15
+ import inspect
16
+ import numpy as np
17
+ import json
18
+ import os
19
+ import shutil
20
+ import sys
21
+ import types
22
+ import io
23
+ import pickle
24
+ import re
25
+ # import requests
26
+ import html
27
+ import hashlib
28
+ import glob
29
+ import tempfile
30
+ import urllib
31
+ import urllib.request
32
+ import uuid
33
+ import boto3
34
+ import threading
35
+ from contextlib import ContextDecorator
36
+ from contextlib import contextmanager, nullcontext
37
+
38
+ from distutils.util import strtobool
39
+ from typing import Any, List, Tuple, Union
40
+ import importlib
41
+ from loguru import logger
42
+ # import wandb
43
+ import torch
44
+ import psutil
45
+ import subprocess
46
+
47
+ import random
48
+ import string
49
+ import pdb
50
+
51
+ # Util classes
52
+ # ------------------------------------------------------------------------------------------
53
+
54
+
55
+ class EasyDict(dict):
56
+ """Convenience class that behaves like a dict but allows access with the attribute syntax."""
57
+
58
+ def __getattr__(self, name: str) -> Any:
59
+ try:
60
+ return self[name]
61
+ except KeyError:
62
+ raise AttributeError(name)
63
+
64
+ def __setattr__(self, name: str, value: Any) -> None:
65
+ self[name] = value
66
+
67
+ def __delattr__(self, name: str) -> None:
68
+ del self[name]
69
+
70
+
71
+ class Logger(object):
72
+ """Redirect stderr to stdout, optionally print stdout to a file, and optionally force flushing on both stdout and the file."""
73
+
74
+ def __init__(self, file_name: str = None, file_mode: str = "w", should_flush: bool = True):
75
+ self.file = None
76
+
77
+ if file_name is not None:
78
+ self.file = open(file_name, file_mode)
79
+
80
+ self.should_flush = should_flush
81
+ self.stdout = sys.stdout
82
+ self.stderr = sys.stderr
83
+
84
+ sys.stdout = self
85
+ sys.stderr = self
86
+
87
+ def __enter__(self) -> "Logger":
88
+ return self
89
+
90
+ def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
91
+ self.close()
92
+
93
+ def write(self, text: Union[str, bytes]) -> None:
94
+ """Write text to stdout (and a file) and optionally flush."""
95
+ if isinstance(text, bytes):
96
+ text = text.decode()
97
+ if len(text) == 0: # workaround for a bug in VSCode debugger: sys.stdout.write(''); sys.stdout.flush() => crash
98
+ return
99
+
100
+ if self.file is not None:
101
+ self.file.write(text)
102
+
103
+ self.stdout.write(text)
104
+
105
+ if self.should_flush:
106
+ self.flush()
107
+
108
+ def flush(self) -> None:
109
+ """Flush written text to both stdout and a file, if open."""
110
+ if self.file is not None:
111
+ self.file.flush()
112
+
113
+ self.stdout.flush()
114
+
115
+ def close(self) -> None:
116
+ """Flush, close possible files, and remove stdout/stderr mirroring."""
117
+ self.flush()
118
+
119
+ # if using multiple loggers, prevent closing in wrong order
120
+ if sys.stdout is self:
121
+ sys.stdout = self.stdout
122
+ if sys.stderr is self:
123
+ sys.stderr = self.stderr
124
+
125
+ if self.file is not None:
126
+ self.file.close()
127
+ self.file = None
128
+
129
+
130
+ # Cache directories
131
+ # ------------------------------------------------------------------------------------------
132
+
133
+ _dnnlib_cache_dir = None
134
+
135
+
136
+ def set_cache_dir(path: str) -> None:
137
+ global _dnnlib_cache_dir
138
+ _dnnlib_cache_dir = path
139
+
140
+
141
+ def make_cache_dir_path(*paths: str) -> str:
142
+ if _dnnlib_cache_dir is not None:
143
+ return os.path.join(_dnnlib_cache_dir, *paths)
144
+ if 'DNNLIB_CACHE_DIR' in os.environ:
145
+ return os.path.join(os.environ['DNNLIB_CACHE_DIR'], *paths)
146
+ if 'HOME' in os.environ:
147
+ return os.path.join(os.environ['HOME'], '.cache', 'dnnlib', *paths)
148
+ if 'USERPROFILE' in os.environ:
149
+ return os.path.join(os.environ['USERPROFILE'], '.cache', 'dnnlib', *paths)
150
+ return os.path.join(tempfile.gettempdir(), '.cache', 'dnnlib', *paths)
151
+
152
+
153
+ # Small util functions
154
+ # ------------------------------------------------------------------------------------------
155
+
156
+
157
+ def format_time(seconds: Union[int, float]) -> str:
158
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
159
+ s = int(np.rint(seconds))
160
+
161
+ if s < 60:
162
+ return "{0}s".format(s)
163
+ elif s < 60 * 60:
164
+ return "{0}m {1:02}s".format(s // 60, s % 60)
165
+ elif s < 24 * 60 * 60:
166
+ return "{0}h {1:02}m {2:02}s".format(s // (60 * 60), (s // 60) % 60, s % 60)
167
+ else:
168
+ return "{0}d {1:02}h {2:02}m".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24, (s // 60) % 60)
169
+
170
+
171
+ def format_time_brief(seconds: Union[int, float]) -> str:
172
+ """Convert the seconds to human readable string with days, hours, minutes and seconds."""
173
+ s = int(np.rint(seconds))
174
+
175
+ if s < 60:
176
+ return "{0}s".format(s)
177
+ elif s < 60 * 60:
178
+ return "{0}m {1:02}s".format(s // 60, s % 60)
179
+ elif s < 24 * 60 * 60:
180
+ return "{0}h {1:02}m".format(s // (60 * 60), (s // 60) % 60)
181
+ else:
182
+ return "{0}d {1:02}h".format(s // (24 * 60 * 60), (s // (60 * 60)) % 24)
183
+
184
+
185
+ def ask_yes_no(question: str) -> bool:
186
+ """Ask the user the question until the user inputs a valid answer."""
187
+ while True:
188
+ try:
189
+ print("{0} [y/n]".format(question))
190
+ return strtobool(input().lower())
191
+ except ValueError:
192
+ pass
193
+
194
+
195
+ def tuple_product(t: Tuple) -> Any:
196
+ """Calculate the product of the tuple elements."""
197
+ result = 1
198
+
199
+ for v in t:
200
+ result *= v
201
+
202
+ return result
203
+
204
+
205
+ _str_to_ctype = {
206
+ "uint8": ctypes.c_ubyte,
207
+ "uint16": ctypes.c_uint16,
208
+ "uint32": ctypes.c_uint32,
209
+ "uint64": ctypes.c_uint64,
210
+ "int8": ctypes.c_byte,
211
+ "int16": ctypes.c_int16,
212
+ "int32": ctypes.c_int32,
213
+ "int64": ctypes.c_int64,
214
+ "float32": ctypes.c_float,
215
+ "float64": ctypes.c_double
216
+ }
217
+
218
+
219
+ def get_dtype_and_ctype(type_obj: Any) -> Tuple[np.dtype, Any]:
220
+ """Given a type name string (or an object having a __name__ attribute), return matching Numpy and ctypes types that have the same size in bytes."""
221
+ type_str = None
222
+
223
+ if isinstance(type_obj, str):
224
+ type_str = type_obj
225
+ elif hasattr(type_obj, "__name__"):
226
+ type_str = type_obj.__name__
227
+ elif hasattr(type_obj, "name"):
228
+ type_str = type_obj.name
229
+ else:
230
+ raise RuntimeError("Cannot infer type name from input")
231
+
232
+ assert type_str in _str_to_ctype.keys()
233
+
234
+ my_dtype = np.dtype(type_str)
235
+ my_ctype = _str_to_ctype[type_str]
236
+
237
+ assert my_dtype.itemsize == ctypes.sizeof(my_ctype)
238
+
239
+ return my_dtype, my_ctype
240
+
241
+
242
+ def is_pickleable(obj: Any) -> bool:
243
+ try:
244
+ with io.BytesIO() as stream:
245
+ pickle.dump(obj, stream)
246
+ return True
247
+ except:
248
+ return False
249
+
250
+
251
+ # Functionality to import modules/objects by name, and call functions by name
252
+ # ------------------------------------------------------------------------------------------
253
+
254
+ def get_module_from_obj_name(obj_name: str) -> Tuple[types.ModuleType, str]:
255
+ """Searches for the underlying module behind the name to some python object.
256
+ Returns the module and the object name (original name with module part removed)."""
257
+
258
+ # allow convenience shorthands, substitute them by full names
259
+ obj_name = re.sub("^np.", "numpy.", obj_name)
260
+ obj_name = re.sub("^tf.", "tensorflow.", obj_name)
261
+
262
+ # list alternatives for (module_name, local_obj_name)
263
+ parts = obj_name.split(".")
264
+ name_pairs = [(".".join(parts[:i]), ".".join(parts[i:])) for i in range(len(parts), 0, -1)]
265
+
266
+ # try each alternative in turn
267
+ for module_name, local_obj_name in name_pairs:
268
+ try:
269
+ module = importlib.import_module(module_name) # may raise ImportError
270
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
271
+ return module, local_obj_name
272
+ except:
273
+ pass
274
+
275
+ # maybe some of the modules themselves contain errors?
276
+ for module_name, _local_obj_name in name_pairs:
277
+ try:
278
+ importlib.import_module(module_name) # may raise ImportError
279
+ except ImportError:
280
+ if not str(sys.exc_info()[1]).startswith("No module named '" + module_name + "'"):
281
+ raise
282
+
283
+ # maybe the requested attribute is missing?
284
+ for module_name, local_obj_name in name_pairs:
285
+ try:
286
+ module = importlib.import_module(module_name) # may raise ImportError
287
+ get_obj_from_module(module, local_obj_name) # may raise AttributeError
288
+ except ImportError:
289
+ pass
290
+
291
+ # we are out of luck, but we have no idea why
292
+ raise ImportError(obj_name)
293
+
294
+
295
+ def get_obj_from_module(module: types.ModuleType, obj_name: str) -> Any:
296
+ """Traverses the object name and returns the last (rightmost) python object."""
297
+ if obj_name == '':
298
+ return module
299
+ obj = module
300
+ for part in obj_name.split("."):
301
+ obj = getattr(obj, part)
302
+ return obj
303
+
304
+
305
+ def get_obj_by_name(name: str) -> Any:
306
+ """Finds the python object with the given name."""
307
+ module, obj_name = get_module_from_obj_name(name)
308
+ return get_obj_from_module(module, obj_name)
309
+
310
+
311
+ def call_func_by_name(*args, func_name: str = None, **kwargs) -> Any:
312
+ """Finds the python object with the given name and calls it as a function."""
313
+ assert func_name is not None
314
+ func_obj = get_obj_by_name(func_name)
315
+ assert callable(func_obj)
316
+ return func_obj(*args, **kwargs)
317
+
318
+
319
+ def construct_class_by_name(*args, class_name: str = None, **kwargs) -> Any:
320
+ """Finds the python class with the given name and constructs it with the given arguments."""
321
+ return call_func_by_name(*args, func_name=class_name, **kwargs)
322
+
323
+
324
+ def get_module_dir_by_obj_name(obj_name: str) -> str:
325
+ """Get the directory path of the module containing the given object name."""
326
+ module, _ = get_module_from_obj_name(obj_name)
327
+ return os.path.dirname(inspect.getfile(module))
328
+
329
+
330
+ def is_top_level_function(obj: Any) -> bool:
331
+ """Determine whether the given object is a top-level function, i.e., defined at module scope using 'def'."""
332
+ return callable(obj) and obj.__name__ in sys.modules[obj.__module__].__dict__
333
+
334
+
335
+ def get_top_level_function_name(obj: Any) -> str:
336
+ """Return the fully-qualified name of a top-level function."""
337
+ assert is_top_level_function(obj)
338
+ module = obj.__module__
339
+ if module == '__main__':
340
+ module = os.path.splitext(os.path.basename(sys.modules[module].__file__))[0]
341
+ return module + "." + obj.__name__
342
+
343
+
344
+ # File system helpers
345
+ # ------------------------------------------------------------------------------------------
346
+
347
+ def list_dir_recursively_with_ignore(dir_path: str, ignores: List[str] = None, add_base_to_relative: bool = False) -> List[Tuple[str, str]]:
348
+ """List all files recursively in a given directory while ignoring given file and directory names.
349
+ Returns list of tuples containing both absolute and relative paths."""
350
+ assert os.path.isdir(dir_path)
351
+ base_name = os.path.basename(os.path.normpath(dir_path))
352
+
353
+ if ignores is None:
354
+ ignores = []
355
+
356
+ result = []
357
+
358
+ for root, dirs, files in os.walk(dir_path, topdown=True):
359
+ for ignore_ in ignores:
360
+ dirs_to_remove = [d for d in dirs if fnmatch.fnmatch(d, ignore_)]
361
+
362
+ # dirs need to be edited in-place
363
+ for d in dirs_to_remove:
364
+ dirs.remove(d)
365
+
366
+ files = [f for f in files if not fnmatch.fnmatch(f, ignore_)]
367
+
368
+ absolute_paths = [os.path.join(root, f) for f in files]
369
+ relative_paths = [os.path.relpath(p, dir_path) for p in absolute_paths]
370
+
371
+ if add_base_to_relative:
372
+ relative_paths = [os.path.join(base_name, p) for p in relative_paths]
373
+
374
+ assert len(absolute_paths) == len(relative_paths)
375
+ result += zip(absolute_paths, relative_paths)
376
+
377
+ return result
378
+
379
+
380
+ def copy_files_and_create_dirs(files: List[Tuple[str, str]]) -> None:
381
+ """Takes in a list of tuples of (src, dst) paths and copies files.
382
+ Will create all necessary directories."""
383
+ for file in files:
384
+ target_dir_name = os.path.dirname(file[1])
385
+
386
+ # will create all intermediate-level directories
387
+ if not os.path.exists(target_dir_name):
388
+ os.makedirs(target_dir_name)
389
+
390
+ shutil.copyfile(file[0], file[1])
391
+
392
+
393
+ # URL helpers
394
+ # ------------------------------------------------------------------------------------------
395
+
396
+ def is_url(obj: Any, allow_file_urls: bool = False) -> bool:
397
+ """Determine whether the given object is a valid URL string."""
398
+ if not isinstance(obj, str) or not "://" in obj:
399
+ return False
400
+ if allow_file_urls and obj.startswith('file://'):
401
+ return True
402
+ try:
403
+ res = requests.compat.urlparse(obj)
404
+ if not res.scheme or not res.netloc or not "." in res.netloc:
405
+ return False
406
+ res = requests.compat.urlparse(requests.compat.urljoin(obj, "/"))
407
+ if not res.scheme or not res.netloc or not "." in res.netloc:
408
+ return False
409
+ except:
410
+ return False
411
+ return True
412
+
413
+
414
+ def open_url(url: str, cache_dir: str = None, num_attempts: int = 10, verbose: bool = True, return_filename: bool = False, cache: bool = True) -> Any:
415
+ """Download the given URL and return a binary-mode file object to access the data."""
416
+ assert num_attempts >= 1
417
+ assert not (return_filename and (not cache))
418
+
419
+ # Doesn't look like an URL scheme so interpret it as a local filename.
420
+ if not re.match('^[a-z]+://', url):
421
+ return url if return_filename else open(url, "rb")
422
+
423
+ # Handle file URLs. This code handles unusual file:// patterns that
424
+ # arise on Windows:
425
+ #
426
+ # file:///c:/foo.txt
427
+ #
428
+ # which would translate to a local '/c:/foo.txt' filename that's
429
+ # invalid. Drop the forward slash for such pathnames.
430
+ #
431
+ # If you touch this code path, you should test it on both Linux and
432
+ # Windows.
433
+ #
434
+ # Some internet resources suggest using urllib.request.url2pathname() but
435
+ # but that converts forward slashes to backslashes and this causes
436
+ # its own set of problems.
437
+ if url.startswith('file://'):
438
+ filename = urllib.parse.urlparse(url).path
439
+ if re.match(r'^/[a-zA-Z]:', filename):
440
+ filename = filename[1:]
441
+ return filename if return_filename else open(filename, "rb")
442
+
443
+ assert is_url(url)
444
+
445
+ # Lookup from cache.
446
+ if cache_dir is None:
447
+ cache_dir = make_cache_dir_path('downloads')
448
+
449
+ url_md5 = hashlib.md5(url.encode("utf-8")).hexdigest()
450
+ if cache:
451
+ cache_files = glob.glob(os.path.join(cache_dir, url_md5 + "_*"))
452
+ if len(cache_files) == 1:
453
+ filename = cache_files[0]
454
+ return filename if return_filename else open(filename, "rb")
455
+
456
+ # Download.
457
+ url_name = None
458
+ url_data = None
459
+ with requests.Session() as session:
460
+ if verbose:
461
+ print("Downloading %s ..." % url, end="", flush=True)
462
+ for attempts_left in reversed(range(num_attempts)):
463
+ try:
464
+ with session.get(url) as res:
465
+ res.raise_for_status()
466
+ if len(res.content) == 0:
467
+ raise IOError("No data received")
468
+
469
+ if len(res.content) < 8192:
470
+ content_str = res.content.decode("utf-8")
471
+ if "download_warning" in res.headers.get("Set-Cookie", ""):
472
+ links = [html.unescape(link) for link in content_str.split('"') if "export=download" in link]
473
+ if len(links) == 1:
474
+ url = requests.compat.urljoin(url, links[0])
475
+ raise IOError("Google Drive virus checker nag")
476
+ if "Google Drive - Quota exceeded" in content_str:
477
+ raise IOError("Google Drive download quota exceeded -- please try again later")
478
+
479
+ match = re.search(r'filename="([^"]*)"', res.headers.get("Content-Disposition", ""))
480
+ url_name = match[1] if match else url
481
+ url_data = res.content
482
+ if verbose:
483
+ print(" done")
484
+ break
485
+ except KeyboardInterrupt:
486
+ raise
487
+ except:
488
+ if not attempts_left:
489
+ if verbose:
490
+ print(" failed")
491
+ raise
492
+ if verbose:
493
+ print(".", end="", flush=True)
494
+
495
+ # Save to cache.
496
+ if cache:
497
+ safe_name = re.sub(r"[^0-9a-zA-Z-._]", "_", url_name)
498
+ cache_file = os.path.join(cache_dir, url_md5 + "_" + safe_name)
499
+ temp_file = os.path.join(cache_dir, "tmp_" + uuid.uuid4().hex + "_" + url_md5 + "_" + safe_name)
500
+ os.makedirs(cache_dir, exist_ok=True)
501
+ with open(temp_file, "wb") as f:
502
+ f.write(url_data)
503
+ os.replace(temp_file, cache_file) # atomic
504
+ if return_filename:
505
+ return cache_file
506
+
507
+ # Return data as file object.
508
+ assert not return_filename
509
+ return io.BytesIO(url_data)
510
+
511
+ # ------------------------------------------------------------------------------------------
512
+ # util function modified from https://github.com/nv-tlabs/LION/blob/0467d2199076e95a7e88bafd99dcd7d48a04b4a7/utils/model_helper.py
513
+ def import_class(model_str):
514
+ from torch_utils.dist_utils import is_rank0
515
+ if is_rank0():
516
+ logger.info('import: {}', model_str)
517
+ p, m = model_str.rsplit('.', 1)
518
+ mod = importlib.import_module(p)
519
+ Model = getattr(mod, m)
520
+ return Model
521
+
522
+ class ScopedTorchProfiler(ContextDecorator):
523
+ """
524
+ Marks ranges for both nvtx profiling (with nsys) and torch autograd profiler
525
+ """
526
+ __global_counts = {}
527
+ enabled=False
528
+
529
+ def __init__(self, unique_name: str):
530
+ """
531
+ Names must be unique!
532
+ """
533
+ ScopedTorchProfiler.__global_counts[unique_name] = 0
534
+ self._name = unique_name
535
+ self._autograd_scope = torch.profiler.record_function(unique_name)
536
+
537
+ def __enter__(self):
538
+ if ScopedTorchProfiler.enabled:
539
+ torch.cuda.nvtx.range_push(self._name)
540
+ self._autograd_scope.__enter__()
541
+
542
+ def __exit__(self, exc_type, exc_value, traceback):
543
+ self._autograd_scope.__exit__(exc_type, exc_value, traceback)
544
+ if ScopedTorchProfiler.enabled:
545
+ torch.cuda.nvtx.range_pop()
546
+
547
+ class TimingsMonitor():
548
+ CUDATimer = namedtuple('CUDATimer', ['start', 'end'])
549
+ def __init__(self, device, enabled=True, timing_names:List[str]=[], cuda_timing_names:List[str]=[]):
550
+ """
551
+ Usage:
552
+ tmonitor = TimingsMonitor(device)
553
+ for i in range(n_iter):
554
+ # Record arbitrary scopes
555
+ with tmonitor.timing_scope('regular_scope_name'):
556
+ ...
557
+ with tmonitor.cuda_timing_scope('nested_scope_name'):
558
+ ...
559
+ with tmonitor.cuda_timing_scope('cuda_scope_name'):
560
+ ...
561
+ tmonitor.record_timing('duration_name', end_time - start_time)
562
+
563
+ # Gather timings
564
+ tmonitor.record_all_cuda_timings()
565
+ tmonitor.update_all_averages()
566
+ averages = tmonitor.get_average_timings()
567
+ all_timings = tmonitor.get_timings()
568
+
569
+ Two types of timers, standard report timing and cuda timings.
570
+ Cuda timing supports scoped context manager cuda_event_scope.
571
+ Args:
572
+ device: device to time on (needed for cuda timers)
573
+ # enabled: HACK to only report timings from rank 0, set enabled=(global_rank==0)
574
+ timing_names: timings to report optional (will auto add new names)
575
+ cuda_timing_names: cuda periods to time optional (will auto add new names)
576
+ """
577
+ self.enabled=enabled
578
+ self.device = device
579
+
580
+ # Normal timing
581
+ # self.all_timings_dict = {k:None for k in timing_names + cuda_timing_names}
582
+ self.all_timings_dict = {}
583
+ self.avg_meter_dict = {}
584
+
585
+ # Cuda event timers to measure time spent on pushing data to gpu and on training step
586
+ self.cuda_event_timers = {}
587
+
588
+ for k in timing_names:
589
+ self.add_new_timing(k)
590
+
591
+ for k in cuda_timing_names:
592
+ self.add_new_cuda_timing(k)
593
+
594
+ # Running averages
595
+ # self.avg_meter_dict = {k:AverageMeter() for k in self.all_timings_dict}
596
+
597
+ def add_new_timing(self, name):
598
+ self.avg_meter_dict[name] = AverageMeter()
599
+ self.all_timings_dict[name] = None
600
+
601
+ def add_new_cuda_timing(self, name):
602
+ start_event = torch.cuda.Event(enable_timing=True)
603
+ end_event = torch.cuda.Event(enable_timing=True)
604
+ self.cuda_event_timers[name] = self.CUDATimer(start=start_event, end=end_event)
605
+ self.add_new_timing(name)
606
+
607
+ def clear_timings(self):
608
+ self.all_timings_dict = {k:None for k in self.all_timings_dict}
609
+
610
+ def get_timings(self):
611
+ return self.all_timings_dict
612
+
613
+ def get_average_timings(self):
614
+ return {k:v.avg for k,v in self.avg_meter_dict.items()}
615
+
616
+ def update_all_averages(self):
617
+ """
618
+ Once per iter, when timings have been finished recording, one should
619
+ call update_average_iter to keep running average of timings.
620
+ """
621
+ for k,v in self.all_timings_dict.items():
622
+ if v is None:
623
+ print("none_timing", k)
624
+ continue
625
+ self.avg_meter_dict[k].update(v)
626
+
627
+ def record_timing(self, name, value):
628
+ if name not in self.all_timings_dict: self.add_new_timing(name)
629
+ # assert name in self.all_timings_dict
630
+ self.all_timings_dict[name] = value
631
+
632
+ def _record_cuda_event_start(self, name):
633
+ if name in self.cuda_event_timers:
634
+ self.cuda_event_timers[name].start.record(
635
+ torch.cuda.current_stream(self.device))
636
+
637
+ def _record_cuda_event_end(self, name):
638
+ if name in self.cuda_event_timers:
639
+ self.cuda_event_timers[name].end.record(
640
+ torch.cuda.current_stream(self.device))
641
+
642
+ @contextmanager
643
+ def cuda_timing_scope(self, name, profile=True):
644
+ if name not in self.all_timings_dict: self.add_new_cuda_timing(name)
645
+ with ScopedTorchProfiler(name) if profile else nullcontext():
646
+ self._record_cuda_event_start(name)
647
+ try:
648
+ yield
649
+ finally:
650
+ self._record_cuda_event_end(name)
651
+
652
+ @contextmanager
653
+ def timing_scope(self, name, profile=True):
654
+ if name not in self.all_timings_dict: self.add_new_timing(name)
655
+ with ScopedTorchProfiler(name) if profile else nullcontext():
656
+ start_time = time.time()
657
+ try:
658
+ yield
659
+ finally:
660
+ self.record_timing(name, time.time()-start_time)
661
+
662
+ def record_all_cuda_timings(self):
663
+ """ After all the cuda events call this to synchronize and record down the cuda timings. """
664
+ for k, events in self.cuda_event_timers.items():
665
+ with torch.no_grad():
666
+ events.end.synchronize()
667
+ # Convert to seconds
668
+ time_elapsed = events.start.elapsed_time(events.end)/1000.
669
+ self.all_timings_dict[k] = time_elapsed
670
+
671
+ def init_s3(config_file):
672
+ config = json.load(open(config_file, 'r'))
673
+ s3_client = boto3.client("s3", **config)
674
+ return s3_client
675
+
676
+ def download_from_s3(file_path, target_path, cfg):
677
+ tic = time.time()
678
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
679
+ bucket_name = file_path.split('/')[2]
680
+ file_key = file_path.split(bucket_name+'/')[-1]
681
+ print(bucket_name, file_key)
682
+ s3_client.download_file(bucket_name, file_key, target_path)
683
+ logger.info(f'finish download from ! s3://{bucket_name}/{file_key} to {target_path} %.1f sec'%(
684
+ time.time() - tic))
685
+
686
+ def upload_to_s3(buffer, bucket_name, key, config_dict):
687
+ logger.info(f'start upload_to_s3! bucket_name={bucket_name}, key={key}')
688
+ tic = time.time()
689
+ s3 = boto3.client('s3', **config_dict)
690
+ s3.put_object(Bucket=bucket_name, Key=key, Body=buffer.getvalue())
691
+ logger.info(f'finish upload_to_s3! s3://{bucket_name}/{key} %.1f sec'%(time.time() - tic))
692
+
693
+ def write_ckpt_to_s3(cfg, all_model_dict, ckpt_name):
694
+ buffer = io.BytesIO()
695
+ tic = time.time()
696
+ torch.save(all_model_dict, buffer) # take ~0.25 sec
697
+ # logger.info('write ckpt to buffer: %.2f sec'%(time.time() - tic))
698
+ group, name = cfg.outdir.rstrip("/").split("/")[-2:]
699
+ key = f"checkpoints/{group}/{name}/ckpt/{ckpt_name}"
700
+ bucket_name = cfg.checkpoint.write_s3_bucket
701
+
702
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
703
+
704
+ config_dict = json.load(open(cfg.checkpoint.write_s3_config, 'r'))
705
+ upload_thread = threading.Thread(target=upload_to_s3, args=(buffer, bucket_name, key, config_dict))
706
+ upload_thread.start()
707
+ path = f"s3://{bucket_name}/{key}"
708
+ return path
709
+
710
+ def upload_file_to_s3(cfg, file_path, key_name=None):
711
+ # file_path is the local file path, can be a yaml file
712
+ # this function is used to upload the ckecpoint only
713
+ tic = time.time()
714
+ group, name = cfg.outdir.rstrip("/").split("/")[-2:]
715
+ if key_name is None:
716
+ key = os.path.basename(file_path)
717
+ key = f"checkpoints/{group}/{name}/{key}"
718
+ bucket_name = cfg.checkpoint.write_s3_bucket
719
+ s3_client = init_s3(cfg.checkpoint.write_s3_config)
720
+ # Upload the file
721
+ with open(file_path, 'rb') as f:
722
+ s3_client.upload_fileobj(f, bucket_name, key)
723
+ full_s3_path = f"s3://{bucket_name}/{key}"
724
+ logger.info(f'upload_to_s3: {file_path} {full_s3_path} | use time: {time.time()-tic}')
725
+
726
+ return full_s3_path
727
+
728
+
729
+ def load_from_s3(file_path, cfg, load_fn):
730
+ """
731
+ ckpt_path example:
732
+ s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
733
+ """
734
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
735
+ bucket_name = file_path.split("s3://")[-1].split('/')[0]
736
+ key = file_path.split(f'{bucket_name}/')[-1]
737
+ # logger.info(f"-> try to load s3://{bucket_name}/{key} ")
738
+ tic = time.time()
739
+ for attemp in range(10):
740
+ try:
741
+ # Download the state dict from S3 into memory (as a binary stream)
742
+ with io.BytesIO() as buffer:
743
+ s3_client.download_fileobj(bucket_name, key, buffer)
744
+ buffer.seek(0)
745
+
746
+ # Load the state dict into a PyTorch model
747
+ # out = torch.load(buffer, map_location=torch.device("cpu"))
748
+ out = load_fn(buffer)
749
+ break
750
+ except:
751
+ logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
752
+ from torch_utils.dist_utils import is_rank0
753
+ if is_rank0():
754
+ logger.info(f'loaded {file_path} | use time: {time.time()-tic:.1f} sec')
755
+ return out
756
+
757
+ def load_torch_dict_from_s3(ckpt_path, cfg):
758
+ """
759
+ ckpt_path example:
760
+ s3://xzeng/checkpoints/2023_0413/vae_kl_5e-1/ckpt/snapshot_epo000163_iter164000.pt
761
+ """
762
+ s3_client = init_s3(cfg.checkpoint.write_s3_config) # use to test the s3_client can be init
763
+ bucket_name = ckpt_path.split("s3://")[-1].split('/')[0]
764
+ key = ckpt_path.split(f'{bucket_name}/')[-1]
765
+ for attemp in range(10):
766
+ try:
767
+ # Download the state dict from S3 into memory (as a binary stream)
768
+ with io.BytesIO() as buffer:
769
+ s3_client.download_fileobj(bucket_name, key, buffer)
770
+ buffer.seek(0)
771
+
772
+ # Load the state dict into a PyTorch model
773
+ out = torch.load(buffer, map_location=torch.device("cpu"))
774
+ break
775
+ except:
776
+ logger.info(f"fail to load s3://{bucket_name}/{key} attemp: {attemp}")
777
+ return out
778
+
779
+ def count_parameters_in_M(model):
780
+ return np.sum(np.prod(v.size()) for name, v in model.named_parameters() if "auxiliary" not in name) / 1e6
781
+
782
+ def printarr(*arrs, float_width=6, **kwargs):
783
+ """
784
+ Print a pretty table giving name, shape, dtype, type, and content information for input tensors or scalars.
785
+
786
+ Call like: printarr(my_arr, some_other_arr, maybe_a_scalar). Accepts a variable number of arguments.
787
+
788
+ Inputs can be:
789
+ - Numpy tensor arrays
790
+ - Pytorch tensor arrays
791
+ - Jax tensor arrays
792
+ - Python ints / floats
793
+ - None
794
+
795
+ It may also work with other array-like types, but they have not been tested.
796
+
797
+ Use the `float_width` option specify the precision to which floating point types are printed.
798
+
799
+ Author: Nicholas Sharp (nmwsharp.com)
800
+ Canonical source: https://gist.github.com/nmwsharp/54d04af87872a4988809f128e1a1d233
801
+ License: This snippet may be used under an MIT license, and it is also released into the public domain.
802
+ Please retain this docstring as a reference.
803
+ """
804
+
805
+ frame = inspect.currentframe().f_back
806
+ default_name = "[temporary]"
807
+
808
+ ## helpers to gather data about each array
809
+ def name_from_outer_scope(a):
810
+ if a is None:
811
+ return '[None]'
812
+ name = default_name
813
+ for k, v in frame.f_locals.items():
814
+ if v is a:
815
+ name = k
816
+ break
817
+ return name
818
+
819
+ def type_strip(type_str):
820
+ return type_str.lstrip('<class ').rstrip('>').replace('torch.', '').strip("'")
821
+
822
+ def dtype_str(a):
823
+ if a is None:
824
+ return 'None'
825
+ if isinstance(a, int):
826
+ return 'int'
827
+ if isinstance(a, float):
828
+ return 'float'
829
+ if isinstance(a, list) and len(a)>0:
830
+ return type_strip(str(type(a[0])))
831
+ if hasattr(a, 'dtype'):
832
+ return type_strip(str(a.dtype))
833
+ else:
834
+ return ''
835
+ def shape_str(a):
836
+ if a is None:
837
+ return 'N/A'
838
+ if isinstance(a, int):
839
+ return 'scalar'
840
+ if isinstance(a, float):
841
+ return 'scalar'
842
+ if isinstance(a, list):
843
+ return f"[{shape_str(a[0]) if len(a)>0 else '?'}]*{len(a)}"
844
+ if hasattr(a, 'shape'):
845
+ return str(tuple(a.shape))
846
+ else:
847
+ return ''
848
+ def type_str(a):
849
+ return type_strip(str(type(a))) # TODO this is is weird... what's the better way?
850
+ def device_str(a):
851
+ if hasattr(a, 'device'):
852
+ device_str = str(a.device)
853
+ if len(device_str) < 10:
854
+ # heuristic: jax returns some goofy long string we don't want, ignore it
855
+ return device_str
856
+ return ""
857
+ def format_float(x):
858
+ return f"{x:{float_width}g}"
859
+ def minmaxmean_str(a):
860
+ if a is None:
861
+ return ('N/A', 'N/A', 'N/A', 'N/A')
862
+ if isinstance(a, int) or isinstance(a, float):
863
+ return (format_float(a),)*4
864
+
865
+ # compute min/max/mean. if anything goes wrong, just print 'N/A'
866
+ min_str = "N/A"
867
+ try: min_str = format_float(a.min())
868
+ except: pass
869
+ max_str = "N/A"
870
+ try: max_str = format_float(a.max())
871
+ except: pass
872
+ mean_str = "N/A"
873
+ try: mean_str = format_float(a.mean())
874
+ except: pass
875
+ try: median_str = format_float(a.median())
876
+ except:
877
+ try: median_str = format_float(np.median(np.array(a)))
878
+ except: median_str = 'N/A'
879
+ return (min_str, max_str, mean_str, median_str)
880
+
881
+ def get_prop_dict(a,k=None):
882
+ minmaxmean = minmaxmean_str(a)
883
+ props = {
884
+ 'name' : name_from_outer_scope(a) if k is None else k,
885
+ # 'type' : str(type(a)).replace('torch.',''),
886
+ 'dtype' : dtype_str(a),
887
+ 'shape' : shape_str(a),
888
+ 'type' : type_str(a),
889
+ 'device' : device_str(a),
890
+ 'min' : minmaxmean[0],
891
+ 'max' : minmaxmean[1],
892
+ 'mean' : minmaxmean[2],
893
+ 'median': minmaxmean[3]
894
+ }
895
+ return props
896
+
897
+ try:
898
+
899
+ props = ['name', 'type', 'dtype', 'shape', 'device', 'min', 'max', 'mean', 'median']
900
+
901
+ # precompute all of the properties for each input
902
+ str_props = []
903
+ for a in arrs:
904
+ str_props.append(get_prop_dict(a))
905
+ for k,a in kwargs.items():
906
+ str_props.append(get_prop_dict(a, k=k))
907
+
908
+ # for each property, compute its length
909
+ maxlen = {}
910
+ for p in props: maxlen[p] = 0
911
+ for sp in str_props:
912
+ for p in props:
913
+ maxlen[p] = max(maxlen[p], len(sp[p]))
914
+
915
+ # if any property got all empty strings, don't bother printing it, remove if from the list
916
+ props = [p for p in props if maxlen[p] > 0]
917
+
918
+ # print a header
919
+ header_str = ""
920
+ for p in props:
921
+ prefix = "" if p == 'name' else " | "
922
+ fmt_key = ">" if p == 'name' else "<"
923
+ header_str += f"{prefix}{p:{fmt_key}{maxlen[p]}}"
924
+ print(header_str)
925
+ print("-"*len(header_str))
926
+
927
+ # now print the acual arrays
928
+ for strp in str_props:
929
+ for p in props:
930
+ prefix = "" if p == 'name' else " | "
931
+ fmt_key = ">" if p == 'name' else "<"
932
+ print(f"{prefix}{strp[p]:{fmt_key}{maxlen[p]}}", end='')
933
+ print("")
934
+
935
+ finally:
936
+ del frame
937
+
938
+ def debug_print_all_tensor_sizes(min_tot_size = 0):
939
+ import gc
940
+ print("---------------------------------------"*3)
941
+ for obj in gc.get_objects():
942
+ try:
943
+ if torch.is_tensor(obj) or (hasattr(obj, 'data') and torch.is_tensor(obj.data)):
944
+ if np.prod(obj.size())>=min_tot_size:
945
+ print(type(obj), obj.size())
946
+ except:
947
+ pass
948
+ def print_cpu_usage():
949
+
950
+ # Get current CPU usage as a percentage
951
+ cpu_usage = psutil.cpu_percent()
952
+
953
+ # Get current memory usage
954
+ memory_usage = psutil.virtual_memory().used
955
+
956
+ # Convert memory usage to a human-readable format
957
+ memory_usage_str = psutil._common.bytes2human(memory_usage)
958
+
959
+ # Print CPU and memory usage
960
+ msg = f"Current CPU usage: {cpu_usage}% | "
961
+ msg += f"Current memory usage: {memory_usage_str}"
962
+ return msg
963
+
964
+ def calmsize(num_bytes):
965
+ if math.isnan(num_bytes):
966
+ return ''
967
+ for unit in ['', 'K', 'M', 'G', 'T', 'P', 'E', 'Z']:
968
+ if abs(num_bytes) < 1024.0:
969
+ return "{:.1f}{}B".format(num_bytes, unit)
970
+ num_bytes /= 1024.0
971
+ return "{:.1f}{}B".format(num_bytes, 'Y')
972
+
973
+ def readable_size(num_bytes: int) -> str:
974
+ return calmsize(num_bytes) ## '' if math.isnan(num_bytes) else '{:.1f}'.format(calmsize(num_bytes))
975
+
976
+ def get_gpu_memory():
977
+ """
978
+ Get the current GPU memory usage for each device as a dictionary
979
+ """
980
+ output = subprocess.check_output(["nvidia-smi", "--query-gpu=memory.used", "--format=csv"])
981
+ output = output.decode("utf-8")
982
+ gpu_memory_values = output.split("\n")[1:-1]
983
+ gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
984
+ gpu_memory = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
985
+ return gpu_memory
986
+
987
+ def get_gpu_util():
988
+ """
989
+ Get the current GPU memory usage for each device as a dictionary
990
+ """
991
+ output = subprocess.check_output(["nvidia-smi", "--query-gpu=utilization.gpu", "--format=csv"])
992
+ output = output.decode("utf-8")
993
+ gpu_memory_values = output.split("\n")[1:-1]
994
+ gpu_memory_values = [int(x.strip().split()[0]) for x in gpu_memory_values]
995
+ gpu_util = dict(zip(range(len(gpu_memory_values)), gpu_memory_values))
996
+ return gpu_util
997
+
998
+
999
+ def print_gpu_usage():
1000
+ useage = get_gpu_memory()
1001
+ msg = f" | GPU usage: "
1002
+ for k, v in useage.items():
1003
+ msg += f"{k}: {v} MB "
1004
+ # utilization = get_gpu_util()
1005
+ # msg + ' | util '
1006
+ # for k, v in utilization.items():
1007
+ # msg += f"{k}: {v} % "
1008
+ return msg
1009
+
1010
+ class AverageMeter(object):
1011
+
1012
+ def __init__(self):
1013
+ self.reset()
1014
+
1015
+ def reset(self):
1016
+ self.avg = 0
1017
+ self.sum = 0
1018
+ self.cnt = 0
1019
+
1020
+ def update(self, val, n=1):
1021
+ self.sum += val * n
1022
+ self.cnt += n
1023
+ self.avg = self.sum / self.cnt
1024
+
1025
+
1026
+ def generate_random_string(length):
1027
+ # This script will generate a string of 10 random ASCII letters (both lowercase and uppercase).
1028
+ # You can adjust the length parameter to fit your needs.
1029
+ letters = string.ascii_letters
1030
+ return ''.join(random.choice(letters) for _ in range(length))
1031
+
1032
+
1033
+ class ForkedPdb(pdb.Pdb):
1034
+ """
1035
+ PDB Subclass for debugging multi-processed code
1036
+ Suggested in: https://stackoverflow.com/questions/4716533/how-to-attach-debugger-to-a-python-subproccess
1037
+ """
1038
+ def interaction(self, *args, **kwargs):
1039
+ _stdin = sys.stdin
1040
+ try:
1041
+ sys.stdin = open('/dev/stdin')
1042
+ pdb.Pdb.interaction(self, *args, **kwargs)
1043
+ finally:
1044
+ sys.stdin = _stdin
1045
+
1046
+ def check_exist_in_s3(file_path, s3_config):
1047
+ s3 = init_s3(s3_config)
1048
+ bucket_name, object_name = s3path_to_bucket_key(file_path)
1049
+
1050
+ try:
1051
+ s3.head_object(Bucket=bucket_name, Key=object_name)
1052
+ return 1
1053
+ except:
1054
+ logger.info(f'file not found: s3://{bucket_name}/{object_name}')
1055
+ return 0
1056
+
1057
+ def s3path_to_bucket_key(file_path):
1058
+ bucket_name = file_path.split('/')[2]
1059
+ object_name = file_path.split(bucket_name + '/')[-1]
1060
+ return bucket_name, object_name
1061
+
1062
+ def copy_file_to_s3(cfg, file_path_local, file_path_s3):
1063
+ # work similar as upload_file_to_s3, but not trying to parse the file path
1064
+ # file_path_s3: s3://{bucket}/{key}
1065
+ bucket_name, key = s3path_to_bucket_key(file_path_s3)
1066
+ tic = time.time()
1067
+ s3_client = init_s3(cfg.checkpoint.write_s3_config)
1068
+
1069
+ # Upload the file
1070
+ with open(file_path_local, 'rb') as f:
1071
+ s3_client.upload_fileobj(f, bucket_name, key)
1072
+ full_s3_path = f"s3://{bucket_name}/{key}"
1073
+ logger.info(f'copy file: {file_path_local} {full_s3_path} | use time: {time.time()-tic}')
1074
+ return full_s3_path
modules/PartField/partfield/model/PVCNN/encoder_pc.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ from ast import Dict
10
+ import math
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+ import torch.nn.functional as F
16
+ from torch_scatter import scatter_mean #, scatter_max
17
+
18
+ from .unet_3daware import setup_unet #UNetTriplane3dAware
19
+ from .conv_pointnet import ConvPointnet
20
+
21
+ from .pc_encoder import PVCNNEncoder #PointNet
22
+
23
+ import einops
24
+
25
+ from .dnnlib_util import ScopedTorchProfiler, printarr
26
+
27
+ def generate_plane_features(p, c, resolution, plane='xz'):
28
+ """
29
+ Args:
30
+ p: (B,3,n_p)
31
+ c: (B,C,n_p)
32
+ """
33
+ padding = 0.
34
+ c_dim = c.size(1)
35
+ # acquire indices of features in plane
36
+ xy = normalize_coordinate(p.clone(), plane=plane, padding=padding) # normalize to the range of (0, 1)
37
+ index = coordinate2index(xy, resolution)
38
+
39
+ # scatter plane features from points
40
+ fea_plane = c.new_zeros(p.size(0), c_dim, resolution**2)
41
+ fea_plane = scatter_mean(c, index, out=fea_plane) # B x 512 x reso^2
42
+ fea_plane = fea_plane.reshape(p.size(0), c_dim, resolution, resolution) # sparce matrix (B x 512 x reso x reso)
43
+ return fea_plane
44
+
45
+ def normalize_coordinate(p, padding=0.1, plane='xz'):
46
+ ''' Normalize coordinate to [0, 1] for unit cube experiments
47
+
48
+ Args:
49
+ p (tensor): point
50
+ padding (float): conventional padding paramter of ONet for unit cube, so [-0.5, 0.5] -> [-0.55, 0.55]
51
+ plane (str): plane feature type, ['xz', 'xy', 'yz']
52
+ '''
53
+ if plane == 'xz':
54
+ xy = p[:, :, [0, 2]]
55
+ elif plane =='xy':
56
+ xy = p[:, :, [0, 1]]
57
+ else:
58
+ xy = p[:, :, [1, 2]]
59
+
60
+ xy_new = xy / (1 + padding + 10e-6) # (-0.5, 0.5)
61
+ xy_new = xy_new + 0.5 # range (0, 1)
62
+
63
+ # if there are outliers out of the range
64
+ if xy_new.max() >= 1:
65
+ xy_new[xy_new >= 1] = 1 - 10e-6
66
+ if xy_new.min() < 0:
67
+ xy_new[xy_new < 0] = 0.0
68
+ return xy_new
69
+
70
+
71
+ def coordinate2index(x, resolution):
72
+ ''' Normalize coordinate to [0, 1] for unit cube experiments.
73
+ Corresponds to our 3D model
74
+
75
+ Args:
76
+ x (tensor): coordinate
77
+ reso (int): defined resolution
78
+ coord_type (str): coordinate type
79
+ '''
80
+ x = (x * resolution).long()
81
+ index = x[:, :, 0] + resolution * x[:, :, 1]
82
+ index = index[:, None, :]
83
+ return index
84
+
85
+ def softclip(x, min, max, hardness=5):
86
+ # Soft clipping for the logsigma
87
+ x = min + F.softplus(hardness*(x - min))/hardness
88
+ x = max - F.softplus(-hardness*(x - max))/hardness
89
+ return x
90
+
91
+
92
+ def sample_triplane_feat(feature_triplane, normalized_pos):
93
+ '''
94
+ normalized_pos [-1, 1]
95
+ '''
96
+ tri_plane = torch.unbind(feature_triplane, dim=1)
97
+
98
+ x_feat = F.grid_sample(
99
+ tri_plane[0],
100
+ torch.cat(
101
+ [normalized_pos[:, :, 0:1], normalized_pos[:, :, 1:2]],
102
+ dim=-1).unsqueeze(dim=1), padding_mode='border',
103
+ align_corners=True)
104
+ y_feat = F.grid_sample(
105
+ tri_plane[1],
106
+ torch.cat(
107
+ [normalized_pos[:, :, 1:2], normalized_pos[:, :, 2:3]],
108
+ dim=-1).unsqueeze(dim=1), padding_mode='border',
109
+ align_corners=True)
110
+
111
+ z_feat = F.grid_sample(
112
+ tri_plane[2],
113
+ torch.cat(
114
+ [normalized_pos[:, :, 0:1], normalized_pos[:, :, 2:3]],
115
+ dim=-1).unsqueeze(dim=1), padding_mode='border',
116
+ align_corners=True)
117
+ final_feat = (x_feat + y_feat + z_feat)
118
+ final_feat = final_feat.squeeze(dim=2).permute(0, 2, 1) # 32dimension
119
+ return final_feat
120
+
121
+
122
+ # @persistence.persistent_class
123
+ class TriPlanePC2Encoder(torch.nn.Module):
124
+ # Encoder that encode point cloud to triplane feature vector similar to ConvOccNet
125
+ def __init__(
126
+ self,
127
+ cfg,
128
+ device='cuda',
129
+ shape_min=-1.0,
130
+ shape_length=2.0,
131
+ use_2d_feat=False,
132
+ # point_encoder='pvcnn',
133
+ # use_point_scatter=False
134
+ ):
135
+ """
136
+ Outputs latent triplane from PC input
137
+ Configs:
138
+ max_logsigma: (float) Soft clip upper range for logsigm
139
+ min_logsigma: (float)
140
+ point_encoder_type: (str) one of ['pvcnn', 'pointnet']
141
+ pvcnn_flatten_voxels: (bool) for pvcnn whether to reduce voxel
142
+ features (instead of scattering point features)
143
+ unet_cfg: (dict)
144
+ z_triplane_channels: (int) output latent triplane
145
+ z_triplane_resolution: (int)
146
+ Args:
147
+
148
+ """
149
+ # assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
150
+ super().__init__()
151
+ self.device = device
152
+
153
+ self.cfg = cfg
154
+
155
+ self.shape_min = shape_min
156
+ self.shape_length = shape_length
157
+
158
+ self.z_triplane_resolution = cfg.z_triplane_resolution
159
+ z_triplane_channels = cfg.z_triplane_channels
160
+
161
+ point_encoder_out_dim = z_triplane_channels #* 2
162
+
163
+ in_channels = 6
164
+ # self.resample_filter=[1, 3, 3, 1]
165
+ if cfg.point_encoder_type == 'pvcnn':
166
+ self.pc_encoder = PVCNNEncoder(point_encoder_out_dim,
167
+ device=self.device, in_channels=in_channels, use_2d_feat=use_2d_feat) # Encode it to a volume vector.
168
+ elif cfg.point_encoder_type == 'pointnet':
169
+ # TODO the pointnet was buggy, investigate
170
+ self.pc_encoder = ConvPointnet(c_dim=point_encoder_out_dim,
171
+ dim=in_channels, hidden_dim=32,
172
+ plane_resolution=self.z_triplane_resolution,
173
+ padding=0)
174
+ else:
175
+ raise NotImplementedError(f"Point encoder {cfg.point_encoder_type} not implemented")
176
+
177
+ if cfg.unet_cfg.enabled:
178
+ self.unet_encoder = setup_unet(
179
+ output_channels=point_encoder_out_dim,
180
+ input_channels=point_encoder_out_dim,
181
+ unet_cfg=cfg.unet_cfg)
182
+ else:
183
+ self.unet_encoder = None
184
+
185
+ # @ScopedTorchProfiler('encode')
186
+ def encode(self, point_cloud_xyz, point_cloud_feature, mv_feat=None, pc2pc_idx=None) -> Dict:
187
+ # output = AttrDict()
188
+ point_cloud_xyz = (point_cloud_xyz - self.shape_min) / self.shape_length # [0, 1]
189
+ point_cloud_xyz = point_cloud_xyz - 0.5 # [-0.5, 0.5]
190
+ point_cloud = torch.cat([point_cloud_xyz, point_cloud_feature], dim=-1)
191
+
192
+ if self.cfg.point_encoder_type == 'pvcnn':
193
+ if mv_feat is not None:
194
+ pc_feat, points_feat = self.pc_encoder(point_cloud, mv_feat, pc2pc_idx)
195
+ else:
196
+ pc_feat, points_feat = self.pc_encoder(point_cloud) # 3D feature volume: BxDx32x32x32
197
+ if self.cfg.use_point_scatter:
198
+ # Scattering from PVCNN point features
199
+ points_feat_ = points_feat[0]
200
+ # shape: batch, latent size, resolution, resolution (e.g. 16, 256, 64, 64)
201
+ pc_feat_1 = generate_plane_features(point_cloud_xyz, points_feat_,
202
+ resolution=self.z_triplane_resolution, plane='xy')
203
+ pc_feat_2 = generate_plane_features(point_cloud_xyz, points_feat_,
204
+ resolution=self.z_triplane_resolution, plane='yz')
205
+ pc_feat_3 = generate_plane_features(point_cloud_xyz, points_feat_,
206
+ resolution=self.z_triplane_resolution, plane='xz')
207
+ pc_feat = pc_feat[0]
208
+
209
+ else:
210
+ pc_feat = pc_feat[0]
211
+ sf = self.z_triplane_resolution//32 # 32 is PVCNN's voxel dim
212
+
213
+ pc_feat_1 = torch.mean(pc_feat, dim=-1) #xy_plane, normalize in z plane
214
+ pc_feat_2 = torch.mean(pc_feat, dim=-3) #yz_plane, normalize in x plane
215
+ pc_feat_3 = torch.mean(pc_feat, dim=-2) #xz_plane, normalize in y plane
216
+
217
+ # nearest upsample
218
+ pc_feat_1 = einops.repeat(pc_feat_1, 'b c h w -> b c (h hm ) (w wm)', hm = sf, wm = sf)
219
+ pc_feat_2 = einops.repeat(pc_feat_2, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
220
+ pc_feat_3 = einops.repeat(pc_feat_3, 'b c h w -> b c (h hm) (w wm)', hm = sf, wm = sf)
221
+ elif self.cfg.point_encoder_type == 'pointnet':
222
+ assert self.cfg.use_point_scatter
223
+ # Run ConvPointnet
224
+ pc_feat = self.pc_encoder(point_cloud)
225
+ pc_feat_1 = pc_feat['xy'] #
226
+ pc_feat_2 = pc_feat['yz']
227
+ pc_feat_3 = pc_feat['xz']
228
+ else:
229
+ raise NotImplementedError()
230
+
231
+ if self.unet_encoder is not None:
232
+ # TODO eval adding a skip connection
233
+ # Unet expects B, 3, C, H, W
234
+ pc_feat_tri_plane_stack_pre = torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
235
+ # dpc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
236
+ # pc_feat_tri_plane_stack = pc_feat_tri_plane_stack_pre + dpc_feat_tri_plane_stack
237
+ pc_feat_tri_plane_stack = self.unet_encoder(pc_feat_tri_plane_stack_pre)
238
+ pc_feat_1, pc_feat_2, pc_feat_3 = torch.unbind(pc_feat_tri_plane_stack, dim=1)
239
+
240
+ return torch.stack([pc_feat_1, pc_feat_2, pc_feat_3], dim=1)
241
+
242
+ def forward(self, point_cloud_xyz, point_cloud_feature=None, mv_feat=None, pc2pc_idx=None):
243
+ return self.encode(point_cloud_xyz, point_cloud_feature=point_cloud_feature, mv_feat=mv_feat, pc2pc_idx=pc2pc_idx)
modules/PartField/partfield/model/PVCNN/pc_encoder.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import functools
5
+
6
+ from .pv_module import SharedMLP, PVConv
7
+
8
+ def create_pointnet_components(
9
+ blocks, in_channels, with_se=False, normalize=True, eps=0,
10
+ width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=False, device='cuda'):
11
+ r, vr = width_multiplier, voxel_resolution_multiplier
12
+ layers, concat_channels = [], 0
13
+ for out_channels, num_blocks, voxel_resolution in blocks:
14
+ out_channels = int(r * out_channels)
15
+ if voxel_resolution is None:
16
+ block = functools.partial(SharedMLP, device=device)
17
+ else:
18
+ block = functools.partial(
19
+ PVConv, kernel_size=3, resolution=int(vr * voxel_resolution),
20
+ with_se=with_se, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn, device=device)
21
+ for _ in range(num_blocks):
22
+ layers.append(block(in_channels, out_channels))
23
+ in_channels = out_channels
24
+ concat_channels += out_channels
25
+ return layers, in_channels, concat_channels
26
+
27
+ class PCMerger(nn.Module):
28
+ # merge surface sampled PC and rendering backprojected PC (w/ 2D features):
29
+ def __init__(self, in_channels=204, device="cuda"):
30
+ super(PCMerger, self).__init__()
31
+ self.mlp_normal = SharedMLP(3, [128, 128], device=device)
32
+ self.mlp_rgb = SharedMLP(3, [128, 128], device=device)
33
+ self.mlp_sam = SharedMLP(204 - 6, [128, 128], device=device)
34
+
35
+ def forward(self, feat, mv_feat, pc2pc_idx):
36
+ mv_feat_normal = self.mlp_normal(mv_feat[:, :3, :])
37
+ mv_feat_rgb = self.mlp_rgb(mv_feat[:, 3:6, :])
38
+ mv_feat_sam = self.mlp_sam(mv_feat[:, 6:, :])
39
+
40
+ mv_feat_normal = mv_feat_normal.permute(0, 2, 1)
41
+ mv_feat_rgb = mv_feat_rgb.permute(0, 2, 1)
42
+ mv_feat_sam = mv_feat_sam.permute(0, 2, 1)
43
+ feat = feat.permute(0, 2, 1)
44
+
45
+ for i in range(mv_feat.shape[0]):
46
+ mask = (pc2pc_idx[i] != -1).reshape(-1)
47
+ idx = pc2pc_idx[i][mask].reshape(-1)
48
+ feat[i][mask] += mv_feat_normal[i][idx] + mv_feat_rgb[i][idx] + mv_feat_sam[i][idx]
49
+
50
+ return feat.permute(0, 2, 1)
51
+
52
+
53
+ class PVCNNEncoder(nn.Module):
54
+ def __init__(self, pvcnn_feat_dim, device='cuda', in_channels=3, use_2d_feat=False):
55
+ super(PVCNNEncoder, self).__init__()
56
+ self.device = device
57
+ self.blocks = ((pvcnn_feat_dim, 1, 32), (128, 2, 16), (256, 1, 8))
58
+ self.use_2d_feat=use_2d_feat
59
+ if in_channels == 6:
60
+ self.append_channel = 2
61
+ elif in_channels == 3:
62
+ self.append_channel = 1
63
+ else:
64
+ raise NotImplementedError
65
+ layers, channels_point, concat_channels_point = create_pointnet_components(
66
+ blocks=self.blocks, in_channels=in_channels + self.append_channel, with_se=False, normalize=False,
67
+ width_multiplier=1, voxel_resolution_multiplier=1, scale_pvcnn=True,
68
+ device=device
69
+ )
70
+ self.encoder = nn.ModuleList(layers)#.to(self.device)
71
+ if self.use_2d_feat:
72
+ self.merger = PCMerger()
73
+
74
+
75
+
76
+ def forward(self, input_pc, mv_feat=None, pc2pc_idx=None):
77
+ features = input_pc.permute(0, 2, 1) * 2 # make point cloud [-1, 1]
78
+ coords = features[:, :3, :]
79
+ out_features_list = []
80
+ voxel_feature_list = []
81
+ zero_padding = torch.zeros(features.shape[0], self.append_channel, features.shape[-1], device=features.device, dtype=features.dtype)
82
+ features = torch.cat([features, zero_padding], dim=1)##################
83
+
84
+ for i in range(len(self.encoder)):
85
+ features, _, voxel_feature = self.encoder[i]((features, coords))
86
+ if i == 0 and mv_feat is not None:
87
+ features = self.merger(features, mv_feat.permute(0, 2, 1), pc2pc_idx)
88
+ out_features_list.append(features)
89
+ voxel_feature_list.append(voxel_feature)
90
+ return voxel_feature_list, out_features_list
modules/PartField/partfield/model/PVCNN/pv_module/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .pvconv import PVConv
2
+ from .shared_mlp import SharedMLP
modules/PartField/partfield/model/PVCNN/pv_module/ball_query.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+
6
+ __all__ = ['BallQuery']
7
+
8
+
9
+ class BallQuery(nn.Module):
10
+ def __init__(self, radius, num_neighbors, include_coordinates=True):
11
+ super().__init__()
12
+ self.radius = radius
13
+ self.num_neighbors = num_neighbors
14
+ self.include_coordinates = include_coordinates
15
+
16
+ def forward(self, points_coords, centers_coords, points_features=None):
17
+ points_coords = points_coords.contiguous()
18
+ centers_coords = centers_coords.contiguous()
19
+ neighbor_indices = F.ball_query(centers_coords, points_coords, self.radius, self.num_neighbors)
20
+ neighbor_coordinates = F.grouping(points_coords, neighbor_indices)
21
+ neighbor_coordinates = neighbor_coordinates - centers_coords.unsqueeze(-1)
22
+
23
+ if points_features is None:
24
+ assert self.include_coordinates, 'No Features For Grouping'
25
+ neighbor_features = neighbor_coordinates
26
+ else:
27
+ neighbor_features = F.grouping(points_features, neighbor_indices)
28
+ if self.include_coordinates:
29
+ neighbor_features = torch.cat([neighbor_coordinates, neighbor_features], dim=1)
30
+ return neighbor_features
31
+
32
+ def extra_repr(self):
33
+ return 'radius={}, num_neighbors={}{}'.format(
34
+ self.radius, self.num_neighbors, ', include coordinates' if self.include_coordinates else '')
modules/PartField/partfield/model/PVCNN/pv_module/frustum.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from . import functional as PF
7
+
8
+ __all__ = ['FrustumPointNetLoss', 'get_box_corners_3d']
9
+
10
+
11
+ class FrustumPointNetLoss(nn.Module):
12
+ def __init__(
13
+ self, num_heading_angle_bins, num_size_templates, size_templates, box_loss_weight=1.0,
14
+ corners_loss_weight=10.0, heading_residual_loss_weight=20.0, size_residual_loss_weight=20.0):
15
+ super().__init__()
16
+ self.box_loss_weight = box_loss_weight
17
+ self.corners_loss_weight = corners_loss_weight
18
+ self.heading_residual_loss_weight = heading_residual_loss_weight
19
+ self.size_residual_loss_weight = size_residual_loss_weight
20
+
21
+ self.num_heading_angle_bins = num_heading_angle_bins
22
+ self.num_size_templates = num_size_templates
23
+ self.register_buffer('size_templates', size_templates.view(self.num_size_templates, 3))
24
+ self.register_buffer(
25
+ 'heading_angle_bin_centers', torch.arange(0, 2 * np.pi, 2 * np.pi / self.num_heading_angle_bins)
26
+ )
27
+
28
+ def forward(self, inputs, targets):
29
+ mask_logits = inputs['mask_logits'] # (B, 2, N)
30
+ center_reg = inputs['center_reg'] # (B, 3)
31
+ center = inputs['center'] # (B, 3)
32
+ heading_scores = inputs['heading_scores'] # (B, NH)
33
+ heading_residuals_normalized = inputs['heading_residuals_normalized'] # (B, NH)
34
+ heading_residuals = inputs['heading_residuals'] # (B, NH)
35
+ size_scores = inputs['size_scores'] # (B, NS)
36
+ size_residuals_normalized = inputs['size_residuals_normalized'] # (B, NS, 3)
37
+ size_residuals = inputs['size_residuals'] # (B, NS, 3)
38
+
39
+ mask_logits_target = targets['mask_logits'] # (B, N)
40
+ center_target = targets['center'] # (B, 3)
41
+ heading_bin_id_target = targets['heading_bin_id'] # (B, )
42
+ heading_residual_target = targets['heading_residual'] # (B, )
43
+ size_template_id_target = targets['size_template_id'] # (B, )
44
+ size_residual_target = targets['size_residual'] # (B, 3)
45
+
46
+ batch_size = center.size(0)
47
+ batch_id = torch.arange(batch_size, device=center.device)
48
+
49
+ # Basic Classification and Regression losses
50
+ mask_loss = F.cross_entropy(mask_logits, mask_logits_target)
51
+ heading_loss = F.cross_entropy(heading_scores, heading_bin_id_target)
52
+ size_loss = F.cross_entropy(size_scores, size_template_id_target)
53
+ center_loss = PF.huber_loss(torch.norm(center_target - center, dim=-1), delta=2.0)
54
+ center_reg_loss = PF.huber_loss(torch.norm(center_target - center_reg, dim=-1), delta=1.0)
55
+
56
+ # Refinement losses for size/heading
57
+ heading_residuals_normalized = heading_residuals_normalized[batch_id, heading_bin_id_target] # (B, )
58
+ heading_residual_normalized_target = heading_residual_target / (np.pi / self.num_heading_angle_bins)
59
+ heading_residual_normalized_loss = PF.huber_loss(
60
+ heading_residuals_normalized - heading_residual_normalized_target, delta=1.0
61
+ )
62
+ size_residuals_normalized = size_residuals_normalized[batch_id, size_template_id_target] # (B, 3)
63
+ size_residual_normalized_target = size_residual_target / self.size_templates[size_template_id_target]
64
+ size_residual_normalized_loss = PF.huber_loss(
65
+ torch.norm(size_residual_normalized_target - size_residuals_normalized, dim=-1), delta=1.0
66
+ )
67
+
68
+ # Bounding box losses
69
+ heading = (heading_residuals[batch_id, heading_bin_id_target]
70
+ + self.heading_angle_bin_centers[heading_bin_id_target]) # (B, )
71
+ # Warning: in origin code, size_residuals are added twice (issue #43 and #49 in charlesq34/frustum-pointnets)
72
+ size = (size_residuals[batch_id, size_template_id_target]
73
+ + self.size_templates[size_template_id_target]) # (B, 3)
74
+ corners = get_box_corners_3d(centers=center, headings=heading, sizes=size, with_flip=False) # (B, 3, 8)
75
+ heading_target = self.heading_angle_bin_centers[heading_bin_id_target] + heading_residual_target # (B, )
76
+ size_target = self.size_templates[size_template_id_target] + size_residual_target # (B, 3)
77
+ corners_target, corners_target_flip = get_box_corners_3d(
78
+ centers=center_target, headings=heading_target,
79
+ sizes=size_target, with_flip=True) # (B, 3, 8)
80
+ corners_loss = PF.huber_loss(
81
+ torch.min(
82
+ torch.norm(corners - corners_target, dim=1), torch.norm(corners - corners_target_flip, dim=1)
83
+ ), delta=1.0)
84
+ # Summing up
85
+ loss = mask_loss + self.box_loss_weight * (
86
+ center_loss + center_reg_loss + heading_loss + size_loss
87
+ + self.heading_residual_loss_weight * heading_residual_normalized_loss
88
+ + self.size_residual_loss_weight * size_residual_normalized_loss
89
+ + self.corners_loss_weight * corners_loss
90
+ )
91
+
92
+ return loss
93
+
94
+
95
+ def get_box_corners_3d(centers, headings, sizes, with_flip=False):
96
+ """
97
+ :param centers: coords of box centers, FloatTensor[N, 3]
98
+ :param headings: heading angles, FloatTensor[N, ]
99
+ :param sizes: box sizes, FloatTensor[N, 3]
100
+ :param with_flip: bool, whether to return flipped box (headings + np.pi)
101
+ :return:
102
+ coords of box corners, FloatTensor[N, 3, 8]
103
+ NOTE: corner points are in counter clockwise order, e.g.,
104
+ 2--1
105
+ 3--0 5
106
+ 7--4
107
+ """
108
+ l = sizes[:, 0] # (N,)
109
+ w = sizes[:, 1] # (N,)
110
+ h = sizes[:, 2] # (N,)
111
+ x_corners = torch.stack([l / 2, l / 2, -l / 2, -l / 2, l / 2, l / 2, -l / 2, -l / 2], dim=1) # (N, 8)
112
+ y_corners = torch.stack([h / 2, h / 2, h / 2, h / 2, -h / 2, -h / 2, -h / 2, -h / 2], dim=1) # (N, 8)
113
+ z_corners = torch.stack([w / 2, -w / 2, -w / 2, w / 2, w / 2, -w / 2, -w / 2, w / 2], dim=1) # (N, 8)
114
+
115
+ c = torch.cos(headings) # (N,)
116
+ s = torch.sin(headings) # (N,)
117
+ o = torch.ones_like(headings) # (N,)
118
+ z = torch.zeros_like(headings) # (N,)
119
+
120
+ centers = centers.unsqueeze(-1) # (B, 3, 1)
121
+ corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
122
+ R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # roty matrix: (N, 3, 3)
123
+ if with_flip:
124
+ R_flip = torch.stack([-c, z, -s, z, o, z, s, z, -c], dim=1).view(-1, 3, 3)
125
+ return torch.matmul(R, corners) + centers, torch.matmul(R_flip, corners) + centers
126
+ else:
127
+ return torch.matmul(R, corners) + centers
128
+
129
+ # centers = centers.unsqueeze(1) # (B, 1, 3)
130
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=-1) # (N, 8, 3)
131
+ # RT = torch.stack([c, z, -s, z, o, z, s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
132
+ # if with_flip:
133
+ # RT_flip = torch.stack([-c, z, s, z, o, z, -s, z, -c], dim=1).view(-1, 3, 3) # (N, 3, 3)
134
+ # return torch.matmul(corners, RT) + centers, torch.matmul(corners, RT_flip) + centers # (N, 8, 3)
135
+ # else:
136
+ # return torch.matmul(corners, RT) + centers # (N, 8, 3)
137
+
138
+ # corners = torch.stack([x_corners, y_corners, z_corners], dim=1) # (N, 3, 8)
139
+ # R = torch.stack([c, z, s, z, o, z, -s, z, c], dim=1).view(-1, 3, 3) # (N, 3, 3)
140
+ # corners = torch.matmul(R, corners) + centers.unsqueeze(2) # (N, 3, 8)
141
+ # corners = corners.transpose(1, 2) # (N, 8, 3)
modules/PartField/partfield/model/PVCNN/pv_module/functional/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .devoxelization import trilinear_devoxelize
modules/PartField/partfield/model/PVCNN/pv_module/functional/devoxelization.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.autograd import Function
2
+ import torch
3
+ import torch.nn.functional as F
4
+
5
+ __all__ = ['trilinear_devoxelize']
6
+
7
+ def trilinear_devoxelize(c, coords, r, training=None):
8
+ coords = (coords * 2 + 1.0) / r - 1.0
9
+ coords = coords.permute(0, 2, 1).reshape(c.shape[0], 1, 1, -1, 3)
10
+ f = F.grid_sample(input=c, grid=coords, padding_mode='border', align_corners=False)
11
+ f = f.squeeze(dim=2).squeeze(dim=2)
12
+ return f
modules/PartField/partfield/model/PVCNN/pv_module/loss.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import functional as F
4
+
5
+ __all__ = ['KLLoss']
6
+
7
+
8
+ class KLLoss(nn.Module):
9
+ def forward(self, x, y):
10
+ return F.kl_loss(x, y)
modules/PartField/partfield/model/PVCNN/pv_module/pointnet.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+ from .ball_query import BallQuery
6
+ from .shared_mlp import SharedMLP
7
+
8
+ __all__ = ['PointNetAModule', 'PointNetSAModule', 'PointNetFPModule']
9
+
10
+
11
+ class PointNetAModule(nn.Module):
12
+ def __init__(self, in_channels, out_channels, include_coordinates=True):
13
+ super().__init__()
14
+ if not isinstance(out_channels, (list, tuple)):
15
+ out_channels = [[out_channels]]
16
+ elif not isinstance(out_channels[0], (list, tuple)):
17
+ out_channels = [out_channels]
18
+
19
+ mlps = []
20
+ total_out_channels = 0
21
+ for _out_channels in out_channels:
22
+ mlps.append(
23
+ SharedMLP(
24
+ in_channels=in_channels + (3 if include_coordinates else 0),
25
+ out_channels=_out_channels, dim=1)
26
+ )
27
+ total_out_channels += _out_channels[-1]
28
+
29
+ self.include_coordinates = include_coordinates
30
+ self.out_channels = total_out_channels
31
+ self.mlps = nn.ModuleList(mlps)
32
+
33
+ def forward(self, inputs):
34
+ features, coords = inputs
35
+ if self.include_coordinates:
36
+ features = torch.cat([features, coords], dim=1)
37
+ coords = torch.zeros((coords.size(0), 3, 1), device=coords.device)
38
+ if len(self.mlps) > 1:
39
+ features_list = []
40
+ for mlp in self.mlps:
41
+ features_list.append(mlp(features).max(dim=-1, keepdim=True).values)
42
+ return torch.cat(features_list, dim=1), coords
43
+ else:
44
+ return self.mlps[0](features).max(dim=-1, keepdim=True).values, coords
45
+
46
+ def extra_repr(self):
47
+ return f'out_channels={self.out_channels}, include_coordinates={self.include_coordinates}'
48
+
49
+
50
+ class PointNetSAModule(nn.Module):
51
+ def __init__(self, num_centers, radius, num_neighbors, in_channels, out_channels, include_coordinates=True):
52
+ super().__init__()
53
+ if not isinstance(radius, (list, tuple)):
54
+ radius = [radius]
55
+ if not isinstance(num_neighbors, (list, tuple)):
56
+ num_neighbors = [num_neighbors] * len(radius)
57
+ assert len(radius) == len(num_neighbors)
58
+ if not isinstance(out_channels, (list, tuple)):
59
+ out_channels = [[out_channels]] * len(radius)
60
+ elif not isinstance(out_channels[0], (list, tuple)):
61
+ out_channels = [out_channels] * len(radius)
62
+ assert len(radius) == len(out_channels)
63
+
64
+ groupers, mlps = [], []
65
+ total_out_channels = 0
66
+ for _radius, _out_channels, _num_neighbors in zip(radius, out_channels, num_neighbors):
67
+ groupers.append(
68
+ BallQuery(radius=_radius, num_neighbors=_num_neighbors, include_coordinates=include_coordinates)
69
+ )
70
+ mlps.append(
71
+ SharedMLP(
72
+ in_channels=in_channels + (3 if include_coordinates else 0),
73
+ out_channels=_out_channels, dim=2)
74
+ )
75
+ total_out_channels += _out_channels[-1]
76
+
77
+ self.num_centers = num_centers
78
+ self.out_channels = total_out_channels
79
+ self.groupers = nn.ModuleList(groupers)
80
+ self.mlps = nn.ModuleList(mlps)
81
+
82
+ def forward(self, inputs):
83
+ features, coords = inputs
84
+ centers_coords = F.furthest_point_sample(coords, self.num_centers)
85
+ features_list = []
86
+ for grouper, mlp in zip(self.groupers, self.mlps):
87
+ features_list.append(mlp(grouper(coords, centers_coords, features)).max(dim=-1).values)
88
+ if len(features_list) > 1:
89
+ return torch.cat(features_list, dim=1), centers_coords
90
+ else:
91
+ return features_list[0], centers_coords
92
+
93
+ def extra_repr(self):
94
+ return f'num_centers={self.num_centers}, out_channels={self.out_channels}'
95
+
96
+
97
+ class PointNetFPModule(nn.Module):
98
+ def __init__(self, in_channels, out_channels):
99
+ super().__init__()
100
+ self.mlp = SharedMLP(in_channels=in_channels, out_channels=out_channels, dim=1)
101
+
102
+ def forward(self, inputs):
103
+ if len(inputs) == 3:
104
+ points_coords, centers_coords, centers_features = inputs
105
+ points_features = None
106
+ else:
107
+ points_coords, centers_coords, centers_features, points_features = inputs
108
+ interpolated_features = F.nearest_neighbor_interpolate(points_coords, centers_coords, centers_features)
109
+ if points_features is not None:
110
+ interpolated_features = torch.cat(
111
+ [interpolated_features, points_features], dim=1
112
+ )
113
+ return self.mlp(interpolated_features), points_coords
modules/PartField/partfield/model/PVCNN/pv_module/pvconv.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ from . import functional as F
4
+ from .voxelization import Voxelization
5
+ from .shared_mlp import SharedMLP
6
+ import torch
7
+
8
+ __all__ = ['PVConv']
9
+
10
+
11
+ class PVConv(nn.Module):
12
+ def __init__(
13
+ self, in_channels, out_channels, kernel_size, resolution, with_se=False, normalize=True, eps=0, scale_pvcnn=False,
14
+ device='cuda'):
15
+ super().__init__()
16
+ self.in_channels = in_channels
17
+ self.out_channels = out_channels
18
+ self.kernel_size = kernel_size
19
+ self.resolution = resolution
20
+ self.voxelization = Voxelization(resolution, normalize=normalize, eps=eps, scale_pvcnn=scale_pvcnn)
21
+ voxel_layers = [
22
+ nn.Conv3d(in_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
23
+ nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
24
+ nn.LeakyReLU(0.1, True),
25
+ nn.Conv3d(out_channels, out_channels, kernel_size, stride=1, padding=kernel_size // 2, device=device),
26
+ nn.InstanceNorm3d(out_channels, eps=1e-4, device=device),
27
+ nn.LeakyReLU(0.1, True),
28
+ ]
29
+ self.voxel_layers = nn.Sequential(*voxel_layers)
30
+ self.point_features = SharedMLP(in_channels, out_channels, device=device)
31
+
32
+ def forward(self, inputs):
33
+ features, coords = inputs
34
+ voxel_features, voxel_coords = self.voxelization(features, coords)
35
+ voxel_features = self.voxel_layers(voxel_features)
36
+ devoxel_features = F.trilinear_devoxelize(voxel_features, voxel_coords, self.resolution, self.training)
37
+ fused_features = devoxel_features + self.point_features(features)
38
+ return fused_features, coords, voxel_features
modules/PartField/partfield/model/PVCNN/pv_module/shared_mlp.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ __all__ = ['SharedMLP']
4
+
5
+
6
+ class SharedMLP(nn.Module):
7
+ def __init__(self, in_channels, out_channels, dim=1, device='cuda'):
8
+ super().__init__()
9
+ # print('==> SharedMLP device: ', device)
10
+ if dim == 1:
11
+ conv = nn.Conv1d
12
+ bn = nn.InstanceNorm1d
13
+ elif dim == 2:
14
+ conv = nn.Conv2d
15
+ bn = nn.InstanceNorm1d
16
+ else:
17
+ raise ValueError
18
+ if not isinstance(out_channels, (list, tuple)):
19
+ out_channels = [out_channels]
20
+ layers = []
21
+ for oc in out_channels:
22
+ layers.extend(
23
+ [
24
+ conv(in_channels, oc, 1, device=device),
25
+ bn(oc, device=device),
26
+ nn.ReLU(True),
27
+ ])
28
+ in_channels = oc
29
+ self.layers = nn.Sequential(*layers)
30
+
31
+ def forward(self, inputs):
32
+ if isinstance(inputs, (list, tuple)):
33
+ return (self.layers(inputs[0]), *inputs[1:])
34
+ else:
35
+ return self.layers(inputs)
modules/PartField/partfield/model/PVCNN/pv_module/voxelization.py ADDED
@@ -0,0 +1,80 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from . import functional as F
5
+
6
+ __all__ = ['Voxelization']
7
+
8
+
9
+ def my_voxelization(features, coords, resolution):
10
+ b, c, _ = features.shape
11
+ result = torch.zeros(b, c + 1, resolution * resolution * resolution, device=features.device, dtype=features.dtype)
12
+ r = resolution
13
+ r2 = resolution * resolution
14
+ coords = coords.long()
15
+ indices = coords[:, 0] * r2 + coords[:, 1] * r + coords[:, 2]
16
+
17
+ # print(r, r2, coords[:, 0].max(), coords[:, 1].max(), coords[:, 2].max())
18
+
19
+ # print(f"Resolution: {resolution}")
20
+ # print(f"Coords shape: {coords.shape}")
21
+ # print(f"Coords max per dim: x={coords[:, 0].max()}, y={coords[:, 1].max()}, z={coords[:, 2].max()}")
22
+ # print(f"Coords min per dim: x={coords[:, 0].min()}, y={coords[:, 1].min()}, z={coords[:, 2].min()}")
23
+ # print(f"Indices shape: {indices.shape}")
24
+ # print(f"Indices max: {indices.max()}, min: {indices.min()}")
25
+ # print(f"Expected max index: {resolution * resolution * resolution - 1}")
26
+
27
+ # # 检查是否有越界的索引
28
+ # max_valid_index = resolution * resolution * resolution - 1
29
+ # invalid_mask = (indices > max_valid_index) | (indices < 0)
30
+ # if invalid_mask.any():
31
+ # print(f"Found {invalid_mask.sum()} invalid indices!")
32
+ # print(f"Invalid indices: {indices[invalid_mask]}")
33
+ # # 找到对应的坐标
34
+ # invalid_coords = coords[:, :, invalid_mask.any(dim=0)]
35
+ # print(f"Invalid coords shape: {invalid_coords.shape}")
36
+ # if invalid_coords.numel() > 0:
37
+ # print(f"Sample invalid coords: {invalid_coords[:, :, :5]}") # 显示前5个无效坐标
38
+
39
+ indices = indices.unsqueeze(dim=1).expand(-1, result.shape[1], -1)
40
+ features = torch.cat([features, torch.ones(features.shape[0], 1, features.shape[2], device=features.device, dtype=features.dtype)], dim=1)
41
+ out_feature = result.scatter_(index=indices.long(), src=features, dim=2, reduce='add')
42
+ cnt = out_feature[:, -1:, :]
43
+ zero_mask = (cnt == 0).to(features.dtype)
44
+ cnt = cnt * (1 - zero_mask) + zero_mask * 1e-5
45
+ vox_feature = out_feature[:, :-1, :] / cnt
46
+ return vox_feature.view(b, c, resolution, resolution, resolution)
47
+
48
+ class Voxelization(nn.Module):
49
+ def __init__(self, resolution, normalize=True, eps=0, scale_pvcnn=False):
50
+ super().__init__()
51
+ self.r = int(resolution)
52
+ self.normalize = normalize
53
+ self.eps = eps
54
+ self.scale_pvcnn = scale_pvcnn
55
+ assert not normalize
56
+
57
+ def forward(self, features, coords):
58
+ # import pdb; pdb.set_trace()
59
+ with torch.no_grad():
60
+ coords = coords.detach()
61
+
62
+ if self.normalize:
63
+ norm_coords = norm_coords / (norm_coords.norm(dim=1, keepdim=True).max(dim=2, keepdim=True).values * 2.0 + self.eps) + 0.5
64
+ else:
65
+ if self.scale_pvcnn:
66
+ norm_coords = (coords + 1) / 2.0 # [0, 1]
67
+ # print(norm_coords.shape, norm_coords.max(), norm_coords.min())
68
+ else:
69
+ # norm_coords = (norm_coords + 1) / 2.0
70
+ norm_coords = (coords + 1) / 2.0
71
+ norm_coords = torch.clamp(norm_coords * self.r, 0, self.r - 1)
72
+ # print(norm_coords.shape, norm_coords.max(), norm_coords.min())
73
+ vox_coords = torch.round(norm_coords)
74
+ # print(vox_coords.shape, vox_coords.max(), vox_coords.min())
75
+ # print(features.shape)
76
+ new_vox_feat = my_voxelization(features, vox_coords, self.r)
77
+ return new_vox_feat, norm_coords
78
+
79
+ def extra_repr(self):
80
+ return 'resolution={}{}'.format(self.r, ', normalized eps = {}'.format(self.eps) if self.normalize else '')
modules/PartField/partfield/model/PVCNN/unet_3daware.py ADDED
@@ -0,0 +1,427 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+
8
+ import einops
9
+
10
+ def conv3x3(in_channels, out_channels, stride=1,
11
+ padding=1, bias=True, groups=1):
12
+ return nn.Conv2d(
13
+ in_channels,
14
+ out_channels,
15
+ kernel_size=3,
16
+ stride=stride,
17
+ padding=padding,
18
+ bias=bias,
19
+ groups=groups)
20
+
21
+ def upconv2x2(in_channels, out_channels, mode='transpose'):
22
+ if mode == 'transpose':
23
+ return nn.ConvTranspose2d(
24
+ in_channels,
25
+ out_channels,
26
+ kernel_size=2,
27
+ stride=2)
28
+ else:
29
+ # out_channels is always going to be the same
30
+ # as in_channels
31
+ return nn.Sequential(
32
+ nn.Upsample(mode='bilinear', scale_factor=2),
33
+ conv1x1(in_channels, out_channels))
34
+
35
+ def conv1x1(in_channels, out_channels, groups=1):
36
+ return nn.Conv2d(
37
+ in_channels,
38
+ out_channels,
39
+ kernel_size=1,
40
+ groups=groups,
41
+ stride=1)
42
+
43
+ class ConvTriplane3dAware(nn.Module):
44
+ """ 3D aware triplane conv (as described in RODIN) """
45
+ def __init__(self, internal_conv_f, in_channels, out_channels, order='xz'):
46
+ """
47
+ Args:
48
+ internal_conv_f: function that should return a 2D convolution Module
49
+ given in and out channels
50
+ order: if triplane input is in 'xz' order
51
+ """
52
+ super(ConvTriplane3dAware, self).__init__()
53
+ # Need 3 seperate convolutions
54
+ self.in_channels = in_channels
55
+ self.out_channels = out_channels
56
+ assert order in ['xz', 'zx']
57
+ self.order = order
58
+ # Going to stack from other planes
59
+ self.plane_convs = nn.ModuleList([
60
+ internal_conv_f(3*self.in_channels, self.out_channels) for _ in range(3)])
61
+
62
+ def forward(self, triplanes_list):
63
+ """
64
+ Args:
65
+ triplanes_list: [(B,Ci,H,W)]*3 in xy,yz,(zx or xz) depending on order
66
+ Returns:
67
+ out_triplanes_list: [(B,Co,H,W)]*3 in xy,yz,(zx or xz) depending on order
68
+ """
69
+ inps = list(triplanes_list)
70
+ xp = 1 #(yz)
71
+ yp = 2 #(zx)
72
+ zp = 0 #(xy)
73
+
74
+ if self.order == 'xz':
75
+ # get into zx order
76
+ inps[yp] = einops.rearrange(inps[yp], 'b c x z -> b c z x')
77
+
78
+
79
+ oplanes = [None]*3
80
+ # order shouldn't matter
81
+ for iplane in [zp, xp, yp]:
82
+ # i_plane -> (j,k)
83
+
84
+ # need to average out i and convert to (j,k)
85
+ # j_plane -> (k,i)
86
+ # k_plane -> (i,j)
87
+ jplane = (iplane+1)%3
88
+ kplane = (iplane+2)%3
89
+
90
+ ifeat = inps[iplane]
91
+ # need to average out nonshared dim
92
+ # Average pool across
93
+
94
+ # j_plane -> (k,i) -> (k,1) -> (1,k) -> (j,k)
95
+ # b c k i -> b c k 1
96
+ jpool = torch.mean(inps[jplane], dim=3 ,keepdim=True)
97
+ jpool = einops.rearrange(jpool, 'b c k 1 -> b c 1 k')
98
+ jpool = einops.repeat(jpool, 'b c 1 k -> b c j k', j=ifeat.size(2))
99
+
100
+ # k_plane -> (i,j) -> (1,j) -> (j,1) -> (j,k)
101
+ # b c i j -> b c 1 j
102
+ kpool = torch.mean(inps[kplane], dim=2 ,keepdim=True)
103
+ kpool = einops.rearrange(kpool, 'b c 1 j -> b c j 1')
104
+ kpool = einops.repeat(kpool, 'b c j 1 -> b c j k', k=ifeat.size(3))
105
+
106
+ # b c h w
107
+ # jpool = jpool.expand_as(ifeat)
108
+ # kpool = kpool.expand_as(ifeat)
109
+
110
+ # concat and conv on feature dim
111
+ catfeat = torch.cat([ifeat, jpool, kpool], dim=1)
112
+ oplane = self.plane_convs[iplane](catfeat)
113
+ oplanes[iplane] = oplane
114
+
115
+ if self.order == 'xz':
116
+ # get back into xz order
117
+ oplanes[yp] = einops.rearrange(oplanes[yp], 'b c z x -> b c x z')
118
+
119
+ return oplanes
120
+
121
+ def roll_triplanes(triplanes_list):
122
+ # B, C, tri, h, w
123
+ tristack = torch.stack((triplanes_list),dim=2)
124
+ return einops.rearrange(tristack, 'b c tri h w -> b c (tri h) w', tri=3)
125
+
126
+ def unroll_triplanes(rolled_triplane):
127
+ # B, C, tri*h, w
128
+ tristack = einops.rearrange(rolled_triplane, 'b c (tri h) w -> b c tri h w', tri=3)
129
+ return torch.unbind(tristack, dim=2)
130
+
131
+ def conv1x1triplane3daware(in_channels, out_channels, order='xz', **kwargs):
132
+ return ConvTriplane3dAware(lambda inp, out: conv1x1(inp,out,**kwargs),
133
+ in_channels, out_channels,order=order)
134
+
135
+ def Normalize(in_channels, num_groups=32):
136
+ num_groups = min(in_channels, num_groups) # avoid error if in_channels < 32
137
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
138
+
139
+ def nonlinearity(x):
140
+ # return F.relu(x)
141
+ # Swish
142
+ return x*torch.sigmoid(x)
143
+
144
+ class Upsample(nn.Module):
145
+ def __init__(self, in_channels, with_conv):
146
+ super().__init__()
147
+ self.with_conv = with_conv
148
+ if self.with_conv:
149
+ self.conv = torch.nn.Conv2d(in_channels,
150
+ in_channels,
151
+ kernel_size=3,
152
+ stride=1,
153
+ padding=1)
154
+
155
+ def forward(self, x):
156
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
157
+ if self.with_conv:
158
+ x = self.conv(x)
159
+ return x
160
+
161
+ class Downsample(nn.Module):
162
+ def __init__(self, in_channels, with_conv):
163
+ super().__init__()
164
+ self.with_conv = with_conv
165
+ if self.with_conv:
166
+ # no asymmetric padding in torch conv, must do it ourselves
167
+ self.conv = torch.nn.Conv2d(in_channels,
168
+ in_channels,
169
+ kernel_size=3,
170
+ stride=2,
171
+ padding=0)
172
+
173
+ def forward(self, x):
174
+ if self.with_conv:
175
+ pad = (0,1,0,1)
176
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
177
+ x = self.conv(x)
178
+ else:
179
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
180
+ return x
181
+
182
+ class ResnetBlock3dAware(nn.Module):
183
+ def __init__(self, in_channels, out_channels=None):
184
+ #, conv_shortcut=False):
185
+ super().__init__()
186
+ self.in_channels = in_channels
187
+ out_channels = in_channels if out_channels is None else out_channels
188
+ self.out_channels = out_channels
189
+ # self.use_conv_shortcut = conv_shortcut
190
+
191
+ self.norm1 = Normalize(in_channels)
192
+ self.conv1 = conv3x3(self.in_channels, self.out_channels)
193
+
194
+ self.norm_mid = Normalize(out_channels)
195
+ self.conv_3daware = conv1x1triplane3daware(self.out_channels, self.out_channels)
196
+
197
+ self.norm2 = Normalize(out_channels)
198
+ self.conv2 = conv3x3(self.out_channels, self.out_channels)
199
+
200
+ if self.in_channels != self.out_channels:
201
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
202
+ out_channels,
203
+ kernel_size=1,
204
+ stride=1,
205
+ padding=0)
206
+
207
+ def forward(self, x):
208
+ # 3x3 plane comm
209
+ h = x
210
+ h = self.norm1(h)
211
+ h = nonlinearity(h)
212
+ h = self.conv1(h)
213
+
214
+ # 1x1 3d aware, crossplane comm
215
+ h = self.norm_mid(h)
216
+ h = nonlinearity(h)
217
+ h = unroll_triplanes(h)
218
+ h = self.conv_3daware(h)
219
+ h = roll_triplanes(h)
220
+
221
+ # 3x3 plane comm
222
+ h = self.norm2(h)
223
+ h = nonlinearity(h)
224
+ h = self.conv2(h)
225
+
226
+ if self.in_channels != self.out_channels:
227
+ x = self.nin_shortcut(x)
228
+
229
+ return x+h
230
+
231
+ class DownConv3dAware(nn.Module):
232
+ """
233
+ A helper Module that performs 2 convolutions and 1 MaxPool.
234
+ A ReLU activation follows each convolution.
235
+ """
236
+ def __init__(self, in_channels, out_channels, downsample=True, with_conv=False):
237
+ super(DownConv3dAware, self).__init__()
238
+
239
+ self.in_channels = in_channels
240
+ self.out_channels = out_channels
241
+
242
+ self.block = ResnetBlock3dAware(in_channels=in_channels,
243
+ out_channels=out_channels)
244
+
245
+ self.do_downsample = downsample
246
+ self.downsample = Downsample(out_channels, with_conv=with_conv)
247
+
248
+ def forward(self, x):
249
+ """
250
+ rolled input, rolled output
251
+ Args:
252
+ x: rolled (b c (tri*h) w)
253
+ """
254
+ x = self.block(x)
255
+ before_pool = x
256
+ # if self.pooling:
257
+ # x = self.pool(x)
258
+ if self.do_downsample:
259
+ # unroll and cat channel-wise (to prevent pooling across triplane boundaries)
260
+ x = einops.rearrange(x, 'b c (tri h) w -> b (c tri) h w', tri=3)
261
+ x = self.downsample(x)
262
+ # undo
263
+ x = einops.rearrange(x, 'b (c tri) h w -> b c (tri h) w', tri=3)
264
+ return x, before_pool
265
+
266
+ class UpConv3dAware(nn.Module):
267
+ """
268
+ A helper Module that performs 2 convolutions and 1 UpConvolution.
269
+ A ReLU activation follows each convolution.
270
+ """
271
+ def __init__(self, in_channels, out_channels,
272
+ merge_mode='concat', with_conv=False): #up_mode='transpose', ):
273
+ super(UpConv3dAware, self).__init__()
274
+
275
+ self.in_channels = in_channels
276
+ self.out_channels = out_channels
277
+ self.merge_mode = merge_mode
278
+
279
+ self.upsample = Upsample(in_channels, with_conv)
280
+
281
+ if self.merge_mode == 'concat':
282
+ self.norm1 = Normalize(in_channels+out_channels)
283
+ self.block = ResnetBlock3dAware(in_channels=in_channels+out_channels,
284
+ out_channels=out_channels)
285
+ else:
286
+ self.norm1 = Normalize(in_channels)
287
+ self.block = ResnetBlock3dAware(in_channels=in_channels,
288
+ out_channels=out_channels)
289
+
290
+
291
+ def forward(self, from_down, from_up):
292
+ """ Forward pass
293
+ rolled inputs, rolled output
294
+ rolled (b c (tri*h) w)
295
+ Arguments:
296
+ from_down: tensor from the encoder pathway
297
+ from_up: upconv'd tensor from the decoder pathway
298
+ """
299
+ # from_up = self.upconv(from_up)
300
+ from_up = self.upsample(from_up)
301
+ if self.merge_mode == 'concat':
302
+ x = torch.cat((from_up, from_down), 1)
303
+ else:
304
+ x = from_up + from_down
305
+
306
+ x = self.norm1(x)
307
+ x = self.block(x)
308
+ return x
309
+
310
+ class UNetTriplane3dAware(nn.Module):
311
+ def __init__(self, out_channels, in_channels=3, depth=5,
312
+ start_filts=64,# up_mode='transpose',
313
+ use_initial_conv=False,
314
+ merge_mode='concat', **kwargs):
315
+ """
316
+ Arguments:
317
+ in_channels: int, number of channels in the input tensor.
318
+ Default is 3 for RGB images.
319
+ depth: int, number of MaxPools in the U-Net.
320
+ start_filts: int, number of convolutional filters for the
321
+ first conv.
322
+ """
323
+ super(UNetTriplane3dAware, self).__init__()
324
+
325
+
326
+ self.out_channels = out_channels
327
+ self.in_channels = in_channels
328
+ self.start_filts = start_filts
329
+ self.depth = depth
330
+
331
+ self.use_initial_conv = use_initial_conv
332
+ if use_initial_conv:
333
+ self.conv_initial = conv1x1(self.in_channels, self.start_filts)
334
+
335
+ self.down_convs = []
336
+ self.up_convs = []
337
+
338
+ # create the encoder pathway and add to a list
339
+ for i in range(depth):
340
+ if i == 0:
341
+ ins = self.start_filts if use_initial_conv else self.in_channels
342
+ else:
343
+ ins = outs
344
+ outs = self.start_filts*(2**i)
345
+ downsamp_it = True if i < depth-1 else False
346
+
347
+ down_conv = DownConv3dAware(ins, outs, downsample = downsamp_it)
348
+ self.down_convs.append(down_conv)
349
+
350
+ for i in range(depth-1):
351
+ ins = outs
352
+ outs = ins // 2
353
+ up_conv = UpConv3dAware(ins, outs,
354
+ merge_mode=merge_mode)
355
+ self.up_convs.append(up_conv)
356
+
357
+ # add the list of modules to current module
358
+ self.down_convs = nn.ModuleList(self.down_convs)
359
+ self.up_convs = nn.ModuleList(self.up_convs)
360
+
361
+ self.norm_out = Normalize(outs)
362
+ self.conv_final = conv1x1(outs, self.out_channels)
363
+
364
+ self.reset_params()
365
+
366
+ @staticmethod
367
+ def weight_init(m):
368
+ if isinstance(m, nn.Conv2d):
369
+ # init.xavier_normal_(m.weight, gain=0.1)
370
+ init.xavier_normal_(m.weight)
371
+ init.constant_(m.bias, 0)
372
+
373
+
374
+ def reset_params(self):
375
+ for i, m in enumerate(self.modules()):
376
+ self.weight_init(m)
377
+
378
+
379
+ def forward(self, x):
380
+ """
381
+ Args:
382
+ x: Stacked triplane expected to be in (B,3,C,H,W)
383
+ """
384
+ # Roll
385
+ x = einops.rearrange(x, 'b tri c h w -> b c (tri h) w', tri=3)
386
+
387
+ if self.use_initial_conv:
388
+ x = self.conv_initial(x)
389
+
390
+ encoder_outs = []
391
+ # encoder pathway, save outputs for merging
392
+ for i, module in enumerate(self.down_convs):
393
+ x, before_pool = module(x)
394
+ encoder_outs.append(before_pool)
395
+
396
+ # Spend a block in the middle
397
+ # x = self.block_mid(x)
398
+
399
+ for i, module in enumerate(self.up_convs):
400
+ before_pool = encoder_outs[-(i+2)]
401
+ x = module(before_pool, x)
402
+
403
+ x = self.norm_out(x)
404
+
405
+ # No softmax is used. This means you need to use
406
+ # nn.CrossEntropyLoss is your training script,
407
+ # as this module includes a softmax already.
408
+ x = self.conv_final(nonlinearity(x))
409
+
410
+ # Unroll
411
+ x = einops.rearrange(x, 'b c (tri h) w -> b tri c h w', tri=3)
412
+ return x
413
+
414
+
415
+ def setup_unet(output_channels, input_channels, unet_cfg):
416
+ if unet_cfg['use_3d_aware']:
417
+ assert(unet_cfg['rolled'])
418
+ unet = UNetTriplane3dAware(
419
+ out_channels=output_channels,
420
+ in_channels=input_channels,
421
+ depth=unet_cfg['depth'],
422
+ use_initial_conv=unet_cfg['use_initial_conv'],
423
+ start_filts=unet_cfg['start_hidden_channels'],)
424
+ else:
425
+ raise NotImplementedError
426
+ return unet
427
+
modules/PartField/partfield/model/UNet/buildingblocks.py ADDED
@@ -0,0 +1,546 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py
2
+ # MIT License
3
+
4
+ # Copyright (c) 2018 Adrian Wolny
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ from functools import partial
25
+
26
+ import torch
27
+ from torch import nn as nn
28
+ from torch.nn import functional as F
29
+
30
+ # from pytorch3dunet.unet3d.se import ChannelSELayer3D, ChannelSpatialSELayer3D, SpatialSELayer3D
31
+
32
+
33
+ def create_conv(in_channels, out_channels, kernel_size, order, num_groups, padding,
34
+ dropout_prob, is3d):
35
+ """
36
+ Create a list of modules with together constitute a single conv layer with non-linearity
37
+ and optional batchnorm/groupnorm.
38
+
39
+ Args:
40
+ in_channels (int): number of input channels
41
+ out_channels (int): number of output channels
42
+ kernel_size(int or tuple): size of the convolving kernel
43
+ order (string): order of things, e.g.
44
+ 'cr' -> conv + ReLU
45
+ 'gcr' -> groupnorm + conv + ReLU
46
+ 'cl' -> conv + LeakyReLU
47
+ 'ce' -> conv + ELU
48
+ 'bcr' -> batchnorm + conv + ReLU
49
+ 'cbrd' -> conv + batchnorm + ReLU + dropout
50
+ 'cbrD' -> conv + batchnorm + ReLU + dropout2d
51
+ num_groups (int): number of groups for the GroupNorm
52
+ padding (int or tuple): add zero-padding added to all three sides of the input
53
+ dropout_prob (float): dropout probability
54
+ is3d (bool): is3d (bool): if True use Conv3d, otherwise use Conv2d
55
+ Return:
56
+ list of tuple (name, module)
57
+ """
58
+ assert 'c' in order, "Conv layer MUST be present"
59
+ assert order[0] not in 'rle', 'Non-linearity cannot be the first operation in the layer'
60
+
61
+ modules = []
62
+ for i, char in enumerate(order):
63
+ if char == 'r':
64
+ modules.append(('ReLU', nn.ReLU(inplace=True)))
65
+ elif char == 'l':
66
+ modules.append(('LeakyReLU', nn.LeakyReLU(inplace=True)))
67
+ elif char == 'e':
68
+ modules.append(('ELU', nn.ELU(inplace=True)))
69
+ elif char == 'c':
70
+ # add learnable bias only in the absence of batchnorm/groupnorm
71
+ bias = not ('g' in order or 'b' in order)
72
+ if is3d:
73
+ conv = nn.Conv3d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
74
+ else:
75
+ conv = nn.Conv2d(in_channels, out_channels, kernel_size, padding=padding, bias=bias)
76
+
77
+ modules.append(('conv', conv))
78
+ elif char == 'g':
79
+ is_before_conv = i < order.index('c')
80
+ if is_before_conv:
81
+ num_channels = in_channels
82
+ else:
83
+ num_channels = out_channels
84
+
85
+ # use only one group if the given number of groups is greater than the number of channels
86
+ if num_channels < num_groups:
87
+ num_groups = 1
88
+
89
+ assert num_channels % num_groups == 0, f'Expected number of channels in input to be divisible by num_groups. num_channels={num_channels}, num_groups={num_groups}'
90
+ modules.append(('groupnorm', nn.GroupNorm(num_groups=num_groups, num_channels=num_channels)))
91
+ elif char == 'b':
92
+ is_before_conv = i < order.index('c')
93
+ if is3d:
94
+ bn = nn.BatchNorm3d
95
+ else:
96
+ bn = nn.BatchNorm2d
97
+
98
+ if is_before_conv:
99
+ modules.append(('batchnorm', bn(in_channels)))
100
+ else:
101
+ modules.append(('batchnorm', bn(out_channels)))
102
+ elif char == 'd':
103
+ modules.append(('dropout', nn.Dropout(p=dropout_prob)))
104
+ elif char == 'D':
105
+ modules.append(('dropout2d', nn.Dropout2d(p=dropout_prob)))
106
+ else:
107
+ raise ValueError(f"Unsupported layer type '{char}'. MUST be one of ['b', 'g', 'r', 'l', 'e', 'c', 'd', 'D']")
108
+
109
+ return modules
110
+
111
+
112
+ class SingleConv(nn.Sequential):
113
+ """
114
+ Basic convolutional module consisting of a Conv3d, non-linearity and optional batchnorm/groupnorm. The order
115
+ of operations can be specified via the `order` parameter
116
+
117
+ Args:
118
+ in_channels (int): number of input channels
119
+ out_channels (int): number of output channels
120
+ kernel_size (int or tuple): size of the convolving kernel
121
+ order (string): determines the order of layers, e.g.
122
+ 'cr' -> conv + ReLU
123
+ 'crg' -> conv + ReLU + groupnorm
124
+ 'cl' -> conv + LeakyReLU
125
+ 'ce' -> conv + ELU
126
+ num_groups (int): number of groups for the GroupNorm
127
+ padding (int or tuple): add zero-padding
128
+ dropout_prob (float): dropout probability, default 0.1
129
+ is3d (bool): if True use Conv3d, otherwise use Conv2d
130
+ """
131
+
132
+ def __init__(self, in_channels, out_channels, kernel_size=3, order='gcr', num_groups=8,
133
+ padding=1, dropout_prob=0.1, is3d=True):
134
+ super(SingleConv, self).__init__()
135
+
136
+ for name, module in create_conv(in_channels, out_channels, kernel_size, order,
137
+ num_groups, padding, dropout_prob, is3d):
138
+ self.add_module(name, module)
139
+
140
+
141
+ class DoubleConv(nn.Sequential):
142
+ """
143
+ A module consisting of two consecutive convolution layers (e.g. BatchNorm3d+ReLU+Conv3d).
144
+ We use (Conv3d+ReLU+GroupNorm3d) by default.
145
+ This can be changed however by providing the 'order' argument, e.g. in order
146
+ to change to Conv3d+BatchNorm3d+ELU use order='cbe'.
147
+ Use padded convolutions to make sure that the output (H_out, W_out) is the same
148
+ as (H_in, W_in), so that you don't have to crop in the decoder path.
149
+
150
+ Args:
151
+ in_channels (int): number of input channels
152
+ out_channels (int): number of output channels
153
+ encoder (bool): if True we're in the encoder path, otherwise we're in the decoder
154
+ kernel_size (int or tuple): size of the convolving kernel
155
+ order (string): determines the order of layers, e.g.
156
+ 'cr' -> conv + ReLU
157
+ 'crg' -> conv + ReLU + groupnorm
158
+ 'cl' -> conv + LeakyReLU
159
+ 'ce' -> conv + ELU
160
+ num_groups (int): number of groups for the GroupNorm
161
+ padding (int or tuple): add zero-padding added to all three sides of the input
162
+ upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
163
+ dropout_prob (float or tuple): dropout probability for each convolution, default 0.1
164
+ is3d (bool): if True use Conv3d instead of Conv2d layers
165
+ """
166
+
167
+ def __init__(self, in_channels, out_channels, encoder, kernel_size=3, order='gcr',
168
+ num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
169
+ super(DoubleConv, self).__init__()
170
+ if encoder:
171
+ # we're in the encoder path
172
+ conv1_in_channels = in_channels
173
+ if upscale == 1:
174
+ conv1_out_channels = out_channels
175
+ else:
176
+ conv1_out_channels = out_channels // 2
177
+ if conv1_out_channels < in_channels:
178
+ conv1_out_channels = in_channels
179
+ conv2_in_channels, conv2_out_channels = conv1_out_channels, out_channels
180
+ else:
181
+ # we're in the decoder path, decrease the number of channels in the 1st convolution
182
+ conv1_in_channels, conv1_out_channels = in_channels, out_channels
183
+ conv2_in_channels, conv2_out_channels = out_channels, out_channels
184
+
185
+ # check if dropout_prob is a tuple and if so
186
+ # split it for different dropout probabilities for each convolution.
187
+ if isinstance(dropout_prob, list) or isinstance(dropout_prob, tuple):
188
+ dropout_prob1 = dropout_prob[0]
189
+ dropout_prob2 = dropout_prob[1]
190
+ else:
191
+ dropout_prob1 = dropout_prob2 = dropout_prob
192
+
193
+ # conv1
194
+ self.add_module('SingleConv1',
195
+ SingleConv(conv1_in_channels, conv1_out_channels, kernel_size, order, num_groups,
196
+ padding=padding, dropout_prob=dropout_prob1, is3d=is3d))
197
+ # conv2
198
+ self.add_module('SingleConv2',
199
+ SingleConv(conv2_in_channels, conv2_out_channels, kernel_size, order, num_groups,
200
+ padding=padding, dropout_prob=dropout_prob2, is3d=is3d))
201
+
202
+
203
+ class ResNetBlock(nn.Module):
204
+ """
205
+ Residual block that can be used instead of standard DoubleConv in the Encoder module.
206
+ Motivated by: https://arxiv.org/pdf/1706.00120.pdf
207
+
208
+ Notice we use ELU instead of ReLU (order='cge') and put non-linearity after the groupnorm.
209
+ """
210
+
211
+ def __init__(self, in_channels, out_channels, kernel_size=3, order='cge', num_groups=8, is3d=True, **kwargs):
212
+ super(ResNetBlock, self).__init__()
213
+
214
+ if in_channels != out_channels:
215
+ # conv1x1 for increasing the number of channels
216
+ if is3d:
217
+ self.conv1 = nn.Conv3d(in_channels, out_channels, 1)
218
+ else:
219
+ self.conv1 = nn.Conv2d(in_channels, out_channels, 1)
220
+ else:
221
+ self.conv1 = nn.Identity()
222
+
223
+ self.conv2 = SingleConv(in_channels, out_channels, kernel_size=kernel_size, order=order, num_groups=num_groups,
224
+ is3d=is3d)
225
+ # remove non-linearity from the 3rd convolution since it's going to be applied after adding the residual
226
+ n_order = order
227
+ for c in 'rel':
228
+ n_order = n_order.replace(c, '')
229
+ self.conv3 = SingleConv(out_channels, out_channels, kernel_size=kernel_size, order=n_order,
230
+ num_groups=num_groups, is3d=is3d)
231
+
232
+ # create non-linearity separately
233
+ if 'l' in order:
234
+ self.non_linearity = nn.LeakyReLU(negative_slope=0.1, inplace=True)
235
+ elif 'e' in order:
236
+ self.non_linearity = nn.ELU(inplace=True)
237
+ else:
238
+ self.non_linearity = nn.ReLU(inplace=True)
239
+
240
+ def forward(self, x):
241
+ # apply first convolution to bring the number of channels to out_channels
242
+ residual = self.conv1(x)
243
+
244
+ out = self.conv2(x)
245
+ out = self.conv3(out)
246
+
247
+ out += residual
248
+ out = self.non_linearity(out)
249
+
250
+ return out
251
+
252
+ class Encoder(nn.Module):
253
+ """
254
+ A single module from the encoder path consisting of the optional max
255
+ pooling layer (one may specify the MaxPool kernel_size to be different
256
+ from the standard (2,2,2), e.g. if the volumetric data is anisotropic
257
+ (make sure to use complementary scale_factor in the decoder path) followed by
258
+ a basic module (DoubleConv or ResNetBlock).
259
+
260
+ Args:
261
+ in_channels (int): number of input channels
262
+ out_channels (int): number of output channels
263
+ conv_kernel_size (int or tuple): size of the convolving kernel
264
+ apply_pooling (bool): if True use MaxPool3d before DoubleConv
265
+ pool_kernel_size (int or tuple): the size of the window
266
+ pool_type (str): pooling layer: 'max' or 'avg'
267
+ basic_module(nn.Module): either ResNetBlock or DoubleConv
268
+ conv_layer_order (string): determines the order of layers
269
+ in `DoubleConv` module. See `DoubleConv` for more info.
270
+ num_groups (int): number of groups for the GroupNorm
271
+ padding (int or tuple): add zero-padding added to all three sides of the input
272
+ upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
273
+ dropout_prob (float or tuple): dropout probability, default 0.1
274
+ is3d (bool): use 3d or 2d convolutions/pooling operation
275
+ """
276
+
277
+ def __init__(self, in_channels, out_channels, conv_kernel_size=3, apply_pooling=True,
278
+ pool_kernel_size=2, pool_type='max', basic_module=DoubleConv, conv_layer_order='gcr',
279
+ num_groups=8, padding=1, upscale=2, dropout_prob=0.1, is3d=True):
280
+ super(Encoder, self).__init__()
281
+ assert pool_type in ['max', 'avg']
282
+ if apply_pooling:
283
+ if pool_type == 'max':
284
+ if is3d:
285
+ self.pooling = nn.MaxPool3d(kernel_size=pool_kernel_size)
286
+ else:
287
+ self.pooling = nn.MaxPool2d(kernel_size=pool_kernel_size)
288
+ else:
289
+ if is3d:
290
+ self.pooling = nn.AvgPool3d(kernel_size=pool_kernel_size)
291
+ else:
292
+ self.pooling = nn.AvgPool2d(kernel_size=pool_kernel_size)
293
+ else:
294
+ self.pooling = None
295
+
296
+ self.basic_module = basic_module(in_channels, out_channels,
297
+ encoder=True,
298
+ kernel_size=conv_kernel_size,
299
+ order=conv_layer_order,
300
+ num_groups=num_groups,
301
+ padding=padding,
302
+ upscale=upscale,
303
+ dropout_prob=dropout_prob,
304
+ is3d=is3d)
305
+
306
+ def forward(self, x):
307
+ if self.pooling is not None:
308
+ x = self.pooling(x)
309
+ x = self.basic_module(x)
310
+ return x
311
+
312
+
313
+ class Decoder(nn.Module):
314
+ """
315
+ A single module for decoder path consisting of the upsampling layer
316
+ (either learned ConvTranspose3d or nearest neighbor interpolation)
317
+ followed by a basic module (DoubleConv or ResNetBlock).
318
+
319
+ Args:
320
+ in_channels (int): number of input channels
321
+ out_channels (int): number of output channels
322
+ conv_kernel_size (int or tuple): size of the convolving kernel
323
+ scale_factor (int or tuple): used as the multiplier for the image H/W/D in
324
+ case of nn.Upsample or as stride in case of ConvTranspose3d, must reverse the MaxPool3d operation
325
+ from the corresponding encoder
326
+ basic_module(nn.Module): either ResNetBlock or DoubleConv
327
+ conv_layer_order (string): determines the order of layers
328
+ in `DoubleConv` module. See `DoubleConv` for more info.
329
+ num_groups (int): number of groups for the GroupNorm
330
+ padding (int or tuple): add zero-padding added to all three sides of the input
331
+ upsample (str): algorithm used for upsampling:
332
+ InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
333
+ TransposeConvUpsampling: 'deconv'
334
+ No upsampling: None
335
+ Default: 'default' (chooses automatically)
336
+ dropout_prob (float or tuple): dropout probability, default 0.1
337
+ """
338
+
339
+ def __init__(self, in_channels, out_channels, conv_kernel_size=3, scale_factor=2, basic_module=DoubleConv,
340
+ conv_layer_order='gcr', num_groups=8, padding=1, upsample='default',
341
+ dropout_prob=0.1, is3d=True):
342
+ super(Decoder, self).__init__()
343
+
344
+ # perform concat joining per default
345
+ concat = True
346
+
347
+ # don't adapt channels after join operation
348
+ adapt_channels = False
349
+
350
+ if upsample is not None and upsample != 'none':
351
+ if upsample == 'default':
352
+ if basic_module == DoubleConv:
353
+ upsample = 'nearest' # use nearest neighbor interpolation for upsampling
354
+ concat = True # use concat joining
355
+ adapt_channels = False # don't adapt channels
356
+ elif basic_module == ResNetBlock: #or basic_module == ResNetBlockSE:
357
+ upsample = 'deconv' # use deconvolution upsampling
358
+ concat = False # use summation joining
359
+ adapt_channels = True # adapt channels after joining
360
+
361
+ # perform deconvolution upsampling if mode is deconv
362
+ if upsample == 'deconv':
363
+ self.upsampling = TransposeConvUpsampling(in_channels=in_channels, out_channels=out_channels,
364
+ kernel_size=conv_kernel_size, scale_factor=scale_factor,
365
+ is3d=is3d)
366
+ else:
367
+ self.upsampling = InterpolateUpsampling(mode=upsample)
368
+ else:
369
+ # no upsampling
370
+ self.upsampling = NoUpsampling()
371
+ # concat joining
372
+ self.joining = partial(self._joining, concat=True)
373
+
374
+ # perform joining operation
375
+ self.joining = partial(self._joining, concat=concat)
376
+
377
+ # adapt the number of in_channels for the ResNetBlock
378
+ if adapt_channels is True:
379
+ in_channels = out_channels
380
+
381
+ self.basic_module = basic_module(in_channels, out_channels,
382
+ encoder=False,
383
+ kernel_size=conv_kernel_size,
384
+ order=conv_layer_order,
385
+ num_groups=num_groups,
386
+ padding=padding,
387
+ dropout_prob=dropout_prob,
388
+ is3d=is3d)
389
+
390
+ def forward(self, encoder_features, x):
391
+ x = self.upsampling(encoder_features=encoder_features, x=x)
392
+ x = self.joining(encoder_features, x)
393
+ x = self.basic_module(x)
394
+ return x
395
+
396
+ @staticmethod
397
+ def _joining(encoder_features, x, concat):
398
+ if concat:
399
+ return torch.cat((encoder_features, x), dim=1)
400
+ else:
401
+ return encoder_features + x
402
+
403
+
404
+ def create_encoders(in_channels, f_maps, basic_module, conv_kernel_size, conv_padding,
405
+ conv_upscale, dropout_prob,
406
+ layer_order, num_groups, pool_kernel_size, is3d):
407
+ # create encoder path consisting of Encoder modules. Depth of the encoder is equal to `len(f_maps)`
408
+ encoders = []
409
+ for i, out_feature_num in enumerate(f_maps):
410
+ if i == 0:
411
+ # apply conv_coord only in the first encoder if any
412
+ encoder = Encoder(in_channels, out_feature_num,
413
+ apply_pooling=False, # skip pooling in the firs encoder
414
+ basic_module=basic_module,
415
+ conv_layer_order=layer_order,
416
+ conv_kernel_size=conv_kernel_size,
417
+ num_groups=num_groups,
418
+ padding=conv_padding,
419
+ upscale=conv_upscale,
420
+ dropout_prob=dropout_prob,
421
+ is3d=is3d)
422
+ else:
423
+ encoder = Encoder(f_maps[i - 1], out_feature_num,
424
+ basic_module=basic_module,
425
+ conv_layer_order=layer_order,
426
+ conv_kernel_size=conv_kernel_size,
427
+ num_groups=num_groups,
428
+ pool_kernel_size=pool_kernel_size,
429
+ padding=conv_padding,
430
+ upscale=conv_upscale,
431
+ dropout_prob=dropout_prob,
432
+ is3d=is3d)
433
+
434
+ encoders.append(encoder)
435
+
436
+ return nn.ModuleList(encoders)
437
+
438
+
439
+ def create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding, layer_order,
440
+ num_groups, upsample, dropout_prob, is3d):
441
+ # create decoder path consisting of the Decoder modules. The length of the decoder list is equal to `len(f_maps) - 1`
442
+ decoders = []
443
+ reversed_f_maps = list(reversed(f_maps[1:]))
444
+ for i in range(len(reversed_f_maps) - 1):
445
+ if basic_module == DoubleConv and upsample != 'deconv':
446
+ in_feature_num = reversed_f_maps[i] + reversed_f_maps[i + 1]
447
+ else:
448
+ in_feature_num = reversed_f_maps[i]
449
+
450
+ out_feature_num = reversed_f_maps[i + 1]
451
+
452
+ decoder = Decoder(in_feature_num, out_feature_num,
453
+ basic_module=basic_module,
454
+ conv_layer_order=layer_order,
455
+ conv_kernel_size=conv_kernel_size,
456
+ num_groups=num_groups,
457
+ padding=conv_padding,
458
+ upsample=upsample,
459
+ dropout_prob=dropout_prob,
460
+ is3d=is3d)
461
+ decoders.append(decoder)
462
+ return nn.ModuleList(decoders)
463
+
464
+
465
+ class AbstractUpsampling(nn.Module):
466
+ """
467
+ Abstract class for upsampling. A given implementation should upsample a given 5D input tensor using either
468
+ interpolation or learned transposed convolution.
469
+ """
470
+
471
+ def __init__(self, upsample):
472
+ super(AbstractUpsampling, self).__init__()
473
+ self.upsample = upsample
474
+
475
+ def forward(self, encoder_features, x):
476
+ # get the spatial dimensions of the output given the encoder_features
477
+ output_size = encoder_features.size()[2:]
478
+ # upsample the input and return
479
+ return self.upsample(x, output_size)
480
+
481
+
482
+ class InterpolateUpsampling(AbstractUpsampling):
483
+ """
484
+ Args:
485
+ mode (str): algorithm used for upsampling:
486
+ 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'. Default: 'nearest'
487
+ used only if transposed_conv is False
488
+ """
489
+
490
+ def __init__(self, mode='nearest'):
491
+ upsample = partial(self._interpolate, mode=mode)
492
+ super().__init__(upsample)
493
+
494
+ @staticmethod
495
+ def _interpolate(x, size, mode):
496
+ return F.interpolate(x, size=size, mode=mode)
497
+
498
+
499
+ class TransposeConvUpsampling(AbstractUpsampling):
500
+ """
501
+ Args:
502
+ in_channels (int): number of input channels for transposed conv
503
+ used only if transposed_conv is True
504
+ out_channels (int): number of output channels for transpose conv
505
+ used only if transposed_conv is True
506
+ kernel_size (int or tuple): size of the convolving kernel
507
+ used only if transposed_conv is True
508
+ scale_factor (int or tuple): stride of the convolution
509
+ used only if transposed_conv is True
510
+ is3d (bool): if True use ConvTranspose3d, otherwise use ConvTranspose2d
511
+ """
512
+
513
+ class Upsample(nn.Module):
514
+ """
515
+ Workaround the 'ValueError: requested an output size...' in the `_output_padding` method in
516
+ transposed convolution. It performs transposed conv followed by the interpolation to the correct size if necessary.
517
+ """
518
+
519
+ def __init__(self, conv_transposed, is3d):
520
+ super().__init__()
521
+ self.conv_transposed = conv_transposed
522
+ self.is3d = is3d
523
+
524
+ def forward(self, x, size):
525
+ x = self.conv_transposed(x)
526
+ return F.interpolate(x, size=size)
527
+
528
+ def __init__(self, in_channels, out_channels, kernel_size=3, scale_factor=2, is3d=True):
529
+ # make sure that the output size reverses the MaxPool3d from the corresponding encoder
530
+ if is3d is True:
531
+ conv_transposed = nn.ConvTranspose3d(in_channels, out_channels, kernel_size=kernel_size,
532
+ stride=scale_factor, padding=1, bias=False)
533
+ else:
534
+ conv_transposed = nn.ConvTranspose2d(in_channels, out_channels, kernel_size=kernel_size,
535
+ stride=scale_factor, padding=1, bias=False)
536
+ upsample = self.Upsample(conv_transposed, is3d)
537
+ super().__init__(upsample)
538
+
539
+
540
+ class NoUpsampling(AbstractUpsampling):
541
+ def __init__(self):
542
+ super().__init__(self._no_upsampling)
543
+
544
+ @staticmethod
545
+ def _no_upsampling(x, size):
546
+ return x
modules/PartField/partfield/model/UNet/model.py ADDED
@@ -0,0 +1,170 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # https://github.com/wolny/pytorch-3dunet/blob/master/pytorch3dunet/unet3d/buildingblocks.py
2
+ # MIT License
3
+
4
+ # Copyright (c) 2018 Adrian Wolny
5
+
6
+ # Permission is hereby granted, free of charge, to any person obtaining a copy
7
+ # of this software and associated documentation files (the "Software"), to deal
8
+ # in the Software without restriction, including without limitation the rights
9
+ # to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
10
+ # copies of the Software, and to permit persons to whom the Software is
11
+ # furnished to do so, subject to the following conditions:
12
+
13
+ # The above copyright notice and this permission notice shall be included in all
14
+ # copies or substantial portions of the Software.
15
+
16
+ # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17
+ # IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18
+ # FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21
+ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22
+ # SOFTWARE.
23
+
24
+ import torch.nn as nn
25
+
26
+ from partfield.model.UNet.buildingblocks import DoubleConv, ResNetBlock, \
27
+ create_decoders, create_encoders
28
+
29
+ def number_of_features_per_level(init_channel_number, num_levels):
30
+ return [init_channel_number * 2 ** k for k in range(num_levels)]
31
+
32
+ class AbstractUNet(nn.Module):
33
+ """
34
+ Base class for standard and residual UNet.
35
+
36
+ Args:
37
+ in_channels (int): number of input channels
38
+ out_channels (int): number of output segmentation masks;
39
+ Note that the of out_channels might correspond to either
40
+ different semantic classes or to different binary segmentation mask.
41
+ It's up to the user of the class to interpret the out_channels and
42
+ use the proper loss criterion during training (i.e. CrossEntropyLoss (multi-class)
43
+ or BCEWithLogitsLoss (two-class) respectively)
44
+ f_maps (int, tuple): number of feature maps at each level of the encoder; if it's an integer the number
45
+ of feature maps is given by the geometric progression: f_maps ^ k, k=1,2,3,4
46
+ final_sigmoid (bool): if True apply element-wise nn.Sigmoid after the final 1x1 convolution,
47
+ otherwise apply nn.Softmax. In effect only if `self.training == False`, i.e. during validation/testing
48
+ basic_module: basic model for the encoder/decoder (DoubleConv, ResNetBlock, ....)
49
+ layer_order (string): determines the order of layers in `SingleConv` module.
50
+ E.g. 'crg' stands for GroupNorm3d+Conv3d+ReLU. See `SingleConv` for more info
51
+ num_groups (int): number of groups for the GroupNorm
52
+ num_levels (int): number of levels in the encoder/decoder path (applied only if f_maps is an int)
53
+ default: 4
54
+ is_segmentation (bool): if True and the model is in eval mode, Sigmoid/Softmax normalization is applied
55
+ after the final convolution; if False (regression problem) the normalization layer is skipped
56
+ conv_kernel_size (int or tuple): size of the convolving kernel in the basic_module
57
+ pool_kernel_size (int or tuple): the size of the window
58
+ conv_padding (int or tuple): add zero-padding added to all three sides of the input
59
+ conv_upscale (int): number of the convolution to upscale in encoder if DoubleConv, default: 2
60
+ upsample (str): algorithm used for decoder upsampling:
61
+ InterpolateUpsampling: 'nearest' | 'linear' | 'bilinear' | 'trilinear' | 'area'
62
+ TransposeConvUpsampling: 'deconv'
63
+ No upsampling: None
64
+ Default: 'default' (chooses automatically)
65
+ dropout_prob (float or tuple): dropout probability, default: 0.1
66
+ is3d (bool): if True the model is 3D, otherwise 2D, default: True
67
+ """
68
+
69
+ def __init__(self, in_channels, out_channels, final_sigmoid, basic_module, f_maps=64, layer_order='gcr',
70
+ num_groups=8, num_levels=4, is_segmentation=False, conv_kernel_size=3, pool_kernel_size=2,
71
+ conv_padding=1, conv_upscale=2, upsample='default', dropout_prob=0.1, is3d=True, encoder_only=False):
72
+ super(AbstractUNet, self).__init__()
73
+
74
+ if isinstance(f_maps, int):
75
+ f_maps = number_of_features_per_level(f_maps, num_levels=num_levels)
76
+
77
+ assert isinstance(f_maps, list) or isinstance(f_maps, tuple)
78
+ assert len(f_maps) > 1, "Required at least 2 levels in the U-Net"
79
+ if 'g' in layer_order:
80
+ assert num_groups is not None, "num_groups must be specified if GroupNorm is used"
81
+
82
+ # create encoder path
83
+ self.encoders = create_encoders(in_channels, f_maps, basic_module, conv_kernel_size,
84
+ conv_padding, conv_upscale, dropout_prob,
85
+ layer_order, num_groups, pool_kernel_size, is3d)
86
+
87
+ self.encoder_only = encoder_only
88
+
89
+ if encoder_only == False:
90
+ # create decoder path
91
+ self.decoders = create_decoders(f_maps, basic_module, conv_kernel_size, conv_padding,
92
+ layer_order, num_groups, upsample, dropout_prob,
93
+ is3d)
94
+
95
+ # in the last layer a 1×1 convolution reduces the number of output channels to the number of labels
96
+ if is3d:
97
+ self.final_conv = nn.Conv3d(f_maps[1], out_channels, 1)
98
+ else:
99
+ self.final_conv = nn.Conv2d(f_maps[1], out_channels, 1)
100
+
101
+ if is_segmentation:
102
+ # semantic segmentation problem
103
+ if final_sigmoid:
104
+ self.final_activation = nn.Sigmoid()
105
+ else:
106
+ self.final_activation = nn.Softmax(dim=1)
107
+ else:
108
+ # regression problem
109
+ self.final_activation = None
110
+
111
+ def forward(self, x, return_bottleneck_feat=False):
112
+ # encoder part
113
+ encoders_features = []
114
+ for encoder in self.encoders:
115
+ x = encoder(x)
116
+ # reverse the encoder outputs to be aligned with the decoder
117
+ encoders_features.insert(0, x)
118
+
119
+ # remove the last encoder's output from the list
120
+ # !!remember: it's the 1st in the list
121
+ bottleneck_feat = encoders_features[0]
122
+ if self.encoder_only:
123
+ return bottleneck_feat
124
+ else:
125
+ encoders_features = encoders_features[1:]
126
+
127
+ # decoder part
128
+ for decoder, encoder_features in zip(self.decoders, encoders_features):
129
+ # pass the output from the corresponding encoder and the output
130
+ # of the previous decoder
131
+ x = decoder(encoder_features, x)
132
+
133
+ x = self.final_conv(x)
134
+ # During training the network outputs logits
135
+ if self.final_activation is not None:
136
+ x = self.final_activation(x)
137
+
138
+ if return_bottleneck_feat:
139
+ return x, bottleneck_feat
140
+ else:
141
+ return x
142
+
143
+ class ResidualUNet3D(AbstractUNet):
144
+ """
145
+ Residual 3DUnet model implementation based on https://arxiv.org/pdf/1706.00120.pdf.
146
+ Uses ResNetBlock as a basic building block, summation joining instead
147
+ of concatenation joining and transposed convolutions for upsampling (watch out for block artifacts).
148
+ Since the model effectively becomes a residual net, in theory it allows for deeper UNet.
149
+ """
150
+
151
+ def __init__(self, in_channels, out_channels, final_sigmoid=True, f_maps=(8, 16, 64, 256, 1024), layer_order='gcr',
152
+ num_groups=8, num_levels=5, is_segmentation=True, conv_padding=1,
153
+ conv_upscale=2, upsample='default', dropout_prob=0.1, encoder_only=False, **kwargs):
154
+ super(ResidualUNet3D, self).__init__(in_channels=in_channels,
155
+ out_channels=out_channels,
156
+ final_sigmoid=final_sigmoid,
157
+ basic_module=ResNetBlock,
158
+ f_maps=f_maps,
159
+ layer_order=layer_order,
160
+ num_groups=num_groups,
161
+ num_levels=num_levels,
162
+ is_segmentation=is_segmentation,
163
+ conv_padding=conv_padding,
164
+ conv_upscale=conv_upscale,
165
+ upsample=upsample,
166
+ dropout_prob=dropout_prob,
167
+ encoder_only=encoder_only,
168
+ is3d=True)
169
+
170
+
modules/PartField/partfield/model/model_utils.py ADDED
@@ -0,0 +1,54 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class VanillaMLP(nn.Module):
5
+ def __init__(self, input_dim, output_dim, out_activation, n_hidden_layers=4, n_neurons=64, activation="ReLU"):
6
+ super().__init__()
7
+ self.n_neurons = n_neurons
8
+ self.n_hidden_layers = n_hidden_layers
9
+ self.activation = activation
10
+ self.out_activation = out_activation
11
+ layers = [
12
+ self.make_linear(input_dim, self.n_neurons, is_first=True, is_last=False),
13
+ self.make_activation(),
14
+ ]
15
+ for i in range(self.n_hidden_layers - 1):
16
+ layers += [
17
+ self.make_linear(
18
+ self.n_neurons, self.n_neurons, is_first=False, is_last=False
19
+ ),
20
+ self.make_activation(),
21
+ ]
22
+ layers += [
23
+ self.make_linear(self.n_neurons, output_dim, is_first=False, is_last=True)
24
+ ]
25
+ if self.out_activation == "sigmoid":
26
+ layers += [nn.Sigmoid()]
27
+ elif self.out_activation == "tanh":
28
+ layers += [nn.Tanh()]
29
+ elif self.out_activation == "hardtanh":
30
+ layers += [nn.Hardtanh()]
31
+ elif self.out_activation == "GELU":
32
+ layers += [nn.GELU()]
33
+ elif self.out_activation == "RELU":
34
+ layers += [nn.ReLU()]
35
+ else:
36
+ raise NotImplementedError
37
+ self.layers = nn.Sequential(*layers)
38
+
39
+ def forward(self, x, split_size=100000):
40
+ with torch.cuda.amp.autocast(enabled=False):
41
+ out = self.layers(x)
42
+ return out
43
+
44
+ def make_linear(self, dim_in, dim_out, is_first, is_last):
45
+ layer = nn.Linear(dim_in, dim_out, bias=False)
46
+ return layer
47
+
48
+ def make_activation(self):
49
+ if self.activation == "ReLU":
50
+ return nn.ReLU(inplace=True)
51
+ elif self.activation == "GELU":
52
+ return nn.GELU()
53
+ else:
54
+ raise NotImplementedError
modules/PartField/partfield/model/triplane.py ADDED
@@ -0,0 +1,331 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #https://github.com/3DTopia/OpenLRM/blob/main/openlrm/models/modeling_lrm.py
2
+ # Copyright (c) 2023-2024, Zexin He
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # https://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ import torch.nn as nn
18
+ from functools import partial
19
+
20
+ def project_onto_planes(planes, coordinates):
21
+ """
22
+ Does a projection of a 3D point onto a batch of 2D planes,
23
+ returning 2D plane coordinates.
24
+
25
+ Takes plane axes of shape n_planes, 3, 3
26
+ # Takes coordinates of shape N, M, 3
27
+ # returns projections of shape N*n_planes, M, 2
28
+ """
29
+ N, M, C = coordinates.shape
30
+ n_planes, _, _ = planes.shape
31
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
32
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
33
+ projections = torch.bmm(coordinates, inv_planes)
34
+ return projections[..., :2]
35
+
36
+ def sample_from_planes(plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
37
+ plane_axes = torch.tensor([[[1, 0, 0],
38
+ [0, 1, 0],
39
+ [0, 0, 1]],
40
+ [[1, 0, 0],
41
+ [0, 0, 1],
42
+ [0, 1, 0]],
43
+ [[0, 0, 1],
44
+ [0, 1, 0],
45
+ [1, 0, 0]]], dtype=torch.float32).cuda()
46
+
47
+ assert padding_mode == 'zeros'
48
+ N, n_planes, C, H, W = plane_features.shape
49
+ _, M, _ = coordinates.shape
50
+ plane_features = plane_features.view(N*n_planes, C, H, W)
51
+
52
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
53
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
54
+ return output_features
55
+
56
+ def get_grid_coord(grid_size = 256, align_corners=False):
57
+ if align_corners == False:
58
+ coords = torch.linspace(-1 + 1/(grid_size), 1 - 1/(grid_size), steps=grid_size)
59
+ else:
60
+ coords = torch.linspace(-1, 1, steps=grid_size)
61
+ i, j, k = torch.meshgrid(coords, coords, coords, indexing='ij')
62
+ coordinates = torch.stack((i, j, k), dim=-1).reshape(-1, 3)
63
+ return coordinates
64
+
65
+ class BasicBlock(nn.Module):
66
+ """
67
+ Transformer block that is in its simplest form.
68
+ Designed for PF-LRM architecture.
69
+ """
70
+ # Block contains a self-attention layer and an MLP
71
+ def __init__(self, inner_dim: int, num_heads: int, eps: float,
72
+ attn_drop: float = 0., attn_bias: bool = False,
73
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
74
+ super().__init__()
75
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
76
+ self.self_attn = nn.MultiheadAttention(
77
+ embed_dim=inner_dim, num_heads=num_heads,
78
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
79
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
80
+ self.mlp = nn.Sequential(
81
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
82
+ nn.GELU(),
83
+ nn.Dropout(mlp_drop),
84
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
85
+ nn.Dropout(mlp_drop),
86
+ )
87
+
88
+ def forward(self, x):
89
+ # x: [N, L, D]
90
+ before_sa = self.norm1(x)
91
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
92
+ x = x + self.mlp(self.norm2(x))
93
+ return x
94
+
95
+ class ConditionBlock(nn.Module):
96
+ """
97
+ Transformer block that takes in a cross-attention condition.
98
+ Designed for SparseLRM architecture.
99
+ """
100
+ # Block contains a cross-attention layer, a self-attention layer, and an MLP
101
+ def __init__(self, inner_dim: int, cond_dim: int, num_heads: int, eps: float,
102
+ attn_drop: float = 0., attn_bias: bool = False,
103
+ mlp_ratio: float = 4., mlp_drop: float = 0.):
104
+ super().__init__()
105
+ self.norm1 = nn.LayerNorm(inner_dim, eps=eps)
106
+ self.cross_attn = nn.MultiheadAttention(
107
+ embed_dim=inner_dim, num_heads=num_heads, kdim=cond_dim, vdim=cond_dim,
108
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
109
+ self.norm2 = nn.LayerNorm(inner_dim, eps=eps)
110
+ self.self_attn = nn.MultiheadAttention(
111
+ embed_dim=inner_dim, num_heads=num_heads,
112
+ dropout=attn_drop, bias=attn_bias, batch_first=True)
113
+ self.norm3 = nn.LayerNorm(inner_dim, eps=eps)
114
+ self.mlp = nn.Sequential(
115
+ nn.Linear(inner_dim, int(inner_dim * mlp_ratio)),
116
+ nn.GELU(),
117
+ nn.Dropout(mlp_drop),
118
+ nn.Linear(int(inner_dim * mlp_ratio), inner_dim),
119
+ nn.Dropout(mlp_drop),
120
+ )
121
+
122
+ def forward(self, x, cond):
123
+ # x: [N, L, D]
124
+ # cond: [N, L_cond, D_cond]
125
+ x = x + self.cross_attn(self.norm1(x), cond, cond, need_weights=False)[0]
126
+ before_sa = self.norm2(x)
127
+ x = x + self.self_attn(before_sa, before_sa, before_sa, need_weights=False)[0]
128
+ x = x + self.mlp(self.norm3(x))
129
+ return x
130
+
131
+ class TransformerDecoder(nn.Module):
132
+ def __init__(self, block_type: str,
133
+ num_layers: int, num_heads: int,
134
+ inner_dim: int, cond_dim: int = None,
135
+ eps: float = 1e-6):
136
+ super().__init__()
137
+ self.block_type = block_type
138
+ self.layers = nn.ModuleList([
139
+ self._block_fn(inner_dim, cond_dim)(
140
+ num_heads=num_heads,
141
+ eps=eps,
142
+ )
143
+ for _ in range(num_layers)
144
+ ])
145
+ self.norm = nn.LayerNorm(inner_dim, eps=eps)
146
+
147
+ @property
148
+ def block_type(self):
149
+ return self._block_type
150
+
151
+ @block_type.setter
152
+ def block_type(self, block_type):
153
+ assert block_type in ['cond', 'basic'], \
154
+ f"Unsupported block type: {block_type}"
155
+ self._block_type = block_type
156
+
157
+ def _block_fn(self, inner_dim, cond_dim):
158
+ assert inner_dim is not None, f"inner_dim must always be specified"
159
+ if self.block_type == 'basic':
160
+ return partial(BasicBlock, inner_dim=inner_dim)
161
+ elif self.block_type == 'cond':
162
+ assert cond_dim is not None, f"Condition dimension must be specified for ConditionBlock"
163
+ return partial(ConditionBlock, inner_dim=inner_dim, cond_dim=cond_dim)
164
+ else:
165
+ raise ValueError(f"Unsupported block type during runtime: {self.block_type}")
166
+
167
+
168
+ def forward_layer(self, layer: nn.Module, x: torch.Tensor, cond: torch.Tensor,):
169
+ if self.block_type == 'basic':
170
+ return layer(x)
171
+ elif self.block_type == 'cond':
172
+ return layer(x, cond)
173
+ else:
174
+ raise NotImplementedError
175
+
176
+ def forward(self, x: torch.Tensor, cond: torch.Tensor = None):
177
+ # x: [N, L, D]
178
+ # cond: [N, L_cond, D_cond] or None
179
+ for layer in self.layers:
180
+ x = self.forward_layer(layer, x, cond)
181
+ x = self.norm(x)
182
+ return x
183
+
184
+ class Voxel2Triplane(nn.Module):
185
+ """
186
+ Full model of the basic single-view large reconstruction model.
187
+ """
188
+ def __init__(self, transformer_dim: int, transformer_layers: int, transformer_heads: int,
189
+ triplane_low_res: int, triplane_high_res: int, triplane_dim: int, voxel_feat_dim: int, normalize_vox_feat=False, voxel_dim=16):
190
+ super().__init__()
191
+
192
+ # attributes
193
+ self.triplane_low_res = triplane_low_res
194
+ self.triplane_high_res = triplane_high_res
195
+ self.triplane_dim = triplane_dim
196
+ self.voxel_feat_dim = voxel_feat_dim
197
+
198
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
199
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5)
200
+ self.transformer = TransformerDecoder(
201
+ block_type='cond',
202
+ num_layers=transformer_layers, num_heads=transformer_heads,
203
+ inner_dim=transformer_dim, cond_dim=voxel_feat_dim
204
+ )
205
+ self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=8, stride=8, padding=0)
206
+
207
+ self.normalize_vox_feat = normalize_vox_feat
208
+ if normalize_vox_feat:
209
+ self.vox_norm = nn.LayerNorm(voxel_feat_dim, eps=1e-6)
210
+ self.vox_pos_embed = nn.Parameter(torch.randn(1, voxel_dim * voxel_dim * voxel_dim, voxel_feat_dim) * (1. / voxel_feat_dim) ** 0.5)
211
+
212
+ def forward_transformer(self, voxel_feats):
213
+ N = voxel_feats.shape[0]
214
+ x = self.pos_embed.repeat(N, 1, 1) # [N, L, D]
215
+ if self.normalize_vox_feat:
216
+ vox_pos_embed = self.vox_pos_embed.repeat(N, 1, 1) # [N, L, D]
217
+ voxel_feats = self.vox_norm(voxel_feats + vox_pos_embed)
218
+ x = self.transformer(
219
+ x,
220
+ cond=voxel_feats
221
+ )
222
+ return x
223
+
224
+ def reshape_upsample(self, tokens):
225
+ N = tokens.shape[0]
226
+ H = W = self.triplane_low_res
227
+ x = tokens.view(N, 3, H, W, -1)
228
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
229
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
230
+ x = self.upsampler(x) # [3*N, D', H', W']
231
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
232
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
233
+ x = x.contiguous()
234
+ return x
235
+
236
+ def forward(self, voxel_feats):
237
+ N = voxel_feats.shape[0]
238
+
239
+ # encode image
240
+ assert voxel_feats.shape[-1] == self.voxel_feat_dim, \
241
+ f"Feature dimension mismatch: {voxel_feats.shape[-1]} vs {self.voxel_feat_dim}"
242
+
243
+ # transformer generating planes
244
+ tokens = self.forward_transformer(voxel_feats)
245
+ planes = self.reshape_upsample(tokens)
246
+ assert planes.shape[0] == N, "Batch size mismatch for planes"
247
+ assert planes.shape[1] == 3, "Planes should have 3 channels"
248
+
249
+ return planes
250
+
251
+
252
+ class TriplaneTransformer(nn.Module):
253
+ """
254
+ Full model of the basic single-view large reconstruction model.
255
+ """
256
+ def __init__(self, input_dim: int, transformer_dim: int, transformer_layers: int, transformer_heads: int,
257
+ triplane_low_res: int, triplane_high_res: int, triplane_dim: int):
258
+ super().__init__()
259
+
260
+ # attributes
261
+ self.triplane_low_res = triplane_low_res
262
+ self.triplane_high_res = triplane_high_res
263
+ self.triplane_dim = triplane_dim
264
+
265
+ # initialize pos_embed with 1/sqrt(dim) * N(0, 1)
266
+ self.pos_embed = nn.Parameter(torch.randn(1, 3*triplane_low_res**2, transformer_dim) * (1. / transformer_dim) ** 0.5)
267
+ self.transformer = TransformerDecoder(
268
+ block_type='basic',
269
+ num_layers=transformer_layers, num_heads=transformer_heads,
270
+ inner_dim=transformer_dim,
271
+ )
272
+
273
+ self.downsampler = nn.Sequential(
274
+ nn.Conv2d(input_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
275
+ nn.ReLU(),
276
+ nn.MaxPool2d(kernel_size=2, stride=2), # Reduces size from 128x128 to 64x64
277
+
278
+ nn.Conv2d(transformer_dim, transformer_dim, kernel_size=3, stride=1, padding=1),
279
+ nn.ReLU(),
280
+ nn.MaxPool2d(kernel_size=2, stride=2), # Reduces size from 64x64 to 32x32
281
+ )
282
+
283
+ self.upsampler = nn.ConvTranspose2d(transformer_dim, triplane_dim, kernel_size=4, stride=4, padding=0)
284
+
285
+ self.mlp = nn.Sequential(
286
+ nn.Linear(input_dim, triplane_dim),
287
+ nn.ReLU(),
288
+ nn.Linear(triplane_dim, triplane_dim)
289
+ )
290
+
291
+ def forward_transformer(self, triplanes):
292
+ N = triplanes.shape[0]
293
+ tokens = torch.einsum('nidhw->nihwd', triplanes).reshape(N, self.pos_embed.shape[1], -1) # [N, L, D]
294
+ x = self.pos_embed.repeat(N, 1, 1) + tokens # [N, L, D]
295
+ x = self.transformer(x)
296
+ return x
297
+
298
+ def reshape_downsample(self, triplanes):
299
+ N = triplanes.shape[0]
300
+ H = W = self.triplane_high_res
301
+ x = triplanes.view(N, 3, -1, H, W)
302
+ x = torch.einsum('nidhw->indhw', x) # [3, N, D, H, W]
303
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
304
+ x = self.downsampler(x) # [3*N, D', H', W']
305
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
306
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
307
+ x = x.contiguous()
308
+ return x
309
+
310
+ def reshape_upsample(self, tokens):
311
+ N = tokens.shape[0]
312
+ H = W = self.triplane_low_res
313
+ x = tokens.view(N, 3, H, W, -1)
314
+ x = torch.einsum('nihwd->indhw', x) # [3, N, D, H, W]
315
+ x = x.contiguous().view(3*N, -1, H, W) # [3*N, D, H, W]
316
+ x = self.upsampler(x) # [3*N, D', H', W']
317
+ x = x.view(3, N, *x.shape[-3:]) # [3, N, D', H', W']
318
+ x = torch.einsum('indhw->nidhw', x) # [N, 3, D', H', W']
319
+ x = x.contiguous()
320
+ return x
321
+
322
+ def forward(self, triplanes):
323
+ downsampled_triplanes = self.reshape_downsample(triplanes)
324
+ tokens = self.forward_transformer(downsampled_triplanes)
325
+ residual = self.reshape_upsample(tokens)
326
+
327
+ triplanes = triplanes.permute(0, 1, 3, 4, 2).contiguous()
328
+ triplanes = self.mlp(triplanes)
329
+ triplanes = triplanes.permute(0, 1, 4, 2, 3).contiguous()
330
+ planes = triplanes + residual
331
+ return planes
modules/PartField/partfield/model_trainer_pvcnn_only_demo.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lightning.pytorch as pl
3
+ from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset
4
+ from torch.utils.data import DataLoader
5
+ from partfield.model.UNet.model import ResidualUNet3D
6
+ from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
7
+ from partfield.model.model_utils import VanillaMLP
8
+ import torch.nn.functional as F
9
+ import torch.nn as nn
10
+ import os
11
+ import trimesh
12
+ import skimage
13
+ import numpy as np
14
+ import h5py
15
+ import torch.distributed as dist
16
+ from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
17
+ import json
18
+ import gc
19
+ import time
20
+ from plyfile import PlyData, PlyElement
21
+
22
+
23
+ class Model(pl.LightningModule):
24
+ def __init__(self, cfg):
25
+ super().__init__()
26
+
27
+ self.save_hyperparameters()
28
+ self.cfg = cfg
29
+ self.automatic_optimization = False
30
+ self.triplane_resolution = cfg.triplane_resolution
31
+ self.triplane_channels_low = cfg.triplane_channels_low
32
+ self.triplane_transformer = TriplaneTransformer(
33
+ input_dim=cfg.triplane_channels_low * 2,
34
+ transformer_dim=1024,
35
+ transformer_layers=6,
36
+ transformer_heads=8,
37
+ triplane_low_res=32,
38
+ triplane_high_res=128,
39
+ triplane_dim=cfg.triplane_channels_high,
40
+ )
41
+ self.sdf_decoder = VanillaMLP(input_dim=64,
42
+ output_dim=1,
43
+ out_activation="tanh",
44
+ n_neurons=64, #64
45
+ n_hidden_layers=6) #6
46
+ self.use_pvcnn = cfg.use_pvcnnonly
47
+ self.use_2d_feat = cfg.use_2d_feat
48
+ if self.use_pvcnn:
49
+ self.pvcnn = TriPlanePC2Encoder(
50
+ cfg.pvcnn,
51
+ device="cuda",
52
+ shape_min=-1,
53
+ shape_length=2,
54
+ use_2d_feat=self.use_2d_feat) #.cuda()
55
+ self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
56
+ self.grid_coord = get_grid_coord(256)
57
+ self.mse_loss = torch.nn.MSELoss()
58
+ self.l1_loss = torch.nn.L1Loss(reduction='none')
59
+
60
+ if cfg.regress_2d_feat:
61
+ self.feat_decoder = VanillaMLP(input_dim=64,
62
+ output_dim=192,
63
+ out_activation="GELU",
64
+ n_neurons=64, #64
65
+ n_hidden_layers=6) #6
66
+
67
+ def predict_dataloader(self):
68
+ if self.cfg.remesh_demo:
69
+ dataset = Demo_Remesh_Dataset(self.cfg)
70
+ elif self.cfg.correspondence_demo:
71
+ dataset = Correspondence_Demo_Dataset(self.cfg)
72
+ else:
73
+ dataset = Demo_Dataset(self.cfg)
74
+
75
+ dataloader = DataLoader(dataset,
76
+ num_workers=self.cfg.dataset.val_num_workers,
77
+ batch_size=self.cfg.dataset.val_batch_size,
78
+ shuffle=False,
79
+ pin_memory=True,
80
+ drop_last=False)
81
+
82
+ return dataloader
83
+
84
+
85
+ @torch.no_grad()
86
+ def predict_step(self, batch, batch_idx):
87
+ save_dir = f"{self.cfg.result_name}"
88
+ os.makedirs(save_dir, exist_ok=True)
89
+
90
+ uid = batch['uid'][0]
91
+ view_id = 0
92
+ starttime = time.time()
93
+
94
+ if uid == "car" or uid == "complex_car":
95
+ # if uid == "complex_car":
96
+ print("Skipping this for now.")
97
+ print(uid)
98
+ return
99
+
100
+ ### Skip if model already processed
101
+ if os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}.npy') or os.path.exists(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy'):
102
+ print("Already processed "+uid)
103
+ return
104
+
105
+ N = batch['pc'].shape[0]
106
+ assert N == 1
107
+
108
+ if self.use_2d_feat:
109
+ print("ERROR. Dataloader not implemented with input 2d feat.")
110
+ exit()
111
+ else:
112
+ pc_feat = self.pvcnn(batch['pc'], batch['pc'])
113
+
114
+ planes = pc_feat
115
+ planes = self.triplane_transformer(planes)
116
+ sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
117
+
118
+ if self.cfg.is_pc:
119
+ tensor_vertices = batch['pc'].reshape(1, -1, 3).cuda().to(torch.float16)
120
+ point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
121
+ point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
122
+
123
+ np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat)
124
+ print(f"Exported part_feat_{uid}_{view_id}.npy")
125
+
126
+ ###########
127
+ from sklearn.decomposition import PCA
128
+ data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
129
+
130
+ pca = PCA(n_components=3)
131
+
132
+ data_reduced = pca.fit_transform(data_scaled)
133
+ data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min())
134
+ colors_255 = (data_reduced * 255).astype(np.uint8)
135
+
136
+ points = batch['pc'].squeeze().detach().cpu().numpy()
137
+
138
+ if colors_255 is None:
139
+ colors_255 = np.full_like(points, 255) # Default to white color (255,255,255)
140
+ else:
141
+ assert colors_255.shape == points.shape, "Colors must have the same shape as points"
142
+
143
+ # Convert to structured array for PLY format
144
+ vertex_data = np.array(
145
+ [(*point, *color) for point, color in zip(points, colors_255)],
146
+ dtype=[("x", "f4"), ("y", "f4"), ("z", "f4"), ("red", "u1"), ("green", "u1"), ("blue", "u1")]
147
+ )
148
+
149
+ # Create PLY element
150
+ el = PlyElement.describe(vertex_data, "vertex")
151
+ # Write to file
152
+ filename = f'{save_dir}/feat_pca_{uid}_{view_id}.ply'
153
+ PlyData([el], text=True).write(filename)
154
+ print(f"Saved PLY file: {filename}")
155
+ ############
156
+
157
+ else:
158
+ use_cuda_version = True
159
+ if use_cuda_version:
160
+
161
+ def sample_points(vertices, faces, n_point_per_face):
162
+ # Generate random barycentric coordinates
163
+ # borrowed from Kaolin https://github.com/NVIDIAGameWorks/kaolin/blob/master/kaolin/ops/mesh/trianglemesh.py#L43
164
+ n_f = faces.shape[0]
165
+ u = torch.sqrt(torch.rand((n_f, n_point_per_face, 1),
166
+ device=vertices.device,
167
+ dtype=vertices.dtype))
168
+ v = torch.rand((n_f, n_point_per_face, 1),
169
+ device=vertices.device,
170
+ dtype=vertices.dtype)
171
+ w0 = 1 - u
172
+ w1 = u * (1 - v)
173
+ w2 = u * v
174
+
175
+ face_v_0 = torch.index_select(vertices, 0, faces[:, 0].reshape(-1))
176
+ face_v_1 = torch.index_select(vertices, 0, faces[:, 1].reshape(-1))
177
+ face_v_2 = torch.index_select(vertices, 0, faces[:, 2].reshape(-1))
178
+ points = w0 * face_v_0.unsqueeze(dim=1) + w1 * face_v_1.unsqueeze(dim=1) + w2 * face_v_2.unsqueeze(dim=1)
179
+ return points
180
+
181
+ def sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face):
182
+ n_sample_each = self.cfg.n_sample_each # we iterate over this to avoid OOM
183
+ n_v = tensor_vertices.shape[1]
184
+ n_sample = n_v // n_sample_each + 1
185
+ all_sample = []
186
+ for i_sample in range(n_sample):
187
+ sampled_feature = sample_triplane_feat(part_planes, tensor_vertices[:, i_sample * n_sample_each: i_sample * n_sample_each + n_sample_each,])
188
+ assert sampled_feature.shape[1] % n_point_per_face == 0
189
+ sampled_feature = sampled_feature.reshape(1, -1, n_point_per_face, sampled_feature.shape[-1])
190
+ sampled_feature = torch.mean(sampled_feature, axis=-2)
191
+ all_sample.append(sampled_feature)
192
+ return torch.cat(all_sample, dim=1)
193
+
194
+ if self.cfg.vertex_feature:
195
+ tensor_vertices = batch['vertices'][0].reshape(1, -1, 3).to(torch.float32)
196
+ point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, 1)
197
+ else:
198
+ n_point_per_face = self.cfg.n_point_per_face
199
+ tensor_vertices = sample_points(batch['vertices'][0], batch['faces'][0], n_point_per_face)
200
+ tensor_vertices = tensor_vertices.reshape(1, -1, 3).to(torch.float32)
201
+ point_feat = sample_and_mean_memory_save_version(part_planes, tensor_vertices, n_point_per_face) # N, M, C
202
+
203
+ #### Take mean feature in the triangle
204
+ print("Time elapsed for feature prediction: " + str(time.time() - starttime))
205
+ point_feat = point_feat.reshape(-1, 448).cpu().numpy()
206
+ np.save(f'{save_dir}/part_feat_{uid}_{view_id}_batch.npy', point_feat)
207
+ print(f"Exported part_feat_{uid}_{view_id}.npy")
208
+
209
+ ###########
210
+ from sklearn.decomposition import PCA
211
+ data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
212
+
213
+ pca = PCA(n_components=3)
214
+
215
+ data_reduced = pca.fit_transform(data_scaled)
216
+ data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min())
217
+ colors_255 = (data_reduced * 255).astype(np.uint8)
218
+ V = batch['vertices'][0].cpu().numpy()
219
+ F = batch['faces'][0].cpu().numpy()
220
+ if self.cfg.vertex_feature:
221
+ colored_mesh = trimesh.Trimesh(vertices=V, faces=F, vertex_colors=colors_255, process=False)
222
+ else:
223
+ colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False)
224
+ colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply')
225
+ ############
226
+ torch.cuda.empty_cache()
227
+
228
+ else:
229
+ ### Mesh input (obj file)
230
+ V = batch['vertices'][0].cpu().numpy()
231
+ F = batch['faces'][0].cpu().numpy()
232
+
233
+ ##### Loop through faces #####
234
+ num_samples_per_face = self.cfg.n_point_per_face
235
+
236
+ all_point_feats = []
237
+ for face in F:
238
+ # Get the vertices of the current face
239
+ v0, v1, v2 = V[face]
240
+
241
+ # Generate random barycentric coordinates
242
+ u = np.random.rand(num_samples_per_face, 1)
243
+ v = np.random.rand(num_samples_per_face, 1)
244
+ is_prob = (u+v) >1
245
+ u[is_prob] = 1 - u[is_prob]
246
+ v[is_prob] = 1 - v[is_prob]
247
+ w = 1 - u - v
248
+
249
+ # Calculate points in Cartesian coordinates
250
+ points = u * v0 + v * v1 + w * v2
251
+
252
+ tensor_vertices = torch.from_numpy(points.copy()).reshape(1, -1, 3).cuda().to(torch.float32)
253
+ point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
254
+
255
+ #### Take mean feature in the triangle
256
+ point_feat = torch.mean(point_feat, axis=1).cpu().detach().numpy()
257
+ all_point_feats.append(point_feat)
258
+ ##############################
259
+
260
+ all_point_feats = np.array(all_point_feats).reshape(-1, 448)
261
+
262
+ point_feat = all_point_feats
263
+
264
+ np.save(f'{save_dir}/part_feat_{uid}_{view_id}.npy', point_feat)
265
+ print(f"Exported part_feat_{uid}_{view_id}.npy")
266
+
267
+ ###########
268
+ from sklearn.decomposition import PCA
269
+ data_scaled = point_feat / np.linalg.norm(point_feat, axis=-1, keepdims=True)
270
+
271
+ pca = PCA(n_components=3)
272
+
273
+ data_reduced = pca.fit_transform(data_scaled)
274
+ data_reduced = (data_reduced - data_reduced.min()) / (data_reduced.max() - data_reduced.min())
275
+ colors_255 = (data_reduced * 255).astype(np.uint8)
276
+
277
+ colored_mesh = trimesh.Trimesh(vertices=V, faces=F, face_colors=colors_255, process=False)
278
+ colored_mesh.export(f'{save_dir}/feat_pca_{uid}_{view_id}.ply')
279
+ ############
280
+
281
+ print("Time elapsed: " + str(time.time()-starttime))
282
+
283
+ return
modules/PartField/partfield/partfield_encoder.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import lightning.pytorch as pl
3
+ # from .dataloader import Demo_Dataset, Demo_Remesh_Dataset, Correspondence_Demo_Dataset
4
+ from torch.utils.data import DataLoader
5
+ from partfield.model.UNet.model import ResidualUNet3D
6
+ from partfield.model.triplane import TriplaneTransformer, get_grid_coord #, sample_from_planes, Voxel2Triplane
7
+ from partfield.model.model_utils import VanillaMLP
8
+ import torch.nn.functional as F
9
+ import torch.nn as nn
10
+ import os
11
+ import trimesh
12
+ import skimage
13
+ import numpy as np
14
+ import h5py
15
+ import torch.distributed as dist
16
+ from partfield.model.PVCNN.encoder_pc import TriPlanePC2Encoder, sample_triplane_feat
17
+ import json
18
+ import gc
19
+ import time
20
+ from plyfile import PlyData, PlyElement
21
+
22
+
23
+ class Model(pl.LightningModule):
24
+ def __init__(self, cfg):
25
+ super().__init__()
26
+
27
+ self.save_hyperparameters()
28
+ self.cfg = cfg
29
+ self.automatic_optimization = False
30
+ self.triplane_resolution = cfg.triplane_resolution
31
+ self.triplane_channels_low = cfg.triplane_channels_low
32
+ self.triplane_transformer = TriplaneTransformer(
33
+ input_dim=cfg.triplane_channels_low * 2,
34
+ transformer_dim=1024,
35
+ transformer_layers=6,
36
+ transformer_heads=8,
37
+ triplane_low_res=32,
38
+ triplane_high_res=128,
39
+ triplane_dim=cfg.triplane_channels_high,
40
+ )
41
+ self.sdf_decoder = VanillaMLP(input_dim=64,
42
+ output_dim=1,
43
+ out_activation="tanh",
44
+ n_neurons=64, #64
45
+ n_hidden_layers=6) #6
46
+ self.use_pvcnn = cfg.use_pvcnnonly
47
+ self.use_2d_feat = cfg.use_2d_feat
48
+ if self.use_pvcnn:
49
+ self.pvcnn = TriPlanePC2Encoder(
50
+ cfg.pvcnn,
51
+ device="cuda",
52
+ shape_min=-1,
53
+ shape_length=2,
54
+ use_2d_feat=self.use_2d_feat) #.cuda()
55
+ self.logit_scale = nn.Parameter(torch.tensor([1.0], requires_grad=True))
56
+ self.grid_coord = get_grid_coord(256)
57
+ self.mse_loss = torch.nn.MSELoss()
58
+ self.l1_loss = torch.nn.L1Loss(reduction='none')
59
+
60
+ if cfg.regress_2d_feat:
61
+ self.feat_decoder = VanillaMLP(input_dim=64,
62
+ output_dim=192,
63
+ out_activation="GELU",
64
+ n_neurons=64, #64
65
+ n_hidden_layers=6) #6
66
+
67
+ # def predict_dataloader(self):
68
+ # if self.cfg.remesh_demo:
69
+ # dataset = Demo_Remesh_Dataset(self.cfg)
70
+ # elif self.cfg.correspondence_demo:
71
+ # dataset = Correspondence_Demo_Dataset(self.cfg)
72
+ # else:
73
+ # dataset = Demo_Dataset(self.cfg)
74
+
75
+ # dataloader = DataLoader(dataset,
76
+ # num_workers=self.cfg.dataset.val_num_workers,
77
+ # batch_size=self.cfg.dataset.val_batch_size,
78
+ # shuffle=False,
79
+ # pin_memory=True,
80
+ # drop_last=False)
81
+
82
+ # return dataloader
83
+
84
+
85
+ @torch.no_grad()
86
+ def encode(self, points):
87
+
88
+ N = points.shape[0]
89
+ # assert N == 1
90
+ pcd = points[..., :3]
91
+
92
+ pc_feat = self.pvcnn(pcd, pcd)
93
+
94
+ planes = pc_feat
95
+ planes = self.triplane_transformer(planes)
96
+ sdf_planes, part_planes = torch.split(planes, [64, planes.shape[2] - 64], dim=2)
97
+
98
+ tensor_vertices = pcd.reshape(N, -1, 3).cuda().to(pcd.dtype)
99
+ point_feat = sample_triplane_feat(part_planes, tensor_vertices) # N, M, C
100
+ # point_feat = point_feat.cpu().detach().numpy().reshape(-1, 448)
101
+ point_feat = point_feat.reshape(N, -1, 448)
102
+
103
+ return point_feat
modules/PartField/partfield/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import trimesh
2
+
3
+ def load_mesh_util(input_fname):
4
+ mesh = trimesh.load(input_fname, force='mesh', process=False)
5
+ return mesh
modules/bbox_gen/config.py ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from omegaconf import OmegaConf, DictConfig
3
+ from dataclasses import dataclass, field
4
+ from typing import Any, Dict, List, Optional, Union
5
+ from datetime import datetime
6
+
7
+ @dataclass
8
+ class ExperimentConfig:
9
+ name: str = "default"
10
+ tag: str = ""
11
+ use_timestamp: bool = False
12
+ timestamp: Optional[str] = None
13
+ exp_root_dir: str = "outputs"
14
+
15
+ ### these shouldn't be set manually
16
+ exp_dir: str = "outputs/default"
17
+ trial_name: str = "exp"
18
+ trial_dir: str = "outputs/default/exp"
19
+ ###
20
+
21
+ resume: Optional[str] = None
22
+ ckpt_path: Optional[str] = None
23
+
24
+ data: dict = field(default_factory=dict)
25
+ model_pl: dict = field(default_factory=dict)
26
+
27
+ trainer: dict = field(default_factory=dict)
28
+ checkpoint: dict = field(default_factory=dict)
29
+ checkpoint_epoch: Optional[dict] = None
30
+ wandb: dict = field(default_factory=dict)
31
+
32
+
33
+ def load_config(*yamls: str, cli_args: list = [], from_string=False, **kwargs) -> Any:
34
+ if from_string:
35
+ yaml_confs = [OmegaConf.create(s) for s in yamls]
36
+ else:
37
+ yaml_confs = [OmegaConf.load(f) for f in yamls]
38
+ cli_conf = OmegaConf.from_cli(cli_args)
39
+ cfg = OmegaConf.merge(*yaml_confs, cli_conf, kwargs)
40
+ OmegaConf.resolve(cfg)
41
+ assert isinstance(cfg, DictConfig)
42
+ scfg = parse_structured(ExperimentConfig, cfg)
43
+ return scfg
44
+
45
+
46
+ def config_to_primitive(config, resolve: bool = True) -> Any:
47
+ return OmegaConf.to_container(config, resolve=resolve)
48
+
49
+
50
+ def dump_config(path: str, config) -> None:
51
+ with open(path, "w") as fp:
52
+ OmegaConf.save(config=config, f=fp)
53
+
54
+
55
+ def parse_structured(fields: Any, cfg: Optional[Union[dict, DictConfig]] = None) -> Any:
56
+ scfg = OmegaConf.structured(fields(**cfg))
57
+ return scfg
modules/bbox_gen/models/autogressive_bbox_gen.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ import os
3
+ import sys
4
+ import torch
5
+ import trimesh
6
+ from torch import nn
7
+ from transformers import AutoModelForCausalLM
8
+ from transformers.generation.logits_process import LogitsProcessorList
9
+ from einops import rearrange
10
+
11
+ from modules.bbox_gen.models.image_encoder import DINOv2ImageEncoder
12
+ from modules.bbox_gen.config import parse_structured
13
+ from modules.bbox_gen.models.bboxopt import BBoxOPT, BBoxOPTConfig
14
+ from modules.bbox_gen.utils.bbox_tokenizer import BoundsTokenizerDiag
15
+ from modules.bbox_gen.models.bbox_gen_models import GroupEmbedding, MultiModalProjector, MeshDecodeLogitsProcessor, SparseStructureEncoder
16
+
17
+ current_dir = os.path.dirname(os.path.abspath(__file__))
18
+ modules_dir = os.path.dirname(os.path.dirname(current_dir))
19
+ partfield_dir = os.path.join(modules_dir, 'PartField')
20
+ if partfield_dir not in sys.path:
21
+ sys.path.insert(0, partfield_dir)
22
+ import importlib.util
23
+ from partfield.config import default_argument_parser, setup
24
+
25
+
26
+ class BboxGen(nn.Module):
27
+
28
+ @dataclass
29
+ class Config:
30
+ # encoder config
31
+ encoder_dim_feat: int = 3
32
+ encoder_dim: int = 64
33
+ encoder_heads: int = 4
34
+ encoder_token_num: int = 256
35
+ encoder_qkv_bias: bool = False
36
+ encoder_use_ln_post: bool = True
37
+ encoder_use_checkpoint: bool = False
38
+ encoder_num_embed_freqs: int = 8
39
+ encoder_embed_include_pi: bool = False
40
+ encoder_init_scale: float = 0.25
41
+ encoder_random_fps: bool = True
42
+ encoder_learnable_query: bool = False
43
+ encoder_layers: int = 4
44
+ group_embedding_dim: int = 64
45
+
46
+ # decoder config
47
+ vocab_size: int = 518
48
+ decoder_hidden_size: int = 1536
49
+ decoder_num_hidden_layers: int = 24
50
+ decoder_ffn_dim: int = 6144
51
+ decoder_heads: int = 16
52
+ decoder_use_flash_attention: bool = True
53
+ decoder_gradient_checkpointing: bool = True
54
+
55
+ # data config
56
+ bins: int = 64
57
+ BOS_id: int = 64
58
+ EOS_id: int = 65
59
+ PAD_id: int = 66
60
+ max_length: int = 2187 # bos + 50x2x3 + 1374 + 512
61
+ voxel_token_length: int = 1886
62
+ voxel_token_placeholder: int = -1
63
+
64
+ # tokenizer config
65
+ max_group_size: int = 50
66
+
67
+ # voxel encoder
68
+ partfield_encoder_path: str = ""
69
+
70
+ cfg: Config
71
+
72
+ def __init__(self, cfg):
73
+ super().__init__()
74
+ self.cfg = parse_structured(self.Config, cfg)
75
+
76
+ self.image_encoder = DINOv2ImageEncoder(
77
+ model_name="facebook/dinov2-with-registers-large",
78
+ )
79
+
80
+ self.image_projector = MultiModalProjector(
81
+ in_features=(1024 + self.cfg.group_embedding_dim),
82
+ out_features=self.cfg.decoder_hidden_size,
83
+ )
84
+
85
+ self.group_embedding = GroupEmbedding(
86
+ max_group_size=self.cfg.max_group_size,
87
+ hidden_size=self.cfg.group_embedding_dim,
88
+ )
89
+
90
+ self.decoder_config = BBoxOPTConfig(
91
+ vocab_size=self.cfg.vocab_size,
92
+ hidden_size=self.cfg.decoder_hidden_size,
93
+ num_hidden_layers=self.cfg.decoder_num_hidden_layers,
94
+ ffn_dim=self.cfg.decoder_ffn_dim,
95
+ max_position_embeddings=self.cfg.max_length,
96
+ num_attention_heads=self.cfg.decoder_heads,
97
+ pad_token_id=self.cfg.PAD_id,
98
+ bos_token_id=self.cfg.BOS_id,
99
+ eos_token_id=self.cfg.EOS_id,
100
+ use_cache=True,
101
+ init_std=0.02,
102
+ )
103
+
104
+ if self.cfg.decoder_use_flash_attention:
105
+ self.decoder: BBoxOPT = AutoModelForCausalLM.from_config(
106
+ self.decoder_config,
107
+ torch_dtype=torch.bfloat16,
108
+ attn_implementation="flash_attention_2"
109
+ )
110
+ else:
111
+ self.decoder: BBoxOPT = AutoModelForCausalLM.from_config(
112
+ self.decoder_config,
113
+ )
114
+ if self.cfg.decoder_gradient_checkpointing:
115
+ self.decoder.gradient_checkpointing_enable()
116
+
117
+ self.logits_processor = LogitsProcessorList()
118
+
119
+ self.logits_processor.append(MeshDecodeLogitsProcessor(
120
+ bins=self.cfg.bins,
121
+ BOS_id=self.cfg.BOS_id,
122
+ EOS_id=self.cfg.EOS_id,
123
+ PAD_id=self.cfg.PAD_id,
124
+ vertices_num=2,
125
+ ))
126
+ self.tokenizer = BoundsTokenizerDiag(
127
+ bins=self.cfg.bins,
128
+ BOS_id=self.cfg.BOS_id,
129
+ EOS_id=self.cfg.EOS_id,
130
+ PAD_id=self.cfg.PAD_id,
131
+ )
132
+
133
+ self._load_partfield_encoder()
134
+
135
+ self.partfield_voxel_encoder = SparseStructureEncoder(
136
+ in_channels=451,
137
+ channels=[448, 448, 448, 1024],
138
+ latent_channels=448,
139
+ num_res_blocks=1,
140
+ num_res_blocks_middle=1,
141
+ norm_type="layer",
142
+ )
143
+
144
+
145
+ def _load_partfield_encoder(self):
146
+ # Load PartField encoder
147
+ model_spec = importlib.util.spec_from_file_location(
148
+ "partfield.partfield_encoder",
149
+ os.path.join(partfield_dir, "partfield", "partfield_encoder.py")
150
+ )
151
+ model_module = importlib.util.module_from_spec(model_spec)
152
+ model_spec.loader.exec_module(model_module)
153
+ Model = model_module.Model
154
+ parser = default_argument_parser()
155
+ args = []
156
+ args.extend(["-c", os.path.join(partfield_dir, "configs/final/demo.yaml")])
157
+ args.append("--opts")
158
+ args.extend(["continue_ckpt", self.cfg.partfield_encoder_path])
159
+ parsed_args = parser.parse_args(args)
160
+ cfg = setup(parsed_args, freeze=False)
161
+ self.partfield_encoder = Model(cfg)
162
+ self.partfield_encoder.eval()
163
+ weights = torch.load(self.cfg.partfield_encoder_path)["state_dict"]
164
+ self.partfield_encoder.load_state_dict(weights)
165
+ for param in self.partfield_encoder.parameters():
166
+ param.requires_grad = False
167
+ print("PartField encoder loaded")
168
+
169
+ def _prepare_lm_inputs(self, voxel_token, input_ids):
170
+ inputs_embeds = torch.zeros(input_ids.shape[0], input_ids.shape[1], self.cfg.decoder_hidden_size, device=input_ids.device, dtype=voxel_token.dtype)
171
+ voxel_token_mask = (input_ids == self.cfg.voxel_token_placeholder)
172
+ inputs_embeds[voxel_token_mask] = voxel_token.view(-1, self.cfg.decoder_hidden_size)
173
+
174
+ inputs_embeds[~voxel_token_mask] = self.decoder.get_input_embeddings()(input_ids[~voxel_token_mask]).to(dtype=inputs_embeds.dtype)
175
+
176
+ attention_mask = (input_ids != self.cfg.PAD_id)
177
+ return inputs_embeds, attention_mask.long()
178
+
179
+ def forward(self, batch):
180
+
181
+ image_latents = self.image_encoder(batch['images'])
182
+ masks = batch['masks']
183
+ masks_emb = self.group_embedding(masks)
184
+ masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') # B x Q x C
185
+ group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype)
186
+ group_emb[:, :masks_emb.shape[1], :] = masks_emb
187
+ image_latents = torch.cat([image_latents, group_emb], dim=-1)
188
+ image_latents = self.image_projector(image_latents)
189
+
190
+ points = batch['points'][..., :3]
191
+ rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype)
192
+ rot_points = torch.matmul(points, rot_matrix)
193
+ rot_points = rot_points * (2 * 0.9) # from (-0.5, 0.5) to (-1, 1)
194
+
195
+ partfield_feat = self.partfield_encoder.encode(rot_points)
196
+ feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype)
197
+ whole_voxel_index = batch['whole_voxel_index'] # (b, m, 3)
198
+
199
+ batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1]
200
+ batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) # (b, m)
201
+ batch_flat = batch_indices.flatten() # (b*m,)
202
+ x_flat = whole_voxel_index[..., 0].flatten() # (b*m,)
203
+ y_flat = whole_voxel_index[..., 1].flatten() # (b*m,)
204
+ z_flat = whole_voxel_index[..., 2].flatten() # (b*m,)
205
+ partfield_feat_flat = partfield_feat.reshape(-1, 448) # (b*m, 448)
206
+ feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat
207
+
208
+ xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype)
209
+ xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3)
210
+ feat_volume = torch.cat([feat_volume, xyz_volume], dim=1)
211
+
212
+ feat_volume = self.partfield_voxel_encoder(feat_volume)
213
+ feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c')
214
+
215
+ voxel_token = torch.cat([image_latents, feat_volume], dim=1) # B x N x D
216
+
217
+ input_ids = batch['input_ids']
218
+ inputs_embeds, attention_mask = self._prepare_lm_inputs(voxel_token, input_ids)
219
+ output = self.decoder(
220
+ attention_mask=attention_mask,
221
+ inputs_embeds=inputs_embeds,
222
+ return_dict=True,
223
+ )
224
+ return {
225
+ "logits": output.logits,
226
+ }
227
+
228
+ def gen_mesh_from_bounds(self, bounds, random_color):
229
+ bboxes = []
230
+ for j in range(bounds.shape[0]):
231
+ bbox = trimesh.primitives.Box(bounds=bounds[j])
232
+ color = random_color[j]
233
+ bbox.visual.vertex_colors = color
234
+ bboxes.append(bbox)
235
+ mesh = trimesh.Scene(bboxes)
236
+ return mesh
237
+
238
+ def generate(self, batch):
239
+
240
+ image_latents = self.image_encoder(batch['images'])
241
+ masks = batch['masks']
242
+ masks_emb = self.group_embedding(masks)
243
+ masks_emb = rearrange(masks_emb, 'b c h w -> b (h w) c') # B x Q x C
244
+ group_emb = torch.zeros((image_latents.shape[0], image_latents.shape[1], masks_emb.shape[2]), device=image_latents.device, dtype=image_latents.dtype)
245
+ group_emb[:, :masks_emb.shape[1], :] = masks_emb
246
+ image_latents = torch.cat([image_latents, group_emb], dim=-1)
247
+ image_latents = self.image_projector(image_latents)
248
+
249
+ points = batch['points'][..., :3]
250
+ rot_matrix = torch.tensor([[1, 0, 0], [0, 0, -1], [0, 1, 0]], device=points.device, dtype=points.dtype)
251
+ rot_points = torch.matmul(points, rot_matrix)
252
+ rot_points = rot_points * (2 * 0.9) # from (-0.5, 0.5) to (-1, 1)
253
+
254
+ partfield_feat = self.partfield_encoder.encode(rot_points)
255
+ feat_volume = torch.zeros((points.shape[0], 448, 64, 64, 64), device=partfield_feat.device, dtype=partfield_feat.dtype)
256
+ whole_voxel_index = batch['whole_voxel_index'] # (b, m, 3)
257
+
258
+ batch_size, num_points = whole_voxel_index.shape[0], whole_voxel_index.shape[1]
259
+ batch_indices = torch.arange(batch_size, device=whole_voxel_index.device).unsqueeze(1).expand(-1, num_points) # (b, m)
260
+ batch_flat = batch_indices.flatten() # (b*m,)
261
+ x_flat = whole_voxel_index[..., 0].flatten() # (b*m,)
262
+ y_flat = whole_voxel_index[..., 1].flatten() # (b*m,)
263
+ z_flat = whole_voxel_index[..., 2].flatten() # (b*m,)
264
+ partfield_feat_flat = partfield_feat.reshape(-1, 448) # (b*m, 448)
265
+ feat_volume[batch_flat, :, x_flat, y_flat, z_flat] = partfield_feat_flat
266
+
267
+ xyz_volume = torch.zeros((points.shape[0], 3, 64, 64, 64), device=points.device, dtype=points.dtype)
268
+ xyz_volume[batch_flat, :, x_flat, y_flat, z_flat] = points.reshape(-1, 3)
269
+ feat_volume = torch.cat([feat_volume, xyz_volume], dim=1)
270
+
271
+ feat_volume = self.partfield_voxel_encoder(feat_volume)
272
+ feat_volume = rearrange(feat_volume, 'b c x y z -> b (x y z) c')
273
+
274
+ voxel_token = torch.cat([image_latents, feat_volume], dim=1) # B x N x D
275
+
276
+ meshes = []
277
+ mesh_names = []
278
+ bboxes = []
279
+
280
+ output = self.decoder.generate(
281
+ inputs_embeds=voxel_token,
282
+ max_new_tokens=self.cfg.max_length - voxel_token.shape[1],
283
+ logits_processor=self.logits_processor,
284
+ do_sample=True,
285
+ top_k=5,
286
+ top_p=0.95,
287
+ temperature=0.5,
288
+ use_cache=True,
289
+ )
290
+
291
+ for i in range(output.shape[0]):
292
+ bounds = self.tokenizer.decode(output[i].detach().cpu().numpy(), coord_rg=(-0.5, 0.5))
293
+ # mesh = self.gen_mesh_from_bounds(bounds, batch['random_color'][i])
294
+ # meshes.append(mesh)
295
+ mesh_names.append("topk=5")
296
+ bboxes.append(bounds)
297
+
298
+ return {
299
+ # 'meshes': meshes,
300
+ 'mesh_names': mesh_names,
301
+ 'bboxes': bboxes,
302
+ }
303
+
304
+
305
+
modules/bbox_gen/models/bbox_gen_models.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ import torch.nn.functional as F
4
+ from diffusers.models.normalization import FP32LayerNorm
5
+ from diffusers.models.attention import FeedForward
6
+ from transformers.generation.logits_process import LogitsProcessor
7
+ from typing import List, Literal, Optional
8
+
9
+ from modules.bbox_gen.modules.norm import GroupNorm32, ChannelLayerNorm32
10
+
11
+
12
+ class GroupEmbedding(nn.Module):
13
+ def __init__(self, max_group_size, hidden_size=64):
14
+ super().__init__()
15
+
16
+ self.group_embedding = nn.Embedding(max_group_size + 1, hidden_size) # +1 for background
17
+ self.group_embedding.weight.data.normal_(mean=0.0, std=0.02)
18
+
19
+ def forward(self, masks):
20
+ batch_size, height, width = masks.shape
21
+ masks_flat = masks.reshape(batch_size, -1)
22
+ embeddings = self.group_embedding(masks_flat)
23
+ embeddings = embeddings.reshape(batch_size, height, width, -1)
24
+ embeddings = embeddings.permute(0, 3, 1, 2)
25
+ return embeddings
26
+
27
+
28
+ class MultiModalProjector(torch.nn.Module):
29
+ def __init__(self, in_features: int, out_features: int, pos_embed_seq_len=None):
30
+ super().__init__()
31
+
32
+ self.norm1 = FP32LayerNorm(in_features)
33
+ self.ff = FeedForward(in_features, out_features, mult=1, activation_fn="gelu")
34
+ self.norm2 = FP32LayerNorm(out_features)
35
+ if pos_embed_seq_len is not None:
36
+ self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_seq_len, in_features))
37
+ else:
38
+ self.pos_embed = None
39
+
40
+ def forward(self, encoder_hidden_states_image: torch.Tensor) -> torch.Tensor:
41
+ if self.pos_embed is not None:
42
+ batch_size, seq_len, embed_dim = encoder_hidden_states_image.shape
43
+ encoder_hidden_states_image = encoder_hidden_states_image.view(-1, 2 * seq_len, embed_dim)
44
+ encoder_hidden_states_image = encoder_hidden_states_image + self.pos_embed
45
+
46
+ hidden_states = self.norm1(encoder_hidden_states_image)
47
+ hidden_states = self.ff(hidden_states)
48
+ hidden_states = self.norm2(hidden_states)
49
+ return hidden_states
50
+
51
+
52
+ class MeshDecodeLogitsProcessor(LogitsProcessor):
53
+ def __init__(self, bins, BOS_id, EOS_id, PAD_id, vertices_num=8):
54
+ super().__init__()
55
+ self.bins = bins
56
+ self.BOS_id = BOS_id
57
+ self.EOS_id = EOS_id
58
+ self.PAD_id = PAD_id
59
+ self.filter_value = -float('inf')
60
+ self.vertices_num = vertices_num
61
+
62
+ def force_token(self, scores, token_id):
63
+ mask = torch.ones_like(scores, dtype=torch.bool)
64
+ mask[:, token_id] = False
65
+ scores[mask] = self.filter_value
66
+
67
+ def __call__(self, input_ids, scores):
68
+ # # all rules:
69
+ # # 1. first token: BOS
70
+ current_len = input_ids.shape[-1]
71
+ if current_len == 0:
72
+ # force bos
73
+ self.force_token(scores, self.BOS_id)
74
+ elif current_len <= self.vertices_num * 3 + 1:
75
+ scores[:, self.bins:] = self.filter_value
76
+ else:
77
+ scores[:, self.BOS_id] = self.filter_value
78
+ scores[:, self.PAD_id] = self.filter_value
79
+
80
+ effective_tokens = current_len - 1
81
+ complete_boxes = effective_tokens % (self.vertices_num * 3) == 0
82
+ # print(effective_tokens, complete_boxes)
83
+ if not complete_boxes:
84
+ scores[:, self.EOS_id] = self.filter_value
85
+
86
+ return scores
87
+
88
+
89
+ def norm_layer(norm_type: str, *args, **kwargs) -> nn.Module:
90
+ """
91
+ Return a normalization layer.
92
+ """
93
+ if norm_type == "group":
94
+ return GroupNorm32(32, *args, **kwargs)
95
+ elif norm_type == "layer":
96
+ return ChannelLayerNorm32(*args, **kwargs)
97
+ else:
98
+ raise ValueError(f"Invalid norm type {norm_type}")
99
+
100
+
101
+ class ResBlock3d(nn.Module):
102
+ def __init__(
103
+ self,
104
+ channels: int,
105
+ out_channels: Optional[int] = None,
106
+ norm_type: Literal["group", "layer"] = "layer",
107
+ ):
108
+ super().__init__()
109
+ self.channels = channels
110
+ self.out_channels = out_channels or channels
111
+
112
+ self.norm1 = norm_layer(norm_type, channels)
113
+ self.norm2 = norm_layer(norm_type, self.out_channels)
114
+ self.conv1 = nn.Conv3d(channels, self.out_channels, 3, padding=1)
115
+ self.conv2 = zero_module(nn.Conv3d(self.out_channels, self.out_channels, 3, padding=1))
116
+ self.skip_connection = nn.Conv3d(channels, self.out_channels, 1) if channels != self.out_channels else nn.Identity()
117
+
118
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
119
+ h = self.norm1(x)
120
+ h = F.silu(h)
121
+ h = self.conv1(h)
122
+ h = self.norm2(h)
123
+ h = F.silu(h)
124
+ h = self.conv2(h)
125
+ h = h + self.skip_connection(x)
126
+ return h
127
+
128
+
129
+ class DownsampleBlock3d(nn.Module):
130
+ def __init__(
131
+ self,
132
+ in_channels: int,
133
+ out_channels: int,
134
+ mode: Literal["conv", "avgpool"] = "conv",
135
+ ):
136
+ assert mode in ["conv", "avgpool"], f"Invalid mode {mode}"
137
+
138
+ super().__init__()
139
+ self.in_channels = in_channels
140
+ self.out_channels = out_channels
141
+
142
+ if mode == "conv":
143
+ self.conv = nn.Conv3d(in_channels, out_channels, 2, stride=2)
144
+ elif mode == "avgpool":
145
+ assert in_channels == out_channels, "Pooling mode requires in_channels to be equal to out_channels"
146
+
147
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
148
+ if hasattr(self, "conv"):
149
+ return self.conv(x)
150
+ else:
151
+ return F.avg_pool3d(x, 2)
152
+
153
+
154
+ def zero_module(module):
155
+ """
156
+ Zero out the parameters of a module and return it.
157
+ """
158
+ for p in module.parameters():
159
+ p.detach().zero_()
160
+ return module
161
+
162
+
163
+ class SparseStructureEncoder(nn.Module):
164
+ def __init__(
165
+ self,
166
+ in_channels: int,
167
+ latent_channels: int,
168
+ num_res_blocks: int,
169
+ channels: List[int],
170
+ num_res_blocks_middle: int = 2,
171
+ norm_type: Literal["group", "layer"] = "layer",
172
+ ):
173
+ super().__init__()
174
+ self.in_channels = in_channels
175
+ self.latent_channels = latent_channels
176
+ self.num_res_blocks = num_res_blocks
177
+ self.channels = channels
178
+ self.num_res_blocks_middle = num_res_blocks_middle
179
+ self.norm_type = norm_type
180
+ self.dtype = torch.float16
181
+ self.input_layer = nn.Conv3d(in_channels, channels[0], 3, padding=1)
182
+
183
+ self.blocks = nn.ModuleList([])
184
+ for i, ch in enumerate(channels):
185
+ self.blocks.extend([
186
+ ResBlock3d(ch, ch)
187
+ for _ in range(num_res_blocks)
188
+ ])
189
+ if i < len(channels) - 1:
190
+ self.blocks.append(
191
+ DownsampleBlock3d(ch, channels[i+1])
192
+ )
193
+
194
+ self.middle_block = nn.Sequential(*[
195
+ ResBlock3d(channels[-1], channels[-1])
196
+ for _ in range(num_res_blocks_middle)
197
+ ])
198
+
199
+ @property
200
+ def device(self) -> torch.device:
201
+ """
202
+ Return the device of the model.
203
+ """
204
+ return next(self.parameters()).device
205
+
206
+ def forward(self, x: torch.Tensor):
207
+ h = self.input_layer(x)
208
+ h = h.type(self.dtype)
209
+
210
+ for block in self.blocks:
211
+ h = block(h)
212
+ h = self.middle_block(h)
213
+
214
+ h = h.type(x.dtype)
215
+ return h
modules/bbox_gen/models/bboxopt.py ADDED
@@ -0,0 +1,221 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.utils.checkpoint
3
+ from torch import nn
4
+
5
+ from transformers import AutoModelForCausalLM, AutoConfig
6
+ from transformers.models.opt.modeling_opt import OPTForCausalLM, OPTModel, OPTDecoder, OPTConfig
7
+
8
+ from transformers.utils import logging
9
+ from typing import Optional, Union
10
+
11
+ from transformers.generation.logits_process import LogitsProcessorList
12
+ from transformers.generation.utils import GenerateNonBeamOutput, GenerateEncoderDecoderOutput, GenerateDecoderOnlyOutput
13
+ from transformers.generation.stopping_criteria import StoppingCriteriaList
14
+ from transformers.generation.configuration_utils import GenerationConfig
15
+ from transformers.generation.streamers import BaseStreamer
16
+
17
+ logger = logging.get_logger(__name__)
18
+
19
+ class BBoxOPTConfig(OPTConfig):
20
+ model_type = "mesh_opt"
21
+
22
+ class BBoxOPTDecoder(OPTDecoder):
23
+ config_class = BBoxOPTConfig
24
+
25
+ class BBoxOPTModel(OPTModel):
26
+ config_class = BBoxOPTConfig
27
+ def __init__(self, config: BBoxOPTConfig):
28
+ super(OPTModel, self).__init__(config)
29
+ self.decoder = BBoxOPTDecoder(config)
30
+ # Initialize weights and apply final processing
31
+ self.post_init()
32
+
33
+ class BBoxOPT(OPTForCausalLM):
34
+ config_class = BBoxOPTConfig
35
+
36
+ def __init__(self, config: BBoxOPTConfig):
37
+ super(OPTForCausalLM, self).__init__(config)
38
+ self.model = BBoxOPTModel(config)
39
+
40
+ # the lm_head weight is automatically tied to the embed tokens weight
41
+ self.lm_head = nn.Linear(config.word_embed_proj_dim, config.vocab_size, bias=False)
42
+
43
+ # Initialize weights and apply final processing
44
+ self.post_init()
45
+
46
+ def _sample(
47
+ self,
48
+ input_ids: torch.LongTensor,
49
+ logits_processor: LogitsProcessorList,
50
+ stopping_criteria: StoppingCriteriaList,
51
+ generation_config: GenerationConfig,
52
+ synced_gpus: bool,
53
+ streamer: Optional["BaseStreamer"],
54
+ **model_kwargs,
55
+ ) -> Union[GenerateNonBeamOutput, torch.LongTensor]:
56
+ r"""
57
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
58
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
59
+
60
+ Parameters:
61
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
62
+ The sequence used as a prompt for the generation.
63
+ logits_processor (`LogitsProcessorList`):
64
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
65
+ used to modify the prediction scores of the language modeling head applied at each generation step.
66
+ stopping_criteria (`StoppingCriteriaList`):
67
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
68
+ used to tell if the generation loop should stop.
69
+ generation_config ([`~generation.GenerationConfig`]):
70
+ The generation configuration to be used as parametrization of the decoding method.
71
+ synced_gpus (`bool`):
72
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
73
+ streamer (`BaseStreamer`, *optional*):
74
+ Streamer object that will be used to stream the generated sequences. Generated tokens are passed
75
+ through `streamer.put(token_ids)` and the streamer is responsible for any further processing.
76
+ model_kwargs:
77
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
78
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
79
+
80
+ Return:
81
+ [`~generation.GenerateDecoderOnlyOutput`], [`~generation.GenerateEncoderDecoderOutput`] or `torch.LongTensor`:
82
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
83
+ [`~generation.GenerateDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
84
+ `return_dict_in_generate=True` or a [`~generation.GenerateEncoderDecoderOutput`] if
85
+ `model.config.is_encoder_decoder=True`.
86
+ """
87
+ # init values
88
+ pad_token_id = generation_config._pad_token_tensor
89
+ output_attentions = generation_config.output_attentions
90
+ output_hidden_states = generation_config.output_hidden_states
91
+ output_scores = generation_config.output_scores
92
+ output_logits = generation_config.output_logits
93
+ return_dict_in_generate = generation_config.return_dict_in_generate
94
+ max_length = generation_config.max_length
95
+ has_eos_stopping_criteria = any(hasattr(criteria, "eos_token_id") for criteria in stopping_criteria)
96
+ do_sample = generation_config.do_sample
97
+
98
+ # init attention / hidden states / scores tuples
99
+ scores = () if (return_dict_in_generate and output_scores) else None
100
+ raw_logits = () if (return_dict_in_generate and output_logits) else None
101
+ decoder_attentions = () if (return_dict_in_generate and output_attentions) else None
102
+ cross_attentions = () if (return_dict_in_generate and output_attentions) else None
103
+ decoder_hidden_states = () if (return_dict_in_generate and output_hidden_states) else None
104
+
105
+ # if model is an encoder-decoder, retrieve encoder attention weights and hidden states
106
+ if return_dict_in_generate and self.config.is_encoder_decoder:
107
+ encoder_attentions = model_kwargs["encoder_outputs"].get("attentions") if output_attentions else None
108
+ encoder_hidden_states = (
109
+ model_kwargs["encoder_outputs"].get("hidden_states") if output_hidden_states else None
110
+ )
111
+
112
+ # keep track of which sequences are already finished
113
+ batch_size, cur_len = input_ids.shape
114
+ this_peer_finished = False
115
+ unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device)
116
+ model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs)
117
+
118
+ while self._has_unfinished_sequences(
119
+ this_peer_finished, synced_gpus, device=input_ids.device
120
+ ) and cur_len < max_length:
121
+ # prepare model inputs
122
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
123
+
124
+ # prepare variable output controls (note: some models won't accept all output controls)
125
+ model_inputs.update({"output_attentions": output_attentions} if output_attentions else {})
126
+ model_inputs.update({"output_hidden_states": output_hidden_states} if output_hidden_states else {})
127
+
128
+ # forward pass to get next token
129
+ outputs = self(**model_inputs, return_dict=True)
130
+
131
+ if synced_gpus and this_peer_finished:
132
+ continue # don't waste resources running the code we don't need
133
+
134
+ # Clone is needed to avoid keeping a hanging ref to outputs.logits which may be very large for first iteration
135
+ # (the clone itself is always small)
136
+ next_token_logits = outputs.logits.clone()[:, -1, :].float()
137
+
138
+ # pre-process distribution
139
+ next_token_scores = logits_processor(input_ids, next_token_logits)
140
+
141
+ # Store scores, attentions and hidden_states when required
142
+ if return_dict_in_generate:
143
+ if output_scores:
144
+ scores += (next_token_scores,)
145
+ if output_logits:
146
+ raw_logits += (next_token_logits,)
147
+ if output_attentions:
148
+ decoder_attentions += (
149
+ (outputs.decoder_attentions,) if self.config.is_encoder_decoder else (outputs.attentions,)
150
+ )
151
+ if self.config.is_encoder_decoder:
152
+ cross_attentions += (outputs.cross_attentions,)
153
+
154
+ if output_hidden_states:
155
+ decoder_hidden_states += (
156
+ (outputs.decoder_hidden_states,)
157
+ if self.config.is_encoder_decoder
158
+ else (outputs.hidden_states,)
159
+ )
160
+
161
+ # token selection
162
+ if do_sample:
163
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
164
+ # TODO (joao): this OP throws "skipping cudagraphs due to ['incompatible ops']", find solution
165
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
166
+ else:
167
+ next_tokens = torch.argmax(next_token_scores, dim=-1)
168
+
169
+ # finished sentences should have their next token be a padding token
170
+ if has_eos_stopping_criteria:
171
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
172
+
173
+ # update generated ids, model inputs, and length for next step
174
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
175
+ if streamer is not None:
176
+ streamer.put(next_tokens.cpu())
177
+ model_kwargs = self._update_model_kwargs_for_generation(
178
+ outputs,
179
+ model_kwargs,
180
+ is_encoder_decoder=self.config.is_encoder_decoder,
181
+ )
182
+
183
+ unfinished_sequences = unfinished_sequences & ~stopping_criteria(input_ids, scores)
184
+ this_peer_finished = unfinished_sequences.max() == 0
185
+ cur_len += 1
186
+
187
+ # This is needed to properly delete outputs.logits which may be very large for first iteration
188
+ # Otherwise a reference to outputs is kept which keeps the logits alive in the next iteration
189
+ del outputs
190
+
191
+ if streamer is not None:
192
+ streamer.end()
193
+
194
+ if return_dict_in_generate:
195
+ if self.config.is_encoder_decoder:
196
+ return GenerateEncoderDecoderOutput(
197
+ sequences=input_ids,
198
+ scores=scores,
199
+ logits=raw_logits,
200
+ encoder_attentions=encoder_attentions,
201
+ encoder_hidden_states=encoder_hidden_states,
202
+ decoder_attentions=decoder_attentions,
203
+ cross_attentions=cross_attentions,
204
+ decoder_hidden_states=decoder_hidden_states,
205
+ past_key_values=model_kwargs.get("past_key_values"),
206
+ )
207
+ else:
208
+ return GenerateDecoderOnlyOutput(
209
+ sequences=input_ids,
210
+ scores=scores,
211
+ logits=raw_logits,
212
+ attentions=decoder_attentions,
213
+ hidden_states=decoder_hidden_states,
214
+ past_key_values=model_kwargs.get("past_key_values"),
215
+ )
216
+ else:
217
+ return input_ids
218
+
219
+
220
+ AutoConfig.register("mesh_opt", BBoxOPTConfig)
221
+ AutoModelForCausalLM.register(BBoxOPTConfig, BBoxOPT)
modules/bbox_gen/models/image_encoder.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Literal
2
+ import numpy as np
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from PIL import Image
7
+
8
+ from transformers import AutoModel
9
+
10
+
11
+ class DINOv2ImageEncoder(nn.Module):
12
+ def __init__(self, model_name: Literal[
13
+ "facebook/dinov2-with-registers-large",
14
+ "facebook/dinov2-large"
15
+ ]):
16
+ super().__init__()
17
+ self.model = AutoModel.from_pretrained(model_name, torch_dtype=torch.bfloat16)
18
+ self.model.requires_grad_(False)
19
+ self.model.eval()
20
+
21
+ DINOv2_INPUT_MEAN = torch.as_tensor([0.485, 0.456, 0.406], dtype=torch.float32)[
22
+ None, :, None, None
23
+ ]
24
+ DINOv2_INPUT_STD = torch.as_tensor([0.229, 0.224, 0.225], dtype=torch.float32)[
25
+ None, :, None, None
26
+ ]
27
+ self.register_buffer("DINOv2_INPUT_MEAN", DINOv2_INPUT_MEAN, persistent=False)
28
+ self.register_buffer("DINOv2_INPUT_STD", DINOv2_INPUT_STD, persistent=False)
29
+ self.max_size = 518
30
+ self.hidden_size = self.model.config.hidden_size
31
+
32
+ def preprocess(self, image: torch.Tensor):
33
+ B, C, H, W = image.shape
34
+ assert C == 3 and H <= self.max_size and W <= self.max_size
35
+ image = (image - self.DINOv2_INPUT_MEAN.to(image)) / self.DINOv2_INPUT_STD.to(image)
36
+ return image
37
+
38
+ def forward(self, image: torch.Tensor):
39
+ image = self.preprocess(image)
40
+ features = self.model(image).last_hidden_state
41
+ return features