vid2avatar baseline
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +43 -0
- .gitignore +19 -0
- LICENSE +399 -0
- README.md +74 -11
- assets/exstrimalik.gif +3 -0
- assets/martial.gif +3 -0
- assets/parkinglot_360.gif +3 -0
- assets/roger.gif +3 -0
- assets/smpl_init.pth +3 -0
- assets/teaser.png +3 -0
- code/check_cuda.py +11 -0
- code/confs/base.yaml +13 -0
- code/confs/dataset/video.yaml +37 -0
- code/confs/model/model_w_bg.yaml +77 -0
- code/lib/datasets/__init__.py +26 -0
- code/lib/datasets/dataset.py +175 -0
- code/lib/libmise/mise.cp37-win_amd64.pyd +0 -0
- code/lib/libmise/mise.cpp +0 -0
- code/lib/libmise/mise.pyx +370 -0
- code/lib/model/body_model_params.py +49 -0
- code/lib/model/deformer.py +89 -0
- code/lib/model/density.py +46 -0
- code/lib/model/embedders.py +50 -0
- code/lib/model/loss.py +64 -0
- code/lib/model/networks.py +178 -0
- code/lib/model/ray_sampler.py +234 -0
- code/lib/model/sampler.py +29 -0
- code/lib/model/smpl.py +94 -0
- code/lib/model/v2a.py +368 -0
- code/lib/smpl/body_models.py +365 -0
- code/lib/smpl/lbs.py +377 -0
- code/lib/smpl/smpl_model/SMPL_FEMALE.pkl +3 -0
- code/lib/smpl/smpl_model/SMPL_MALE.pkl +3 -0
- code/lib/smpl/utils.py +49 -0
- code/lib/smpl/vertex_ids.py +71 -0
- code/lib/smpl/vertex_joint_selector.py +77 -0
- code/lib/utils/meshing.py +63 -0
- code/lib/utils/utils.py +232 -0
- code/setup.py +34 -0
- code/test.py +39 -0
- code/train.py +45 -0
- code/v2a_model.py +311 -0
- data/parkinglot/cameras.npz +3 -0
- data/parkinglot/cameras_normalize.npz +3 -0
- data/parkinglot/checkpoints/epoch=6299-loss=0.01887552998960018.ckpt +3 -0
- data/parkinglot/image/0000.png +3 -0
- data/parkinglot/image/0001.png +3 -0
- data/parkinglot/image/0002.png +3 -0
- data/parkinglot/image/0003.png +3 -0
- data/parkinglot/image/0004.png +3 -0
.gitattributes
CHANGED
@@ -33,3 +33,46 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
assets/exstrimalik.gif filter=lfs diff=lfs merge=lfs -text
|
37 |
+
assets/martial.gif filter=lfs diff=lfs merge=lfs -text
|
38 |
+
assets/parkinglot_360.gif filter=lfs diff=lfs merge=lfs -text
|
39 |
+
assets/roger.gif filter=lfs diff=lfs merge=lfs -text
|
40 |
+
assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
41 |
+
data/parkinglot/image/0000.png filter=lfs diff=lfs merge=lfs -text
|
42 |
+
data/parkinglot/image/0001.png filter=lfs diff=lfs merge=lfs -text
|
43 |
+
data/parkinglot/image/0002.png filter=lfs diff=lfs merge=lfs -text
|
44 |
+
data/parkinglot/image/0003.png filter=lfs diff=lfs merge=lfs -text
|
45 |
+
data/parkinglot/image/0004.png filter=lfs diff=lfs merge=lfs -text
|
46 |
+
data/parkinglot/image/0005.png filter=lfs diff=lfs merge=lfs -text
|
47 |
+
data/parkinglot/image/0006.png filter=lfs diff=lfs merge=lfs -text
|
48 |
+
data/parkinglot/image/0007.png filter=lfs diff=lfs merge=lfs -text
|
49 |
+
data/parkinglot/image/0010.png filter=lfs diff=lfs merge=lfs -text
|
50 |
+
data/parkinglot/image/0011.png filter=lfs diff=lfs merge=lfs -text
|
51 |
+
data/parkinglot/image/0012.png filter=lfs diff=lfs merge=lfs -text
|
52 |
+
data/parkinglot/image/0013.png filter=lfs diff=lfs merge=lfs -text
|
53 |
+
data/parkinglot/image/0014.png filter=lfs diff=lfs merge=lfs -text
|
54 |
+
data/parkinglot/image/0015.png filter=lfs diff=lfs merge=lfs -text
|
55 |
+
data/parkinglot/image/0016.png filter=lfs diff=lfs merge=lfs -text
|
56 |
+
data/parkinglot/image/0017.png filter=lfs diff=lfs merge=lfs -text
|
57 |
+
data/parkinglot/image/0018.png filter=lfs diff=lfs merge=lfs -text
|
58 |
+
data/parkinglot/image/0020.png filter=lfs diff=lfs merge=lfs -text
|
59 |
+
data/parkinglot/image/0021.png filter=lfs diff=lfs merge=lfs -text
|
60 |
+
data/parkinglot/image/0022.png filter=lfs diff=lfs merge=lfs -text
|
61 |
+
data/parkinglot/image/0023.png filter=lfs diff=lfs merge=lfs -text
|
62 |
+
data/parkinglot/image/0024.png filter=lfs diff=lfs merge=lfs -text
|
63 |
+
data/parkinglot/image/0025.png filter=lfs diff=lfs merge=lfs -text
|
64 |
+
data/parkinglot/image/0027.png filter=lfs diff=lfs merge=lfs -text
|
65 |
+
data/parkinglot/image/0028.png filter=lfs diff=lfs merge=lfs -text
|
66 |
+
data/parkinglot/image/0029.png filter=lfs diff=lfs merge=lfs -text
|
67 |
+
data/parkinglot/image/0030.png filter=lfs diff=lfs merge=lfs -text
|
68 |
+
data/parkinglot/image/0031.png filter=lfs diff=lfs merge=lfs -text
|
69 |
+
data/parkinglot/image/0032.png filter=lfs diff=lfs merge=lfs -text
|
70 |
+
data/parkinglot/image/0033.png filter=lfs diff=lfs merge=lfs -text
|
71 |
+
data/parkinglot/image/0034.png filter=lfs diff=lfs merge=lfs -text
|
72 |
+
data/parkinglot/image/0035.png filter=lfs diff=lfs merge=lfs -text
|
73 |
+
data/parkinglot/image/0036.png filter=lfs diff=lfs merge=lfs -text
|
74 |
+
data/parkinglot/image/0037.png filter=lfs diff=lfs merge=lfs -text
|
75 |
+
data/parkinglot/image/0038.png filter=lfs diff=lfs merge=lfs -text
|
76 |
+
data/parkinglot/image/0039.png filter=lfs diff=lfs merge=lfs -text
|
77 |
+
data/parkinglot/image/0040.png filter=lfs diff=lfs merge=lfs -text
|
78 |
+
data/parkinglot/image/0041.png filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
.vscode/
|
2 |
+
env/
|
3 |
+
__pycache__/
|
4 |
+
*.ply
|
5 |
+
*.npz
|
6 |
+
*.npy
|
7 |
+
*.pkl
|
8 |
+
outputs
|
9 |
+
*.obj
|
10 |
+
*.ipynb
|
11 |
+
*.so
|
12 |
+
data
|
13 |
+
code/UNKNOWN.egg-info
|
14 |
+
code/dist
|
15 |
+
code/build
|
16 |
+
visualization/imgui.ini
|
17 |
+
export
|
18 |
+
preprocessing/raw_data
|
19 |
+
preprocessing/romp
|
LICENSE
ADDED
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial 4.0 International Public
|
58 |
+
License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial 4.0 International Public License ("Public
|
63 |
+
License"). To the extent this Public License may be interpreted as a
|
64 |
+
contract, You are granted the Licensed Rights in consideration of Your
|
65 |
+
acceptance of these terms and conditions, and the Licensor grants You
|
66 |
+
such rights in consideration of benefits the Licensor receives from
|
67 |
+
making the Licensed Material available under these terms and
|
68 |
+
conditions.
|
69 |
+
|
70 |
+
Section 1 -- Definitions.
|
71 |
+
|
72 |
+
a. Adapted Material means material subject to Copyright and Similar
|
73 |
+
Rights that is derived from or based upon the Licensed Material
|
74 |
+
and in which the Licensed Material is translated, altered,
|
75 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
76 |
+
permission under the Copyright and Similar Rights held by the
|
77 |
+
Licensor. For purposes of this Public License, where the Licensed
|
78 |
+
Material is a musical work, performance, or sound recording,
|
79 |
+
Adapted Material is always produced where the Licensed Material is
|
80 |
+
synched in timed relation with a moving image.
|
81 |
+
|
82 |
+
b. Adapter's License means the license You apply to Your Copyright
|
83 |
+
and Similar Rights in Your contributions to Adapted Material in
|
84 |
+
accordance with the terms and conditions of this Public License.
|
85 |
+
|
86 |
+
c. Copyright and Similar Rights means copyright and/or similar rights
|
87 |
+
closely related to copyright including, without limitation,
|
88 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
89 |
+
Rights, without regard to how the rights are labeled or
|
90 |
+
categorized. For purposes of this Public License, the rights
|
91 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
92 |
+
Rights.
|
93 |
+
d. Effective Technological Measures means those measures that, in the
|
94 |
+
absence of proper authority, may not be circumvented under laws
|
95 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
96 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
97 |
+
agreements.
|
98 |
+
|
99 |
+
e. Exceptions and Limitations means fair use, fair dealing, and/or
|
100 |
+
any other exception or limitation to Copyright and Similar Rights
|
101 |
+
that applies to Your use of the Licensed Material.
|
102 |
+
|
103 |
+
f. Licensed Material means the artistic or literary work, database,
|
104 |
+
or other material to which the Licensor applied this Public
|
105 |
+
License.
|
106 |
+
|
107 |
+
g. Licensed Rights means the rights granted to You subject to the
|
108 |
+
terms and conditions of this Public License, which are limited to
|
109 |
+
all Copyright and Similar Rights that apply to Your use of the
|
110 |
+
Licensed Material and that the Licensor has authority to license.
|
111 |
+
|
112 |
+
h. Licensor means the individual(s) or entity(ies) granting rights
|
113 |
+
under this Public License.
|
114 |
+
|
115 |
+
i. NonCommercial means not primarily intended for or directed towards
|
116 |
+
commercial advantage or monetary compensation. For purposes of
|
117 |
+
this Public License, the exchange of the Licensed Material for
|
118 |
+
other material subject to Copyright and Similar Rights by digital
|
119 |
+
file-sharing or similar means is NonCommercial provided there is
|
120 |
+
no payment of monetary compensation in connection with the
|
121 |
+
exchange.
|
122 |
+
|
123 |
+
j. Share means to provide material to the public by any means or
|
124 |
+
process that requires permission under the Licensed Rights, such
|
125 |
+
as reproduction, public display, public performance, distribution,
|
126 |
+
dissemination, communication, or importation, and to make material
|
127 |
+
available to the public including in ways that members of the
|
128 |
+
public may access the material from a place and at a time
|
129 |
+
individually chosen by them.
|
130 |
+
|
131 |
+
k. Sui Generis Database Rights means rights other than copyright
|
132 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
133 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
134 |
+
as amended and/or succeeded, as well as other essentially
|
135 |
+
equivalent rights anywhere in the world.
|
136 |
+
|
137 |
+
l. You means the individual or entity exercising the Licensed Rights
|
138 |
+
under this Public License. Your has a corresponding meaning.
|
139 |
+
|
140 |
+
Section 2 -- Scope.
|
141 |
+
|
142 |
+
a. License grant.
|
143 |
+
|
144 |
+
1. Subject to the terms and conditions of this Public License,
|
145 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
146 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
147 |
+
exercise the Licensed Rights in the Licensed Material to:
|
148 |
+
|
149 |
+
a. reproduce and Share the Licensed Material, in whole or
|
150 |
+
in part, for NonCommercial purposes only; and
|
151 |
+
|
152 |
+
b. produce, reproduce, and Share Adapted Material for
|
153 |
+
NonCommercial purposes only.
|
154 |
+
|
155 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
156 |
+
Exceptions and Limitations apply to Your use, this Public
|
157 |
+
License does not apply, and You do not need to comply with
|
158 |
+
its terms and conditions.
|
159 |
+
|
160 |
+
3. Term. The term of this Public License is specified in Section
|
161 |
+
6(a).
|
162 |
+
|
163 |
+
4. Media and formats; technical modifications allowed. The
|
164 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
165 |
+
all media and formats whether now known or hereafter created,
|
166 |
+
and to make technical modifications necessary to do so. The
|
167 |
+
Licensor waives and/or agrees not to assert any right or
|
168 |
+
authority to forbid You from making technical modifications
|
169 |
+
necessary to exercise the Licensed Rights, including
|
170 |
+
technical modifications necessary to circumvent Effective
|
171 |
+
Technological Measures. For purposes of this Public License,
|
172 |
+
simply making modifications authorized by this Section 2(a)
|
173 |
+
(4) never produces Adapted Material.
|
174 |
+
|
175 |
+
5. Downstream recipients.
|
176 |
+
|
177 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
178 |
+
recipient of the Licensed Material automatically
|
179 |
+
receives an offer from the Licensor to exercise the
|
180 |
+
Licensed Rights under the terms and conditions of this
|
181 |
+
Public License.
|
182 |
+
|
183 |
+
b. No downstream restrictions. You may not offer or impose
|
184 |
+
any additional or different terms or conditions on, or
|
185 |
+
apply any Effective Technological Measures to, the
|
186 |
+
Licensed Material if doing so restricts exercise of the
|
187 |
+
Licensed Rights by any recipient of the Licensed
|
188 |
+
Material.
|
189 |
+
|
190 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
191 |
+
may be construed as permission to assert or imply that You
|
192 |
+
are, or that Your use of the Licensed Material is, connected
|
193 |
+
with, or sponsored, endorsed, or granted official status by,
|
194 |
+
the Licensor or others designated to receive attribution as
|
195 |
+
provided in Section 3(a)(1)(A)(i).
|
196 |
+
|
197 |
+
b. Other rights.
|
198 |
+
|
199 |
+
1. Moral rights, such as the right of integrity, are not
|
200 |
+
licensed under this Public License, nor are publicity,
|
201 |
+
privacy, and/or other similar personality rights; however, to
|
202 |
+
the extent possible, the Licensor waives and/or agrees not to
|
203 |
+
assert any such rights held by the Licensor to the limited
|
204 |
+
extent necessary to allow You to exercise the Licensed
|
205 |
+
Rights, but not otherwise.
|
206 |
+
|
207 |
+
2. Patent and trademark rights are not licensed under this
|
208 |
+
Public License.
|
209 |
+
|
210 |
+
3. To the extent possible, the Licensor waives any right to
|
211 |
+
collect royalties from You for the exercise of the Licensed
|
212 |
+
Rights, whether directly or through a collecting society
|
213 |
+
under any voluntary or waivable statutory or compulsory
|
214 |
+
licensing scheme. In all other cases the Licensor expressly
|
215 |
+
reserves any right to collect such royalties, including when
|
216 |
+
the Licensed Material is used other than for NonCommercial
|
217 |
+
purposes.
|
218 |
+
|
219 |
+
Section 3 -- License Conditions.
|
220 |
+
|
221 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
222 |
+
following conditions.
|
223 |
+
|
224 |
+
a. Attribution.
|
225 |
+
|
226 |
+
1. If You Share the Licensed Material (including in modified
|
227 |
+
form), You must:
|
228 |
+
|
229 |
+
a. retain the following if it is supplied by the Licensor
|
230 |
+
with the Licensed Material:
|
231 |
+
|
232 |
+
i. identification of the creator(s) of the Licensed
|
233 |
+
Material and any others designated to receive
|
234 |
+
attribution, in any reasonable manner requested by
|
235 |
+
the Licensor (including by pseudonym if
|
236 |
+
designated);
|
237 |
+
|
238 |
+
ii. a copyright notice;
|
239 |
+
|
240 |
+
iii. a notice that refers to this Public License;
|
241 |
+
|
242 |
+
iv. a notice that refers to the disclaimer of
|
243 |
+
warranties;
|
244 |
+
|
245 |
+
v. a URI or hyperlink to the Licensed Material to the
|
246 |
+
extent reasonably practicable;
|
247 |
+
|
248 |
+
b. indicate if You modified the Licensed Material and
|
249 |
+
retain an indication of any previous modifications; and
|
250 |
+
|
251 |
+
c. indicate the Licensed Material is licensed under this
|
252 |
+
Public License, and include the text of, or the URI or
|
253 |
+
hyperlink to, this Public License.
|
254 |
+
|
255 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
256 |
+
reasonable manner based on the medium, means, and context in
|
257 |
+
which You Share the Licensed Material. For example, it may be
|
258 |
+
reasonable to satisfy the conditions by providing a URI or
|
259 |
+
hyperlink to a resource that includes the required
|
260 |
+
information.
|
261 |
+
|
262 |
+
3. If requested by the Licensor, You must remove any of the
|
263 |
+
information required by Section 3(a)(1)(A) to the extent
|
264 |
+
reasonably practicable.
|
265 |
+
|
266 |
+
4. If You Share Adapted Material You produce, the Adapter's
|
267 |
+
License You apply must not prevent recipients of the Adapted
|
268 |
+
Material from complying with this Public License.
|
269 |
+
|
270 |
+
Section 4 -- Sui Generis Database Rights.
|
271 |
+
|
272 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
273 |
+
apply to Your use of the Licensed Material:
|
274 |
+
|
275 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
276 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
277 |
+
portion of the contents of the database for NonCommercial purposes
|
278 |
+
only;
|
279 |
+
|
280 |
+
b. if You include all or a substantial portion of the database
|
281 |
+
contents in a database in which You have Sui Generis Database
|
282 |
+
Rights, then the database in which You have Sui Generis Database
|
283 |
+
Rights (but not its individual contents) is Adapted Material; and
|
284 |
+
|
285 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
286 |
+
all or a substantial portion of the contents of the database.
|
287 |
+
|
288 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
289 |
+
replace Your obligations under this Public License where the Licensed
|
290 |
+
Rights include other Copyright and Similar Rights.
|
291 |
+
|
292 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
293 |
+
|
294 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
295 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
296 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
297 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
298 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
299 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
300 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
301 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
302 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
303 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
304 |
+
|
305 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
306 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
307 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
308 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
309 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
310 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
311 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
312 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
313 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
314 |
+
|
315 |
+
c. The disclaimer of warranties and limitation of liability provided
|
316 |
+
above shall be interpreted in a manner that, to the extent
|
317 |
+
possible, most closely approximates an absolute disclaimer and
|
318 |
+
waiver of all liability.
|
319 |
+
|
320 |
+
Section 6 -- Term and Termination.
|
321 |
+
|
322 |
+
a. This Public License applies for the term of the Copyright and
|
323 |
+
Similar Rights licensed here. However, if You fail to comply with
|
324 |
+
this Public License, then Your rights under this Public License
|
325 |
+
terminate automatically.
|
326 |
+
|
327 |
+
b. Where Your right to use the Licensed Material has terminated under
|
328 |
+
Section 6(a), it reinstates:
|
329 |
+
|
330 |
+
1. automatically as of the date the violation is cured, provided
|
331 |
+
it is cured within 30 days of Your discovery of the
|
332 |
+
violation; or
|
333 |
+
|
334 |
+
2. upon express reinstatement by the Licensor.
|
335 |
+
|
336 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
337 |
+
right the Licensor may have to seek remedies for Your violations
|
338 |
+
of this Public License.
|
339 |
+
|
340 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
341 |
+
Licensed Material under separate terms or conditions or stop
|
342 |
+
distributing the Licensed Material at any time; however, doing so
|
343 |
+
will not terminate this Public License.
|
344 |
+
|
345 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
346 |
+
License.
|
347 |
+
|
348 |
+
Section 7 -- Other Terms and Conditions.
|
349 |
+
|
350 |
+
a. The Licensor shall not be bound by any additional or different
|
351 |
+
terms or conditions communicated by You unless expressly agreed.
|
352 |
+
|
353 |
+
b. Any arrangements, understandings, or agreements regarding the
|
354 |
+
Licensed Material not stated herein are separate from and
|
355 |
+
independent of the terms and conditions of this Public License.
|
356 |
+
|
357 |
+
Section 8 -- Interpretation.
|
358 |
+
|
359 |
+
a. For the avoidance of doubt, this Public License does not, and
|
360 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
361 |
+
conditions on any use of the Licensed Material that could lawfully
|
362 |
+
be made without permission under this Public License.
|
363 |
+
|
364 |
+
b. To the extent possible, if any provision of this Public License is
|
365 |
+
deemed unenforceable, it shall be automatically reformed to the
|
366 |
+
minimum extent necessary to make it enforceable. If the provision
|
367 |
+
cannot be reformed, it shall be severed from this Public License
|
368 |
+
without affecting the enforceability of the remaining terms and
|
369 |
+
conditions.
|
370 |
+
|
371 |
+
c. No term or condition of this Public License will be waived and no
|
372 |
+
failure to comply consented to unless expressly agreed to by the
|
373 |
+
Licensor.
|
374 |
+
|
375 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
376 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
377 |
+
that apply to the Licensor or You, including from the legal
|
378 |
+
processes of any jurisdiction or authority.
|
379 |
+
|
380 |
+
=======================================================================
|
381 |
+
|
382 |
+
Creative Commons is not a party to its public
|
383 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
384 |
+
its public licenses to material it publishes and in those instances
|
385 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
386 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
387 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
388 |
+
material is shared under a Creative Commons public license or as
|
389 |
+
otherwise permitted by the Creative Commons policies published at
|
390 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
391 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
392 |
+
of Creative Commons without its prior written consent including,
|
393 |
+
without limitation, in connection with any unauthorized modifications
|
394 |
+
to any of its public licenses or any other arrangements,
|
395 |
+
understandings, or agreements concerning use of licensed material. For
|
396 |
+
the avoidance of doubt, this paragraph does not form part of the
|
397 |
+
public licenses.
|
398 |
+
|
399 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
@@ -1,11 +1,74 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
6 |
-
|
7 |
-
|
8 |
-
|
9 |
-
|
10 |
-
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Vid2Avatar: 3D Avatar Reconstruction from Videos in the Wild via Self-supervised Scene Decomposition
|
2 |
+
## [Paper](https://arxiv.org/abs/2302.11566) | [Video Youtube](https://youtu.be/EGi47YeIeGQ) | [Project Page](https://moygcc.github.io/vid2avatar/) | [SynWild Data](https://synwild.ait.ethz.ch/)
|
3 |
+
|
4 |
+
|
5 |
+
Official Repository for CVPR 2023 paper [*Vid2Avatar: 3D Avatar Reconstruction from Videos in the Wild via Self-supervised Scene Decomposition*](https://arxiv.org/abs/2302.11566).
|
6 |
+
|
7 |
+
<img src="assets/teaser.png" width="800" height="223"/>
|
8 |
+
|
9 |
+
## Getting Started
|
10 |
+
* Clone this repo: `git clone https://github.com/MoyGcc/vid2avatar`
|
11 |
+
* Create a python virtual environment and activate. `conda create -n v2a python=3.7` and `conda activate v2a`
|
12 |
+
* Install dependenices. `cd vid2avatar`, `pip install -r requirement.txt` and `cd code; python setup.py develop`
|
13 |
+
* Install [Kaolin](https://kaolin.readthedocs.io/en/v0.10.0/notes/installation.html). We use version 0.10.0.
|
14 |
+
* Download [SMPL model](https://smpl.is.tue.mpg.de/download.php) (1.0.0 for Python 2.7 (10 shape PCs)) and move them to the corresponding places:
|
15 |
+
```
|
16 |
+
mkdir code/lib/smpl/smpl_model/
|
17 |
+
mv /path/to/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl code/lib/smpl/smpl_model/SMPL_FEMALE.pkl
|
18 |
+
mv /path/to/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl code/lib/smpl/smpl_model/SMPL_MALE.pkl
|
19 |
+
```
|
20 |
+
## Download preprocessed demo data
|
21 |
+
You can quickly start trying out Vid2Avatar with a preprocessed demo sequence including the pre-trained checkpoint. This can be downloaded from [Google drive](https://drive.google.com/drive/u/1/folders/1AUtKSmib7CvpWBCFO6mQ9spVrga_CTU4) which is originally a video clip provided by [NeuMan](https://github.com/apple/ml-neuman). Put this preprocessed demo data under the folder `data/` and put the folder `checkpoints` under `outputs/parkinglot/`.
|
22 |
+
|
23 |
+
## Training
|
24 |
+
Before training, make sure that the `metaninfo` in the data config file `/code/confs/dataset/video.yaml` does match the expected training video. You can also continue the training by changing the flag `is_continue` in the model config file `code/confs/model/model_w_bg`. And then run:
|
25 |
+
```
|
26 |
+
cd code
|
27 |
+
python train.py
|
28 |
+
```
|
29 |
+
The training usually takes 24-48 hours. The validation results can be found at `outputs/`.
|
30 |
+
## Test
|
31 |
+
Run the following command to obtain the final outputs. By default, this loads the latest checkpoint.
|
32 |
+
```
|
33 |
+
cd code
|
34 |
+
python test.py
|
35 |
+
```
|
36 |
+
## 3D Visualization
|
37 |
+
We use [AITViewer](https://github.com/eth-ait/aitviewer) to visualize the human models in 3D. First install AITViewer: `pip install aitviewer imgui==1.4.1`, and then run the following command to visualize the canonical mesh (--mode static) or deformed mesh sequence (--mode dynamic):
|
38 |
+
```
|
39 |
+
cd visualization
|
40 |
+
python vis.py --mode {MODE} --path {PATH}
|
41 |
+
```
|
42 |
+
<p align="center">
|
43 |
+
<img src="assets/parkinglot_360.gif" width="623" height="346"/>
|
44 |
+
</p>
|
45 |
+
|
46 |
+
## Play on custom video
|
47 |
+
* We use [ROMP](https://github.com/Arthur151/ROMP#installation) to obtain initial SMPL shape and poses: `pip install --upgrade simple-romp`
|
48 |
+
* Install [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/installation/0_index.md) as well as the python bindings.
|
49 |
+
* Put the video frames under the folder `preprocessing/raw_data/{SEQUENCE_NAME}/frames`
|
50 |
+
* Modify the preprocessing script `preprocessing/run_preprocessing.sh` accordingly: the data source, sequence name, and the gender. The data source is by default "custom" which will estimate camera intrinsics. If the camera intrinsics are known, it's better if the true camera parameters can be given.
|
51 |
+
* Run preprocessing: `cd preprocessing` and `bash run_preprocessing.sh`. The processed data will be stored in `data/`. The intermediate outputs of the preprocessing can be found at `preprocessing/raw_data/{SEQUENCE_NAME}/`
|
52 |
+
* Launch training and test in the same way as above. The `metainfo` in the data config file `/code/confs/dataset/video.yaml` should be changed according to the custom video.
|
53 |
+
|
54 |
+
<p align="center">
|
55 |
+
<img src="assets/roger.gif" width="240" height="270"/> <img src="assets/exstrimalik.gif" width="240" height="270"/> <img src="assets/martial.gif" width="240" height="270"/>
|
56 |
+
</p>
|
57 |
+
|
58 |
+
## Acknowledgement
|
59 |
+
We have used codes from other great research work, including [VolSDF](https://github.com/lioryariv/volsdf), [NeRF++](https://github.com/Kai-46/nerfplusplus), [SMPL-X](https://github.com/vchoutas/smplx), [Anim-NeRF](https://github.com/JanaldoChen/Anim-NeRF), [I M Avatar](https://github.com/zhengyuf/IMavatar) and [SNARF](https://github.com/xuchen-ethz/snarf). We sincerely thank the authors for their awesome work! We also thank the authors of [ICON](https://github.com/YuliangXiu/ICON) and [SelfRecon](https://github.com/jby1993/SelfReconCode) for discussing experiment.
|
60 |
+
|
61 |
+
## Related Works
|
62 |
+
Here are more recent related human body reconstruction projects from our team:
|
63 |
+
* [Jiang and Chen et. al. - InstantAvatar: Learning Avatars from Monocular Video in 60 Seconds](https://github.com/tijiang13/InstantAvatar)
|
64 |
+
* [Shen and Guo et. al. - X-Avatar: Expressive Human Avatars](https://skype-line.github.io/projects/X-Avatar/)
|
65 |
+
* [Yin et. al. - Hi4D: 4D Instance Segmentation of Close Human Interaction](https://yifeiyin04.github.io/Hi4D/)
|
66 |
+
|
67 |
+
```
|
68 |
+
@inproceedings{guo2023vid2avatar,
|
69 |
+
title={Vid2Avatar: 3D Avatar Reconstruction from Videos in the Wild via Self-supervised Scene Decomposition},
|
70 |
+
author={Guo, Chen and Jiang, Tianjian and Chen, Xu and Song, Jie and Hilliges, Otmar},
|
71 |
+
booktitle = {Computer Vision and Pattern Recognition (CVPR)},
|
72 |
+
year = {2023}
|
73 |
+
}
|
74 |
+
```
|
assets/exstrimalik.gif
ADDED
![]() |
Git LFS Details
|
assets/martial.gif
ADDED
![]() |
Git LFS Details
|
assets/parkinglot_360.gif
ADDED
![]() |
Git LFS Details
|
assets/roger.gif
ADDED
![]() |
Git LFS Details
|
assets/smpl_init.pth
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93541fbf3eb32ade3201565d4b0d793c851aa4173b58a46f1e62f3da292037ce
|
3 |
+
size 2415862
|
assets/teaser.png
ADDED
![]() |
Git LFS Details
|
code/check_cuda.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
print("Number of GPUs:", torch.cuda.device_count())
|
4 |
+
|
5 |
+
print("Torch version:",torch.__version__)
|
6 |
+
|
7 |
+
print("Is CUDA enabled?",torch.cuda.is_available())
|
8 |
+
|
9 |
+
print(torch.cuda.device_count())
|
10 |
+
|
11 |
+
# pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
|
code/confs/base.yaml
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
hydra:
|
2 |
+
run:
|
3 |
+
dir: "../outputs/${exp}/${run}"
|
4 |
+
|
5 |
+
defaults:
|
6 |
+
- model: model_w_bg
|
7 |
+
- dataset: video
|
8 |
+
- _self_
|
9 |
+
|
10 |
+
seed: 42
|
11 |
+
project_name: "model_w_bg"
|
12 |
+
exp: ${dataset.train.type}
|
13 |
+
run: ${dataset.metainfo.subject}
|
code/confs/dataset/video.yaml
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
metainfo:
|
2 |
+
gender: 'male'
|
3 |
+
data_dir : C:\Users\leob3\vid2avatar\data\parkinglot
|
4 |
+
subject: "parkinglot"
|
5 |
+
start_frame: 0
|
6 |
+
end_frame: 42
|
7 |
+
|
8 |
+
train:
|
9 |
+
type: "Video"
|
10 |
+
batch_size: 1
|
11 |
+
drop_last: False
|
12 |
+
shuffle: True
|
13 |
+
worker: 8
|
14 |
+
|
15 |
+
num_sample : 512
|
16 |
+
|
17 |
+
valid:
|
18 |
+
type: "VideoVal"
|
19 |
+
image_id: 0
|
20 |
+
batch_size: 1
|
21 |
+
drop_last: False
|
22 |
+
shuffle: False
|
23 |
+
worker: 8
|
24 |
+
|
25 |
+
num_sample : -1
|
26 |
+
pixel_per_batch: 2048
|
27 |
+
|
28 |
+
test:
|
29 |
+
type: "VideoTest"
|
30 |
+
image_id: 0
|
31 |
+
batch_size: 1
|
32 |
+
drop_last: False
|
33 |
+
shuffle: False
|
34 |
+
worker: 8
|
35 |
+
|
36 |
+
num_sample : -1
|
37 |
+
pixel_per_batch: 2048
|
code/confs/model/model_w_bg.yaml
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
learning_rate : 5.0e-4
|
2 |
+
sched_milestones : [200,500]
|
3 |
+
sched_factor : 0.5
|
4 |
+
smpl_init: True
|
5 |
+
is_continue: False
|
6 |
+
use_body_parsing: False
|
7 |
+
with_bkgd: True
|
8 |
+
using_inpainting: False
|
9 |
+
use_smpl_deformer: True
|
10 |
+
use_bbox_sampler: False
|
11 |
+
|
12 |
+
implicit_network:
|
13 |
+
feature_vector_size: 256
|
14 |
+
d_in: 3
|
15 |
+
d_out: 1
|
16 |
+
dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]
|
17 |
+
init: 'geometry'
|
18 |
+
bias: 0.6
|
19 |
+
skip_in: [4]
|
20 |
+
weight_norm: True
|
21 |
+
embedder_mode: 'fourier'
|
22 |
+
multires: 6
|
23 |
+
cond: 'smpl'
|
24 |
+
scene_bounding_sphere: 3.0
|
25 |
+
rendering_network:
|
26 |
+
feature_vector_size: 256
|
27 |
+
mode: "pose"
|
28 |
+
d_in: 14
|
29 |
+
d_out: 3
|
30 |
+
dims: [ 256, 256, 256, 256]
|
31 |
+
weight_norm: True
|
32 |
+
multires_view: -1
|
33 |
+
bg_implicit_network:
|
34 |
+
feature_vector_size: 256
|
35 |
+
d_in: 4
|
36 |
+
d_out: 1
|
37 |
+
dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]
|
38 |
+
init: 'none'
|
39 |
+
bias: 0.0
|
40 |
+
skip_in: [4]
|
41 |
+
weight_norm: False
|
42 |
+
embedder_mode: 'fourier'
|
43 |
+
multires: 10
|
44 |
+
cond: 'frame'
|
45 |
+
dim_frame_encoding: 32
|
46 |
+
bg_rendering_network:
|
47 |
+
feature_vector_size: 256
|
48 |
+
mode: 'nerf_frame_encoding'
|
49 |
+
d_in: 3
|
50 |
+
d_out: 3
|
51 |
+
dims: [128]
|
52 |
+
weight_norm: False
|
53 |
+
multires_view: 4
|
54 |
+
dim_frame_encoding: 32
|
55 |
+
shadow_network:
|
56 |
+
d_in: 3
|
57 |
+
d_out: 1
|
58 |
+
dims: [128, 128]
|
59 |
+
weight_norm: False
|
60 |
+
density:
|
61 |
+
params_init: {beta: 0.1}
|
62 |
+
beta_min: 0.0001
|
63 |
+
ray_sampler:
|
64 |
+
near: 0.0
|
65 |
+
N_samples: 64
|
66 |
+
N_samples_eval: 128
|
67 |
+
N_samples_extra: 32
|
68 |
+
eps: 0.1
|
69 |
+
beta_iters: 10
|
70 |
+
max_total_iters: 5
|
71 |
+
N_samples_inverse_sphere: 32
|
72 |
+
add_tiny: 1.0e-6
|
73 |
+
loss:
|
74 |
+
eikonal_weight : 0.1
|
75 |
+
bce_weight: 5.0e-3
|
76 |
+
opacity_sparse_weight: 3.0e-3
|
77 |
+
in_shape_weight: 1.0e-2
|
code/lib/datasets/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .dataset import Dataset, ValDataset, TestDataset
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
|
4 |
+
def find_dataset_using_name(name):
|
5 |
+
mapping = {
|
6 |
+
"Video": Dataset,
|
7 |
+
"VideoVal": ValDataset,
|
8 |
+
"VideoTest": TestDataset,
|
9 |
+
}
|
10 |
+
cls = mapping.get(name, None)
|
11 |
+
if cls is None:
|
12 |
+
raise ValueError(f"Fail to find dataset {name}")
|
13 |
+
return cls
|
14 |
+
|
15 |
+
|
16 |
+
def create_dataset(metainfo, split):
|
17 |
+
dataset_cls = find_dataset_using_name(split.type)
|
18 |
+
dataset = dataset_cls(metainfo, split)
|
19 |
+
return DataLoader(
|
20 |
+
dataset,
|
21 |
+
batch_size=split.batch_size,
|
22 |
+
drop_last=split.drop_last,
|
23 |
+
shuffle=split.shuffle,
|
24 |
+
num_workers=split.worker,
|
25 |
+
pin_memory=True
|
26 |
+
)
|
code/lib/datasets/dataset.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import glob
|
3 |
+
import hydra
|
4 |
+
import cv2
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from lib.utils import utils
|
8 |
+
|
9 |
+
|
10 |
+
class Dataset(torch.utils.data.Dataset):
|
11 |
+
def __init__(self, metainfo, split):
|
12 |
+
root = os.path.join("../data", metainfo.data_dir)
|
13 |
+
root = hydra.utils.to_absolute_path(root)
|
14 |
+
|
15 |
+
self.start_frame = metainfo.start_frame
|
16 |
+
self.end_frame = metainfo.end_frame
|
17 |
+
self.skip_step = 1
|
18 |
+
self.images, self.img_sizes = [], []
|
19 |
+
self.training_indices = list(range(metainfo.start_frame, metainfo.end_frame, self.skip_step))
|
20 |
+
|
21 |
+
# images
|
22 |
+
img_dir = os.path.join(root, "image")
|
23 |
+
self.img_paths = sorted(glob.glob(f"{img_dir}/*.png"))
|
24 |
+
|
25 |
+
# only store the image paths to avoid OOM
|
26 |
+
self.img_paths = [self.img_paths[i] for i in self.training_indices]
|
27 |
+
self.img_size = cv2.imread(self.img_paths[0]).shape[:2]
|
28 |
+
self.n_images = len(self.img_paths)
|
29 |
+
|
30 |
+
# coarse projected SMPL masks, only for sampling
|
31 |
+
mask_dir = os.path.join(root, "mask")
|
32 |
+
self.mask_paths = sorted(glob.glob(f"{mask_dir}/*.png"))
|
33 |
+
self.mask_paths = [self.mask_paths[i] for i in self.training_indices]
|
34 |
+
|
35 |
+
self.shape = np.load(os.path.join(root, "mean_shape.npy"))
|
36 |
+
self.poses = np.load(os.path.join(root, 'poses.npy'))[self.training_indices]
|
37 |
+
self.trans = np.load(os.path.join(root, 'normalize_trans.npy'))[self.training_indices]
|
38 |
+
# cameras
|
39 |
+
camera_dict = np.load(os.path.join(root, "cameras_normalize.npz"))
|
40 |
+
scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.training_indices]
|
41 |
+
world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.training_indices]
|
42 |
+
|
43 |
+
self.scale = 1 / scale_mats[0][0, 0]
|
44 |
+
|
45 |
+
self.intrinsics_all = []
|
46 |
+
self.pose_all = []
|
47 |
+
for scale_mat, world_mat in zip(scale_mats, world_mats):
|
48 |
+
P = world_mat @ scale_mat
|
49 |
+
P = P[:3, :4]
|
50 |
+
intrinsics, pose = utils.load_K_Rt_from_P(None, P)
|
51 |
+
self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
|
52 |
+
self.pose_all.append(torch.from_numpy(pose).float())
|
53 |
+
assert len(self.intrinsics_all) == len(self.pose_all)
|
54 |
+
|
55 |
+
# other properties
|
56 |
+
self.num_sample = split.num_sample
|
57 |
+
self.sampling_strategy = "weighted"
|
58 |
+
|
59 |
+
def __len__(self):
|
60 |
+
return self.n_images
|
61 |
+
|
62 |
+
def __getitem__(self, idx):
|
63 |
+
# normalize RGB
|
64 |
+
img = cv2.imread(self.img_paths[idx])
|
65 |
+
# preprocess: BGR -> RGB -> Normalize
|
66 |
+
|
67 |
+
img = img[:, :, ::-1] / 255
|
68 |
+
|
69 |
+
mask = cv2.imread(self.mask_paths[idx])
|
70 |
+
# preprocess: BGR -> Gray -> Mask
|
71 |
+
mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) > 0
|
72 |
+
|
73 |
+
img_size = self.img_size
|
74 |
+
|
75 |
+
uv = np.mgrid[:img_size[0], :img_size[1]].astype(np.int32)
|
76 |
+
uv = np.flip(uv, axis=0).copy().transpose(1, 2, 0).astype(np.float32)
|
77 |
+
|
78 |
+
smpl_params = torch.zeros([86]).float()
|
79 |
+
smpl_params[0] = torch.from_numpy(np.asarray(self.scale)).float()
|
80 |
+
|
81 |
+
smpl_params[1:4] = torch.from_numpy(self.trans[idx]).float()
|
82 |
+
smpl_params[4:76] = torch.from_numpy(self.poses[idx]).float()
|
83 |
+
smpl_params[76:] = torch.from_numpy(self.shape).float()
|
84 |
+
|
85 |
+
if self.num_sample > 0:
|
86 |
+
data = {
|
87 |
+
"rgb": img,
|
88 |
+
"uv": uv,
|
89 |
+
"object_mask": mask,
|
90 |
+
}
|
91 |
+
|
92 |
+
samples, index_outside = utils.weighted_sampling(data, img_size, self.num_sample)
|
93 |
+
inputs = {
|
94 |
+
"uv": samples["uv"].astype(np.float32),
|
95 |
+
"intrinsics": self.intrinsics_all[idx],
|
96 |
+
"pose": self.pose_all[idx],
|
97 |
+
"smpl_params": smpl_params,
|
98 |
+
'index_outside': index_outside,
|
99 |
+
"idx": idx
|
100 |
+
}
|
101 |
+
images = {"rgb": samples["rgb"].astype(np.float32)}
|
102 |
+
return inputs, images
|
103 |
+
else:
|
104 |
+
inputs = {
|
105 |
+
"uv": uv.reshape(-1, 2).astype(np.float32),
|
106 |
+
"intrinsics": self.intrinsics_all[idx],
|
107 |
+
"pose": self.pose_all[idx],
|
108 |
+
"smpl_params": smpl_params,
|
109 |
+
"idx": idx
|
110 |
+
}
|
111 |
+
images = {
|
112 |
+
"rgb": img.reshape(-1, 3).astype(np.float32),
|
113 |
+
"img_size": self.img_size
|
114 |
+
}
|
115 |
+
return inputs, images
|
116 |
+
|
117 |
+
class ValDataset(torch.utils.data.Dataset):
|
118 |
+
def __init__(self, metainfo, split):
|
119 |
+
self.dataset = Dataset(metainfo, split)
|
120 |
+
self.img_size = self.dataset.img_size
|
121 |
+
|
122 |
+
self.total_pixels = np.prod(self.img_size)
|
123 |
+
self.pixel_per_batch = split.pixel_per_batch
|
124 |
+
|
125 |
+
def __len__(self):
|
126 |
+
return 1
|
127 |
+
|
128 |
+
def __getitem__(self, idx):
|
129 |
+
image_id = int(np.random.choice(len(self.dataset), 1))
|
130 |
+
self.data = self.dataset[image_id]
|
131 |
+
inputs, images = self.data
|
132 |
+
|
133 |
+
inputs = {
|
134 |
+
"uv": inputs["uv"],
|
135 |
+
"intrinsics": inputs['intrinsics'],
|
136 |
+
"pose": inputs['pose'],
|
137 |
+
"smpl_params": inputs["smpl_params"],
|
138 |
+
'image_id': image_id,
|
139 |
+
"idx": inputs['idx']
|
140 |
+
}
|
141 |
+
images = {
|
142 |
+
"rgb": images["rgb"],
|
143 |
+
"img_size": images["img_size"],
|
144 |
+
'pixel_per_batch': self.pixel_per_batch,
|
145 |
+
'total_pixels': self.total_pixels
|
146 |
+
}
|
147 |
+
return inputs, images
|
148 |
+
|
149 |
+
class TestDataset(torch.utils.data.Dataset):
|
150 |
+
def __init__(self, metainfo, split):
|
151 |
+
self.dataset = Dataset(metainfo, split)
|
152 |
+
|
153 |
+
self.img_size = self.dataset.img_size
|
154 |
+
|
155 |
+
self.total_pixels = np.prod(self.img_size)
|
156 |
+
self.pixel_per_batch = split.pixel_per_batch
|
157 |
+
def __len__(self):
|
158 |
+
return len(self.dataset)
|
159 |
+
|
160 |
+
def __getitem__(self, idx):
|
161 |
+
data = self.dataset[idx]
|
162 |
+
|
163 |
+
inputs, images = data
|
164 |
+
inputs = {
|
165 |
+
"uv": inputs["uv"],
|
166 |
+
"intrinsics": inputs['intrinsics'],
|
167 |
+
"pose": inputs['pose'],
|
168 |
+
"smpl_params": inputs["smpl_params"],
|
169 |
+
"idx": inputs['idx']
|
170 |
+
}
|
171 |
+
images = {
|
172 |
+
"rgb": images["rgb"],
|
173 |
+
"img_size": images["img_size"]
|
174 |
+
}
|
175 |
+
return inputs, images, self.pixel_per_batch, self.total_pixels, idx
|
code/lib/libmise/mise.cp37-win_amd64.pyd
ADDED
Binary file (180 kB). View file
|
|
code/lib/libmise/mise.cpp
ADDED
The diff for this file is too large to render.
See raw diff
|
|
code/lib/libmise/mise.pyx
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# distutils: language = c++
|
2 |
+
cimport cython
|
3 |
+
from libc.stdint cimport int32_t, int64_t
|
4 |
+
from cython.operator cimport dereference as dref
|
5 |
+
from libcpp.vector cimport vector
|
6 |
+
from libcpp.map cimport map
|
7 |
+
from libc.math cimport isnan, NAN
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
|
11 |
+
cdef struct Vector3D:
|
12 |
+
int x, y, z
|
13 |
+
|
14 |
+
|
15 |
+
cdef struct Voxel:
|
16 |
+
Vector3D loc
|
17 |
+
unsigned int level
|
18 |
+
bint is_leaf
|
19 |
+
unsigned long children[2][2][2]
|
20 |
+
|
21 |
+
|
22 |
+
cdef struct GridPoint:
|
23 |
+
Vector3D loc
|
24 |
+
double value
|
25 |
+
bint known
|
26 |
+
|
27 |
+
|
28 |
+
cdef inline unsigned long vec_to_idx(Vector3D coord, long resolution):
|
29 |
+
cdef unsigned long idx
|
30 |
+
idx = resolution * resolution * coord.x + resolution * coord.y + coord.z
|
31 |
+
return idx
|
32 |
+
|
33 |
+
|
34 |
+
cdef class MISE:
|
35 |
+
cdef vector[Voxel] voxels
|
36 |
+
cdef vector[GridPoint] grid_points
|
37 |
+
cdef map[long, long] grid_point_hash
|
38 |
+
cdef readonly int resolution_0
|
39 |
+
cdef readonly int depth
|
40 |
+
cdef readonly double threshold
|
41 |
+
cdef readonly int voxel_size_0
|
42 |
+
cdef readonly int resolution
|
43 |
+
|
44 |
+
def __cinit__(self, int resolution_0, int depth, double threshold):
|
45 |
+
self.resolution_0 = resolution_0
|
46 |
+
self.depth = depth
|
47 |
+
self.threshold = threshold
|
48 |
+
self.voxel_size_0 = (1 << depth)
|
49 |
+
self.resolution = resolution_0 * self.voxel_size_0
|
50 |
+
|
51 |
+
# Create initial voxels
|
52 |
+
self.voxels.reserve(resolution_0 * resolution_0 * resolution_0)
|
53 |
+
|
54 |
+
cdef Voxel voxel
|
55 |
+
cdef GridPoint point
|
56 |
+
cdef Vector3D loc
|
57 |
+
cdef int i, j, k
|
58 |
+
for i in range(resolution_0):
|
59 |
+
for j in range(resolution_0):
|
60 |
+
for k in range (resolution_0):
|
61 |
+
loc = Vector3D(
|
62 |
+
i * self.voxel_size_0,
|
63 |
+
j * self.voxel_size_0,
|
64 |
+
k * self.voxel_size_0,
|
65 |
+
)
|
66 |
+
voxel = Voxel(
|
67 |
+
loc=loc,
|
68 |
+
level=0,
|
69 |
+
is_leaf=True,
|
70 |
+
)
|
71 |
+
|
72 |
+
assert(self.voxels.size() == vec_to_idx(Vector3D(i, j, k), resolution_0))
|
73 |
+
self.voxels.push_back(voxel)
|
74 |
+
|
75 |
+
# Create initial grid points
|
76 |
+
self.grid_points.reserve((resolution_0 + 1) * (resolution_0 + 1) * (resolution_0 + 1))
|
77 |
+
for i in range(resolution_0 + 1):
|
78 |
+
for j in range(resolution_0 + 1):
|
79 |
+
for k in range(resolution_0 + 1):
|
80 |
+
loc = Vector3D(
|
81 |
+
i * self.voxel_size_0,
|
82 |
+
j * self.voxel_size_0,
|
83 |
+
k * self.voxel_size_0,
|
84 |
+
)
|
85 |
+
assert(self.grid_points.size() == vec_to_idx(Vector3D(i, j, k), resolution_0 + 1))
|
86 |
+
self.add_grid_point(loc)
|
87 |
+
|
88 |
+
def update(self, int64_t[:, :] points, double[:] values):
|
89 |
+
"""Update points and set their values. Also determine all active voxels and subdivide them."""
|
90 |
+
assert(points.shape[0] == values.shape[0])
|
91 |
+
assert(points.shape[1] == 3)
|
92 |
+
cdef Vector3D loc
|
93 |
+
cdef long idx
|
94 |
+
cdef int i
|
95 |
+
|
96 |
+
# Find all indices of point and set value
|
97 |
+
for i in range(points.shape[0]):
|
98 |
+
loc = Vector3D(points[i, 0], points[i, 1], points[i, 2])
|
99 |
+
idx = self.get_grid_point_idx(loc)
|
100 |
+
if idx == -1:
|
101 |
+
raise ValueError('Point not in grid!')
|
102 |
+
self.grid_points[idx].value = values[i]
|
103 |
+
self.grid_points[idx].known = True
|
104 |
+
# Subdivide activate voxels and add new points
|
105 |
+
self.subdivide_voxels()
|
106 |
+
|
107 |
+
def query(self):
|
108 |
+
"""Query points to evaluate."""
|
109 |
+
# Find all points with unknown value
|
110 |
+
cdef vector[Vector3D] points
|
111 |
+
cdef int n_unknown = 0
|
112 |
+
for p in self.grid_points:
|
113 |
+
if not p.known:
|
114 |
+
n_unknown += 1
|
115 |
+
|
116 |
+
points.reserve(n_unknown)
|
117 |
+
for p in self.grid_points:
|
118 |
+
if not p.known:
|
119 |
+
points.push_back(p.loc)
|
120 |
+
|
121 |
+
# Convert to numpy
|
122 |
+
points_np = np.zeros((points.size(), 3), dtype=np.int64)
|
123 |
+
cdef int64_t[:, :] points_view = points_np
|
124 |
+
for i in range(points.size()):
|
125 |
+
points_view[i, 0] = points[i].x
|
126 |
+
points_view[i, 1] = points[i].y
|
127 |
+
points_view[i, 2] = points[i].z
|
128 |
+
|
129 |
+
return points_np
|
130 |
+
|
131 |
+
def to_dense(self):
|
132 |
+
"""Output dense matrix at highest resolution."""
|
133 |
+
out_array = np.full((self.resolution + 1,) * 3, np.nan)
|
134 |
+
cdef double[:, :, :] out_view = out_array
|
135 |
+
cdef GridPoint point
|
136 |
+
cdef int i, j, k
|
137 |
+
|
138 |
+
for point in self.grid_points:
|
139 |
+
# Take voxel for which points is upper left corner
|
140 |
+
# assert(point.known)
|
141 |
+
out_view[point.loc.x, point.loc.y, point.loc.z] = point.value
|
142 |
+
|
143 |
+
# Complete along x axis
|
144 |
+
for i in range(1, self.resolution + 1):
|
145 |
+
for j in range(self.resolution + 1):
|
146 |
+
for k in range(self.resolution + 1):
|
147 |
+
if isnan(out_view[i, j, k]):
|
148 |
+
out_view[i, j, k] = out_view[i-1, j, k]
|
149 |
+
|
150 |
+
# Complete along y axis
|
151 |
+
for i in range(self.resolution + 1):
|
152 |
+
for j in range(1, self.resolution + 1):
|
153 |
+
for k in range(self.resolution + 1):
|
154 |
+
if isnan(out_view[i, j, k]):
|
155 |
+
out_view[i, j, k] = out_view[i, j-1, k]
|
156 |
+
|
157 |
+
|
158 |
+
# Complete along z axis
|
159 |
+
for i in range(self.resolution + 1):
|
160 |
+
for j in range(self.resolution + 1):
|
161 |
+
for k in range(1, self.resolution + 1):
|
162 |
+
if isnan(out_view[i, j, k]):
|
163 |
+
out_view[i, j, k] = out_view[i, j, k-1]
|
164 |
+
assert(not isnan(out_view[i, j, k]))
|
165 |
+
return out_array
|
166 |
+
|
167 |
+
def get_points(self):
|
168 |
+
points_np = np.zeros((self.grid_points.size(), 3), dtype=np.int64)
|
169 |
+
values_np = np.zeros((self.grid_points.size()), dtype=np.float64)
|
170 |
+
|
171 |
+
cdef long[:, :] points_view = points_np
|
172 |
+
cdef double[:] values_view = values_np
|
173 |
+
cdef Vector3D loc
|
174 |
+
cdef int i
|
175 |
+
|
176 |
+
for i in range(self.grid_points.size()):
|
177 |
+
loc = self.grid_points[i].loc
|
178 |
+
points_view[i, 0] = loc.x
|
179 |
+
points_view[i, 1] = loc.y
|
180 |
+
points_view[i, 2] = loc.z
|
181 |
+
values_view[i] = self.grid_points[i].value
|
182 |
+
|
183 |
+
return points_np, values_np
|
184 |
+
|
185 |
+
cdef void subdivide_voxels(self) except +:
|
186 |
+
cdef vector[bint] next_to_positive
|
187 |
+
cdef vector[bint] next_to_negative
|
188 |
+
cdef int i, j, k
|
189 |
+
cdef long idx
|
190 |
+
cdef Vector3D loc, adj_loc
|
191 |
+
|
192 |
+
# Initialize vectors
|
193 |
+
next_to_positive.resize(self.voxels.size(), False)
|
194 |
+
next_to_negative.resize(self.voxels.size(), False)
|
195 |
+
|
196 |
+
# Iterate over grid points and mark voxels active
|
197 |
+
# TODO: can move this to update operation and add attibute to voxel
|
198 |
+
for grid_point in self.grid_points:
|
199 |
+
loc = grid_point.loc
|
200 |
+
if not grid_point.known:
|
201 |
+
continue
|
202 |
+
|
203 |
+
# Iterate over the 8 adjacent voxels
|
204 |
+
for i in range(-1, 1):
|
205 |
+
for j in range(-1, 1):
|
206 |
+
for k in range(-1, 1):
|
207 |
+
adj_loc = Vector3D(
|
208 |
+
x=loc.x + i,
|
209 |
+
y=loc.y + j,
|
210 |
+
z=loc.z + k,
|
211 |
+
)
|
212 |
+
idx = self.get_voxel_idx(adj_loc)
|
213 |
+
if idx == -1:
|
214 |
+
continue
|
215 |
+
|
216 |
+
if grid_point.value >= self.threshold:
|
217 |
+
next_to_positive[idx] = True
|
218 |
+
if grid_point.value <= self.threshold:
|
219 |
+
next_to_negative[idx] = True
|
220 |
+
|
221 |
+
cdef int n_subdivide = 0
|
222 |
+
|
223 |
+
for idx in range(self.voxels.size()):
|
224 |
+
if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth:
|
225 |
+
continue
|
226 |
+
if next_to_positive[idx] and next_to_negative[idx]:
|
227 |
+
n_subdivide += 1
|
228 |
+
|
229 |
+
self.voxels.reserve(self.voxels.size() + 8 * n_subdivide)
|
230 |
+
self.grid_points.reserve(self.voxels.size() + 19 * n_subdivide)
|
231 |
+
|
232 |
+
for idx in range(self.voxels.size()):
|
233 |
+
if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth:
|
234 |
+
continue
|
235 |
+
if next_to_positive[idx] and next_to_negative[idx]:
|
236 |
+
self.subdivide_voxel(idx)
|
237 |
+
|
238 |
+
cdef void subdivide_voxel(self, long idx):
|
239 |
+
cdef Voxel voxel
|
240 |
+
cdef GridPoint point
|
241 |
+
cdef Vector3D loc0 = self.voxels[idx].loc
|
242 |
+
cdef Vector3D loc
|
243 |
+
cdef int new_level = self.voxels[idx].level + 1
|
244 |
+
cdef int new_size = 1 << (self.depth - new_level)
|
245 |
+
assert(new_level <= self.depth)
|
246 |
+
assert(1 <= new_size <= self.voxel_size_0)
|
247 |
+
|
248 |
+
# Current voxel is not leaf anymore
|
249 |
+
self.voxels[idx].is_leaf = False
|
250 |
+
# Add new voxels
|
251 |
+
cdef int i, j, k
|
252 |
+
for i in range(2):
|
253 |
+
for j in range(2):
|
254 |
+
for k in range(2):
|
255 |
+
loc = Vector3D(
|
256 |
+
x=loc0.x + i * new_size,
|
257 |
+
y=loc0.y + j * new_size,
|
258 |
+
z=loc0.z + k * new_size,
|
259 |
+
)
|
260 |
+
voxel = Voxel(
|
261 |
+
loc=loc,
|
262 |
+
level=new_level,
|
263 |
+
is_leaf=True
|
264 |
+
)
|
265 |
+
|
266 |
+
self.voxels[idx].children[i][j][k] = self.voxels.size()
|
267 |
+
self.voxels.push_back(voxel)
|
268 |
+
|
269 |
+
# Add new grid points
|
270 |
+
for i in range(3):
|
271 |
+
for j in range(3):
|
272 |
+
for k in range(3):
|
273 |
+
loc = Vector3D(
|
274 |
+
loc0.x + i * new_size,
|
275 |
+
loc0.y + j * new_size,
|
276 |
+
loc0.z + k * new_size,
|
277 |
+
)
|
278 |
+
|
279 |
+
# Only add new grid points
|
280 |
+
if self.get_grid_point_idx(loc) == -1:
|
281 |
+
self.add_grid_point(loc)
|
282 |
+
|
283 |
+
|
284 |
+
@cython.cdivision(True)
|
285 |
+
cdef long get_voxel_idx(self, Vector3D loc) except +:
|
286 |
+
"""Utility function for getting voxel index corresponding to 3D coordinates."""
|
287 |
+
# Shorthands
|
288 |
+
cdef long resolution = self.resolution
|
289 |
+
cdef long resolution_0 = self.resolution_0
|
290 |
+
cdef long depth = self.depth
|
291 |
+
cdef long voxel_size_0 = self.voxel_size_0
|
292 |
+
|
293 |
+
# Return -1 if point lies outside bounds
|
294 |
+
if not (0 <= loc.x < resolution and 0<= loc.y < resolution and 0 <= loc.z < resolution):
|
295 |
+
return -1
|
296 |
+
|
297 |
+
# Coordinates in coarse voxel grid
|
298 |
+
cdef Vector3D loc0 = Vector3D(
|
299 |
+
x=loc.x >> depth,
|
300 |
+
y=loc.y >> depth,
|
301 |
+
z=loc.z >> depth,
|
302 |
+
)
|
303 |
+
|
304 |
+
# Initial voxels
|
305 |
+
cdef int idx = vec_to_idx(loc0, resolution_0)
|
306 |
+
cdef Voxel voxel = self.voxels[idx]
|
307 |
+
assert(voxel.loc.x == loc0.x * voxel_size_0)
|
308 |
+
assert(voxel.loc.y == loc0.y * voxel_size_0)
|
309 |
+
assert(voxel.loc.z == loc0.z * voxel_size_0)
|
310 |
+
|
311 |
+
# Relative coordinates
|
312 |
+
cdef Vector3D loc_rel = Vector3D(
|
313 |
+
x=loc.x - (loc0.x << depth),
|
314 |
+
y=loc.y - (loc0.y << depth),
|
315 |
+
z=loc.z - (loc0.z << depth),
|
316 |
+
)
|
317 |
+
|
318 |
+
cdef Vector3D loc_offset
|
319 |
+
cdef long voxel_size = voxel_size_0
|
320 |
+
|
321 |
+
while not voxel.is_leaf:
|
322 |
+
voxel_size = voxel_size >> 1
|
323 |
+
assert(voxel_size >= 1)
|
324 |
+
|
325 |
+
# Determine child
|
326 |
+
loc_offset = Vector3D(
|
327 |
+
x=1 if (loc_rel.x >= voxel_size) else 0,
|
328 |
+
y=1 if (loc_rel.y >= voxel_size) else 0,
|
329 |
+
z=1 if (loc_rel.z >= voxel_size) else 0,
|
330 |
+
)
|
331 |
+
# New voxel
|
332 |
+
idx = voxel.children[loc_offset.x][loc_offset.y][loc_offset.z]
|
333 |
+
voxel = self.voxels[idx]
|
334 |
+
|
335 |
+
# New relative coordinates
|
336 |
+
loc_rel = Vector3D(
|
337 |
+
x=loc_rel.x - loc_offset.x * voxel_size,
|
338 |
+
y=loc_rel.y - loc_offset.y * voxel_size,
|
339 |
+
z=loc_rel.z - loc_offset.z * voxel_size,
|
340 |
+
)
|
341 |
+
|
342 |
+
assert(0<= loc_rel.x < voxel_size)
|
343 |
+
assert(0<= loc_rel.y < voxel_size)
|
344 |
+
assert(0<= loc_rel.z < voxel_size)
|
345 |
+
|
346 |
+
|
347 |
+
# Return idx
|
348 |
+
return idx
|
349 |
+
|
350 |
+
|
351 |
+
cdef inline void add_grid_point(self, Vector3D loc):
|
352 |
+
cdef GridPoint point = GridPoint(
|
353 |
+
loc=loc,
|
354 |
+
value=0.,
|
355 |
+
known=False,
|
356 |
+
)
|
357 |
+
self.grid_point_hash[vec_to_idx(loc, self.resolution + 1)] = self.grid_points.size()
|
358 |
+
self.grid_points.push_back(point)
|
359 |
+
|
360 |
+
cdef inline int get_grid_point_idx(self, Vector3D loc):
|
361 |
+
p_idx = self.grid_point_hash.find(vec_to_idx(loc, self.resolution + 1))
|
362 |
+
if p_idx == self.grid_point_hash.end():
|
363 |
+
return -1
|
364 |
+
|
365 |
+
cdef int idx = dref(p_idx).second
|
366 |
+
assert(self.grid_points[idx].loc.x == loc.x)
|
367 |
+
assert(self.grid_points[idx].loc.y == loc.y)
|
368 |
+
assert(self.grid_points[idx].loc.z == loc.z)
|
369 |
+
|
370 |
+
return idx
|
code/lib/model/body_model_params.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
class BodyModelParams(nn.Module):
|
5 |
+
def __init__(self, num_frames, model_type='smpl'):
|
6 |
+
super(BodyModelParams, self).__init__()
|
7 |
+
self.num_frames = num_frames
|
8 |
+
self.model_type = model_type
|
9 |
+
self.params_dim = {
|
10 |
+
'betas': 10,
|
11 |
+
'global_orient': 3,
|
12 |
+
'transl': 3,
|
13 |
+
}
|
14 |
+
if model_type == 'smpl':
|
15 |
+
self.params_dim.update({
|
16 |
+
'body_pose': 69,
|
17 |
+
})
|
18 |
+
else:
|
19 |
+
assert ValueError(f'Unknown model type {model_type}, exiting!')
|
20 |
+
|
21 |
+
self.param_names = self.params_dim.keys()
|
22 |
+
|
23 |
+
for param_name in self.param_names:
|
24 |
+
if param_name == 'betas':
|
25 |
+
param = nn.Embedding(1, self.params_dim[param_name])
|
26 |
+
param.weight.data.fill_(0)
|
27 |
+
param.weight.requires_grad = False
|
28 |
+
setattr(self, param_name, param)
|
29 |
+
else:
|
30 |
+
param = nn.Embedding(num_frames, self.params_dim[param_name])
|
31 |
+
param.weight.data.fill_(0)
|
32 |
+
param.weight.requires_grad = False
|
33 |
+
setattr(self, param_name, param)
|
34 |
+
|
35 |
+
def init_parameters(self, param_name, data, requires_grad=False):
|
36 |
+
getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]]
|
37 |
+
getattr(self, param_name).weight.requires_grad = requires_grad
|
38 |
+
|
39 |
+
def set_requires_grad(self, param_name, requires_grad=True):
|
40 |
+
getattr(self, param_name).weight.requires_grad = requires_grad
|
41 |
+
|
42 |
+
def forward(self, frame_ids):
|
43 |
+
params = {}
|
44 |
+
for param_name in self.param_names:
|
45 |
+
if param_name == 'betas':
|
46 |
+
params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids))
|
47 |
+
else:
|
48 |
+
params[param_name] = getattr(self, param_name)(frame_ids)
|
49 |
+
return params
|
code/lib/model/deformer.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from .smpl import SMPLServer
|
4 |
+
from pytorch3d import ops
|
5 |
+
|
6 |
+
class SMPLDeformer():
|
7 |
+
def __init__(self, max_dist=0.1, K=1, gender='female', betas=None):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
self.max_dist = max_dist
|
11 |
+
self.K = K
|
12 |
+
self.smpl = SMPLServer(gender=gender)
|
13 |
+
smpl_params_canoical = self.smpl.param_canonical.clone()
|
14 |
+
smpl_params_canoical[:, 76:] = torch.tensor(betas).float().to(self.smpl.param_canonical.device)
|
15 |
+
cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(smpl_params_canoical, [1, 3, 72, 10], dim=1)
|
16 |
+
smpl_output = self.smpl(cano_scale, cano_transl, cano_thetas, cano_betas)
|
17 |
+
self.smpl_verts = smpl_output['smpl_verts']
|
18 |
+
self.smpl_weights = smpl_output['smpl_weights']
|
19 |
+
def forward(self, x, smpl_tfs, return_weights=True, inverse=False, smpl_verts=None):
|
20 |
+
if x.shape[0] == 0: return x
|
21 |
+
if smpl_verts is None:
|
22 |
+
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
|
23 |
+
else:
|
24 |
+
weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=smpl_verts[0], smpl_weights=self.smpl_weights)
|
25 |
+
if return_weights:
|
26 |
+
return weights
|
27 |
+
|
28 |
+
x_transformed = skinning(x.unsqueeze(0), weights, smpl_tfs, inverse=inverse)
|
29 |
+
|
30 |
+
return x_transformed.squeeze(0), outlier_mask
|
31 |
+
def forward_skinning(self, xc, cond, smpl_tfs):
|
32 |
+
weights, _ = self.query_skinning_weights_smpl_multi(xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
|
33 |
+
x_transformed = skinning(xc, weights, smpl_tfs, inverse=False)
|
34 |
+
|
35 |
+
return x_transformed
|
36 |
+
|
37 |
+
def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights):
|
38 |
+
|
39 |
+
distance_batch, index_batch, neighbor_points = ops.knn_points(pts, smpl_verts.unsqueeze(0),
|
40 |
+
K=self.K, return_nn=True)
|
41 |
+
distance_batch = torch.clamp(distance_batch, max=4)
|
42 |
+
weights_conf = torch.exp(-distance_batch)
|
43 |
+
distance_batch = torch.sqrt(distance_batch)
|
44 |
+
weights_conf = weights_conf / weights_conf.sum(-1, keepdim=True)
|
45 |
+
index_batch = index_batch[0]
|
46 |
+
weights = smpl_weights[:, index_batch, :]
|
47 |
+
weights = torch.sum(weights * weights_conf.unsqueeze(-1), dim=-2).detach()
|
48 |
+
|
49 |
+
outlier_mask = (distance_batch[..., 0] > self.max_dist)[0]
|
50 |
+
return weights, outlier_mask
|
51 |
+
|
52 |
+
def query_weights(self, xc):
|
53 |
+
weights = self.forward(xc, None, return_weights=True, inverse=False)
|
54 |
+
return weights
|
55 |
+
|
56 |
+
def forward_skinning_normal(self, xc, normal, cond, tfs, inverse = False):
|
57 |
+
if normal.ndim == 2:
|
58 |
+
normal = normal.unsqueeze(0)
|
59 |
+
w = self.query_weights(xc[0], cond)
|
60 |
+
|
61 |
+
p_h = F.pad(normal, (0, 1), value=0)
|
62 |
+
|
63 |
+
if inverse:
|
64 |
+
# p:num_point, n:num_bone, i,j: num_dim+1
|
65 |
+
tf_w = torch.einsum('bpn,bnij->bpij', w.double(), tfs.double())
|
66 |
+
p_h = torch.einsum('bpij,bpj->bpi', tf_w.inverse(), p_h.double()).float()
|
67 |
+
else:
|
68 |
+
p_h = torch.einsum('bpn, bnij, bpj->bpi', w.double(), tfs.double(), p_h.double()).float()
|
69 |
+
|
70 |
+
return p_h[:, :, :3]
|
71 |
+
|
72 |
+
def skinning(x, w, tfs, inverse=False):
|
73 |
+
"""Linear blend skinning
|
74 |
+
Args:
|
75 |
+
x (tensor): canonical points. shape: [B, N, D]
|
76 |
+
w (tensor): conditional input. [B, N, J]
|
77 |
+
tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
|
78 |
+
Returns:
|
79 |
+
x (tensor): skinned points. shape: [B, N, D]
|
80 |
+
"""
|
81 |
+
x_h = F.pad(x, (0, 1), value=1.0)
|
82 |
+
|
83 |
+
if inverse:
|
84 |
+
# p:n_point, n:n_bone, i,k: n_dim+1
|
85 |
+
w_tf = torch.einsum("bpn,bnij->bpij", w, tfs)
|
86 |
+
x_h = torch.einsum("bpij,bpj->bpi", w_tf.inverse(), x_h)
|
87 |
+
else:
|
88 |
+
x_h = torch.einsum("bpn,bnij,bpj->bpi", w, tfs, x_h)
|
89 |
+
return x_h[:, :, :3]
|
code/lib/model/density.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
class Density(nn.Module):
|
5 |
+
def __init__(self, params_init={}):
|
6 |
+
super().__init__()
|
7 |
+
for p in params_init:
|
8 |
+
param = nn.Parameter(torch.tensor(params_init[p]))
|
9 |
+
setattr(self, p, param)
|
10 |
+
|
11 |
+
def forward(self, sdf, beta=None):
|
12 |
+
return self.density_func(sdf, beta=beta)
|
13 |
+
|
14 |
+
|
15 |
+
class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
|
16 |
+
def __init__(self, params_init={}, beta_min=0.0001):
|
17 |
+
super().__init__(params_init=params_init)
|
18 |
+
self.beta_min = torch.tensor(beta_min).cuda()
|
19 |
+
|
20 |
+
def density_func(self, sdf, beta=None):
|
21 |
+
if beta is None:
|
22 |
+
beta = self.get_beta()
|
23 |
+
|
24 |
+
alpha = 1 / beta
|
25 |
+
return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))
|
26 |
+
|
27 |
+
def get_beta(self):
|
28 |
+
beta = self.beta.abs() + self.beta_min
|
29 |
+
return beta
|
30 |
+
|
31 |
+
|
32 |
+
class AbsDensity(Density): # like NeRF++
|
33 |
+
def density_func(self, sdf, beta=None):
|
34 |
+
return torch.abs(sdf)
|
35 |
+
|
36 |
+
|
37 |
+
class SimpleDensity(Density): # like NeRF
|
38 |
+
def __init__(self, params_init={}, noise_std=1.0):
|
39 |
+
super().__init__(params_init=params_init)
|
40 |
+
self.noise_std = noise_std
|
41 |
+
|
42 |
+
def density_func(self, sdf, beta=None):
|
43 |
+
if self.training and self.noise_std > 0.0:
|
44 |
+
noise = torch.randn(sdf.shape).cuda() * self.noise_std
|
45 |
+
sdf = sdf + noise
|
46 |
+
return torch.relu(sdf)
|
code/lib/model/embedders.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class Embedder:
|
4 |
+
def __init__(self, **kwargs):
|
5 |
+
self.kwargs = kwargs
|
6 |
+
self.create_embedding_fn()
|
7 |
+
|
8 |
+
def create_embedding_fn(self):
|
9 |
+
embed_fns = []
|
10 |
+
d = self.kwargs['input_dims']
|
11 |
+
out_dim = 0
|
12 |
+
if self.kwargs['include_input']:
|
13 |
+
embed_fns.append(lambda x: x)
|
14 |
+
out_dim += d
|
15 |
+
|
16 |
+
max_freq = self.kwargs['max_freq_log2']
|
17 |
+
N_freqs = self.kwargs['num_freqs']
|
18 |
+
|
19 |
+
if self.kwargs['log_sampling']:
|
20 |
+
freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
|
21 |
+
else:
|
22 |
+
freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
|
23 |
+
|
24 |
+
for freq in freq_bands:
|
25 |
+
for p_fn in self.kwargs['periodic_fns']:
|
26 |
+
embed_fns.append(lambda x, p_fn=p_fn,
|
27 |
+
freq=freq: p_fn(x * freq))
|
28 |
+
out_dim += d
|
29 |
+
|
30 |
+
self.embed_fns = embed_fns
|
31 |
+
self.out_dim = out_dim
|
32 |
+
|
33 |
+
def embed(self, inputs):
|
34 |
+
return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
|
35 |
+
|
36 |
+
def get_embedder(multires, input_dims=3, mode='fourier'):
|
37 |
+
embed_kwargs = {
|
38 |
+
'include_input': True,
|
39 |
+
'input_dims': input_dims,
|
40 |
+
'max_freq_log2': multires-1,
|
41 |
+
'num_freqs': multires,
|
42 |
+
'log_sampling': True,
|
43 |
+
'periodic_fns': [torch.sin, torch.cos],
|
44 |
+
}
|
45 |
+
if mode == 'fourier':
|
46 |
+
embedder_obj = Embedder(**embed_kwargs)
|
47 |
+
|
48 |
+
|
49 |
+
def embed(x, eo=embedder_obj): return eo.embed(x)
|
50 |
+
return embed, embedder_obj.out_dim
|
code/lib/model/loss.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class Loss(nn.Module):
|
6 |
+
def __init__(self, opt):
|
7 |
+
super().__init__()
|
8 |
+
self.eikonal_weight = opt.eikonal_weight
|
9 |
+
self.bce_weight = opt.bce_weight
|
10 |
+
self.opacity_sparse_weight = opt.opacity_sparse_weight
|
11 |
+
self.in_shape_weight = opt.in_shape_weight
|
12 |
+
self.eps = 1e-6
|
13 |
+
self.milestone = 200
|
14 |
+
self.l1_loss = nn.L1Loss(reduction='mean')
|
15 |
+
self.l2_loss = nn.MSELoss(reduction='mean')
|
16 |
+
|
17 |
+
# L1 reconstruction loss for RGB values
|
18 |
+
def get_rgb_loss(self, rgb_values, rgb_gt):
|
19 |
+
rgb_loss = self.l1_loss(rgb_values, rgb_gt)
|
20 |
+
return rgb_loss
|
21 |
+
|
22 |
+
# Eikonal loss introduced in IGR
|
23 |
+
def get_eikonal_loss(self, grad_theta):
|
24 |
+
eikonal_loss = ((grad_theta.norm(2, dim=-1) - 1)**2).mean()
|
25 |
+
return eikonal_loss
|
26 |
+
|
27 |
+
# BCE loss for clear boundary
|
28 |
+
def get_bce_loss(self, acc_map):
|
29 |
+
binary_loss = -1 * (acc_map * (acc_map + self.eps).log() + (1-acc_map) * (1 - acc_map + self.eps).log()).mean() * 2
|
30 |
+
return binary_loss
|
31 |
+
|
32 |
+
# Global opacity sparseness regularization
|
33 |
+
def get_opacity_sparse_loss(self, acc_map, index_off_surface):
|
34 |
+
opacity_sparse_loss = self.l1_loss(acc_map[index_off_surface], torch.zeros_like(acc_map[index_off_surface]))
|
35 |
+
return opacity_sparse_loss
|
36 |
+
|
37 |
+
# Optional: This loss helps to stablize the training in the very beginning
|
38 |
+
def get_in_shape_loss(self, acc_map, index_in_surface):
|
39 |
+
in_shape_loss = self.l1_loss(acc_map[index_in_surface], torch.ones_like(acc_map[index_in_surface]))
|
40 |
+
return in_shape_loss
|
41 |
+
|
42 |
+
def forward(self, model_outputs, ground_truth):
|
43 |
+
nan_filter = ~torch.any(model_outputs['rgb_values'].isnan(), dim=1)
|
44 |
+
rgb_gt = ground_truth['rgb'][0].cuda()
|
45 |
+
rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'][nan_filter], rgb_gt[nan_filter])
|
46 |
+
eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
|
47 |
+
bce_loss = self.get_bce_loss(model_outputs['acc_map'])
|
48 |
+
opacity_sparse_loss = self.get_opacity_sparse_loss(model_outputs['acc_map'], model_outputs['index_off_surface'])
|
49 |
+
in_shape_loss = self.get_in_shape_loss(model_outputs['acc_map'], model_outputs['index_in_surface'])
|
50 |
+
curr_epoch_for_loss = min(self.milestone, model_outputs['epoch']) # will not increase after the milestone
|
51 |
+
|
52 |
+
loss = rgb_loss + \
|
53 |
+
self.eikonal_weight * eikonal_loss + \
|
54 |
+
self.bce_weight * bce_loss + \
|
55 |
+
self.opacity_sparse_weight * (1 + curr_epoch_for_loss ** 2 / 40) * opacity_sparse_loss + \
|
56 |
+
self.in_shape_weight * (1 - curr_epoch_for_loss / self.milestone) * in_shape_loss
|
57 |
+
return {
|
58 |
+
'loss': loss,
|
59 |
+
'rgb_loss': rgb_loss,
|
60 |
+
'eikonal_loss': eikonal_loss,
|
61 |
+
'bce_loss': bce_loss,
|
62 |
+
'opacity_sparse_loss': opacity_sparse_loss,
|
63 |
+
'in_shape_loss': in_shape_loss,
|
64 |
+
}
|
code/lib/model/networks.py
ADDED
@@ -0,0 +1,178 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from .embedders import get_embedder
|
5 |
+
|
6 |
+
class ImplicitNet(nn.Module):
|
7 |
+
def __init__(self, opt):
|
8 |
+
super().__init__()
|
9 |
+
|
10 |
+
dims = [opt.d_in] + list(
|
11 |
+
opt.dims) + [opt.d_out + opt.feature_vector_size]
|
12 |
+
self.num_layers = len(dims)
|
13 |
+
self.skip_in = opt.skip_in
|
14 |
+
self.embed_fn = None
|
15 |
+
self.opt = opt
|
16 |
+
|
17 |
+
if opt.multires > 0:
|
18 |
+
embed_fn, input_ch = get_embedder(opt.multires, input_dims=opt.d_in, mode=opt.embedder_mode)
|
19 |
+
self.embed_fn = embed_fn
|
20 |
+
dims[0] = input_ch
|
21 |
+
self.cond = opt.cond
|
22 |
+
if self.cond == 'smpl':
|
23 |
+
self.cond_layer = [0]
|
24 |
+
self.cond_dim = 69
|
25 |
+
elif self.cond == 'frame':
|
26 |
+
self.cond_layer = [0]
|
27 |
+
self.cond_dim = opt.dim_frame_encoding
|
28 |
+
self.dim_pose_embed = 0
|
29 |
+
if self.dim_pose_embed > 0:
|
30 |
+
self.lin_p0 = nn.Linear(self.cond_dim, self.dim_pose_embed)
|
31 |
+
self.cond_dim = self.dim_pose_embed
|
32 |
+
for l in range(0, self.num_layers - 1):
|
33 |
+
if l + 1 in self.skip_in:
|
34 |
+
out_dim = dims[l + 1] - dims[0]
|
35 |
+
else:
|
36 |
+
out_dim = dims[l + 1]
|
37 |
+
|
38 |
+
if self.cond != 'none' and l in self.cond_layer:
|
39 |
+
lin = nn.Linear(dims[l] + self.cond_dim, out_dim)
|
40 |
+
else:
|
41 |
+
lin = nn.Linear(dims[l], out_dim)
|
42 |
+
if opt.init == 'geometry':
|
43 |
+
if l == self.num_layers - 2:
|
44 |
+
torch.nn.init.normal_(lin.weight,
|
45 |
+
mean=np.sqrt(np.pi) /
|
46 |
+
np.sqrt(dims[l]),
|
47 |
+
std=0.0001)
|
48 |
+
torch.nn.init.constant_(lin.bias, -opt.bias)
|
49 |
+
elif opt.multires > 0 and l == 0:
|
50 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
51 |
+
torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
|
52 |
+
torch.nn.init.normal_(lin.weight[:, :3], 0.0,
|
53 |
+
np.sqrt(2) / np.sqrt(out_dim))
|
54 |
+
elif opt.multires > 0 and l in self.skip_in:
|
55 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
56 |
+
torch.nn.init.normal_(lin.weight, 0.0,
|
57 |
+
np.sqrt(2) / np.sqrt(out_dim))
|
58 |
+
torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):],
|
59 |
+
0.0)
|
60 |
+
else:
|
61 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
62 |
+
torch.nn.init.normal_(lin.weight, 0.0,
|
63 |
+
np.sqrt(2) / np.sqrt(out_dim))
|
64 |
+
if opt.init == 'zero':
|
65 |
+
init_val = 1e-5
|
66 |
+
if l == self.num_layers - 2:
|
67 |
+
torch.nn.init.constant_(lin.bias, 0.0)
|
68 |
+
torch.nn.init.uniform_(lin.weight, -init_val, init_val)
|
69 |
+
if opt.weight_norm:
|
70 |
+
lin = nn.utils.weight_norm(lin)
|
71 |
+
setattr(self, "lin" + str(l), lin)
|
72 |
+
self.softplus = nn.Softplus(beta=100)
|
73 |
+
|
74 |
+
def forward(self, input, cond, current_epoch=None):
|
75 |
+
if input.ndim == 2: input = input.unsqueeze(0)
|
76 |
+
|
77 |
+
num_batch, num_point, num_dim = input.shape
|
78 |
+
|
79 |
+
if num_batch * num_point == 0: return input
|
80 |
+
|
81 |
+
input = input.reshape(num_batch * num_point, num_dim)
|
82 |
+
|
83 |
+
if self.cond != 'none':
|
84 |
+
num_batch, num_cond = cond[self.cond].shape
|
85 |
+
|
86 |
+
input_cond = cond[self.cond].unsqueeze(1).expand(num_batch, num_point, num_cond)
|
87 |
+
|
88 |
+
input_cond = input_cond.reshape(num_batch * num_point, num_cond)
|
89 |
+
|
90 |
+
if self.dim_pose_embed:
|
91 |
+
input_cond = self.lin_p0(input_cond)
|
92 |
+
|
93 |
+
if self.embed_fn is not None:
|
94 |
+
input = self.embed_fn(input)
|
95 |
+
|
96 |
+
x = input
|
97 |
+
|
98 |
+
for l in range(0, self.num_layers - 1):
|
99 |
+
lin = getattr(self, "lin" + str(l))
|
100 |
+
if self.cond != 'none' and l in self.cond_layer:
|
101 |
+
x = torch.cat([x, input_cond], dim=-1)
|
102 |
+
if l in self.skip_in:
|
103 |
+
x = torch.cat([x, input], 1) / np.sqrt(2)
|
104 |
+
x = lin(x)
|
105 |
+
if l < self.num_layers - 2:
|
106 |
+
x = self.softplus(x)
|
107 |
+
|
108 |
+
x = x.reshape(num_batch, num_point, -1)
|
109 |
+
|
110 |
+
return x
|
111 |
+
|
112 |
+
def gradient(self, x, cond):
|
113 |
+
x.requires_grad_(True)
|
114 |
+
y = self.forward(x, cond)[:, :1]
|
115 |
+
d_output = torch.ones_like(y, requires_grad=False, device=y.device)
|
116 |
+
gradients = torch.autograd.grad(outputs=y,
|
117 |
+
inputs=x,
|
118 |
+
grad_outputs=d_output,
|
119 |
+
create_graph=True,
|
120 |
+
retain_graph=True,
|
121 |
+
only_inputs=True)[0]
|
122 |
+
return gradients.unsqueeze(1)
|
123 |
+
|
124 |
+
|
125 |
+
class RenderingNet(nn.Module):
|
126 |
+
def __init__(self, opt):
|
127 |
+
super().__init__()
|
128 |
+
|
129 |
+
self.mode = opt.mode
|
130 |
+
dims = [opt.d_in + opt.feature_vector_size] + list(
|
131 |
+
opt.dims) + [opt.d_out]
|
132 |
+
|
133 |
+
self.embedview_fn = None
|
134 |
+
if opt.multires_view > 0:
|
135 |
+
embedview_fn, input_ch = get_embedder(opt.multires_view)
|
136 |
+
self.embedview_fn = embedview_fn
|
137 |
+
dims[0] += (input_ch - 3)
|
138 |
+
if self.mode == 'nerf_frame_encoding':
|
139 |
+
dims[0] += opt.dim_frame_encoding
|
140 |
+
if self.mode == 'pose':
|
141 |
+
self.dim_cond_embed = 8
|
142 |
+
self.cond_dim = 69 # dimension of the body pose, global orientation excluded.
|
143 |
+
# lower the condition dimension
|
144 |
+
self.lin_pose = torch.nn.Linear(self.cond_dim, self.dim_cond_embed)
|
145 |
+
self.num_layers = len(dims)
|
146 |
+
for l in range(0, self.num_layers - 1):
|
147 |
+
out_dim = dims[l + 1]
|
148 |
+
lin = nn.Linear(dims[l], out_dim)
|
149 |
+
if opt.weight_norm:
|
150 |
+
lin = nn.utils.weight_norm(lin)
|
151 |
+
setattr(self, "lin" + str(l), lin)
|
152 |
+
self.relu = nn.ReLU()
|
153 |
+
self.sigmoid = nn.Sigmoid()
|
154 |
+
|
155 |
+
def forward(self, points, normals, view_dirs, body_pose, feature_vectors, frame_latent_code=None):
|
156 |
+
if self.embedview_fn is not None:
|
157 |
+
if self.mode == 'nerf_frame_encoding':
|
158 |
+
view_dirs = self.embedview_fn(view_dirs)
|
159 |
+
|
160 |
+
if self.mode == 'nerf_frame_encoding':
|
161 |
+
frame_latent_code = frame_latent_code.expand(view_dirs.shape[0], -1)
|
162 |
+
rendering_input = torch.cat([view_dirs, frame_latent_code, feature_vectors], dim=-1)
|
163 |
+
elif self.mode == 'pose':
|
164 |
+
num_points = points.shape[0]
|
165 |
+
body_pose = body_pose.unsqueeze(1).expand(-1, num_points, -1).reshape(num_points, -1)
|
166 |
+
body_pose = self.lin_pose(body_pose)
|
167 |
+
rendering_input = torch.cat([points, normals, body_pose, feature_vectors], dim=-1)
|
168 |
+
else:
|
169 |
+
raise NotImplementedError
|
170 |
+
|
171 |
+
x = rendering_input
|
172 |
+
for l in range(0, self.num_layers - 1):
|
173 |
+
lin = getattr(self, "lin" + str(l))
|
174 |
+
x = lin(x)
|
175 |
+
if l < self.num_layers - 2:
|
176 |
+
x = self.relu(x)
|
177 |
+
x = self.sigmoid(x)
|
178 |
+
return x
|
code/lib/model/ray_sampler.py
ADDED
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import abc
|
2 |
+
import torch
|
3 |
+
from lib.utils import utils
|
4 |
+
|
5 |
+
class RaySampler(metaclass=abc.ABCMeta):
|
6 |
+
def __init__(self,near, far):
|
7 |
+
self.near = near
|
8 |
+
self.far = far
|
9 |
+
|
10 |
+
@abc.abstractmethod
|
11 |
+
def get_z_vals(self, ray_dirs, cam_loc, model):
|
12 |
+
pass
|
13 |
+
|
14 |
+
class UniformSampler(RaySampler):
|
15 |
+
"""Samples uniformly in the range [near, far]
|
16 |
+
"""
|
17 |
+
def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
|
18 |
+
super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
|
19 |
+
self.N_samples = N_samples
|
20 |
+
self.scene_bounding_sphere = scene_bounding_sphere
|
21 |
+
self.take_sphere_intersection = take_sphere_intersection
|
22 |
+
|
23 |
+
def get_z_vals(self, ray_dirs, cam_loc, model):
|
24 |
+
if not self.take_sphere_intersection:
|
25 |
+
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
|
26 |
+
else:
|
27 |
+
sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
|
28 |
+
near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
|
29 |
+
far = sphere_intersections[:,1:]
|
30 |
+
|
31 |
+
t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
|
32 |
+
z_vals = near * (1. - t_vals) + far * (t_vals)
|
33 |
+
|
34 |
+
if model.training:
|
35 |
+
# get intervals between samples
|
36 |
+
mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
|
37 |
+
upper = torch.cat([mids, z_vals[..., -1:]], -1)
|
38 |
+
lower = torch.cat([z_vals[..., :1], mids], -1)
|
39 |
+
# stratified samples in those intervals
|
40 |
+
t_rand = torch.rand(z_vals.shape).cuda()
|
41 |
+
|
42 |
+
z_vals = lower + (upper - lower) * t_rand
|
43 |
+
|
44 |
+
return z_vals
|
45 |
+
|
46 |
+
class ErrorBoundSampler(RaySampler):
|
47 |
+
def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
|
48 |
+
eps, beta_iters, max_total_iters,
|
49 |
+
inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
|
50 |
+
super().__init__(near, 2.0 * scene_bounding_sphere)
|
51 |
+
self.N_samples = N_samples
|
52 |
+
self.N_samples_eval = N_samples_eval
|
53 |
+
self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)
|
54 |
+
|
55 |
+
self.N_samples_extra = N_samples_extra
|
56 |
+
|
57 |
+
self.eps = eps
|
58 |
+
self.beta_iters = beta_iters
|
59 |
+
self.max_total_iters = max_total_iters
|
60 |
+
self.scene_bounding_sphere = scene_bounding_sphere
|
61 |
+
self.add_tiny = add_tiny
|
62 |
+
|
63 |
+
self.inverse_sphere_bg = inverse_sphere_bg
|
64 |
+
if inverse_sphere_bg:
|
65 |
+
N_samples_inverse_sphere = 32
|
66 |
+
self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
|
67 |
+
|
68 |
+
def get_z_vals(self, ray_dirs, cam_loc, model, cond, smpl_tfs, eval_mode, smpl_verts):
|
69 |
+
beta0 = model.density.get_beta().detach()
|
70 |
+
|
71 |
+
# Start with uniform sampling
|
72 |
+
z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
|
73 |
+
samples, samples_idx = z_vals, None
|
74 |
+
|
75 |
+
# Get maximum beta from the upper bound (Lemma 2)
|
76 |
+
dists = z_vals[:, 1:] - z_vals[:, :-1]
|
77 |
+
bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
|
78 |
+
beta = torch.sqrt(bound)
|
79 |
+
|
80 |
+
total_iters, not_converge = 0, True
|
81 |
+
|
82 |
+
# VolSDF Algorithm 1
|
83 |
+
while not_converge and total_iters < self.max_total_iters:
|
84 |
+
points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
|
85 |
+
points_flat = points.reshape(-1, 3)
|
86 |
+
# Calculating the SDF only for the new sampled points
|
87 |
+
model.implicit_network.eval()
|
88 |
+
with torch.no_grad():
|
89 |
+
samples_sdf = model.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_verts=smpl_verts)[0]
|
90 |
+
model.implicit_network.train()
|
91 |
+
if samples_idx is not None:
|
92 |
+
sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
|
93 |
+
samples_sdf.reshape(-1, samples.shape[1])], -1)
|
94 |
+
sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
|
95 |
+
else:
|
96 |
+
sdf = samples_sdf
|
97 |
+
|
98 |
+
|
99 |
+
# Calculating the bound d* (Theorem 1)
|
100 |
+
d = sdf.reshape(z_vals.shape)
|
101 |
+
dists = z_vals[:, 1:] - z_vals[:, :-1]
|
102 |
+
a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
|
103 |
+
first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
|
104 |
+
second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
|
105 |
+
d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
|
106 |
+
d_star[first_cond] = b[first_cond]
|
107 |
+
d_star[second_cond] = c[second_cond]
|
108 |
+
s = (a + b + c) / 2.0
|
109 |
+
area_before_sqrt = s * (s - a) * (s - b) * (s - c)
|
110 |
+
mask = ~first_cond & ~second_cond & (b + c - a > 0)
|
111 |
+
d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
|
112 |
+
d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
|
113 |
+
|
114 |
+
|
115 |
+
# Updating beta using line search
|
116 |
+
curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
|
117 |
+
beta[curr_error <= self.eps] = beta0
|
118 |
+
beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
|
119 |
+
for j in range(self.beta_iters):
|
120 |
+
beta_mid = (beta_min + beta_max) / 2.
|
121 |
+
curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
|
122 |
+
beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
|
123 |
+
beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
|
124 |
+
beta = beta_max
|
125 |
+
|
126 |
+
|
127 |
+
# Upsample more points
|
128 |
+
density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
|
129 |
+
|
130 |
+
dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
|
131 |
+
free_energy = dists * density
|
132 |
+
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
|
133 |
+
alpha = 1 - torch.exp(-free_energy)
|
134 |
+
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
|
135 |
+
weights = alpha * transmittance # probability of the ray hits something here
|
136 |
+
|
137 |
+
# Check if we are done and this is the last sampling
|
138 |
+
total_iters += 1
|
139 |
+
not_converge = beta.max() > beta0
|
140 |
+
|
141 |
+
if not_converge and total_iters < self.max_total_iters:
|
142 |
+
''' Sample more points proportional to the current error bound'''
|
143 |
+
|
144 |
+
N = self.N_samples_eval
|
145 |
+
|
146 |
+
bins = z_vals
|
147 |
+
error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
|
148 |
+
error_integral = torch.cumsum(error_per_section, dim=-1)
|
149 |
+
bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
|
150 |
+
|
151 |
+
pdf = bound_opacity + self.add_tiny
|
152 |
+
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
|
153 |
+
cdf = torch.cumsum(pdf, -1)
|
154 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
|
155 |
+
|
156 |
+
else:
|
157 |
+
''' Sample the final sample set to be used in the volume rendering integral '''
|
158 |
+
|
159 |
+
N = self.N_samples
|
160 |
+
|
161 |
+
bins = z_vals
|
162 |
+
pdf = weights[..., :-1]
|
163 |
+
pdf = pdf + 1e-5 # prevent nans
|
164 |
+
pdf = pdf / torch.sum(pdf, -1, keepdim=True)
|
165 |
+
cdf = torch.cumsum(pdf, -1)
|
166 |
+
cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
|
167 |
+
|
168 |
+
|
169 |
+
# Invert CDF
|
170 |
+
if (not_converge and total_iters < self.max_total_iters) or (not model.training):
|
171 |
+
u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
|
172 |
+
else:
|
173 |
+
u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
|
174 |
+
u = u.contiguous()
|
175 |
+
|
176 |
+
inds = torch.searchsorted(cdf, u, right=True)
|
177 |
+
below = torch.max(torch.zeros_like(inds - 1), inds - 1)
|
178 |
+
above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
|
179 |
+
inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
|
180 |
+
|
181 |
+
matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
|
182 |
+
cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
183 |
+
bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
|
184 |
+
|
185 |
+
denom = (cdf_g[..., 1] - cdf_g[..., 0])
|
186 |
+
denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
|
187 |
+
t = (u - cdf_g[..., 0]) / denom
|
188 |
+
samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
|
189 |
+
|
190 |
+
|
191 |
+
# Adding samples if we not converged
|
192 |
+
if not_converge and total_iters < self.max_total_iters:
|
193 |
+
z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
|
194 |
+
|
195 |
+
|
196 |
+
z_samples = samples
|
197 |
+
|
198 |
+
near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
|
199 |
+
if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
|
200 |
+
far = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
|
201 |
+
|
202 |
+
if self.N_samples_extra > 0:
|
203 |
+
if model.training:
|
204 |
+
sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
|
205 |
+
else:
|
206 |
+
sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
|
207 |
+
z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
|
208 |
+
else:
|
209 |
+
z_vals_extra = torch.cat([near, far], -1)
|
210 |
+
|
211 |
+
z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
|
212 |
+
|
213 |
+
# add some of the near surface points
|
214 |
+
idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
|
215 |
+
z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
|
216 |
+
|
217 |
+
if self.inverse_sphere_bg:
|
218 |
+
z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
|
219 |
+
z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
|
220 |
+
z_vals = (z_vals, z_vals_inverse_sphere)
|
221 |
+
|
222 |
+
return z_vals, z_samples_eik
|
223 |
+
|
224 |
+
def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
|
225 |
+
density = model.density(sdf.reshape(z_vals.shape), beta=beta)
|
226 |
+
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
|
227 |
+
integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
|
228 |
+
error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
|
229 |
+
error_integral = torch.cumsum(error_per_section, dim=-1)
|
230 |
+
bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
|
231 |
+
|
232 |
+
return bound_opacity.max(-1)[0]
|
233 |
+
|
234 |
+
|
code/lib/model/sampler.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
|
4 |
+
class PointInSpace:
|
5 |
+
def __init__(self, global_sigma=0.5, local_sigma=0.01):
|
6 |
+
self.global_sigma = global_sigma
|
7 |
+
self.local_sigma = local_sigma
|
8 |
+
|
9 |
+
def get_points(self, pc_input=None, local_sigma=None, global_ratio=0.125):
|
10 |
+
"""Sample one point near each of the given point + 1/8 uniformly.
|
11 |
+
Args:
|
12 |
+
pc_input (tensor): sampling centers. shape: [B, N, D]
|
13 |
+
Returns:
|
14 |
+
samples (tensor): sampled points. shape: [B, N + N / 8, D]
|
15 |
+
"""
|
16 |
+
|
17 |
+
batch_size, sample_size, dim = pc_input.shape
|
18 |
+
if local_sigma is None:
|
19 |
+
sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
|
20 |
+
else:
|
21 |
+
sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma)
|
22 |
+
sample_global = (
|
23 |
+
torch.rand(batch_size, int(sample_size * global_ratio), dim, device=pc_input.device)
|
24 |
+
* (self.global_sigma * 2)
|
25 |
+
) - self.global_sigma
|
26 |
+
|
27 |
+
sample = torch.cat([sample_local, sample_global], dim=1)
|
28 |
+
|
29 |
+
return sample
|
code/lib/model/smpl.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import hydra
|
3 |
+
import numpy as np
|
4 |
+
from ..smpl.body_models import SMPL
|
5 |
+
|
6 |
+
class SMPLServer(torch.nn.Module):
|
7 |
+
|
8 |
+
def __init__(self, gender='neutral', betas=None, v_template=None):
|
9 |
+
super().__init__()
|
10 |
+
|
11 |
+
|
12 |
+
self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'),
|
13 |
+
gender=gender,
|
14 |
+
batch_size=1,
|
15 |
+
use_hands=False,
|
16 |
+
use_feet_keypoints=False,
|
17 |
+
dtype=torch.float32).cuda()
|
18 |
+
|
19 |
+
self.bone_parents = self.smpl.bone_parents.astype(int)
|
20 |
+
self.bone_parents[0] = -1
|
21 |
+
self.bone_ids = []
|
22 |
+
self.faces = self.smpl.faces
|
23 |
+
for i in range(24): self.bone_ids.append([self.bone_parents[i], i])
|
24 |
+
|
25 |
+
if v_template is not None:
|
26 |
+
self.v_template = torch.tensor(v_template).float().cuda()
|
27 |
+
else:
|
28 |
+
self.v_template = None
|
29 |
+
|
30 |
+
if betas is not None:
|
31 |
+
self.betas = torch.tensor(betas).float().cuda()
|
32 |
+
else:
|
33 |
+
self.betas = None
|
34 |
+
|
35 |
+
# define the canonical pose
|
36 |
+
param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda()
|
37 |
+
param_canonical[0, 0] = 1
|
38 |
+
param_canonical[0, 9] = np.pi / 6
|
39 |
+
param_canonical[0, 12] = -np.pi / 6
|
40 |
+
if self.betas is not None and self.v_template is None:
|
41 |
+
param_canonical[0,-10:] = self.betas
|
42 |
+
self.param_canonical = param_canonical
|
43 |
+
|
44 |
+
output = self.forward(*torch.split(self.param_canonical, [1, 3, 72, 10], dim=1), absolute=True)
|
45 |
+
self.verts_c = output['smpl_verts']
|
46 |
+
self.joints_c = output['smpl_jnts']
|
47 |
+
self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse()
|
48 |
+
|
49 |
+
|
50 |
+
def forward(self, scale, transl, thetas, betas, absolute=False):
|
51 |
+
"""return SMPL output from params
|
52 |
+
Args:
|
53 |
+
scale : scale factor. shape: [B, 1]
|
54 |
+
transl: translation. shape: [B, 3]
|
55 |
+
thetas: pose. shape: [B, 72]
|
56 |
+
betas: shape. shape: [B, 10]
|
57 |
+
absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical.
|
58 |
+
Returns:
|
59 |
+
smpl_verts: vertices. shape: [B, 6893. 3]
|
60 |
+
smpl_tfs: bone transformations. shape: [B, 24, 4, 4]
|
61 |
+
smpl_jnts: joint positions. shape: [B, 25, 3]
|
62 |
+
"""
|
63 |
+
|
64 |
+
output = {}
|
65 |
+
|
66 |
+
# ignore betas if v_template is provided
|
67 |
+
if self.v_template is not None:
|
68 |
+
betas = torch.zeros_like(betas)
|
69 |
+
|
70 |
+
|
71 |
+
smpl_output = self.smpl.forward(betas=betas,
|
72 |
+
transl=torch.zeros_like(transl),
|
73 |
+
body_pose=thetas[:, 3:],
|
74 |
+
global_orient=thetas[:, :3],
|
75 |
+
return_verts=True,
|
76 |
+
return_full_pose=True,
|
77 |
+
v_template=self.v_template)
|
78 |
+
|
79 |
+
verts = smpl_output.vertices.clone()
|
80 |
+
output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1)
|
81 |
+
|
82 |
+
joints = smpl_output.joints.clone()
|
83 |
+
output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1)
|
84 |
+
|
85 |
+
tf_mats = smpl_output.T.clone()
|
86 |
+
tf_mats[:, :, :3, :] = tf_mats[:, :, :3, :] * scale.unsqueeze(1).unsqueeze(1)
|
87 |
+
tf_mats[:, :, :3, 3] = tf_mats[:, :, :3, 3] + transl.unsqueeze(1) * scale.unsqueeze(1)
|
88 |
+
|
89 |
+
if not absolute:
|
90 |
+
tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv)
|
91 |
+
|
92 |
+
output['smpl_tfs'] = tf_mats
|
93 |
+
output['smpl_weights'] = smpl_output.weights
|
94 |
+
return output
|
code/lib/model/v2a.py
ADDED
@@ -0,0 +1,368 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .networks import ImplicitNet, RenderingNet
|
2 |
+
from .density import LaplaceDensity, AbsDensity
|
3 |
+
from .ray_sampler import ErrorBoundSampler
|
4 |
+
from .deformer import SMPLDeformer
|
5 |
+
from .smpl import SMPLServer
|
6 |
+
|
7 |
+
from .sampler import PointInSpace
|
8 |
+
|
9 |
+
from ..utils import utils
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torch
|
13 |
+
import torch.nn as nn
|
14 |
+
from torch.autograd import grad
|
15 |
+
import hydra
|
16 |
+
import kaolin
|
17 |
+
from kaolin.ops.mesh import index_vertices_by_faces
|
18 |
+
class V2A(nn.Module):
|
19 |
+
def __init__(self, opt, betas_path, gender, num_training_frames):
|
20 |
+
super().__init__()
|
21 |
+
|
22 |
+
# Foreground networks
|
23 |
+
self.implicit_network = ImplicitNet(opt.implicit_network)
|
24 |
+
self.rendering_network = RenderingNet(opt.rendering_network)
|
25 |
+
|
26 |
+
# Background networks
|
27 |
+
self.bg_implicit_network = ImplicitNet(opt.bg_implicit_network)
|
28 |
+
self.bg_rendering_network = RenderingNet(opt.bg_rendering_network)
|
29 |
+
|
30 |
+
# Frame latent encoder
|
31 |
+
self.frame_latent_encoder = nn.Embedding(num_training_frames, opt.bg_rendering_network.dim_frame_encoding)
|
32 |
+
self.sampler = PointInSpace()
|
33 |
+
|
34 |
+
betas = np.load(betas_path)
|
35 |
+
self.use_smpl_deformer = opt.use_smpl_deformer
|
36 |
+
self.gender = gender
|
37 |
+
if self.use_smpl_deformer:
|
38 |
+
self.deformer = SMPLDeformer(betas=betas, gender=self.gender)
|
39 |
+
|
40 |
+
# pre-defined bounding sphere
|
41 |
+
self.sdf_bounding_sphere = 3.0
|
42 |
+
|
43 |
+
# threshold for the out-surface points
|
44 |
+
self.threshold = 0.05
|
45 |
+
|
46 |
+
self.density = LaplaceDensity(**opt.density)
|
47 |
+
self.bg_density = AbsDensity()
|
48 |
+
|
49 |
+
self.ray_sampler = ErrorBoundSampler(self.sdf_bounding_sphere, inverse_sphere_bg=True, **opt.ray_sampler)
|
50 |
+
self.smpl_server = SMPLServer(gender=self.gender, betas=betas)
|
51 |
+
|
52 |
+
if opt.smpl_init:
|
53 |
+
smpl_model_state = torch.load(hydra.utils.to_absolute_path('../assets/smpl_init.pth'))
|
54 |
+
self.implicit_network.load_state_dict(smpl_model_state["model_state_dict"])
|
55 |
+
|
56 |
+
self.smpl_v_cano = self.smpl_server.verts_c
|
57 |
+
self.smpl_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device)
|
58 |
+
|
59 |
+
self.mesh_v_cano = self.smpl_server.verts_c
|
60 |
+
self.mesh_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device)
|
61 |
+
self.mesh_face_vertices = index_vertices_by_faces(self.mesh_v_cano, self.mesh_f_cano)
|
62 |
+
|
63 |
+
def sdf_func_with_smpl_deformer(self, x, cond, smpl_tfs, smpl_verts):
|
64 |
+
""" sdf_func_with_smpl_deformer method
|
65 |
+
Used to compute SDF values for input points using the SMPL deformer and the implicit network.
|
66 |
+
It handles the deforming of points, network inference, feature extraction, and handling of outlier points.
|
67 |
+
"""
|
68 |
+
if hasattr(self, "deformer"):
|
69 |
+
x_c, outlier_mask = self.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts)
|
70 |
+
output = self.implicit_network(x_c, cond)[0]
|
71 |
+
sdf = output[:, 0:1]
|
72 |
+
feature = output[:, 1:]
|
73 |
+
if not self.training:
|
74 |
+
sdf[outlier_mask] = 4. # set a large SDF value for outlier points
|
75 |
+
|
76 |
+
return sdf, x_c, feature
|
77 |
+
|
78 |
+
def check_off_in_surface_points_cano_mesh(self, x_cano, N_samples, threshold=0.05):
|
79 |
+
"""check_off_in_surface_points_cano_mesh method
|
80 |
+
Used to check whether points are off the surface or within the surface of a canonical mesh.
|
81 |
+
It calculates distances, signs, and signed distances to determine the position of points with respect to the mesh surface.
|
82 |
+
The method plays a role in identifying points that might be considered outliers or outside the reconstructed avatar's surface.
|
83 |
+
"""
|
84 |
+
|
85 |
+
distance, _, _ = kaolin.metrics.trianglemesh.point_to_mesh_distance(x_cano.unsqueeze(0).contiguous(), self.mesh_face_vertices)
|
86 |
+
|
87 |
+
distance = torch.sqrt(distance) # kaolin outputs squared distance
|
88 |
+
sign = kaolin.ops.mesh.check_sign(self.mesh_v_cano, self.mesh_f_cano, x_cano.unsqueeze(0)).float()
|
89 |
+
sign = 1 - 2 * sign # -1 for off-surface, 1 for in-surface
|
90 |
+
signed_distance = sign * distance
|
91 |
+
batch_size = x_cano.shape[0] // N_samples
|
92 |
+
signed_distance = signed_distance.reshape(batch_size, N_samples, 1) # The distances are reshaped to match the batch size and the number of samples
|
93 |
+
|
94 |
+
minimum = torch.min(signed_distance, 1)[0]
|
95 |
+
index_off_surface = (minimum > threshold).squeeze(1)
|
96 |
+
index_in_surface = (minimum <= 0.).squeeze(1)
|
97 |
+
return index_off_surface, index_in_surface # Indexes of off-surface points and in-surface points
|
98 |
+
|
99 |
+
def forward(self, input):
|
100 |
+
# Parse model input, prepares the necessary input data and SMPL parameters for subsequent calculations
|
101 |
+
torch.set_grad_enabled(True)
|
102 |
+
intrinsics = input["intrinsics"]
|
103 |
+
pose = input["pose"]
|
104 |
+
uv = input["uv"]
|
105 |
+
|
106 |
+
scale = input['smpl_params'][:, 0]
|
107 |
+
smpl_pose = input["smpl_pose"]
|
108 |
+
smpl_shape = input["smpl_shape"]
|
109 |
+
smpl_trans = input["smpl_trans"]
|
110 |
+
smpl_output = self.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape) # invokes the SMPL model to obtain the transformations for pose and shape changes
|
111 |
+
|
112 |
+
smpl_tfs = smpl_output['smpl_tfs']
|
113 |
+
|
114 |
+
cond = {'smpl': smpl_pose[:, 3:]/np.pi}
|
115 |
+
if self.training:
|
116 |
+
if input['current_epoch'] < 20 or input['current_epoch'] % 20 == 0: # set the pose to zero for the first 20 epochs
|
117 |
+
cond = {'smpl': smpl_pose[:, 3:] * 0.}
|
118 |
+
ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics) # get the ray directions and camera location
|
119 |
+
batch_size, num_pixels, _ = ray_dirs.shape
|
120 |
+
|
121 |
+
cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) # reshape to match the batch size and the number of pixels
|
122 |
+
ray_dirs = ray_dirs.reshape(-1, 3) # reshape to match the batch size and the number of pixels
|
123 |
+
|
124 |
+
z_vals, _ = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self, cond, smpl_tfs, eval_mode=True, smpl_verts=smpl_output['smpl_verts']) # get the z values for each pixel
|
125 |
+
|
126 |
+
z_vals, z_vals_bg = z_vals # unpack the z values for the foreground and the background
|
127 |
+
z_max = z_vals[:,-1] # get the maximum z value
|
128 |
+
z_vals = z_vals[:,:-1] # get the z values for the foreground
|
129 |
+
N_samples = z_vals.shape[1] # get the number of samples
|
130 |
+
|
131 |
+
points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # 3D points along the rays are calculated by adding z_vals scaled by ray directions to the camera location. The result is stored in the points tensor of shape (batch_size * num_pixels, N_samples, 3)
|
132 |
+
points_flat = points.reshape(-1, 3) # The points tensor is reshaped into a flattened tensor points_flat of shape (batch_size * num_pixels * N_samples, 3)
|
133 |
+
|
134 |
+
dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) # The dirs tensor is created by repeating ray_dirs for each sample along the rays. The resulting tensor has shape (batch_size * num_pixels, N_samples, 3)
|
135 |
+
sdf_output, canonical_points, feature_vectors = self.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_output['smpl_verts']) # The sdf_func_with_smpl_deformer method is called to compute the signed distance functions (SDF) for the points
|
136 |
+
|
137 |
+
sdf_output = sdf_output.unsqueeze(1) # The sdf_output tensor is reshaped by unsqueezing along the first dimension
|
138 |
+
|
139 |
+
if self.training:
|
140 |
+
index_off_surface, index_in_surface = self.check_off_in_surface_points_cano_mesh(canonical_points, N_samples, threshold=self.threshold)
|
141 |
+
canonical_points = canonical_points.reshape(num_pixels, N_samples, 3)
|
142 |
+
|
143 |
+
canonical_points = canonical_points.reshape(-1, 3) # The canonical points tensor flattened to shape (-1, 3)
|
144 |
+
|
145 |
+
# sample canonical SMPL surface pnts for the eikonal loss
|
146 |
+
smpl_verts_c = self.smpl_server.verts_c.repeat(batch_size, 1,1) # The canonical SMPL surface vertices are repeated across the batch dimension
|
147 |
+
|
148 |
+
indices = torch.randperm(smpl_verts_c.shape[1])[:num_pixels].cuda() # Random indices are generated to select a subset of vertices for sampling. The number of selected vertices is num_pixels
|
149 |
+
verts_c = torch.index_select(smpl_verts_c, 1, indices) # The selected vertices are gathered from smpl_verts_c, resulting in the tensor verts_c.
|
150 |
+
sample = self.sampler.get_points(verts_c, global_ratio=0.) # The get_points method of the sampler class is called to sample points around the canonical SMPL surface points. The global_ratio is set to 0.0, indicating local sampling
|
151 |
+
|
152 |
+
sample.requires_grad_() # The sampled points are marked as requiring gradients
|
153 |
+
local_pred = self.implicit_network(sample, cond)[..., 0:1] # The sampled points (sample) are passed through the implicit network along with the conditioning (cond). The local prediction (SDF) for each sampled point is extracted using [..., 0:1]
|
154 |
+
grad_theta = gradient(sample, local_pred) # compute gradients with respect to the sampled points and their local predictions (local_pred).
|
155 |
+
|
156 |
+
differentiable_points = canonical_points # The differentiable_points tensor is assigned the value of canonical_points
|
157 |
+
|
158 |
+
else:
|
159 |
+
differentiable_points = canonical_points.reshape(num_pixels, N_samples, 3).reshape(-1, 3)
|
160 |
+
grad_theta = None
|
161 |
+
|
162 |
+
sdf_output = sdf_output.reshape(num_pixels, N_samples, 1).reshape(-1, 1) # flattened to shape (num_pixels * N_samples, )
|
163 |
+
z_vals = z_vals
|
164 |
+
view = -dirs.reshape(-1, 3) # The view vector is calculated as the negation of the reshaped dirs, giving the view directions for points along the rays.
|
165 |
+
|
166 |
+
if differentiable_points.shape[0] > 0: # If there are differentiable points (indicating that gradient information is available)
|
167 |
+
fg_rgb_flat, others = self.get_rbg_value(points_flat, differentiable_points, view,
|
168 |
+
cond, smpl_tfs, feature_vectors=feature_vectors, is_training=self.training) # The returned values include fg_rgb_flat (foreground RGB values) and others (other calculated values, including normals)
|
169 |
+
normal_values = others['normals'] # The normal values are extracted from the others dictionary
|
170 |
+
|
171 |
+
if 'image_id' in input.keys():
|
172 |
+
frame_latent_code = self.frame_latent_encoder(input['image_id'])
|
173 |
+
else:
|
174 |
+
frame_latent_code = self.frame_latent_encoder(input['idx'])
|
175 |
+
|
176 |
+
fg_rgb = fg_rgb_flat.reshape(-1, N_samples, 3)
|
177 |
+
normal_values = normal_values.reshape(-1, N_samples, 3)
|
178 |
+
weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf_output)
|
179 |
+
|
180 |
+
fg_rgb_values = torch.sum(weights.unsqueeze(-1) * fg_rgb, 1)
|
181 |
+
|
182 |
+
# Background rendering
|
183 |
+
if input['idx'] is not None:
|
184 |
+
N_bg_samples = z_vals_bg.shape[1]
|
185 |
+
z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0
|
186 |
+
|
187 |
+
bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)
|
188 |
+
bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)
|
189 |
+
|
190 |
+
bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4]
|
191 |
+
bg_points_flat = bg_points.reshape(-1, 4)
|
192 |
+
bg_dirs_flat = bg_dirs.reshape(-1, 3)
|
193 |
+
bg_output = self.bg_implicit_network(bg_points_flat, {'frame': frame_latent_code})[0]
|
194 |
+
bg_sdf = bg_output[:, :1]
|
195 |
+
bg_feature_vectors = bg_output[:, 1:]
|
196 |
+
|
197 |
+
bg_rendering_output = self.bg_rendering_network(None, None, bg_dirs_flat, None, bg_feature_vectors, frame_latent_code)
|
198 |
+
if bg_rendering_output.shape[-1] == 4:
|
199 |
+
bg_rgb_flat = bg_rendering_output[..., :-1]
|
200 |
+
shadow_r = bg_rendering_output[..., -1]
|
201 |
+
bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
|
202 |
+
shadow_r = shadow_r.reshape(-1, N_bg_samples, 1)
|
203 |
+
bg_rgb = (1 - shadow_r) * bg_rgb
|
204 |
+
else:
|
205 |
+
bg_rgb_flat = bg_rendering_output
|
206 |
+
bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
|
207 |
+
bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)
|
208 |
+
bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)
|
209 |
+
else:
|
210 |
+
bg_rgb_values = torch.ones_like(fg_rgb_values, device=fg_rgb_values.device)
|
211 |
+
|
212 |
+
# Composite foreground and background
|
213 |
+
bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values
|
214 |
+
rgb_values = fg_rgb_values + bg_rgb_values
|
215 |
+
|
216 |
+
normal_values = torch.sum(weights.unsqueeze(-1) * normal_values, 1)
|
217 |
+
|
218 |
+
if self.training:
|
219 |
+
output = {
|
220 |
+
'points': points,
|
221 |
+
'rgb_values': rgb_values,
|
222 |
+
'normal_values': normal_values,
|
223 |
+
'index_outside': input['index_outside'],
|
224 |
+
'index_off_surface': index_off_surface,
|
225 |
+
'index_in_surface': index_in_surface,
|
226 |
+
'acc_map': torch.sum(weights, -1),
|
227 |
+
'sdf_output': sdf_output,
|
228 |
+
'grad_theta': grad_theta,
|
229 |
+
'epoch': input['current_epoch'],
|
230 |
+
}
|
231 |
+
else:
|
232 |
+
fg_output_rgb = fg_rgb_values + bg_transmittance.unsqueeze(-1) * torch.ones_like(fg_rgb_values, device=fg_rgb_values.device)
|
233 |
+
output = {
|
234 |
+
'acc_map': torch.sum(weights, -1),
|
235 |
+
'rgb_values': rgb_values,
|
236 |
+
'fg_rgb_values': fg_output_rgb,
|
237 |
+
'normal_values': normal_values,
|
238 |
+
'sdf_output': sdf_output,
|
239 |
+
}
|
240 |
+
return output
|
241 |
+
|
242 |
+
def get_rbg_value(self, x, points, view_dirs, cond, tfs, feature_vectors, is_training=True):
|
243 |
+
pnts_c = points
|
244 |
+
others = {}
|
245 |
+
|
246 |
+
_, gradients, feature_vectors = self.forward_gradient(x, pnts_c, cond, tfs, create_graph=is_training, retain_graph=is_training)
|
247 |
+
# ensure the gradient is normalized
|
248 |
+
normals = nn.functional.normalize(gradients, dim=-1, eps=1e-6)
|
249 |
+
fg_rendering_output = self.rendering_network(pnts_c, normals, view_dirs, cond['smpl'],
|
250 |
+
feature_vectors)
|
251 |
+
|
252 |
+
rgb_vals = fg_rendering_output[:, :3]
|
253 |
+
others['normals'] = normals
|
254 |
+
return rgb_vals, others
|
255 |
+
|
256 |
+
def forward_gradient(self, x, pnts_c, cond, tfs, create_graph=True, retain_graph=True):
|
257 |
+
if pnts_c.shape[0] == 0:
|
258 |
+
return pnts_c.detach()
|
259 |
+
pnts_c.requires_grad_(True)
|
260 |
+
|
261 |
+
pnts_d = self.deformer.forward_skinning(pnts_c.unsqueeze(0), None, tfs).squeeze(0)
|
262 |
+
num_dim = pnts_d.shape[-1]
|
263 |
+
grads = []
|
264 |
+
for i in range(num_dim):
|
265 |
+
d_out = torch.zeros_like(pnts_d, requires_grad=False, device=pnts_d.device)
|
266 |
+
d_out[:, i] = 1
|
267 |
+
grad = torch.autograd.grad(
|
268 |
+
outputs=pnts_d,
|
269 |
+
inputs=pnts_c,
|
270 |
+
grad_outputs=d_out,
|
271 |
+
create_graph=create_graph,
|
272 |
+
retain_graph=True if i < num_dim - 1 else retain_graph,
|
273 |
+
only_inputs=True)[0]
|
274 |
+
grads.append(grad)
|
275 |
+
grads = torch.stack(grads, dim=-2)
|
276 |
+
grads_inv = grads.inverse()
|
277 |
+
|
278 |
+
output = self.implicit_network(pnts_c, cond)[0]
|
279 |
+
sdf = output[:, :1]
|
280 |
+
|
281 |
+
feature = output[:, 1:]
|
282 |
+
d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
|
283 |
+
gradients = torch.autograd.grad(
|
284 |
+
outputs=sdf,
|
285 |
+
inputs=pnts_c,
|
286 |
+
grad_outputs=d_output,
|
287 |
+
create_graph=create_graph,
|
288 |
+
retain_graph=retain_graph,
|
289 |
+
only_inputs=True)[0]
|
290 |
+
|
291 |
+
return grads.reshape(grads.shape[0], -1), torch.nn.functional.normalize(torch.einsum('bi,bij->bj', gradients, grads_inv), dim=1), feature
|
292 |
+
|
293 |
+
def volume_rendering(self, z_vals, z_max, sdf):
|
294 |
+
density_flat = self.density(sdf)
|
295 |
+
density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
|
296 |
+
|
297 |
+
# included also the dist from the sphere intersection
|
298 |
+
dists = z_vals[:, 1:] - z_vals[:, :-1]
|
299 |
+
dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)
|
300 |
+
|
301 |
+
# LOG SPACE
|
302 |
+
free_energy = dists * density
|
303 |
+
shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy], dim=-1) # add 0 for transperancy 1 at t_0
|
304 |
+
alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
|
305 |
+
transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
|
306 |
+
fg_transmittance = transmittance[:, :-1]
|
307 |
+
weights = alpha * fg_transmittance # probability of the ray hits something here
|
308 |
+
bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering
|
309 |
+
|
310 |
+
return weights, bg_transmittance
|
311 |
+
|
312 |
+
def bg_volume_rendering(self, z_vals_bg, bg_sdf):
|
313 |
+
bg_density_flat = self.bg_density(bg_sdf)
|
314 |
+
bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples
|
315 |
+
|
316 |
+
bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]
|
317 |
+
bg_dists = torch.cat([bg_dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)
|
318 |
+
|
319 |
+
# LOG SPACE
|
320 |
+
bg_free_energy = bg_dists * bg_density
|
321 |
+
bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1).cuda(), bg_free_energy[:, :-1]], dim=-1) # shift one step
|
322 |
+
bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here
|
323 |
+
bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now
|
324 |
+
bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here
|
325 |
+
|
326 |
+
return bg_weights
|
327 |
+
|
328 |
+
def depth2pts_outside(self, ray_o, ray_d, depth):
|
329 |
+
|
330 |
+
'''
|
331 |
+
ray_o, ray_d: [..., 3]
|
332 |
+
depth: [...]; inverse of distance to sphere origin
|
333 |
+
'''
|
334 |
+
|
335 |
+
o_dot_d = torch.sum(ray_d * ray_o, dim=-1)
|
336 |
+
under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.sdf_bounding_sphere ** 2)
|
337 |
+
d_sphere = torch.sqrt(under_sqrt) - o_dot_d
|
338 |
+
p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d
|
339 |
+
p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d
|
340 |
+
p_mid_norm = torch.norm(p_mid, dim=-1)
|
341 |
+
|
342 |
+
rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
|
343 |
+
rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
|
344 |
+
phi = torch.asin(p_mid_norm / self.sdf_bounding_sphere)
|
345 |
+
theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
|
346 |
+
rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
|
347 |
+
|
348 |
+
# now rotate p_sphere
|
349 |
+
# Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
|
350 |
+
p_sphere_new = p_sphere * torch.cos(rot_angle) + \
|
351 |
+
torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
|
352 |
+
rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))
|
353 |
+
p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
|
354 |
+
pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
|
355 |
+
|
356 |
+
return pts
|
357 |
+
|
358 |
+
def gradient(inputs, outputs):
|
359 |
+
|
360 |
+
d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
|
361 |
+
points_grad = grad(
|
362 |
+
outputs=outputs,
|
363 |
+
inputs=inputs,
|
364 |
+
grad_outputs=d_points,
|
365 |
+
create_graph=True,
|
366 |
+
retain_graph=True,
|
367 |
+
only_inputs=True)[0][:, :, -3:]
|
368 |
+
return points_grad
|
code/lib/smpl/body_models.py
ADDED
@@ -0,0 +1,365 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems and the Max Planck Institute for Biological
|
14 |
+
# Cybernetics. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import print_function
|
20 |
+
from __future__ import division
|
21 |
+
|
22 |
+
import os
|
23 |
+
import os.path as osp
|
24 |
+
|
25 |
+
|
26 |
+
import pickle
|
27 |
+
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
from collections import namedtuple
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torch.nn as nn
|
34 |
+
|
35 |
+
from .lbs import (
|
36 |
+
lbs, vertices2joints, blend_shapes)
|
37 |
+
|
38 |
+
from .vertex_ids import vertex_ids as VERTEX_IDS
|
39 |
+
from .utils import Struct, to_np, to_tensor
|
40 |
+
from .vertex_joint_selector import VertexJointSelector
|
41 |
+
|
42 |
+
|
43 |
+
ModelOutput = namedtuple('ModelOutput',
|
44 |
+
['vertices','faces', 'joints', 'full_pose', 'betas',
|
45 |
+
'global_orient',
|
46 |
+
'body_pose', 'expression',
|
47 |
+
'left_hand_pose', 'right_hand_pose',
|
48 |
+
'jaw_pose', 'T', 'T_weighted', 'weights'])
|
49 |
+
ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields)
|
50 |
+
|
51 |
+
class SMPL(nn.Module):
|
52 |
+
|
53 |
+
NUM_JOINTS = 23
|
54 |
+
NUM_BODY_JOINTS = 23
|
55 |
+
NUM_BETAS = 10
|
56 |
+
|
57 |
+
def __init__(self, model_path, data_struct=None,
|
58 |
+
create_betas=True,
|
59 |
+
betas=None,
|
60 |
+
create_global_orient=True,
|
61 |
+
global_orient=None,
|
62 |
+
create_body_pose=True,
|
63 |
+
body_pose=None,
|
64 |
+
create_transl=True,
|
65 |
+
transl=None,
|
66 |
+
dtype=torch.float32,
|
67 |
+
batch_size=1,
|
68 |
+
joint_mapper=None, gender='neutral',
|
69 |
+
vertex_ids=None,
|
70 |
+
pose_blend=True,
|
71 |
+
**kwargs):
|
72 |
+
''' SMPL model constructor
|
73 |
+
|
74 |
+
Parameters
|
75 |
+
----------
|
76 |
+
model_path: str
|
77 |
+
The path to the folder or to the file where the model
|
78 |
+
parameters are stored
|
79 |
+
data_struct: Strct
|
80 |
+
A struct object. If given, then the parameters of the model are
|
81 |
+
read from the object. Otherwise, the model tries to read the
|
82 |
+
parameters from the given `model_path`. (default = None)
|
83 |
+
create_global_orient: bool, optional
|
84 |
+
Flag for creating a member variable for the global orientation
|
85 |
+
of the body. (default = True)
|
86 |
+
global_orient: torch.tensor, optional, Bx3
|
87 |
+
The default value for the global orientation variable.
|
88 |
+
(default = None)
|
89 |
+
create_body_pose: bool, optional
|
90 |
+
Flag for creating a member variable for the pose of the body.
|
91 |
+
(default = True)
|
92 |
+
body_pose: torch.tensor, optional, Bx(Body Joints * 3)
|
93 |
+
The default value for the body pose variable.
|
94 |
+
(default = None)
|
95 |
+
create_betas: bool, optional
|
96 |
+
Flag for creating a member variable for the shape space
|
97 |
+
(default = True).
|
98 |
+
betas: torch.tensor, optional, Bx10
|
99 |
+
The default value for the shape member variable.
|
100 |
+
(default = None)
|
101 |
+
create_transl: bool, optional
|
102 |
+
Flag for creating a member variable for the translation
|
103 |
+
of the body. (default = True)
|
104 |
+
transl: torch.tensor, optional, Bx3
|
105 |
+
The default value for the transl variable.
|
106 |
+
(default = None)
|
107 |
+
dtype: torch.dtype, optional
|
108 |
+
The data type for the created variables
|
109 |
+
batch_size: int, optional
|
110 |
+
The batch size used for creating the member variables
|
111 |
+
joint_mapper: object, optional
|
112 |
+
An object that re-maps the joints. Useful if one wants to
|
113 |
+
re-order the SMPL joints to some other convention (e.g. MSCOCO)
|
114 |
+
(default = None)
|
115 |
+
gender: str, optional
|
116 |
+
Which gender to load
|
117 |
+
vertex_ids: dict, optional
|
118 |
+
A dictionary containing the indices of the extra vertices that
|
119 |
+
will be selected
|
120 |
+
'''
|
121 |
+
|
122 |
+
self.gender = gender
|
123 |
+
self.pose_blend = pose_blend
|
124 |
+
|
125 |
+
if data_struct is None:
|
126 |
+
if osp.isdir(model_path):
|
127 |
+
model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')
|
128 |
+
smpl_path = os.path.join(model_path, model_fn)
|
129 |
+
else:
|
130 |
+
smpl_path = model_path
|
131 |
+
assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
|
132 |
+
smpl_path)
|
133 |
+
|
134 |
+
with open(smpl_path, 'rb') as smpl_file:
|
135 |
+
data_struct = Struct(**pickle.load(smpl_file,encoding='latin1'))
|
136 |
+
super(SMPL, self).__init__()
|
137 |
+
self.batch_size = batch_size
|
138 |
+
|
139 |
+
if vertex_ids is None:
|
140 |
+
# SMPL and SMPL-H share the same topology, so any extra joints can
|
141 |
+
# be drawn from the same place
|
142 |
+
vertex_ids = VERTEX_IDS['smplh']
|
143 |
+
|
144 |
+
self.dtype = dtype
|
145 |
+
|
146 |
+
self.joint_mapper = joint_mapper
|
147 |
+
|
148 |
+
self.vertex_joint_selector = VertexJointSelector(
|
149 |
+
vertex_ids=vertex_ids, **kwargs)
|
150 |
+
|
151 |
+
self.faces = data_struct.f
|
152 |
+
self.register_buffer('faces_tensor',
|
153 |
+
to_tensor(to_np(self.faces, dtype=np.int64),
|
154 |
+
dtype=torch.long))
|
155 |
+
|
156 |
+
if create_betas:
|
157 |
+
if betas is None:
|
158 |
+
default_betas = torch.zeros([batch_size, self.NUM_BETAS],
|
159 |
+
dtype=dtype)
|
160 |
+
else:
|
161 |
+
if 'torch.Tensor' in str(type(betas)):
|
162 |
+
default_betas = betas.clone().detach()
|
163 |
+
else:
|
164 |
+
default_betas = torch.tensor(betas,
|
165 |
+
dtype=dtype)
|
166 |
+
|
167 |
+
self.register_parameter('betas', nn.Parameter(default_betas,
|
168 |
+
requires_grad=True))
|
169 |
+
|
170 |
+
# The tensor that contains the global rotation of the model
|
171 |
+
# It is separated from the pose of the joints in case we wish to
|
172 |
+
# optimize only over one of them
|
173 |
+
if create_global_orient:
|
174 |
+
if global_orient is None:
|
175 |
+
default_global_orient = torch.zeros([batch_size, 3],
|
176 |
+
dtype=dtype)
|
177 |
+
else:
|
178 |
+
if 'torch.Tensor' in str(type(global_orient)):
|
179 |
+
default_global_orient = global_orient.clone().detach()
|
180 |
+
else:
|
181 |
+
default_global_orient = torch.tensor(global_orient,
|
182 |
+
dtype=dtype)
|
183 |
+
|
184 |
+
global_orient = nn.Parameter(default_global_orient,
|
185 |
+
requires_grad=True)
|
186 |
+
self.register_parameter('global_orient', global_orient)
|
187 |
+
|
188 |
+
if create_body_pose:
|
189 |
+
if body_pose is None:
|
190 |
+
default_body_pose = torch.zeros(
|
191 |
+
[batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
|
192 |
+
else:
|
193 |
+
if 'torch.Tensor' in str(type(body_pose)):
|
194 |
+
default_body_pose = body_pose.clone().detach()
|
195 |
+
else:
|
196 |
+
default_body_pose = torch.tensor(body_pose,
|
197 |
+
dtype=dtype)
|
198 |
+
self.register_parameter(
|
199 |
+
'body_pose',
|
200 |
+
nn.Parameter(default_body_pose, requires_grad=True))
|
201 |
+
|
202 |
+
if create_transl:
|
203 |
+
if transl is None:
|
204 |
+
default_transl = torch.zeros([batch_size, 3],
|
205 |
+
dtype=dtype,
|
206 |
+
requires_grad=True)
|
207 |
+
else:
|
208 |
+
default_transl = torch.tensor(transl, dtype=dtype)
|
209 |
+
self.register_parameter(
|
210 |
+
'transl',
|
211 |
+
nn.Parameter(default_transl, requires_grad=True))
|
212 |
+
|
213 |
+
# The vertices of the template model
|
214 |
+
self.register_buffer('v_template',
|
215 |
+
to_tensor(to_np(data_struct.v_template),
|
216 |
+
dtype=dtype))
|
217 |
+
|
218 |
+
# The shape components
|
219 |
+
shapedirs = data_struct.shapedirs[:, :, :self.NUM_BETAS]
|
220 |
+
# The shape components
|
221 |
+
self.register_buffer(
|
222 |
+
'shapedirs',
|
223 |
+
to_tensor(to_np(shapedirs), dtype=dtype))
|
224 |
+
|
225 |
+
|
226 |
+
j_regressor = to_tensor(to_np(
|
227 |
+
data_struct.J_regressor), dtype=dtype)
|
228 |
+
self.register_buffer('J_regressor', j_regressor)
|
229 |
+
|
230 |
+
# if self.gender == 'neutral':
|
231 |
+
# joint_regressor = to_tensor(to_np(
|
232 |
+
# data_struct.cocoplus_regressor), dtype=dtype).permute(1,0)
|
233 |
+
# self.register_buffer('joint_regressor', joint_regressor)
|
234 |
+
|
235 |
+
# Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
|
236 |
+
num_pose_basis = data_struct.posedirs.shape[-1]
|
237 |
+
# 207 x 20670
|
238 |
+
posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T
|
239 |
+
self.register_buffer('posedirs',
|
240 |
+
to_tensor(to_np(posedirs), dtype=dtype))
|
241 |
+
|
242 |
+
# indices of parents for each joints
|
243 |
+
parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
|
244 |
+
parents[0] = -1
|
245 |
+
self.register_buffer('parents', parents)
|
246 |
+
|
247 |
+
self.bone_parents = to_np(data_struct.kintree_table[0])
|
248 |
+
|
249 |
+
self.register_buffer('lbs_weights',
|
250 |
+
to_tensor(to_np(data_struct.weights), dtype=dtype))
|
251 |
+
|
252 |
+
def create_mean_pose(self, data_struct):
|
253 |
+
pass
|
254 |
+
|
255 |
+
@torch.no_grad()
|
256 |
+
def reset_params(self, **params_dict):
|
257 |
+
for param_name, param in self.named_parameters():
|
258 |
+
if param_name in params_dict:
|
259 |
+
param[:] = torch.tensor(params_dict[param_name])
|
260 |
+
else:
|
261 |
+
param.fill_(0)
|
262 |
+
|
263 |
+
def get_T_hip(self, betas=None):
|
264 |
+
v_shaped = self.v_template + blend_shapes(betas, self.shapedirs)
|
265 |
+
J = vertices2joints(self.J_regressor, v_shaped)
|
266 |
+
T_hip = J[0,0]
|
267 |
+
return T_hip
|
268 |
+
|
269 |
+
def get_num_verts(self):
|
270 |
+
return self.v_template.shape[0]
|
271 |
+
|
272 |
+
def get_num_faces(self):
|
273 |
+
return self.faces.shape[0]
|
274 |
+
|
275 |
+
def extra_repr(self):
|
276 |
+
return 'Number of betas: {}'.format(self.NUM_BETAS)
|
277 |
+
|
278 |
+
def forward(self, betas=None, body_pose=None, global_orient=None,
|
279 |
+
transl=None, return_verts=True, return_full_pose=False,displacement=None,v_template=None,
|
280 |
+
**kwargs):
|
281 |
+
''' Forward pass for the SMPL model
|
282 |
+
|
283 |
+
Parameters
|
284 |
+
----------
|
285 |
+
global_orient: torch.tensor, optional, shape Bx3
|
286 |
+
If given, ignore the member variable and use it as the global
|
287 |
+
rotation of the body. Useful if someone wishes to predicts this
|
288 |
+
with an external model. (default=None)
|
289 |
+
betas: torch.tensor, optional, shape Bx10
|
290 |
+
If given, ignore the member variable `betas` and use it
|
291 |
+
instead. For example, it can used if shape parameters
|
292 |
+
`betas` are predicted from some external model.
|
293 |
+
(default=None)
|
294 |
+
body_pose: torch.tensor, optional, shape Bx(J*3)
|
295 |
+
If given, ignore the member variable `body_pose` and use it
|
296 |
+
instead. For example, it can used if someone predicts the
|
297 |
+
pose of the body joints are predicted from some external model.
|
298 |
+
It should be a tensor that contains joint rotations in
|
299 |
+
axis-angle format. (default=None)
|
300 |
+
transl: torch.tensor, optional, shape Bx3
|
301 |
+
If given, ignore the member variable `transl` and use it
|
302 |
+
instead. For example, it can used if the translation
|
303 |
+
`transl` is predicted from some external model.
|
304 |
+
(default=None)
|
305 |
+
return_verts: bool, optional
|
306 |
+
Return the vertices. (default=True)
|
307 |
+
return_full_pose: bool, optional
|
308 |
+
Returns the full axis-angle pose vector (default=False)
|
309 |
+
|
310 |
+
Returns
|
311 |
+
-------
|
312 |
+
'''
|
313 |
+
# If no shape and pose parameters are passed along, then use the
|
314 |
+
# ones from the module
|
315 |
+
global_orient = (global_orient if global_orient is not None else
|
316 |
+
self.global_orient)
|
317 |
+
body_pose = body_pose if body_pose is not None else self.body_pose
|
318 |
+
betas = betas if betas is not None else self.betas
|
319 |
+
|
320 |
+
apply_trans = transl is not None or hasattr(self, 'transl')
|
321 |
+
if transl is None and hasattr(self, 'transl'):
|
322 |
+
transl = self.transl
|
323 |
+
|
324 |
+
full_pose = torch.cat([global_orient, body_pose], dim=1)
|
325 |
+
|
326 |
+
# if betas.shape[0] != self.batch_size:
|
327 |
+
# num_repeats = int(self.batch_size / betas.shape[0])
|
328 |
+
# betas = betas.expand(num_repeats, -1)
|
329 |
+
|
330 |
+
if v_template is None:
|
331 |
+
v_template = self.v_template
|
332 |
+
|
333 |
+
if displacement is not None:
|
334 |
+
vertices, joints_smpl, T_weighted, W, T = lbs(betas, full_pose, v_template+displacement,
|
335 |
+
self.shapedirs, self.posedirs,
|
336 |
+
self.J_regressor, self.parents,
|
337 |
+
self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend)
|
338 |
+
else:
|
339 |
+
vertices, joints_smpl,T_weighted, W, T = lbs(betas, full_pose, v_template,
|
340 |
+
self.shapedirs, self.posedirs,
|
341 |
+
self.J_regressor, self.parents,
|
342 |
+
self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend)
|
343 |
+
|
344 |
+
# if self.gender is not 'neutral':
|
345 |
+
joints = self.vertex_joint_selector(vertices, joints_smpl)
|
346 |
+
# else:
|
347 |
+
# joints = torch.matmul(vertices.permute(0,2,1),self.joint_regressor).permute(0,2,1)
|
348 |
+
# Map the joints to the current dataset
|
349 |
+
if self.joint_mapper is not None:
|
350 |
+
joints = self.joint_mapper(joints)
|
351 |
+
|
352 |
+
if apply_trans:
|
353 |
+
joints_smpl = joints_smpl + transl.unsqueeze(dim=1)
|
354 |
+
joints = joints + transl.unsqueeze(dim=1)
|
355 |
+
vertices = vertices + transl.unsqueeze(dim=1)
|
356 |
+
|
357 |
+
output = ModelOutput(vertices=vertices if return_verts else None,
|
358 |
+
faces=self.faces,
|
359 |
+
global_orient=global_orient,
|
360 |
+
body_pose=body_pose,
|
361 |
+
joints=joints_smpl,
|
362 |
+
betas=self.betas,
|
363 |
+
full_pose=full_pose if return_full_pose else None,
|
364 |
+
T=T, T_weighted=T_weighted, weights=W)
|
365 |
+
return output
|
code/lib/smpl/lbs.py
ADDED
@@ -0,0 +1,377 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems and the Max Planck Institute for Biological
|
14 |
+
# Cybernetics. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import print_function
|
20 |
+
from __future__ import division
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn.functional as F
|
26 |
+
|
27 |
+
from .utils import rot_mat_to_euler
|
28 |
+
|
29 |
+
|
30 |
+
def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
|
31 |
+
dynamic_lmk_b_coords,
|
32 |
+
neck_kin_chain, dtype=torch.float32):
|
33 |
+
''' Compute the faces, barycentric coordinates for the dynamic landmarks
|
34 |
+
|
35 |
+
|
36 |
+
To do so, we first compute the rotation of the neck around the y-axis
|
37 |
+
and then use a pre-computed look-up table to find the faces and the
|
38 |
+
barycentric coordinates that will be used.
|
39 |
+
|
40 |
+
Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
|
41 |
+
for providing the original TensorFlow implementation and for the LUT.
|
42 |
+
|
43 |
+
Parameters
|
44 |
+
----------
|
45 |
+
vertices: torch.tensor BxVx3, dtype = torch.float32
|
46 |
+
The tensor of input vertices
|
47 |
+
pose: torch.tensor Bx(Jx3), dtype = torch.float32
|
48 |
+
The current pose of the body model
|
49 |
+
dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
|
50 |
+
The look-up table from neck rotation to faces
|
51 |
+
dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
|
52 |
+
The look-up table from neck rotation to barycentric coordinates
|
53 |
+
neck_kin_chain: list
|
54 |
+
A python list that contains the indices of the joints that form the
|
55 |
+
kinematic chain of the neck.
|
56 |
+
dtype: torch.dtype, optional
|
57 |
+
|
58 |
+
Returns
|
59 |
+
-------
|
60 |
+
dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
|
61 |
+
A tensor of size BxL that contains the indices of the faces that
|
62 |
+
will be used to compute the current dynamic landmarks.
|
63 |
+
dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
|
64 |
+
A tensor of size BxL that contains the indices of the faces that
|
65 |
+
will be used to compute the current dynamic landmarks.
|
66 |
+
'''
|
67 |
+
|
68 |
+
batch_size = vertices.shape[0]
|
69 |
+
|
70 |
+
aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
|
71 |
+
neck_kin_chain)
|
72 |
+
rot_mats = batch_rodrigues(
|
73 |
+
aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
|
74 |
+
|
75 |
+
rel_rot_mat = torch.eye(3, device=vertices.device,
|
76 |
+
dtype=dtype).unsqueeze_(dim=0)
|
77 |
+
for idx in range(len(neck_kin_chain)):
|
78 |
+
rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
|
79 |
+
|
80 |
+
y_rot_angle = torch.round(
|
81 |
+
torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
|
82 |
+
max=39)).to(dtype=torch.long)
|
83 |
+
neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
|
84 |
+
mask = y_rot_angle.lt(-39).to(dtype=torch.long)
|
85 |
+
neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
|
86 |
+
y_rot_angle = (neg_mask * neg_vals +
|
87 |
+
(1 - neg_mask) * y_rot_angle)
|
88 |
+
|
89 |
+
dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
|
90 |
+
0, y_rot_angle)
|
91 |
+
dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
|
92 |
+
0, y_rot_angle)
|
93 |
+
|
94 |
+
return dyn_lmk_faces_idx, dyn_lmk_b_coords
|
95 |
+
|
96 |
+
|
97 |
+
def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
|
98 |
+
''' Calculates landmarks by barycentric interpolation
|
99 |
+
|
100 |
+
Parameters
|
101 |
+
----------
|
102 |
+
vertices: torch.tensor BxVx3, dtype = torch.float32
|
103 |
+
The tensor of input vertices
|
104 |
+
faces: torch.tensor Fx3, dtype = torch.long
|
105 |
+
The faces of the mesh
|
106 |
+
lmk_faces_idx: torch.tensor L, dtype = torch.long
|
107 |
+
The tensor with the indices of the faces used to calculate the
|
108 |
+
landmarks.
|
109 |
+
lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
|
110 |
+
The tensor of barycentric coordinates that are used to interpolate
|
111 |
+
the landmarks
|
112 |
+
|
113 |
+
Returns
|
114 |
+
-------
|
115 |
+
landmarks: torch.tensor BxLx3, dtype = torch.float32
|
116 |
+
The coordinates of the landmarks for each mesh in the batch
|
117 |
+
'''
|
118 |
+
# Extract the indices of the vertices for each face
|
119 |
+
# BxLx3
|
120 |
+
batch_size, num_verts = vertices.shape[:2]
|
121 |
+
device = vertices.device
|
122 |
+
|
123 |
+
lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).expand(
|
124 |
+
batch_size, -1, -1).long()
|
125 |
+
|
126 |
+
lmk_faces = lmk_faces + torch.arange(
|
127 |
+
batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
|
128 |
+
|
129 |
+
lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
|
130 |
+
batch_size, -1, 3, 3)
|
131 |
+
|
132 |
+
landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
|
133 |
+
return landmarks
|
134 |
+
|
135 |
+
|
136 |
+
def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
|
137 |
+
lbs_weights, pose2rot=True, dtype=torch.float32, pose_blend=True):
|
138 |
+
''' Performs Linear Blend Skinning with the given shape and pose parameters
|
139 |
+
|
140 |
+
Parameters
|
141 |
+
----------
|
142 |
+
betas : torch.tensor BxNB
|
143 |
+
The tensor of shape parameters
|
144 |
+
pose : torch.tensor Bx(J + 1) * 3
|
145 |
+
The pose parameters in axis-angle format
|
146 |
+
v_template torch.tensor BxVx3
|
147 |
+
The template mesh that will be deformed
|
148 |
+
shapedirs : torch.tensor 1xNB
|
149 |
+
The tensor of PCA shape displacements
|
150 |
+
posedirs : torch.tensor Px(V * 3)
|
151 |
+
The pose PCA coefficients
|
152 |
+
J_regressor : torch.tensor JxV
|
153 |
+
The regressor array that is used to calculate the joints from
|
154 |
+
the position of the vertices
|
155 |
+
parents: torch.tensor J
|
156 |
+
The array that describes the kinematic tree for the model
|
157 |
+
lbs_weights: torch.tensor N x V x (J + 1)
|
158 |
+
The linear blend skinning weights that represent how much the
|
159 |
+
rotation matrix of each part affects each vertex
|
160 |
+
pose2rot: bool, optional
|
161 |
+
Flag on whether to convert the input pose tensor to rotation
|
162 |
+
matrices. The default value is True. If False, then the pose tensor
|
163 |
+
should already contain rotation matrices and have a size of
|
164 |
+
Bx(J + 1)x9
|
165 |
+
dtype: torch.dtype, optional
|
166 |
+
|
167 |
+
Returns
|
168 |
+
-------
|
169 |
+
verts: torch.tensor BxVx3
|
170 |
+
The vertices of the mesh after applying the shape and pose
|
171 |
+
displacements.
|
172 |
+
joints: torch.tensor BxJx3
|
173 |
+
The joints of the model
|
174 |
+
'''
|
175 |
+
|
176 |
+
batch_size = max(betas.shape[0], pose.shape[0])
|
177 |
+
device = betas.device
|
178 |
+
|
179 |
+
# Add shape contribution
|
180 |
+
v_shaped = v_template + blend_shapes(betas, shapedirs)
|
181 |
+
|
182 |
+
# Get the joints
|
183 |
+
# NxJx3 array
|
184 |
+
J = vertices2joints(J_regressor, v_shaped)
|
185 |
+
|
186 |
+
# 3. Add pose blend shapes
|
187 |
+
# N x J x 3 x 3
|
188 |
+
ident = torch.eye(3, dtype=dtype, device=device)
|
189 |
+
|
190 |
+
|
191 |
+
if pose2rot:
|
192 |
+
rot_mats = batch_rodrigues(
|
193 |
+
pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
|
194 |
+
|
195 |
+
pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
|
196 |
+
# (N x P) x (P, V * 3) -> N x V x 3
|
197 |
+
pose_offsets = torch.matmul(pose_feature, posedirs) \
|
198 |
+
.view(batch_size, -1, 3)
|
199 |
+
else:
|
200 |
+
pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
|
201 |
+
rot_mats = pose.view(batch_size, -1, 3, 3)
|
202 |
+
|
203 |
+
pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
|
204 |
+
posedirs).view(batch_size, -1, 3)
|
205 |
+
|
206 |
+
if pose_blend:
|
207 |
+
v_posed = pose_offsets + v_shaped
|
208 |
+
else:
|
209 |
+
v_posed = v_shaped
|
210 |
+
|
211 |
+
# 4. Get the global joint location
|
212 |
+
J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
|
213 |
+
|
214 |
+
# 5. Do skinning:
|
215 |
+
# W is N x V x (J + 1)
|
216 |
+
W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
|
217 |
+
# (N x V x (J + 1)) x (N x (J + 1) x 16)
|
218 |
+
num_joints = J_regressor.shape[0]
|
219 |
+
T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
|
220 |
+
.view(batch_size, -1, 4, 4)
|
221 |
+
|
222 |
+
homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
|
223 |
+
dtype=dtype, device=device)
|
224 |
+
v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
|
225 |
+
v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
|
226 |
+
|
227 |
+
verts = v_homo[:, :, :3, 0]
|
228 |
+
|
229 |
+
return verts, J_transformed, T, W, A.view(batch_size, num_joints, 4,4)
|
230 |
+
|
231 |
+
|
232 |
+
def vertices2joints(J_regressor, vertices):
|
233 |
+
''' Calculates the 3D joint locations from the vertices
|
234 |
+
|
235 |
+
Parameters
|
236 |
+
----------
|
237 |
+
J_regressor : torch.tensor JxV
|
238 |
+
The regressor array that is used to calculate the joints from the
|
239 |
+
position of the vertices
|
240 |
+
vertices : torch.tensor BxVx3
|
241 |
+
The tensor of mesh vertices
|
242 |
+
|
243 |
+
Returns
|
244 |
+
-------
|
245 |
+
torch.tensor BxJx3
|
246 |
+
The location of the joints
|
247 |
+
'''
|
248 |
+
|
249 |
+
return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
|
250 |
+
|
251 |
+
|
252 |
+
def blend_shapes(betas, shape_disps):
|
253 |
+
''' Calculates the per vertex displacement due to the blend shapes
|
254 |
+
|
255 |
+
|
256 |
+
Parameters
|
257 |
+
----------
|
258 |
+
betas : torch.tensor Bx(num_betas)
|
259 |
+
Blend shape coefficients
|
260 |
+
shape_disps: torch.tensor Vx3x(num_betas)
|
261 |
+
Blend shapes
|
262 |
+
|
263 |
+
Returns
|
264 |
+
-------
|
265 |
+
torch.tensor BxVx3
|
266 |
+
The per-vertex displacement due to shape deformation
|
267 |
+
'''
|
268 |
+
|
269 |
+
# Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
|
270 |
+
# i.e. Multiply each shape displacement by its corresponding beta and
|
271 |
+
# then sum them.
|
272 |
+
blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
|
273 |
+
return blend_shape
|
274 |
+
|
275 |
+
|
276 |
+
def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
|
277 |
+
''' Calculates the rotation matrices for a batch of rotation vectors
|
278 |
+
Parameters
|
279 |
+
----------
|
280 |
+
rot_vecs: torch.tensor Nx3
|
281 |
+
array of N axis-angle vectors
|
282 |
+
Returns
|
283 |
+
-------
|
284 |
+
R: torch.tensor Nx3x3
|
285 |
+
The rotation matrices for the given axis-angle parameters
|
286 |
+
'''
|
287 |
+
|
288 |
+
batch_size = rot_vecs.shape[0]
|
289 |
+
device = rot_vecs.device
|
290 |
+
|
291 |
+
angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
|
292 |
+
rot_dir = rot_vecs / angle
|
293 |
+
|
294 |
+
cos = torch.unsqueeze(torch.cos(angle), dim=1)
|
295 |
+
sin = torch.unsqueeze(torch.sin(angle), dim=1)
|
296 |
+
|
297 |
+
# Bx1 arrays
|
298 |
+
rx, ry, rz = torch.split(rot_dir, 1, dim=1)
|
299 |
+
K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
|
300 |
+
|
301 |
+
zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
|
302 |
+
K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
|
303 |
+
.view((batch_size, 3, 3))
|
304 |
+
|
305 |
+
ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
|
306 |
+
rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
|
307 |
+
return rot_mat
|
308 |
+
|
309 |
+
|
310 |
+
def transform_mat(R, t):
|
311 |
+
''' Creates a batch of transformation matrices
|
312 |
+
Args:
|
313 |
+
- R: Bx3x3 array of a batch of rotation matrices
|
314 |
+
- t: Bx3x1 array of a batch of translation vectors
|
315 |
+
Returns:
|
316 |
+
- T: Bx4x4 Transformation matrix
|
317 |
+
'''
|
318 |
+
# No padding left or right, only add an extra row
|
319 |
+
return torch.cat([F.pad(R, [0, 0, 0, 1]),
|
320 |
+
F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
|
321 |
+
|
322 |
+
|
323 |
+
def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
|
324 |
+
"""
|
325 |
+
Applies a batch of rigid transformations to the joints
|
326 |
+
|
327 |
+
Parameters
|
328 |
+
----------
|
329 |
+
rot_mats : torch.tensor BxNx3x3
|
330 |
+
Tensor of rotation matrices
|
331 |
+
joints : torch.tensor BxNx3
|
332 |
+
Locations of joints
|
333 |
+
parents : torch.tensor BxN
|
334 |
+
The kinematic tree of each object
|
335 |
+
dtype : torch.dtype, optional:
|
336 |
+
The data type of the created tensors, the default is torch.float32
|
337 |
+
|
338 |
+
Returns
|
339 |
+
-------
|
340 |
+
posed_joints : torch.tensor BxNx3
|
341 |
+
The locations of the joints after applying the pose rotations
|
342 |
+
rel_transforms : torch.tensor BxNx4x4
|
343 |
+
The relative (with respect to the root joint) rigid transformations
|
344 |
+
for all the joints
|
345 |
+
"""
|
346 |
+
|
347 |
+
joints = torch.unsqueeze(joints, dim=-1)
|
348 |
+
|
349 |
+
rel_joints = joints.clone()
|
350 |
+
rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]]
|
351 |
+
|
352 |
+
transforms_mat = transform_mat(
|
353 |
+
rot_mats.reshape(-1, 3, 3),
|
354 |
+
rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
|
355 |
+
|
356 |
+
transform_chain = [transforms_mat[:, 0]]
|
357 |
+
for i in range(1, parents.shape[0]):
|
358 |
+
# Subtract the joint location at the rest pose
|
359 |
+
# No need for rotation, since it's identity when at rest
|
360 |
+
curr_res = torch.matmul(transform_chain[parents[i]],
|
361 |
+
transforms_mat[:, i])
|
362 |
+
transform_chain.append(curr_res)
|
363 |
+
|
364 |
+
transforms = torch.stack(transform_chain, dim=1)
|
365 |
+
|
366 |
+
# The last column of the transformations contains the posed joints
|
367 |
+
posed_joints = transforms[:, :, :3, 3]
|
368 |
+
|
369 |
+
# The last column of the transformations contains the posed joints
|
370 |
+
posed_joints = transforms[:, :, :3, 3]
|
371 |
+
|
372 |
+
joints_homogen = F.pad(joints, [0, 0, 0, 1])
|
373 |
+
|
374 |
+
rel_transforms = transforms - F.pad(
|
375 |
+
torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
|
376 |
+
|
377 |
+
return posed_joints, rel_transforms
|
code/lib/smpl/smpl_model/SMPL_FEMALE.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:a583c1b98e4afc19042641f1bae5cd8a1f712a6724886291a7627ec07acd408d
|
3 |
+
size 39056454
|
code/lib/smpl/smpl_model/SMPL_MALE.pkl
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e8c0bbbbc635dcb166ed29c303fb4bef16ea5f623e5a89263495a9e403575bd
|
3 |
+
size 39056404
|
code/lib/smpl/utils.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems and the Max Planck Institute for Biological
|
14 |
+
# Cybernetics. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
from __future__ import print_function
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
import torch
|
24 |
+
|
25 |
+
|
26 |
+
def to_tensor(array, dtype=torch.float32):
|
27 |
+
if 'torch.tensor' not in str(type(array)):
|
28 |
+
return torch.tensor(array, dtype=dtype)
|
29 |
+
|
30 |
+
|
31 |
+
class Struct(object):
|
32 |
+
def __init__(self, **kwargs):
|
33 |
+
for key, val in kwargs.items():
|
34 |
+
setattr(self, key, val)
|
35 |
+
|
36 |
+
|
37 |
+
def to_np(array, dtype=np.float32):
|
38 |
+
if 'scipy.sparse' in str(type(array)):
|
39 |
+
array = array.todense()
|
40 |
+
return np.array(array, dtype=dtype)
|
41 |
+
|
42 |
+
|
43 |
+
def rot_mat_to_euler(rot_mats):
|
44 |
+
# Calculates rotation matrix to euler angles
|
45 |
+
# Careful for extreme cases of eular angles like [0.0, pi, 0.0]
|
46 |
+
|
47 |
+
sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
|
48 |
+
rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
|
49 |
+
return torch.atan2(-rot_mats[:, 2, 0], sy)
|
code/lib/smpl/vertex_ids.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems and the Max Planck Institute for Biological
|
14 |
+
# Cybernetics. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
from __future__ import print_function
|
19 |
+
from __future__ import absolute_import
|
20 |
+
from __future__ import division
|
21 |
+
|
22 |
+
# Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
|
23 |
+
# MSCOCO and OpenPose joints
|
24 |
+
vertex_ids = {
|
25 |
+
'smplh': {
|
26 |
+
'nose': 332,
|
27 |
+
'reye': 6260,
|
28 |
+
'leye': 2800,
|
29 |
+
'rear': 4071,
|
30 |
+
'lear': 583,
|
31 |
+
'rthumb': 6191,
|
32 |
+
'rindex': 5782,
|
33 |
+
'rmiddle': 5905,
|
34 |
+
'rring': 6016,
|
35 |
+
'rpinky': 6133,
|
36 |
+
'lthumb': 2746,
|
37 |
+
'lindex': 2319,
|
38 |
+
'lmiddle': 2445,
|
39 |
+
'lring': 2556,
|
40 |
+
'lpinky': 2673,
|
41 |
+
'LBigToe': 3216,
|
42 |
+
'LSmallToe': 3226,
|
43 |
+
'LHeel': 3387,
|
44 |
+
'RBigToe': 6617,
|
45 |
+
'RSmallToe': 6624,
|
46 |
+
'RHeel': 6787
|
47 |
+
},
|
48 |
+
'smplx': {
|
49 |
+
'nose': 9120,
|
50 |
+
'reye': 9929,
|
51 |
+
'leye': 9448,
|
52 |
+
'rear': 616,
|
53 |
+
'lear': 6,
|
54 |
+
'rthumb': 8079,
|
55 |
+
'rindex': 7669,
|
56 |
+
'rmiddle': 7794,
|
57 |
+
'rring': 7905,
|
58 |
+
'rpinky': 8022,
|
59 |
+
'lthumb': 5361,
|
60 |
+
'lindex': 4933,
|
61 |
+
'lmiddle': 5058,
|
62 |
+
'lring': 5169,
|
63 |
+
'lpinky': 5286,
|
64 |
+
'LBigToe': 5770,
|
65 |
+
'LSmallToe': 5780,
|
66 |
+
'LHeel': 8846,
|
67 |
+
'RBigToe': 8463,
|
68 |
+
'RSmallToe': 8474,
|
69 |
+
'RHeel': 8635
|
70 |
+
}
|
71 |
+
}
|
code/lib/smpl/vertex_joint_selector.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
|
3 |
+
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
|
4 |
+
# holder of all proprietary rights on this computer program.
|
5 |
+
# You can only use this computer program if you have closed
|
6 |
+
# a license agreement with MPG or you get the right to use the computer
|
7 |
+
# program from someone who is authorized to grant you that right.
|
8 |
+
# Any use of the computer program without a valid license is prohibited and
|
9 |
+
# liable to prosecution.
|
10 |
+
#
|
11 |
+
# Copyright©2019 Max-Planck-Gesellschaft zur Förderung
|
12 |
+
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
|
13 |
+
# for Intelligent Systems and the Max Planck Institute for Biological
|
14 |
+
# Cybernetics. All rights reserved.
|
15 |
+
#
|
16 |
+
# Contact: ps-license@tuebingen.mpg.de
|
17 |
+
|
18 |
+
from __future__ import absolute_import
|
19 |
+
from __future__ import print_function
|
20 |
+
from __future__ import division
|
21 |
+
|
22 |
+
import numpy as np
|
23 |
+
|
24 |
+
import torch
|
25 |
+
import torch.nn as nn
|
26 |
+
|
27 |
+
from .utils import to_tensor
|
28 |
+
|
29 |
+
|
30 |
+
class VertexJointSelector(nn.Module):
|
31 |
+
|
32 |
+
def __init__(self, vertex_ids=None,
|
33 |
+
use_hands=True,
|
34 |
+
use_feet_keypoints=True, **kwargs):
|
35 |
+
super(VertexJointSelector, self).__init__()
|
36 |
+
|
37 |
+
extra_joints_idxs = []
|
38 |
+
|
39 |
+
face_keyp_idxs = np.array([
|
40 |
+
vertex_ids['nose'],
|
41 |
+
vertex_ids['reye'],
|
42 |
+
vertex_ids['leye'],
|
43 |
+
vertex_ids['rear'],
|
44 |
+
vertex_ids['lear']], dtype=np.int64)
|
45 |
+
|
46 |
+
extra_joints_idxs = np.concatenate([extra_joints_idxs,
|
47 |
+
face_keyp_idxs])
|
48 |
+
|
49 |
+
if use_feet_keypoints:
|
50 |
+
feet_keyp_idxs = np.array([vertex_ids['LBigToe'],
|
51 |
+
vertex_ids['LSmallToe'],
|
52 |
+
vertex_ids['LHeel'],
|
53 |
+
vertex_ids['RBigToe'],
|
54 |
+
vertex_ids['RSmallToe'],
|
55 |
+
vertex_ids['RHeel']], dtype=np.int32)
|
56 |
+
|
57 |
+
extra_joints_idxs = np.concatenate(
|
58 |
+
[extra_joints_idxs, feet_keyp_idxs])
|
59 |
+
|
60 |
+
if use_hands:
|
61 |
+
self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky']
|
62 |
+
|
63 |
+
tips_idxs = []
|
64 |
+
for hand_id in ['l', 'r']:
|
65 |
+
for tip_name in self.tip_names:
|
66 |
+
tips_idxs.append(vertex_ids[hand_id + tip_name])
|
67 |
+
|
68 |
+
extra_joints_idxs = np.concatenate(
|
69 |
+
[extra_joints_idxs, tips_idxs])
|
70 |
+
|
71 |
+
self.register_buffer('extra_joints_idxs',
|
72 |
+
to_tensor(extra_joints_idxs, dtype=torch.long))
|
73 |
+
|
74 |
+
def forward(self, vertices, joints):
|
75 |
+
extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)
|
76 |
+
joints = torch.cat([joints, extra_joints], dim=1)
|
77 |
+
return joints
|
code/lib/utils/meshing.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from skimage import measure
|
4 |
+
from lib.libmise import mise
|
5 |
+
import trimesh
|
6 |
+
|
7 |
+
def generate_mesh(func, verts, level_set=0, res_init=32, res_up=3, point_batch=5000):
|
8 |
+
|
9 |
+
scale = 1.1 # Scale of the padded bbox regarding the tight one.
|
10 |
+
verts = verts.data.cpu().numpy()
|
11 |
+
|
12 |
+
gt_bbox = np.stack([verts.min(axis=0), verts.max(axis=0)], axis=0)
|
13 |
+
gt_center = (gt_bbox[0] + gt_bbox[1]) * 0.5
|
14 |
+
gt_scale = (gt_bbox[1] - gt_bbox[0]).max()
|
15 |
+
|
16 |
+
mesh_extractor = mise.MISE(res_init, res_up, level_set)
|
17 |
+
points = mesh_extractor.query()
|
18 |
+
|
19 |
+
# query occupancy grid
|
20 |
+
while points.shape[0] != 0:
|
21 |
+
|
22 |
+
orig_points = points
|
23 |
+
points = points.astype(np.float32)
|
24 |
+
points = (points / mesh_extractor.resolution - 0.5) * scale
|
25 |
+
points = points * gt_scale + gt_center
|
26 |
+
points = torch.tensor(points).float().cuda()
|
27 |
+
|
28 |
+
values = []
|
29 |
+
for _, pnts in enumerate((torch.split(points,point_batch,dim=0))):
|
30 |
+
out = func(pnts)
|
31 |
+
values.append(out['sdf'].data.cpu().numpy())
|
32 |
+
values = np.concatenate(values, axis=0).astype(np.float64)[:,0]
|
33 |
+
|
34 |
+
mesh_extractor.update(orig_points, values)
|
35 |
+
|
36 |
+
points = mesh_extractor.query()
|
37 |
+
|
38 |
+
value_grid = mesh_extractor.to_dense()
|
39 |
+
|
40 |
+
# marching cube
|
41 |
+
verts, faces, normals, values = measure.marching_cubes_lewiner(
|
42 |
+
volume=value_grid,
|
43 |
+
gradient_direction='ascent',
|
44 |
+
level=level_set)
|
45 |
+
|
46 |
+
verts = (verts / mesh_extractor.resolution - 0.5) * scale
|
47 |
+
verts = verts * gt_scale + gt_center
|
48 |
+
faces = faces[:, [0,2,1]]
|
49 |
+
meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values)
|
50 |
+
|
51 |
+
#remove disconnect part
|
52 |
+
connected_comp = meshexport.split(only_watertight=False)
|
53 |
+
max_area = 0
|
54 |
+
max_comp = None
|
55 |
+
for comp in connected_comp:
|
56 |
+
if comp.area > max_area:
|
57 |
+
max_area = comp.area
|
58 |
+
max_comp = comp
|
59 |
+
meshexport = max_comp
|
60 |
+
|
61 |
+
return meshexport
|
62 |
+
|
63 |
+
|
code/lib/utils/utils.py
ADDED
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
from torch.nn import functional as F
|
5 |
+
|
6 |
+
|
7 |
+
def split_input(model_input, total_pixels, n_pixels = 10000):
|
8 |
+
'''
|
9 |
+
Split the input to fit Cuda memory for large resolution.
|
10 |
+
Can decrease the value of n_pixels in case of cuda out of memory error.
|
11 |
+
'''
|
12 |
+
|
13 |
+
split = []
|
14 |
+
|
15 |
+
for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
|
16 |
+
data = model_input.copy()
|
17 |
+
data['uv'] = torch.index_select(model_input['uv'], 1, indx)
|
18 |
+
split.append(data)
|
19 |
+
return split
|
20 |
+
|
21 |
+
|
22 |
+
def merge_output(res, total_pixels, batch_size):
|
23 |
+
''' Merge the split output. '''
|
24 |
+
|
25 |
+
model_outputs = {}
|
26 |
+
for entry in res[0]:
|
27 |
+
if res[0][entry] is None:
|
28 |
+
continue
|
29 |
+
if len(res[0][entry].shape) == 1:
|
30 |
+
model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
|
31 |
+
1).reshape(batch_size * total_pixels)
|
32 |
+
else:
|
33 |
+
model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
|
34 |
+
1).reshape(batch_size * total_pixels, -1)
|
35 |
+
return model_outputs
|
36 |
+
|
37 |
+
|
38 |
+
def get_psnr(img1, img2, normalize_rgb=False):
|
39 |
+
if normalize_rgb: # [-1,1] --> [0,1]
|
40 |
+
img1 = (img1 + 1.) / 2.
|
41 |
+
img2 = (img2 + 1. ) / 2.
|
42 |
+
|
43 |
+
mse = torch.mean((img1 - img2) ** 2)
|
44 |
+
psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())
|
45 |
+
|
46 |
+
return psnr
|
47 |
+
|
48 |
+
|
49 |
+
def load_K_Rt_from_P(filename, P=None):
|
50 |
+
if P is None:
|
51 |
+
lines = open(filename).read().splitlines()
|
52 |
+
if len(lines) == 4:
|
53 |
+
lines = lines[1:]
|
54 |
+
lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
|
55 |
+
P = np.asarray(lines).astype(np.float32).squeeze()
|
56 |
+
|
57 |
+
out = cv2.decomposeProjectionMatrix(P)
|
58 |
+
K = out[0]
|
59 |
+
R = out[1]
|
60 |
+
t = out[2]
|
61 |
+
|
62 |
+
K = K/K[2,2]
|
63 |
+
intrinsics = np.eye(4)
|
64 |
+
intrinsics[:3, :3] = K
|
65 |
+
|
66 |
+
pose = np.eye(4, dtype=np.float32)
|
67 |
+
pose[:3, :3] = R.transpose()
|
68 |
+
pose[:3,3] = (t[:3] / t[3])[:,0]
|
69 |
+
|
70 |
+
return intrinsics, pose
|
71 |
+
|
72 |
+
|
73 |
+
def get_camera_params(uv, pose, intrinsics):
|
74 |
+
if pose.shape[1] == 7: #In case of quaternion vector representation
|
75 |
+
cam_loc = pose[:, 4:]
|
76 |
+
R = quat_to_rot(pose[:,:4])
|
77 |
+
p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
|
78 |
+
p[:, :3, :3] = R
|
79 |
+
p[:, :3, 3] = cam_loc
|
80 |
+
else: # In case of pose matrix representation
|
81 |
+
cam_loc = pose[:, :3, 3]
|
82 |
+
p = pose
|
83 |
+
|
84 |
+
batch_size, num_samples, _ = uv.shape
|
85 |
+
|
86 |
+
depth = torch.ones((batch_size, num_samples)).cuda()
|
87 |
+
x_cam = uv[:, :, 0].view(batch_size, -1)
|
88 |
+
y_cam = uv[:, :, 1].view(batch_size, -1)
|
89 |
+
z_cam = depth.view(batch_size, -1)
|
90 |
+
|
91 |
+
pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
|
92 |
+
|
93 |
+
# permute for batch matrix product
|
94 |
+
pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
|
95 |
+
|
96 |
+
world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
|
97 |
+
ray_dirs = world_coords - cam_loc[:, None, :]
|
98 |
+
ray_dirs = F.normalize(ray_dirs, dim=2)
|
99 |
+
|
100 |
+
return ray_dirs, cam_loc
|
101 |
+
|
102 |
+
def lift(x, y, z, intrinsics):
|
103 |
+
# parse intrinsics
|
104 |
+
intrinsics = intrinsics.cuda()
|
105 |
+
fx = intrinsics[:, 0, 0]
|
106 |
+
fy = intrinsics[:, 1, 1]
|
107 |
+
cx = intrinsics[:, 0, 2]
|
108 |
+
cy = intrinsics[:, 1, 2]
|
109 |
+
sk = intrinsics[:, 0, 1]
|
110 |
+
|
111 |
+
x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
|
112 |
+
y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
|
113 |
+
|
114 |
+
# homogeneous
|
115 |
+
return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
|
116 |
+
|
117 |
+
|
118 |
+
def quat_to_rot(q):
|
119 |
+
batch_size, _ = q.shape
|
120 |
+
q = F.normalize(q, dim=1)
|
121 |
+
R = torch.ones((batch_size, 3,3)).cuda()
|
122 |
+
qr=q[:,0]
|
123 |
+
qi = q[:, 1]
|
124 |
+
qj = q[:, 2]
|
125 |
+
qk = q[:, 3]
|
126 |
+
R[:, 0, 0]=1-2 * (qj**2 + qk**2)
|
127 |
+
R[:, 0, 1] = 2 * (qj *qi -qk*qr)
|
128 |
+
R[:, 0, 2] = 2 * (qi * qk + qr * qj)
|
129 |
+
R[:, 1, 0] = 2 * (qj * qi + qk * qr)
|
130 |
+
R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
|
131 |
+
R[:, 1, 2] = 2*(qj*qk - qi*qr)
|
132 |
+
R[:, 2, 0] = 2 * (qk * qi-qj * qr)
|
133 |
+
R[:, 2, 1] = 2 * (qj*qk + qi*qr)
|
134 |
+
R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
|
135 |
+
return R
|
136 |
+
|
137 |
+
|
138 |
+
def rot_to_quat(R):
|
139 |
+
batch_size, _,_ = R.shape
|
140 |
+
q = torch.ones((batch_size, 4)).cuda()
|
141 |
+
|
142 |
+
R00 = R[:, 0,0]
|
143 |
+
R01 = R[:, 0, 1]
|
144 |
+
R02 = R[:, 0, 2]
|
145 |
+
R10 = R[:, 1, 0]
|
146 |
+
R11 = R[:, 1, 1]
|
147 |
+
R12 = R[:, 1, 2]
|
148 |
+
R20 = R[:, 2, 0]
|
149 |
+
R21 = R[:, 2, 1]
|
150 |
+
R22 = R[:, 2, 2]
|
151 |
+
|
152 |
+
q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
|
153 |
+
q[:, 1]=(R21-R12)/(4*q[:,0])
|
154 |
+
q[:, 2] = (R02 - R20) / (4 * q[:, 0])
|
155 |
+
q[:, 3] = (R10 - R01) / (4 * q[:, 0])
|
156 |
+
return q
|
157 |
+
|
158 |
+
|
159 |
+
def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
|
160 |
+
# Input: n_rays x 3 ; n_rays x 3
|
161 |
+
# Output: n_rays x 1, n_rays x 1 (close and far)
|
162 |
+
|
163 |
+
ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),
|
164 |
+
cam_loc.view(-1, 3, 1)).squeeze(-1)
|
165 |
+
under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)
|
166 |
+
|
167 |
+
# sanity check
|
168 |
+
if (under_sqrt <= 0).sum() > 0:
|
169 |
+
print('BOUNDING SPHERE PROBLEM!')
|
170 |
+
exit()
|
171 |
+
|
172 |
+
sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
|
173 |
+
sphere_intersections = sphere_intersections.clamp_min(0.0)
|
174 |
+
|
175 |
+
return sphere_intersections
|
176 |
+
|
177 |
+
def bilinear_interpolation(xs, ys, dist_map):
|
178 |
+
x1 = np.floor(xs).astype(np.int32)
|
179 |
+
y1 = np.floor(ys).astype(np.int32)
|
180 |
+
x2 = x1 + 1
|
181 |
+
y2 = y1 + 1
|
182 |
+
|
183 |
+
dx = np.expand_dims(np.stack([x2 - xs, xs - x1], axis=1), axis=1)
|
184 |
+
dy = np.expand_dims(np.stack([y2 - ys, ys - y1], axis=1), axis=2)
|
185 |
+
Q = np.stack([
|
186 |
+
dist_map[x1, y1], dist_map[x1, y2], dist_map[x2, y1], dist_map[x2, y2]
|
187 |
+
], axis=1).reshape(-1, 2, 2)
|
188 |
+
return np.squeeze(dx @ Q @ dy) # ((x2 - x1) * (y2 - y1)) = 1
|
189 |
+
|
190 |
+
def get_index_outside_of_bbox(samples_uniform, bbox_min, bbox_max):
|
191 |
+
samples_uniform_row = samples_uniform[:, 0]
|
192 |
+
samples_uniform_col = samples_uniform[:, 1]
|
193 |
+
index_outside = np.where((samples_uniform_row < bbox_min[0]) | (samples_uniform_row > bbox_max[0]) | (samples_uniform_col < bbox_min[1]) | (samples_uniform_col > bbox_max[1]))[0]
|
194 |
+
return index_outside
|
195 |
+
|
196 |
+
|
197 |
+
def weighted_sampling(data, img_size, num_sample, bbox_ratio=0.9):
|
198 |
+
"""
|
199 |
+
More sampling within the bounding box
|
200 |
+
"""
|
201 |
+
|
202 |
+
# calculate bounding box
|
203 |
+
mask = data["object_mask"]
|
204 |
+
where = np.asarray(np.where(mask))
|
205 |
+
bbox_min = where.min(axis=1)
|
206 |
+
bbox_max = where.max(axis=1)
|
207 |
+
|
208 |
+
num_sample_bbox = int(num_sample * bbox_ratio)
|
209 |
+
samples_bbox = np.random.rand(num_sample_bbox, 2)
|
210 |
+
samples_bbox = samples_bbox * (bbox_max - bbox_min) + bbox_min
|
211 |
+
|
212 |
+
num_sample_uniform = num_sample - num_sample_bbox
|
213 |
+
samples_uniform = np.random.rand(num_sample_uniform, 2)
|
214 |
+
samples_uniform *= (img_size[0] - 1, img_size[1] - 1)
|
215 |
+
|
216 |
+
# get indices for uniform samples outside of bbox
|
217 |
+
index_outside = get_index_outside_of_bbox(samples_uniform, bbox_min, bbox_max) + num_sample_bbox
|
218 |
+
|
219 |
+
indices = np.concatenate([samples_bbox, samples_uniform], axis=0)
|
220 |
+
output = {}
|
221 |
+
for key, val in data.items():
|
222 |
+
if len(val.shape) == 3:
|
223 |
+
new_val = np.stack([
|
224 |
+
bilinear_interpolation(indices[:, 0], indices[:, 1], val[:, :, i])
|
225 |
+
for i in range(val.shape[2])
|
226 |
+
], axis=-1)
|
227 |
+
else:
|
228 |
+
new_val = bilinear_interpolation(indices[:, 0], indices[:, 1], val)
|
229 |
+
new_val = new_val.reshape(-1, *val.shape[2:])
|
230 |
+
output[key] = new_val
|
231 |
+
|
232 |
+
return output, index_outside
|
code/setup.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2020 The TensorFlow Authors
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# https://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
"""Set-up script for installing extension modules."""
|
15 |
+
from Cython.Build import cythonize
|
16 |
+
import numpy
|
17 |
+
from setuptools import Extension
|
18 |
+
from setuptools import setup
|
19 |
+
|
20 |
+
# Get the numpy include directory.
|
21 |
+
numpy_include_dir = numpy.get_include()
|
22 |
+
|
23 |
+
# mise (efficient mesh extraction)
|
24 |
+
mise_module = Extension(
|
25 |
+
"lib.libmise.mise",
|
26 |
+
sources=["lib/libmise/mise.pyx"],
|
27 |
+
)
|
28 |
+
|
29 |
+
# Gather all extension modules
|
30 |
+
ext_modules = [
|
31 |
+
mise_module,
|
32 |
+
]
|
33 |
+
|
34 |
+
setup(ext_modules=cythonize(ext_modules),)
|
code/test.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from v2a_model import V2AModel
|
2 |
+
from lib.datasets import create_dataset
|
3 |
+
import hydra
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_lightning.loggers import WandbLogger
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
|
9 |
+
@hydra.main(config_path="confs", config_name="base")
|
10 |
+
def main(opt):
|
11 |
+
pl.seed_everything(42)
|
12 |
+
print("Working dir:", os.getcwd())
|
13 |
+
|
14 |
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
15 |
+
dirpath="checkpoints/",
|
16 |
+
filename="{epoch:04d}-{loss}",
|
17 |
+
save_on_train_epoch_end=True,
|
18 |
+
save_last=True)
|
19 |
+
logger = WandbLogger(project=opt.project_name, name=f"{opt.exp}/{opt.run}")
|
20 |
+
|
21 |
+
trainer = pl.Trainer(
|
22 |
+
gpus=1,
|
23 |
+
accelerator="gpu",
|
24 |
+
callbacks=[checkpoint_callback],
|
25 |
+
max_epochs=8000,
|
26 |
+
check_val_every_n_epoch=50,
|
27 |
+
logger=logger,
|
28 |
+
log_every_n_steps=1,
|
29 |
+
num_sanity_val_steps=0
|
30 |
+
)
|
31 |
+
|
32 |
+
model = V2AModel(opt)
|
33 |
+
checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1]
|
34 |
+
testset = create_dataset(opt.dataset.metainfo, opt.dataset.test)
|
35 |
+
|
36 |
+
trainer.test(model, testset, ckpt_path=checkpoint)
|
37 |
+
|
38 |
+
if __name__ == '__main__':
|
39 |
+
main()
|
code/train.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from v2a_model import V2AModel
|
2 |
+
from lib.datasets import create_dataset
|
3 |
+
import hydra
|
4 |
+
import pytorch_lightning as pl
|
5 |
+
from pytorch_lightning.loggers import WandbLogger
|
6 |
+
import os
|
7 |
+
import glob
|
8 |
+
|
9 |
+
@hydra.main(config_path="confs", config_name="base")
|
10 |
+
def main(opt):
|
11 |
+
pl.seed_everything(42)
|
12 |
+
print("Working dir:", os.getcwd())
|
13 |
+
|
14 |
+
checkpoint_callback = pl.callbacks.ModelCheckpoint(
|
15 |
+
dirpath="checkpoints/",
|
16 |
+
filename="{epoch:04d}-{loss}",
|
17 |
+
save_on_train_epoch_end=True,
|
18 |
+
save_last=True)
|
19 |
+
logger = WandbLogger(project=opt.project_name, name=f"{opt.exp}/{opt.run}")
|
20 |
+
|
21 |
+
trainer = pl.Trainer(
|
22 |
+
gpus=1,
|
23 |
+
accelerator="gpu",
|
24 |
+
callbacks=[checkpoint_callback],
|
25 |
+
max_epochs=8000,
|
26 |
+
check_val_every_n_epoch=50,
|
27 |
+
logger=logger,
|
28 |
+
log_every_n_steps=1,
|
29 |
+
num_sanity_val_steps=0
|
30 |
+
)
|
31 |
+
|
32 |
+
|
33 |
+
model = V2AModel(opt)
|
34 |
+
trainset = create_dataset(opt.dataset.metainfo, opt.dataset.train)
|
35 |
+
validset = create_dataset(opt.dataset.metainfo, opt.dataset.valid)
|
36 |
+
|
37 |
+
if opt.model.is_continue == True:
|
38 |
+
checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1]
|
39 |
+
trainer.fit(model, trainset, validset, ckpt_path=checkpoint)
|
40 |
+
else:
|
41 |
+
trainer.fit(model, trainset, validset)
|
42 |
+
|
43 |
+
|
44 |
+
if __name__ == '__main__':
|
45 |
+
main()
|
code/v2a_model.py
ADDED
@@ -0,0 +1,311 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pytorch_lightning as pl
|
2 |
+
import torch.optim as optim
|
3 |
+
from lib.model.v2a import V2A
|
4 |
+
from lib.model.body_model_params import BodyModelParams
|
5 |
+
from lib.model.deformer import SMPLDeformer
|
6 |
+
import cv2
|
7 |
+
import torch
|
8 |
+
from lib.model.loss import Loss
|
9 |
+
import hydra
|
10 |
+
import os
|
11 |
+
import numpy as np
|
12 |
+
from lib.utils.meshing import generate_mesh
|
13 |
+
from kaolin.ops.mesh import index_vertices_by_faces
|
14 |
+
import trimesh
|
15 |
+
from lib.model.deformer import skinning
|
16 |
+
from lib.utils import utils
|
17 |
+
class V2AModel(pl.LightningModule):
|
18 |
+
def __init__(self, opt) -> None:
|
19 |
+
super().__init__()
|
20 |
+
|
21 |
+
self.opt = opt
|
22 |
+
num_training_frames = opt.dataset.metainfo.end_frame - opt.dataset.metainfo.start_frame
|
23 |
+
self.betas_path = os.path.join(hydra.utils.to_absolute_path('..'), 'data', opt.dataset.metainfo.data_dir, 'mean_shape.npy')
|
24 |
+
self.gender = opt.dataset.metainfo.gender
|
25 |
+
self.model = V2A(opt.model, self.betas_path, self.gender, num_training_frames)
|
26 |
+
self.start_frame = opt.dataset.metainfo.start_frame
|
27 |
+
self.end_frame = opt.dataset.metainfo.end_frame
|
28 |
+
self.training_modules = ["model"]
|
29 |
+
|
30 |
+
self.training_indices = list(range(self.start_frame, self.end_frame))
|
31 |
+
self.body_model_params = BodyModelParams(num_training_frames, model_type='smpl')
|
32 |
+
self.load_body_model_params()
|
33 |
+
optim_params = self.body_model_params.param_names
|
34 |
+
for param_name in optim_params:
|
35 |
+
self.body_model_params.set_requires_grad(param_name, requires_grad=True)
|
36 |
+
self.training_modules += ['body_model_params']
|
37 |
+
|
38 |
+
self.loss = Loss(opt.model.loss)
|
39 |
+
|
40 |
+
def load_body_model_params(self):
|
41 |
+
body_model_params = {param_name: [] for param_name in self.body_model_params.param_names}
|
42 |
+
data_root = os.path.join('../data', self.opt.dataset.metainfo.data_dir)
|
43 |
+
data_root = hydra.utils.to_absolute_path(data_root)
|
44 |
+
|
45 |
+
body_model_params['betas'] = torch.tensor(np.load(os.path.join(data_root, 'mean_shape.npy'))[None], dtype=torch.float32)
|
46 |
+
body_model_params['global_orient'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices][:, :3], dtype=torch.float32)
|
47 |
+
body_model_params['body_pose'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices] [:, 3:], dtype=torch.float32)
|
48 |
+
body_model_params['transl'] = torch.tensor(np.load(os.path.join(data_root, 'normalize_trans.npy'))[self.training_indices], dtype=torch.float32)
|
49 |
+
|
50 |
+
for param_name in body_model_params.keys():
|
51 |
+
self.body_model_params.init_parameters(param_name, body_model_params[param_name], requires_grad=False)
|
52 |
+
|
53 |
+
def configure_optimizers(self):
|
54 |
+
params = [{'params': self.model.parameters(), 'lr':self.opt.model.learning_rate}]
|
55 |
+
params.append({'params': self.body_model_params.parameters(), 'lr':self.opt.model.learning_rate*0.1})
|
56 |
+
self.optimizer = optim.Adam(params, lr=self.opt.model.learning_rate, eps=1e-8)
|
57 |
+
self.scheduler = optim.lr_scheduler.MultiStepLR(
|
58 |
+
self.optimizer, milestones=self.opt.model.sched_milestones, gamma=self.opt.model.sched_factor)
|
59 |
+
return [self.optimizer], [self.scheduler]
|
60 |
+
|
61 |
+
def training_step(self, batch):
|
62 |
+
inputs, targets = batch
|
63 |
+
|
64 |
+
batch_idx = inputs["idx"]
|
65 |
+
|
66 |
+
body_model_params = self.body_model_params(batch_idx)
|
67 |
+
inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)
|
68 |
+
inputs['smpl_shape'] = body_model_params['betas']
|
69 |
+
inputs['smpl_trans'] = body_model_params['transl']
|
70 |
+
|
71 |
+
inputs['current_epoch'] = self.current_epoch
|
72 |
+
model_outputs = self.model(inputs)
|
73 |
+
|
74 |
+
loss_output = self.loss(model_outputs, targets)
|
75 |
+
for k, v in loss_output.items():
|
76 |
+
if k in ["loss"]:
|
77 |
+
self.log(k, v.item(), prog_bar=True, on_step=True)
|
78 |
+
else:
|
79 |
+
self.log(k, v.item(), prog_bar=True, on_step=True)
|
80 |
+
return loss_output["loss"]
|
81 |
+
|
82 |
+
def training_epoch_end(self, outputs) -> None:
|
83 |
+
# Canonical mesh update every 20 epochs
|
84 |
+
if self.current_epoch != 0 and self.current_epoch % 20 == 0:
|
85 |
+
cond = {'smpl': torch.zeros(1, 69).float().cuda()}
|
86 |
+
mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=2)
|
87 |
+
self.model.mesh_v_cano = torch.tensor(mesh_canonical.vertices[None], device = self.model.smpl_v_cano.device).float()
|
88 |
+
self.model.mesh_f_cano = torch.tensor(mesh_canonical.faces.astype(np.int64), device=self.model.smpl_v_cano.device)
|
89 |
+
self.model.mesh_face_vertices = index_vertices_by_faces(self.model.mesh_v_cano, self.model.mesh_f_cano)
|
90 |
+
return super().training_epoch_end(outputs)
|
91 |
+
|
92 |
+
def query_oc(self, x, cond):
|
93 |
+
|
94 |
+
x = x.reshape(-1, 3)
|
95 |
+
mnfld_pred = self.model.implicit_network(x, cond)[:,:,0].reshape(-1,1)
|
96 |
+
return {'sdf':mnfld_pred}
|
97 |
+
|
98 |
+
def query_wc(self, x):
|
99 |
+
|
100 |
+
x = x.reshape(-1, 3)
|
101 |
+
w = self.model.deformer.query_weights(x)
|
102 |
+
|
103 |
+
return w
|
104 |
+
|
105 |
+
def query_od(self, x, cond, smpl_tfs, smpl_verts):
|
106 |
+
|
107 |
+
x = x.reshape(-1, 3)
|
108 |
+
x_c, _ = self.model.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts)
|
109 |
+
output = self.model.implicit_network(x_c, cond)[0]
|
110 |
+
sdf = output[:, 0:1]
|
111 |
+
|
112 |
+
return {'sdf': sdf}
|
113 |
+
|
114 |
+
def get_deformed_mesh_fast_mode(self, verts, smpl_tfs):
|
115 |
+
verts = torch.tensor(verts).cuda().float()
|
116 |
+
weights = self.model.deformer.query_weights(verts)
|
117 |
+
verts_deformed = skinning(verts.unsqueeze(0), weights, smpl_tfs).data.cpu().numpy()[0]
|
118 |
+
return verts_deformed
|
119 |
+
|
120 |
+
def validation_step(self, batch, *args, **kwargs):
|
121 |
+
|
122 |
+
output = {}
|
123 |
+
inputs, targets = batch
|
124 |
+
inputs['current_epoch'] = self.current_epoch
|
125 |
+
self.model.eval()
|
126 |
+
|
127 |
+
body_model_params = self.body_model_params(inputs['image_id'])
|
128 |
+
inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)
|
129 |
+
inputs['smpl_shape'] = body_model_params['betas']
|
130 |
+
inputs['smpl_trans'] = body_model_params['transl']
|
131 |
+
|
132 |
+
cond = {'smpl': inputs["smpl_pose"][:, 3:]/np.pi}
|
133 |
+
mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=3)
|
134 |
+
|
135 |
+
mesh_canonical = trimesh.Trimesh(mesh_canonical.vertices, mesh_canonical.faces)
|
136 |
+
|
137 |
+
output.update({
|
138 |
+
'canonical_mesh':mesh_canonical
|
139 |
+
})
|
140 |
+
|
141 |
+
split = utils.split_input(inputs, targets["total_pixels"][0], n_pixels=min(targets['pixel_per_batch'], targets["img_size"][0] * targets["img_size"][1]))
|
142 |
+
|
143 |
+
res = []
|
144 |
+
for s in split:
|
145 |
+
|
146 |
+
out = self.model(s)
|
147 |
+
|
148 |
+
for k, v in out.items():
|
149 |
+
try:
|
150 |
+
out[k] = v.detach()
|
151 |
+
except:
|
152 |
+
out[k] = v
|
153 |
+
|
154 |
+
res.append({
|
155 |
+
'rgb_values': out['rgb_values'].detach(),
|
156 |
+
'normal_values': out['normal_values'].detach(),
|
157 |
+
'fg_rgb_values': out['fg_rgb_values'].detach(),
|
158 |
+
})
|
159 |
+
batch_size = targets['rgb'].shape[0]
|
160 |
+
|
161 |
+
model_outputs = utils.merge_output(res, targets["total_pixels"][0], batch_size)
|
162 |
+
|
163 |
+
output.update({
|
164 |
+
"rgb_values": model_outputs["rgb_values"].detach().clone(),
|
165 |
+
"normal_values": model_outputs["normal_values"].detach().clone(),
|
166 |
+
"fg_rgb_values": model_outputs["fg_rgb_values"].detach().clone(),
|
167 |
+
**targets,
|
168 |
+
})
|
169 |
+
|
170 |
+
return output
|
171 |
+
|
172 |
+
def validation_step_end(self, batch_parts):
|
173 |
+
return batch_parts
|
174 |
+
|
175 |
+
def validation_epoch_end(self, outputs) -> None:
|
176 |
+
img_size = outputs[0]["img_size"]
|
177 |
+
|
178 |
+
rgb_pred = torch.cat([output["rgb_values"] for output in outputs], dim=0)
|
179 |
+
rgb_pred = rgb_pred.reshape(*img_size, -1)
|
180 |
+
|
181 |
+
fg_rgb_pred = torch.cat([output["fg_rgb_values"] for output in outputs], dim=0)
|
182 |
+
fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1)
|
183 |
+
|
184 |
+
normal_pred = torch.cat([output["normal_values"] for output in outputs], dim=0)
|
185 |
+
normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2
|
186 |
+
|
187 |
+
rgb_gt = torch.cat([output["rgb"] for output in outputs], dim=1).squeeze(0)
|
188 |
+
rgb_gt = rgb_gt.reshape(*img_size, -1)
|
189 |
+
if 'normal' in outputs[0].keys():
|
190 |
+
normal_gt = torch.cat([output["normal"] for output in outputs], dim=1).squeeze(0)
|
191 |
+
normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2
|
192 |
+
normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy()
|
193 |
+
else:
|
194 |
+
normal = torch.cat([normal_pred], dim=0).cpu().numpy()
|
195 |
+
|
196 |
+
rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy()
|
197 |
+
rgb = (rgb * 255).astype(np.uint8)
|
198 |
+
|
199 |
+
fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy()
|
200 |
+
fg_rgb = (fg_rgb * 255).astype(np.uint8)
|
201 |
+
|
202 |
+
normal = (normal * 255).astype(np.uint8)
|
203 |
+
|
204 |
+
os.makedirs("rendering", exist_ok=True)
|
205 |
+
os.makedirs("normal", exist_ok=True)
|
206 |
+
os.makedirs('fg_rendering', exist_ok=True)
|
207 |
+
|
208 |
+
canonical_mesh = outputs[0]['canonical_mesh']
|
209 |
+
canonical_mesh.export(f"rendering/{self.current_epoch}.ply")
|
210 |
+
|
211 |
+
cv2.imwrite(f"rendering/{self.current_epoch}.png", rgb[:, :, ::-1])
|
212 |
+
cv2.imwrite(f"normal/{self.current_epoch}.png", normal[:, :, ::-1])
|
213 |
+
cv2.imwrite(f"fg_rendering/{self.current_epoch}.png", fg_rgb[:, :, ::-1])
|
214 |
+
|
215 |
+
def test_step(self, batch, *args, **kwargs):
|
216 |
+
inputs, targets, pixel_per_batch, total_pixels, idx = batch
|
217 |
+
num_splits = (total_pixels + pixel_per_batch -
|
218 |
+
1) // pixel_per_batch
|
219 |
+
results = []
|
220 |
+
|
221 |
+
scale, smpl_trans, smpl_pose, smpl_shape = torch.split(inputs["smpl_params"], [1, 3, 72, 10], dim=1)
|
222 |
+
|
223 |
+
body_model_params = self.body_model_params(inputs['idx'])
|
224 |
+
smpl_shape = body_model_params['betas'] if body_model_params['betas'].dim() == 2 else body_model_params['betas'].unsqueeze(0)
|
225 |
+
smpl_trans = body_model_params['transl']
|
226 |
+
smpl_pose = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)
|
227 |
+
|
228 |
+
smpl_outputs = self.model.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape)
|
229 |
+
smpl_tfs = smpl_outputs['smpl_tfs']
|
230 |
+
cond = {'smpl': smpl_pose[:, 3:]/np.pi}
|
231 |
+
|
232 |
+
mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=4)
|
233 |
+
self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender, K=7)
|
234 |
+
verts_deformed = self.get_deformed_mesh_fast_mode(mesh_canonical.vertices, smpl_tfs)
|
235 |
+
mesh_deformed = trimesh.Trimesh(vertices=verts_deformed, faces=mesh_canonical.faces, process=False)
|
236 |
+
|
237 |
+
os.makedirs("test_mask", exist_ok=True)
|
238 |
+
os.makedirs("test_rendering", exist_ok=True)
|
239 |
+
os.makedirs("test_fg_rendering", exist_ok=True)
|
240 |
+
os.makedirs("test_normal", exist_ok=True)
|
241 |
+
os.makedirs("test_mesh", exist_ok=True)
|
242 |
+
|
243 |
+
mesh_canonical.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_canonical.ply")
|
244 |
+
mesh_deformed.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_deformed.ply")
|
245 |
+
self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender)
|
246 |
+
for i in range(num_splits):
|
247 |
+
indices = list(range(i * pixel_per_batch,
|
248 |
+
min((i + 1) * pixel_per_batch, total_pixels)))
|
249 |
+
batch_inputs = {"uv": inputs["uv"][:, indices],
|
250 |
+
"intrinsics": inputs['intrinsics'],
|
251 |
+
"pose": inputs['pose'],
|
252 |
+
"smpl_params": inputs["smpl_params"],
|
253 |
+
"smpl_pose": inputs["smpl_params"][:, 4:76],
|
254 |
+
"smpl_shape": inputs["smpl_params"][:, 76:],
|
255 |
+
"smpl_trans": inputs["smpl_params"][:, 1:4],
|
256 |
+
"idx": inputs["idx"] if 'idx' in inputs.keys() else None}
|
257 |
+
|
258 |
+
body_model_params = self.body_model_params(inputs['idx'])
|
259 |
+
|
260 |
+
batch_inputs.update({'smpl_pose': torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)})
|
261 |
+
batch_inputs.update({'smpl_shape': body_model_params['betas']})
|
262 |
+
batch_inputs.update({'smpl_trans': body_model_params['transl']})
|
263 |
+
|
264 |
+
batch_targets = {"rgb": targets["rgb"][:, indices].detach().clone() if 'rgb' in targets.keys() else None,
|
265 |
+
"img_size": targets["img_size"]}
|
266 |
+
|
267 |
+
with torch.no_grad():
|
268 |
+
model_outputs = self.model(batch_inputs)
|
269 |
+
results.append({"rgb_values":model_outputs["rgb_values"].detach().clone(),
|
270 |
+
"fg_rgb_values":model_outputs["fg_rgb_values"].detach().clone(),
|
271 |
+
"normal_values": model_outputs["normal_values"].detach().clone(),
|
272 |
+
"acc_map": model_outputs["acc_map"].detach().clone(),
|
273 |
+
**batch_targets})
|
274 |
+
|
275 |
+
img_size = results[0]["img_size"]
|
276 |
+
rgb_pred = torch.cat([result["rgb_values"] for result in results], dim=0)
|
277 |
+
rgb_pred = rgb_pred.reshape(*img_size, -1)
|
278 |
+
|
279 |
+
fg_rgb_pred = torch.cat([result["fg_rgb_values"] for result in results], dim=0)
|
280 |
+
fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1)
|
281 |
+
|
282 |
+
normal_pred = torch.cat([result["normal_values"] for result in results], dim=0)
|
283 |
+
normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2
|
284 |
+
|
285 |
+
pred_mask = torch.cat([result["acc_map"] for result in results], dim=0)
|
286 |
+
pred_mask = pred_mask.reshape(*img_size, -1)
|
287 |
+
|
288 |
+
if results[0]['rgb'] is not None:
|
289 |
+
rgb_gt = torch.cat([result["rgb"] for result in results], dim=1).squeeze(0)
|
290 |
+
rgb_gt = rgb_gt.reshape(*img_size, -1)
|
291 |
+
rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy()
|
292 |
+
else:
|
293 |
+
rgb = torch.cat([rgb_pred], dim=0).cpu().numpy()
|
294 |
+
if 'normal' in results[0].keys():
|
295 |
+
normal_gt = torch.cat([result["normal"] for result in results], dim=1).squeeze(0)
|
296 |
+
normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2
|
297 |
+
normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy()
|
298 |
+
else:
|
299 |
+
normal = torch.cat([normal_pred], dim=0).cpu().numpy()
|
300 |
+
|
301 |
+
rgb = (rgb * 255).astype(np.uint8)
|
302 |
+
|
303 |
+
fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy()
|
304 |
+
fg_rgb = (fg_rgb * 255).astype(np.uint8)
|
305 |
+
|
306 |
+
normal = (normal * 255).astype(np.uint8)
|
307 |
+
|
308 |
+
cv2.imwrite(f"test_mask/{int(idx.cpu().numpy()):04d}.png", pred_mask.cpu().numpy() * 255)
|
309 |
+
cv2.imwrite(f"test_rendering/{int(idx.cpu().numpy()):04d}.png", rgb[:, :, ::-1])
|
310 |
+
cv2.imwrite(f"test_normal/{int(idx.cpu().numpy()):04d}.png", normal[:, :, ::-1])
|
311 |
+
cv2.imwrite(f"test_fg_rendering/{int(idx.cpu().numpy()):04d}.png", fg_rgb[:, :, ::-1])
|
data/parkinglot/cameras.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:ef3e45c7a142677332eda018e986d156f8a496e90a84308d395cbb44c9c0f686
|
3 |
+
size 15626
|
data/parkinglot/cameras_normalize.npz
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cc2622c844d57ee385a6f2e85c89d70a27c2181d397f7556e562f4633267e935
|
3 |
+
size 29550
|
data/parkinglot/checkpoints/epoch=6299-loss=0.01887552998960018.ckpt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b82917e0738e2a3a297870275480b77d324fb08e2eacfec35b7121ba5d4e8f28
|
3 |
+
size 36504439
|
data/parkinglot/image/0000.png
ADDED
![]() |
Git LFS Details
|
data/parkinglot/image/0001.png
ADDED
![]() |
Git LFS Details
|
data/parkinglot/image/0002.png
ADDED
![]() |
Git LFS Details
|
data/parkinglot/image/0003.png
ADDED
![]() |
Git LFS Details
|
data/parkinglot/image/0004.png
ADDED
![]() |
Git LFS Details
|