leobcc commited on
Commit
6325697
·
1 Parent(s): b5b41d5

vid2avatar baseline

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +43 -0
  2. .gitignore +19 -0
  3. LICENSE +399 -0
  4. README.md +74 -11
  5. assets/exstrimalik.gif +3 -0
  6. assets/martial.gif +3 -0
  7. assets/parkinglot_360.gif +3 -0
  8. assets/roger.gif +3 -0
  9. assets/smpl_init.pth +3 -0
  10. assets/teaser.png +3 -0
  11. code/check_cuda.py +11 -0
  12. code/confs/base.yaml +13 -0
  13. code/confs/dataset/video.yaml +37 -0
  14. code/confs/model/model_w_bg.yaml +77 -0
  15. code/lib/datasets/__init__.py +26 -0
  16. code/lib/datasets/dataset.py +175 -0
  17. code/lib/libmise/mise.cp37-win_amd64.pyd +0 -0
  18. code/lib/libmise/mise.cpp +0 -0
  19. code/lib/libmise/mise.pyx +370 -0
  20. code/lib/model/body_model_params.py +49 -0
  21. code/lib/model/deformer.py +89 -0
  22. code/lib/model/density.py +46 -0
  23. code/lib/model/embedders.py +50 -0
  24. code/lib/model/loss.py +64 -0
  25. code/lib/model/networks.py +178 -0
  26. code/lib/model/ray_sampler.py +234 -0
  27. code/lib/model/sampler.py +29 -0
  28. code/lib/model/smpl.py +94 -0
  29. code/lib/model/v2a.py +368 -0
  30. code/lib/smpl/body_models.py +365 -0
  31. code/lib/smpl/lbs.py +377 -0
  32. code/lib/smpl/smpl_model/SMPL_FEMALE.pkl +3 -0
  33. code/lib/smpl/smpl_model/SMPL_MALE.pkl +3 -0
  34. code/lib/smpl/utils.py +49 -0
  35. code/lib/smpl/vertex_ids.py +71 -0
  36. code/lib/smpl/vertex_joint_selector.py +77 -0
  37. code/lib/utils/meshing.py +63 -0
  38. code/lib/utils/utils.py +232 -0
  39. code/setup.py +34 -0
  40. code/test.py +39 -0
  41. code/train.py +45 -0
  42. code/v2a_model.py +311 -0
  43. data/parkinglot/cameras.npz +3 -0
  44. data/parkinglot/cameras_normalize.npz +3 -0
  45. data/parkinglot/checkpoints/epoch=6299-loss=0.01887552998960018.ckpt +3 -0
  46. data/parkinglot/image/0000.png +3 -0
  47. data/parkinglot/image/0001.png +3 -0
  48. data/parkinglot/image/0002.png +3 -0
  49. data/parkinglot/image/0003.png +3 -0
  50. data/parkinglot/image/0004.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,46 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ assets/exstrimalik.gif filter=lfs diff=lfs merge=lfs -text
37
+ assets/martial.gif filter=lfs diff=lfs merge=lfs -text
38
+ assets/parkinglot_360.gif filter=lfs diff=lfs merge=lfs -text
39
+ assets/roger.gif filter=lfs diff=lfs merge=lfs -text
40
+ assets/teaser.png filter=lfs diff=lfs merge=lfs -text
41
+ data/parkinglot/image/0000.png filter=lfs diff=lfs merge=lfs -text
42
+ data/parkinglot/image/0001.png filter=lfs diff=lfs merge=lfs -text
43
+ data/parkinglot/image/0002.png filter=lfs diff=lfs merge=lfs -text
44
+ data/parkinglot/image/0003.png filter=lfs diff=lfs merge=lfs -text
45
+ data/parkinglot/image/0004.png filter=lfs diff=lfs merge=lfs -text
46
+ data/parkinglot/image/0005.png filter=lfs diff=lfs merge=lfs -text
47
+ data/parkinglot/image/0006.png filter=lfs diff=lfs merge=lfs -text
48
+ data/parkinglot/image/0007.png filter=lfs diff=lfs merge=lfs -text
49
+ data/parkinglot/image/0010.png filter=lfs diff=lfs merge=lfs -text
50
+ data/parkinglot/image/0011.png filter=lfs diff=lfs merge=lfs -text
51
+ data/parkinglot/image/0012.png filter=lfs diff=lfs merge=lfs -text
52
+ data/parkinglot/image/0013.png filter=lfs diff=lfs merge=lfs -text
53
+ data/parkinglot/image/0014.png filter=lfs diff=lfs merge=lfs -text
54
+ data/parkinglot/image/0015.png filter=lfs diff=lfs merge=lfs -text
55
+ data/parkinglot/image/0016.png filter=lfs diff=lfs merge=lfs -text
56
+ data/parkinglot/image/0017.png filter=lfs diff=lfs merge=lfs -text
57
+ data/parkinglot/image/0018.png filter=lfs diff=lfs merge=lfs -text
58
+ data/parkinglot/image/0020.png filter=lfs diff=lfs merge=lfs -text
59
+ data/parkinglot/image/0021.png filter=lfs diff=lfs merge=lfs -text
60
+ data/parkinglot/image/0022.png filter=lfs diff=lfs merge=lfs -text
61
+ data/parkinglot/image/0023.png filter=lfs diff=lfs merge=lfs -text
62
+ data/parkinglot/image/0024.png filter=lfs diff=lfs merge=lfs -text
63
+ data/parkinglot/image/0025.png filter=lfs diff=lfs merge=lfs -text
64
+ data/parkinglot/image/0027.png filter=lfs diff=lfs merge=lfs -text
65
+ data/parkinglot/image/0028.png filter=lfs diff=lfs merge=lfs -text
66
+ data/parkinglot/image/0029.png filter=lfs diff=lfs merge=lfs -text
67
+ data/parkinglot/image/0030.png filter=lfs diff=lfs merge=lfs -text
68
+ data/parkinglot/image/0031.png filter=lfs diff=lfs merge=lfs -text
69
+ data/parkinglot/image/0032.png filter=lfs diff=lfs merge=lfs -text
70
+ data/parkinglot/image/0033.png filter=lfs diff=lfs merge=lfs -text
71
+ data/parkinglot/image/0034.png filter=lfs diff=lfs merge=lfs -text
72
+ data/parkinglot/image/0035.png filter=lfs diff=lfs merge=lfs -text
73
+ data/parkinglot/image/0036.png filter=lfs diff=lfs merge=lfs -text
74
+ data/parkinglot/image/0037.png filter=lfs diff=lfs merge=lfs -text
75
+ data/parkinglot/image/0038.png filter=lfs diff=lfs merge=lfs -text
76
+ data/parkinglot/image/0039.png filter=lfs diff=lfs merge=lfs -text
77
+ data/parkinglot/image/0040.png filter=lfs diff=lfs merge=lfs -text
78
+ data/parkinglot/image/0041.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ .vscode/
2
+ env/
3
+ __pycache__/
4
+ *.ply
5
+ *.npz
6
+ *.npy
7
+ *.pkl
8
+ outputs
9
+ *.obj
10
+ *.ipynb
11
+ *.so
12
+ data
13
+ code/UNKNOWN.egg-info
14
+ code/dist
15
+ code/build
16
+ visualization/imgui.ini
17
+ export
18
+ preprocessing/raw_data
19
+ preprocessing/romp
LICENSE ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+ Section 1 -- Definitions.
71
+
72
+ a. Adapted Material means material subject to Copyright and Similar
73
+ Rights that is derived from or based upon the Licensed Material
74
+ and in which the Licensed Material is translated, altered,
75
+ arranged, transformed, or otherwise modified in a manner requiring
76
+ permission under the Copyright and Similar Rights held by the
77
+ Licensor. For purposes of this Public License, where the Licensed
78
+ Material is a musical work, performance, or sound recording,
79
+ Adapted Material is always produced where the Licensed Material is
80
+ synched in timed relation with a moving image.
81
+
82
+ b. Adapter's License means the license You apply to Your Copyright
83
+ and Similar Rights in Your contributions to Adapted Material in
84
+ accordance with the terms and conditions of this Public License.
85
+
86
+ c. Copyright and Similar Rights means copyright and/or similar rights
87
+ closely related to copyright including, without limitation,
88
+ performance, broadcast, sound recording, and Sui Generis Database
89
+ Rights, without regard to how the rights are labeled or
90
+ categorized. For purposes of this Public License, the rights
91
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
92
+ Rights.
93
+ d. Effective Technological Measures means those measures that, in the
94
+ absence of proper authority, may not be circumvented under laws
95
+ fulfilling obligations under Article 11 of the WIPO Copyright
96
+ Treaty adopted on December 20, 1996, and/or similar international
97
+ agreements.
98
+
99
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
100
+ any other exception or limitation to Copyright and Similar Rights
101
+ that applies to Your use of the Licensed Material.
102
+
103
+ f. Licensed Material means the artistic or literary work, database,
104
+ or other material to which the Licensor applied this Public
105
+ License.
106
+
107
+ g. Licensed Rights means the rights granted to You subject to the
108
+ terms and conditions of this Public License, which are limited to
109
+ all Copyright and Similar Rights that apply to Your use of the
110
+ Licensed Material and that the Licensor has authority to license.
111
+
112
+ h. Licensor means the individual(s) or entity(ies) granting rights
113
+ under this Public License.
114
+
115
+ i. NonCommercial means not primarily intended for or directed towards
116
+ commercial advantage or monetary compensation. For purposes of
117
+ this Public License, the exchange of the Licensed Material for
118
+ other material subject to Copyright and Similar Rights by digital
119
+ file-sharing or similar means is NonCommercial provided there is
120
+ no payment of monetary compensation in connection with the
121
+ exchange.
122
+
123
+ j. Share means to provide material to the public by any means or
124
+ process that requires permission under the Licensed Rights, such
125
+ as reproduction, public display, public performance, distribution,
126
+ dissemination, communication, or importation, and to make material
127
+ available to the public including in ways that members of the
128
+ public may access the material from a place and at a time
129
+ individually chosen by them.
130
+
131
+ k. Sui Generis Database Rights means rights other than copyright
132
+ resulting from Directive 96/9/EC of the European Parliament and of
133
+ the Council of 11 March 1996 on the legal protection of databases,
134
+ as amended and/or succeeded, as well as other essentially
135
+ equivalent rights anywhere in the world.
136
+
137
+ l. You means the individual or entity exercising the Licensed Rights
138
+ under this Public License. Your has a corresponding meaning.
139
+
140
+ Section 2 -- Scope.
141
+
142
+ a. License grant.
143
+
144
+ 1. Subject to the terms and conditions of this Public License,
145
+ the Licensor hereby grants You a worldwide, royalty-free,
146
+ non-sublicensable, non-exclusive, irrevocable license to
147
+ exercise the Licensed Rights in the Licensed Material to:
148
+
149
+ a. reproduce and Share the Licensed Material, in whole or
150
+ in part, for NonCommercial purposes only; and
151
+
152
+ b. produce, reproduce, and Share Adapted Material for
153
+ NonCommercial purposes only.
154
+
155
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
156
+ Exceptions and Limitations apply to Your use, this Public
157
+ License does not apply, and You do not need to comply with
158
+ its terms and conditions.
159
+
160
+ 3. Term. The term of this Public License is specified in Section
161
+ 6(a).
162
+
163
+ 4. Media and formats; technical modifications allowed. The
164
+ Licensor authorizes You to exercise the Licensed Rights in
165
+ all media and formats whether now known or hereafter created,
166
+ and to make technical modifications necessary to do so. The
167
+ Licensor waives and/or agrees not to assert any right or
168
+ authority to forbid You from making technical modifications
169
+ necessary to exercise the Licensed Rights, including
170
+ technical modifications necessary to circumvent Effective
171
+ Technological Measures. For purposes of this Public License,
172
+ simply making modifications authorized by this Section 2(a)
173
+ (4) never produces Adapted Material.
174
+
175
+ 5. Downstream recipients.
176
+
177
+ a. Offer from the Licensor -- Licensed Material. Every
178
+ recipient of the Licensed Material automatically
179
+ receives an offer from the Licensor to exercise the
180
+ Licensed Rights under the terms and conditions of this
181
+ Public License.
182
+
183
+ b. No downstream restrictions. You may not offer or impose
184
+ any additional or different terms or conditions on, or
185
+ apply any Effective Technological Measures to, the
186
+ Licensed Material if doing so restricts exercise of the
187
+ Licensed Rights by any recipient of the Licensed
188
+ Material.
189
+
190
+ 6. No endorsement. Nothing in this Public License constitutes or
191
+ may be construed as permission to assert or imply that You
192
+ are, or that Your use of the Licensed Material is, connected
193
+ with, or sponsored, endorsed, or granted official status by,
194
+ the Licensor or others designated to receive attribution as
195
+ provided in Section 3(a)(1)(A)(i).
196
+
197
+ b. Other rights.
198
+
199
+ 1. Moral rights, such as the right of integrity, are not
200
+ licensed under this Public License, nor are publicity,
201
+ privacy, and/or other similar personality rights; however, to
202
+ the extent possible, the Licensor waives and/or agrees not to
203
+ assert any such rights held by the Licensor to the limited
204
+ extent necessary to allow You to exercise the Licensed
205
+ Rights, but not otherwise.
206
+
207
+ 2. Patent and trademark rights are not licensed under this
208
+ Public License.
209
+
210
+ 3. To the extent possible, the Licensor waives any right to
211
+ collect royalties from You for the exercise of the Licensed
212
+ Rights, whether directly or through a collecting society
213
+ under any voluntary or waivable statutory or compulsory
214
+ licensing scheme. In all other cases the Licensor expressly
215
+ reserves any right to collect such royalties, including when
216
+ the Licensed Material is used other than for NonCommercial
217
+ purposes.
218
+
219
+ Section 3 -- License Conditions.
220
+
221
+ Your exercise of the Licensed Rights is expressly made subject to the
222
+ following conditions.
223
+
224
+ a. Attribution.
225
+
226
+ 1. If You Share the Licensed Material (including in modified
227
+ form), You must:
228
+
229
+ a. retain the following if it is supplied by the Licensor
230
+ with the Licensed Material:
231
+
232
+ i. identification of the creator(s) of the Licensed
233
+ Material and any others designated to receive
234
+ attribution, in any reasonable manner requested by
235
+ the Licensor (including by pseudonym if
236
+ designated);
237
+
238
+ ii. a copyright notice;
239
+
240
+ iii. a notice that refers to this Public License;
241
+
242
+ iv. a notice that refers to the disclaimer of
243
+ warranties;
244
+
245
+ v. a URI or hyperlink to the Licensed Material to the
246
+ extent reasonably practicable;
247
+
248
+ b. indicate if You modified the Licensed Material and
249
+ retain an indication of any previous modifications; and
250
+
251
+ c. indicate the Licensed Material is licensed under this
252
+ Public License, and include the text of, or the URI or
253
+ hyperlink to, this Public License.
254
+
255
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
256
+ reasonable manner based on the medium, means, and context in
257
+ which You Share the Licensed Material. For example, it may be
258
+ reasonable to satisfy the conditions by providing a URI or
259
+ hyperlink to a resource that includes the required
260
+ information.
261
+
262
+ 3. If requested by the Licensor, You must remove any of the
263
+ information required by Section 3(a)(1)(A) to the extent
264
+ reasonably practicable.
265
+
266
+ 4. If You Share Adapted Material You produce, the Adapter's
267
+ License You apply must not prevent recipients of the Adapted
268
+ Material from complying with this Public License.
269
+
270
+ Section 4 -- Sui Generis Database Rights.
271
+
272
+ Where the Licensed Rights include Sui Generis Database Rights that
273
+ apply to Your use of the Licensed Material:
274
+
275
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
276
+ to extract, reuse, reproduce, and Share all or a substantial
277
+ portion of the contents of the database for NonCommercial purposes
278
+ only;
279
+
280
+ b. if You include all or a substantial portion of the database
281
+ contents in a database in which You have Sui Generis Database
282
+ Rights, then the database in which You have Sui Generis Database
283
+ Rights (but not its individual contents) is Adapted Material; and
284
+
285
+ c. You must comply with the conditions in Section 3(a) if You Share
286
+ all or a substantial portion of the contents of the database.
287
+
288
+ For the avoidance of doubt, this Section 4 supplements and does not
289
+ replace Your obligations under this Public License where the Licensed
290
+ Rights include other Copyright and Similar Rights.
291
+
292
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
293
+
294
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
295
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
296
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
297
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
298
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
299
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
300
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
301
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
302
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
303
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
304
+
305
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
306
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
307
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
308
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
309
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
310
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
311
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
312
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
313
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
314
+
315
+ c. The disclaimer of warranties and limitation of liability provided
316
+ above shall be interpreted in a manner that, to the extent
317
+ possible, most closely approximates an absolute disclaimer and
318
+ waiver of all liability.
319
+
320
+ Section 6 -- Term and Termination.
321
+
322
+ a. This Public License applies for the term of the Copyright and
323
+ Similar Rights licensed here. However, if You fail to comply with
324
+ this Public License, then Your rights under this Public License
325
+ terminate automatically.
326
+
327
+ b. Where Your right to use the Licensed Material has terminated under
328
+ Section 6(a), it reinstates:
329
+
330
+ 1. automatically as of the date the violation is cured, provided
331
+ it is cured within 30 days of Your discovery of the
332
+ violation; or
333
+
334
+ 2. upon express reinstatement by the Licensor.
335
+
336
+ For the avoidance of doubt, this Section 6(b) does not affect any
337
+ right the Licensor may have to seek remedies for Your violations
338
+ of this Public License.
339
+
340
+ c. For the avoidance of doubt, the Licensor may also offer the
341
+ Licensed Material under separate terms or conditions or stop
342
+ distributing the Licensed Material at any time; however, doing so
343
+ will not terminate this Public License.
344
+
345
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
346
+ License.
347
+
348
+ Section 7 -- Other Terms and Conditions.
349
+
350
+ a. The Licensor shall not be bound by any additional or different
351
+ terms or conditions communicated by You unless expressly agreed.
352
+
353
+ b. Any arrangements, understandings, or agreements regarding the
354
+ Licensed Material not stated herein are separate from and
355
+ independent of the terms and conditions of this Public License.
356
+
357
+ Section 8 -- Interpretation.
358
+
359
+ a. For the avoidance of doubt, this Public License does not, and
360
+ shall not be interpreted to, reduce, limit, restrict, or impose
361
+ conditions on any use of the Licensed Material that could lawfully
362
+ be made without permission under this Public License.
363
+
364
+ b. To the extent possible, if any provision of this Public License is
365
+ deemed unenforceable, it shall be automatically reformed to the
366
+ minimum extent necessary to make it enforceable. If the provision
367
+ cannot be reformed, it shall be severed from this Public License
368
+ without affecting the enforceability of the remaining terms and
369
+ conditions.
370
+
371
+ c. No term or condition of this Public License will be waived and no
372
+ failure to comply consented to unless expressly agreed to by the
373
+ Licensor.
374
+
375
+ d. Nothing in this Public License constitutes or may be interpreted
376
+ as a limitation upon, or waiver of, any privileges and immunities
377
+ that apply to the Licensor or You, including from the legal
378
+ processes of any jurisdiction or authority.
379
+
380
+ =======================================================================
381
+
382
+ Creative Commons is not a party to its public
383
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
384
+ its public licenses to material it publishes and in those instances
385
+ will be considered the “Licensor.” The text of the Creative Commons
386
+ public licenses is dedicated to the public domain under the CC0 Public
387
+ Domain Dedication. Except for the limited purpose of indicating that
388
+ material is shared under a Creative Commons public license or as
389
+ otherwise permitted by the Creative Commons policies published at
390
+ creativecommons.org/policies, Creative Commons does not authorize the
391
+ use of the trademark "Creative Commons" or any other trademark or logo
392
+ of Creative Commons without its prior written consent including,
393
+ without limitation, in connection with any unauthorized modifications
394
+ to any of its public licenses or any other arrangements,
395
+ understandings, or agreements concerning use of licensed material. For
396
+ the avoidance of doubt, this paragraph does not form part of the
397
+ public licenses.
398
+
399
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,11 +1,74 @@
1
- ---
2
- title: IF3D
3
- emoji: 👁
4
- colorFrom: blue
5
- colorTo: pink
6
- sdk: docker
7
- pinned: false
8
- license: mit
9
- ---
10
-
11
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Vid2Avatar: 3D Avatar Reconstruction from Videos in the Wild via Self-supervised Scene Decomposition
2
+ ## [Paper](https://arxiv.org/abs/2302.11566) | [Video Youtube](https://youtu.be/EGi47YeIeGQ) | [Project Page](https://moygcc.github.io/vid2avatar/) | [SynWild Data](https://synwild.ait.ethz.ch/)
3
+
4
+
5
+ Official Repository for CVPR 2023 paper [*Vid2Avatar: 3D Avatar Reconstruction from Videos in the Wild via Self-supervised Scene Decomposition*](https://arxiv.org/abs/2302.11566).
6
+
7
+ <img src="assets/teaser.png" width="800" height="223"/>
8
+
9
+ ## Getting Started
10
+ * Clone this repo: `git clone https://github.com/MoyGcc/vid2avatar`
11
+ * Create a python virtual environment and activate. `conda create -n v2a python=3.7` and `conda activate v2a`
12
+ * Install dependenices. `cd vid2avatar`, `pip install -r requirement.txt` and `cd code; python setup.py develop`
13
+ * Install [Kaolin](https://kaolin.readthedocs.io/en/v0.10.0/notes/installation.html). We use version 0.10.0.
14
+ * Download [SMPL model](https://smpl.is.tue.mpg.de/download.php) (1.0.0 for Python 2.7 (10 shape PCs)) and move them to the corresponding places:
15
+ ```
16
+ mkdir code/lib/smpl/smpl_model/
17
+ mv /path/to/smpl/models/basicModel_f_lbs_10_207_0_v1.0.0.pkl code/lib/smpl/smpl_model/SMPL_FEMALE.pkl
18
+ mv /path/to/smpl/models/basicmodel_m_lbs_10_207_0_v1.0.0.pkl code/lib/smpl/smpl_model/SMPL_MALE.pkl
19
+ ```
20
+ ## Download preprocessed demo data
21
+ You can quickly start trying out Vid2Avatar with a preprocessed demo sequence including the pre-trained checkpoint. This can be downloaded from [Google drive](https://drive.google.com/drive/u/1/folders/1AUtKSmib7CvpWBCFO6mQ9spVrga_CTU4) which is originally a video clip provided by [NeuMan](https://github.com/apple/ml-neuman). Put this preprocessed demo data under the folder `data/` and put the folder `checkpoints` under `outputs/parkinglot/`.
22
+
23
+ ## Training
24
+ Before training, make sure that the `metaninfo` in the data config file `/code/confs/dataset/video.yaml` does match the expected training video. You can also continue the training by changing the flag `is_continue` in the model config file `code/confs/model/model_w_bg`. And then run:
25
+ ```
26
+ cd code
27
+ python train.py
28
+ ```
29
+ The training usually takes 24-48 hours. The validation results can be found at `outputs/`.
30
+ ## Test
31
+ Run the following command to obtain the final outputs. By default, this loads the latest checkpoint.
32
+ ```
33
+ cd code
34
+ python test.py
35
+ ```
36
+ ## 3D Visualization
37
+ We use [AITViewer](https://github.com/eth-ait/aitviewer) to visualize the human models in 3D. First install AITViewer: `pip install aitviewer imgui==1.4.1`, and then run the following command to visualize the canonical mesh (--mode static) or deformed mesh sequence (--mode dynamic):
38
+ ```
39
+ cd visualization
40
+ python vis.py --mode {MODE} --path {PATH}
41
+ ```
42
+ <p align="center">
43
+ <img src="assets/parkinglot_360.gif" width="623" height="346"/>
44
+ </p>
45
+
46
+ ## Play on custom video
47
+ * We use [ROMP](https://github.com/Arthur151/ROMP#installation) to obtain initial SMPL shape and poses: `pip install --upgrade simple-romp`
48
+ * Install [OpenPose](https://github.com/CMU-Perceptual-Computing-Lab/openpose/blob/master/doc/installation/0_index.md) as well as the python bindings.
49
+ * Put the video frames under the folder `preprocessing/raw_data/{SEQUENCE_NAME}/frames`
50
+ * Modify the preprocessing script `preprocessing/run_preprocessing.sh` accordingly: the data source, sequence name, and the gender. The data source is by default "custom" which will estimate camera intrinsics. If the camera intrinsics are known, it's better if the true camera parameters can be given.
51
+ * Run preprocessing: `cd preprocessing` and `bash run_preprocessing.sh`. The processed data will be stored in `data/`. The intermediate outputs of the preprocessing can be found at `preprocessing/raw_data/{SEQUENCE_NAME}/`
52
+ * Launch training and test in the same way as above. The `metainfo` in the data config file `/code/confs/dataset/video.yaml` should be changed according to the custom video.
53
+
54
+ <p align="center">
55
+ <img src="assets/roger.gif" width="240" height="270"/> <img src="assets/exstrimalik.gif" width="240" height="270"/> <img src="assets/martial.gif" width="240" height="270"/>
56
+ </p>
57
+
58
+ ## Acknowledgement
59
+ We have used codes from other great research work, including [VolSDF](https://github.com/lioryariv/volsdf), [NeRF++](https://github.com/Kai-46/nerfplusplus), [SMPL-X](https://github.com/vchoutas/smplx), [Anim-NeRF](https://github.com/JanaldoChen/Anim-NeRF), [I M Avatar](https://github.com/zhengyuf/IMavatar) and [SNARF](https://github.com/xuchen-ethz/snarf). We sincerely thank the authors for their awesome work! We also thank the authors of [ICON](https://github.com/YuliangXiu/ICON) and [SelfRecon](https://github.com/jby1993/SelfReconCode) for discussing experiment.
60
+
61
+ ## Related Works
62
+ Here are more recent related human body reconstruction projects from our team:
63
+ * [Jiang and Chen et. al. - InstantAvatar: Learning Avatars from Monocular Video in 60 Seconds](https://github.com/tijiang13/InstantAvatar)
64
+ * [Shen and Guo et. al. - X-Avatar: Expressive Human Avatars](https://skype-line.github.io/projects/X-Avatar/)
65
+ * [Yin et. al. - Hi4D: 4D Instance Segmentation of Close Human Interaction](https://yifeiyin04.github.io/Hi4D/)
66
+
67
+ ```
68
+ @inproceedings{guo2023vid2avatar,
69
+ title={Vid2Avatar: 3D Avatar Reconstruction from Videos in the Wild via Self-supervised Scene Decomposition},
70
+ author={Guo, Chen and Jiang, Tianjian and Chen, Xu and Song, Jie and Hilliges, Otmar},
71
+ booktitle = {Computer Vision and Pattern Recognition (CVPR)},
72
+ year = {2023}
73
+ }
74
+ ```
assets/exstrimalik.gif ADDED

Git LFS Details

  • SHA256: ee39544241ba64040c7eeef85f8c4f4b855edb4d2532ea4a42c54dcdf21730a1
  • Pointer size: 133 Bytes
  • Size of remote file: 20.9 MB
assets/martial.gif ADDED

Git LFS Details

  • SHA256: 3d97a85a0fd61d8d0c28ee81c29337d5821da8e07aabdbc3b1e0085a64b9f165
  • Pointer size: 133 Bytes
  • Size of remote file: 33.3 MB
assets/parkinglot_360.gif ADDED

Git LFS Details

  • SHA256: 54dd084d63386eb4de38e070f90786d53bc63669be461f3e150d1f3c08a4805b
  • Pointer size: 133 Bytes
  • Size of remote file: 57.1 MB
assets/roger.gif ADDED

Git LFS Details

  • SHA256: 85d5897cec8b32aa554ae1660f48bdcebeb298c7881b1212c522331c11595d82
  • Pointer size: 133 Bytes
  • Size of remote file: 42.9 MB
assets/smpl_init.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:93541fbf3eb32ade3201565d4b0d793c851aa4173b58a46f1e62f3da292037ce
3
+ size 2415862
assets/teaser.png ADDED

Git LFS Details

  • SHA256: 353410716a784c880a252328b10743981a61737a3efa02db7055801347c35b28
  • Pointer size: 132 Bytes
  • Size of remote file: 1.07 MB
code/check_cuda.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ print("Number of GPUs:", torch.cuda.device_count())
4
+
5
+ print("Torch version:",torch.__version__)
6
+
7
+ print("Is CUDA enabled?",torch.cuda.is_available())
8
+
9
+ print(torch.cuda.device_count())
10
+
11
+ # pip install torch==1.9.1+cu111 torchvision==0.10.1+cu111 torchaudio==0.9.1 -f https://download.pytorch.org/whl/torch_stable.html
code/confs/base.yaml ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ hydra:
2
+ run:
3
+ dir: "../outputs/${exp}/${run}"
4
+
5
+ defaults:
6
+ - model: model_w_bg
7
+ - dataset: video
8
+ - _self_
9
+
10
+ seed: 42
11
+ project_name: "model_w_bg"
12
+ exp: ${dataset.train.type}
13
+ run: ${dataset.metainfo.subject}
code/confs/dataset/video.yaml ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ metainfo:
2
+ gender: 'male'
3
+ data_dir : C:\Users\leob3\vid2avatar\data\parkinglot
4
+ subject: "parkinglot"
5
+ start_frame: 0
6
+ end_frame: 42
7
+
8
+ train:
9
+ type: "Video"
10
+ batch_size: 1
11
+ drop_last: False
12
+ shuffle: True
13
+ worker: 8
14
+
15
+ num_sample : 512
16
+
17
+ valid:
18
+ type: "VideoVal"
19
+ image_id: 0
20
+ batch_size: 1
21
+ drop_last: False
22
+ shuffle: False
23
+ worker: 8
24
+
25
+ num_sample : -1
26
+ pixel_per_batch: 2048
27
+
28
+ test:
29
+ type: "VideoTest"
30
+ image_id: 0
31
+ batch_size: 1
32
+ drop_last: False
33
+ shuffle: False
34
+ worker: 8
35
+
36
+ num_sample : -1
37
+ pixel_per_batch: 2048
code/confs/model/model_w_bg.yaml ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ learning_rate : 5.0e-4
2
+ sched_milestones : [200,500]
3
+ sched_factor : 0.5
4
+ smpl_init: True
5
+ is_continue: False
6
+ use_body_parsing: False
7
+ with_bkgd: True
8
+ using_inpainting: False
9
+ use_smpl_deformer: True
10
+ use_bbox_sampler: False
11
+
12
+ implicit_network:
13
+ feature_vector_size: 256
14
+ d_in: 3
15
+ d_out: 1
16
+ dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]
17
+ init: 'geometry'
18
+ bias: 0.6
19
+ skip_in: [4]
20
+ weight_norm: True
21
+ embedder_mode: 'fourier'
22
+ multires: 6
23
+ cond: 'smpl'
24
+ scene_bounding_sphere: 3.0
25
+ rendering_network:
26
+ feature_vector_size: 256
27
+ mode: "pose"
28
+ d_in: 14
29
+ d_out: 3
30
+ dims: [ 256, 256, 256, 256]
31
+ weight_norm: True
32
+ multires_view: -1
33
+ bg_implicit_network:
34
+ feature_vector_size: 256
35
+ d_in: 4
36
+ d_out: 1
37
+ dims: [ 256, 256, 256, 256, 256, 256, 256, 256 ]
38
+ init: 'none'
39
+ bias: 0.0
40
+ skip_in: [4]
41
+ weight_norm: False
42
+ embedder_mode: 'fourier'
43
+ multires: 10
44
+ cond: 'frame'
45
+ dim_frame_encoding: 32
46
+ bg_rendering_network:
47
+ feature_vector_size: 256
48
+ mode: 'nerf_frame_encoding'
49
+ d_in: 3
50
+ d_out: 3
51
+ dims: [128]
52
+ weight_norm: False
53
+ multires_view: 4
54
+ dim_frame_encoding: 32
55
+ shadow_network:
56
+ d_in: 3
57
+ d_out: 1
58
+ dims: [128, 128]
59
+ weight_norm: False
60
+ density:
61
+ params_init: {beta: 0.1}
62
+ beta_min: 0.0001
63
+ ray_sampler:
64
+ near: 0.0
65
+ N_samples: 64
66
+ N_samples_eval: 128
67
+ N_samples_extra: 32
68
+ eps: 0.1
69
+ beta_iters: 10
70
+ max_total_iters: 5
71
+ N_samples_inverse_sphere: 32
72
+ add_tiny: 1.0e-6
73
+ loss:
74
+ eikonal_weight : 0.1
75
+ bce_weight: 5.0e-3
76
+ opacity_sparse_weight: 3.0e-3
77
+ in_shape_weight: 1.0e-2
code/lib/datasets/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .dataset import Dataset, ValDataset, TestDataset
2
+ from torch.utils.data import DataLoader
3
+
4
+ def find_dataset_using_name(name):
5
+ mapping = {
6
+ "Video": Dataset,
7
+ "VideoVal": ValDataset,
8
+ "VideoTest": TestDataset,
9
+ }
10
+ cls = mapping.get(name, None)
11
+ if cls is None:
12
+ raise ValueError(f"Fail to find dataset {name}")
13
+ return cls
14
+
15
+
16
+ def create_dataset(metainfo, split):
17
+ dataset_cls = find_dataset_using_name(split.type)
18
+ dataset = dataset_cls(metainfo, split)
19
+ return DataLoader(
20
+ dataset,
21
+ batch_size=split.batch_size,
22
+ drop_last=split.drop_last,
23
+ shuffle=split.shuffle,
24
+ num_workers=split.worker,
25
+ pin_memory=True
26
+ )
code/lib/datasets/dataset.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import glob
3
+ import hydra
4
+ import cv2
5
+ import numpy as np
6
+ import torch
7
+ from lib.utils import utils
8
+
9
+
10
+ class Dataset(torch.utils.data.Dataset):
11
+ def __init__(self, metainfo, split):
12
+ root = os.path.join("../data", metainfo.data_dir)
13
+ root = hydra.utils.to_absolute_path(root)
14
+
15
+ self.start_frame = metainfo.start_frame
16
+ self.end_frame = metainfo.end_frame
17
+ self.skip_step = 1
18
+ self.images, self.img_sizes = [], []
19
+ self.training_indices = list(range(metainfo.start_frame, metainfo.end_frame, self.skip_step))
20
+
21
+ # images
22
+ img_dir = os.path.join(root, "image")
23
+ self.img_paths = sorted(glob.glob(f"{img_dir}/*.png"))
24
+
25
+ # only store the image paths to avoid OOM
26
+ self.img_paths = [self.img_paths[i] for i in self.training_indices]
27
+ self.img_size = cv2.imread(self.img_paths[0]).shape[:2]
28
+ self.n_images = len(self.img_paths)
29
+
30
+ # coarse projected SMPL masks, only for sampling
31
+ mask_dir = os.path.join(root, "mask")
32
+ self.mask_paths = sorted(glob.glob(f"{mask_dir}/*.png"))
33
+ self.mask_paths = [self.mask_paths[i] for i in self.training_indices]
34
+
35
+ self.shape = np.load(os.path.join(root, "mean_shape.npy"))
36
+ self.poses = np.load(os.path.join(root, 'poses.npy'))[self.training_indices]
37
+ self.trans = np.load(os.path.join(root, 'normalize_trans.npy'))[self.training_indices]
38
+ # cameras
39
+ camera_dict = np.load(os.path.join(root, "cameras_normalize.npz"))
40
+ scale_mats = [camera_dict['scale_mat_%d' % idx].astype(np.float32) for idx in self.training_indices]
41
+ world_mats = [camera_dict['world_mat_%d' % idx].astype(np.float32) for idx in self.training_indices]
42
+
43
+ self.scale = 1 / scale_mats[0][0, 0]
44
+
45
+ self.intrinsics_all = []
46
+ self.pose_all = []
47
+ for scale_mat, world_mat in zip(scale_mats, world_mats):
48
+ P = world_mat @ scale_mat
49
+ P = P[:3, :4]
50
+ intrinsics, pose = utils.load_K_Rt_from_P(None, P)
51
+ self.intrinsics_all.append(torch.from_numpy(intrinsics).float())
52
+ self.pose_all.append(torch.from_numpy(pose).float())
53
+ assert len(self.intrinsics_all) == len(self.pose_all)
54
+
55
+ # other properties
56
+ self.num_sample = split.num_sample
57
+ self.sampling_strategy = "weighted"
58
+
59
+ def __len__(self):
60
+ return self.n_images
61
+
62
+ def __getitem__(self, idx):
63
+ # normalize RGB
64
+ img = cv2.imread(self.img_paths[idx])
65
+ # preprocess: BGR -> RGB -> Normalize
66
+
67
+ img = img[:, :, ::-1] / 255
68
+
69
+ mask = cv2.imread(self.mask_paths[idx])
70
+ # preprocess: BGR -> Gray -> Mask
71
+ mask = cv2.cvtColor(mask, cv2.COLOR_BGR2GRAY) > 0
72
+
73
+ img_size = self.img_size
74
+
75
+ uv = np.mgrid[:img_size[0], :img_size[1]].astype(np.int32)
76
+ uv = np.flip(uv, axis=0).copy().transpose(1, 2, 0).astype(np.float32)
77
+
78
+ smpl_params = torch.zeros([86]).float()
79
+ smpl_params[0] = torch.from_numpy(np.asarray(self.scale)).float()
80
+
81
+ smpl_params[1:4] = torch.from_numpy(self.trans[idx]).float()
82
+ smpl_params[4:76] = torch.from_numpy(self.poses[idx]).float()
83
+ smpl_params[76:] = torch.from_numpy(self.shape).float()
84
+
85
+ if self.num_sample > 0:
86
+ data = {
87
+ "rgb": img,
88
+ "uv": uv,
89
+ "object_mask": mask,
90
+ }
91
+
92
+ samples, index_outside = utils.weighted_sampling(data, img_size, self.num_sample)
93
+ inputs = {
94
+ "uv": samples["uv"].astype(np.float32),
95
+ "intrinsics": self.intrinsics_all[idx],
96
+ "pose": self.pose_all[idx],
97
+ "smpl_params": smpl_params,
98
+ 'index_outside': index_outside,
99
+ "idx": idx
100
+ }
101
+ images = {"rgb": samples["rgb"].astype(np.float32)}
102
+ return inputs, images
103
+ else:
104
+ inputs = {
105
+ "uv": uv.reshape(-1, 2).astype(np.float32),
106
+ "intrinsics": self.intrinsics_all[idx],
107
+ "pose": self.pose_all[idx],
108
+ "smpl_params": smpl_params,
109
+ "idx": idx
110
+ }
111
+ images = {
112
+ "rgb": img.reshape(-1, 3).astype(np.float32),
113
+ "img_size": self.img_size
114
+ }
115
+ return inputs, images
116
+
117
+ class ValDataset(torch.utils.data.Dataset):
118
+ def __init__(self, metainfo, split):
119
+ self.dataset = Dataset(metainfo, split)
120
+ self.img_size = self.dataset.img_size
121
+
122
+ self.total_pixels = np.prod(self.img_size)
123
+ self.pixel_per_batch = split.pixel_per_batch
124
+
125
+ def __len__(self):
126
+ return 1
127
+
128
+ def __getitem__(self, idx):
129
+ image_id = int(np.random.choice(len(self.dataset), 1))
130
+ self.data = self.dataset[image_id]
131
+ inputs, images = self.data
132
+
133
+ inputs = {
134
+ "uv": inputs["uv"],
135
+ "intrinsics": inputs['intrinsics'],
136
+ "pose": inputs['pose'],
137
+ "smpl_params": inputs["smpl_params"],
138
+ 'image_id': image_id,
139
+ "idx": inputs['idx']
140
+ }
141
+ images = {
142
+ "rgb": images["rgb"],
143
+ "img_size": images["img_size"],
144
+ 'pixel_per_batch': self.pixel_per_batch,
145
+ 'total_pixels': self.total_pixels
146
+ }
147
+ return inputs, images
148
+
149
+ class TestDataset(torch.utils.data.Dataset):
150
+ def __init__(self, metainfo, split):
151
+ self.dataset = Dataset(metainfo, split)
152
+
153
+ self.img_size = self.dataset.img_size
154
+
155
+ self.total_pixels = np.prod(self.img_size)
156
+ self.pixel_per_batch = split.pixel_per_batch
157
+ def __len__(self):
158
+ return len(self.dataset)
159
+
160
+ def __getitem__(self, idx):
161
+ data = self.dataset[idx]
162
+
163
+ inputs, images = data
164
+ inputs = {
165
+ "uv": inputs["uv"],
166
+ "intrinsics": inputs['intrinsics'],
167
+ "pose": inputs['pose'],
168
+ "smpl_params": inputs["smpl_params"],
169
+ "idx": inputs['idx']
170
+ }
171
+ images = {
172
+ "rgb": images["rgb"],
173
+ "img_size": images["img_size"]
174
+ }
175
+ return inputs, images, self.pixel_per_batch, self.total_pixels, idx
code/lib/libmise/mise.cp37-win_amd64.pyd ADDED
Binary file (180 kB). View file
 
code/lib/libmise/mise.cpp ADDED
The diff for this file is too large to render. See raw diff
 
code/lib/libmise/mise.pyx ADDED
@@ -0,0 +1,370 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # distutils: language = c++
2
+ cimport cython
3
+ from libc.stdint cimport int32_t, int64_t
4
+ from cython.operator cimport dereference as dref
5
+ from libcpp.vector cimport vector
6
+ from libcpp.map cimport map
7
+ from libc.math cimport isnan, NAN
8
+ import numpy as np
9
+
10
+
11
+ cdef struct Vector3D:
12
+ int x, y, z
13
+
14
+
15
+ cdef struct Voxel:
16
+ Vector3D loc
17
+ unsigned int level
18
+ bint is_leaf
19
+ unsigned long children[2][2][2]
20
+
21
+
22
+ cdef struct GridPoint:
23
+ Vector3D loc
24
+ double value
25
+ bint known
26
+
27
+
28
+ cdef inline unsigned long vec_to_idx(Vector3D coord, long resolution):
29
+ cdef unsigned long idx
30
+ idx = resolution * resolution * coord.x + resolution * coord.y + coord.z
31
+ return idx
32
+
33
+
34
+ cdef class MISE:
35
+ cdef vector[Voxel] voxels
36
+ cdef vector[GridPoint] grid_points
37
+ cdef map[long, long] grid_point_hash
38
+ cdef readonly int resolution_0
39
+ cdef readonly int depth
40
+ cdef readonly double threshold
41
+ cdef readonly int voxel_size_0
42
+ cdef readonly int resolution
43
+
44
+ def __cinit__(self, int resolution_0, int depth, double threshold):
45
+ self.resolution_0 = resolution_0
46
+ self.depth = depth
47
+ self.threshold = threshold
48
+ self.voxel_size_0 = (1 << depth)
49
+ self.resolution = resolution_0 * self.voxel_size_0
50
+
51
+ # Create initial voxels
52
+ self.voxels.reserve(resolution_0 * resolution_0 * resolution_0)
53
+
54
+ cdef Voxel voxel
55
+ cdef GridPoint point
56
+ cdef Vector3D loc
57
+ cdef int i, j, k
58
+ for i in range(resolution_0):
59
+ for j in range(resolution_0):
60
+ for k in range (resolution_0):
61
+ loc = Vector3D(
62
+ i * self.voxel_size_0,
63
+ j * self.voxel_size_0,
64
+ k * self.voxel_size_0,
65
+ )
66
+ voxel = Voxel(
67
+ loc=loc,
68
+ level=0,
69
+ is_leaf=True,
70
+ )
71
+
72
+ assert(self.voxels.size() == vec_to_idx(Vector3D(i, j, k), resolution_0))
73
+ self.voxels.push_back(voxel)
74
+
75
+ # Create initial grid points
76
+ self.grid_points.reserve((resolution_0 + 1) * (resolution_0 + 1) * (resolution_0 + 1))
77
+ for i in range(resolution_0 + 1):
78
+ for j in range(resolution_0 + 1):
79
+ for k in range(resolution_0 + 1):
80
+ loc = Vector3D(
81
+ i * self.voxel_size_0,
82
+ j * self.voxel_size_0,
83
+ k * self.voxel_size_0,
84
+ )
85
+ assert(self.grid_points.size() == vec_to_idx(Vector3D(i, j, k), resolution_0 + 1))
86
+ self.add_grid_point(loc)
87
+
88
+ def update(self, int64_t[:, :] points, double[:] values):
89
+ """Update points and set their values. Also determine all active voxels and subdivide them."""
90
+ assert(points.shape[0] == values.shape[0])
91
+ assert(points.shape[1] == 3)
92
+ cdef Vector3D loc
93
+ cdef long idx
94
+ cdef int i
95
+
96
+ # Find all indices of point and set value
97
+ for i in range(points.shape[0]):
98
+ loc = Vector3D(points[i, 0], points[i, 1], points[i, 2])
99
+ idx = self.get_grid_point_idx(loc)
100
+ if idx == -1:
101
+ raise ValueError('Point not in grid!')
102
+ self.grid_points[idx].value = values[i]
103
+ self.grid_points[idx].known = True
104
+ # Subdivide activate voxels and add new points
105
+ self.subdivide_voxels()
106
+
107
+ def query(self):
108
+ """Query points to evaluate."""
109
+ # Find all points with unknown value
110
+ cdef vector[Vector3D] points
111
+ cdef int n_unknown = 0
112
+ for p in self.grid_points:
113
+ if not p.known:
114
+ n_unknown += 1
115
+
116
+ points.reserve(n_unknown)
117
+ for p in self.grid_points:
118
+ if not p.known:
119
+ points.push_back(p.loc)
120
+
121
+ # Convert to numpy
122
+ points_np = np.zeros((points.size(), 3), dtype=np.int64)
123
+ cdef int64_t[:, :] points_view = points_np
124
+ for i in range(points.size()):
125
+ points_view[i, 0] = points[i].x
126
+ points_view[i, 1] = points[i].y
127
+ points_view[i, 2] = points[i].z
128
+
129
+ return points_np
130
+
131
+ def to_dense(self):
132
+ """Output dense matrix at highest resolution."""
133
+ out_array = np.full((self.resolution + 1,) * 3, np.nan)
134
+ cdef double[:, :, :] out_view = out_array
135
+ cdef GridPoint point
136
+ cdef int i, j, k
137
+
138
+ for point in self.grid_points:
139
+ # Take voxel for which points is upper left corner
140
+ # assert(point.known)
141
+ out_view[point.loc.x, point.loc.y, point.loc.z] = point.value
142
+
143
+ # Complete along x axis
144
+ for i in range(1, self.resolution + 1):
145
+ for j in range(self.resolution + 1):
146
+ for k in range(self.resolution + 1):
147
+ if isnan(out_view[i, j, k]):
148
+ out_view[i, j, k] = out_view[i-1, j, k]
149
+
150
+ # Complete along y axis
151
+ for i in range(self.resolution + 1):
152
+ for j in range(1, self.resolution + 1):
153
+ for k in range(self.resolution + 1):
154
+ if isnan(out_view[i, j, k]):
155
+ out_view[i, j, k] = out_view[i, j-1, k]
156
+
157
+
158
+ # Complete along z axis
159
+ for i in range(self.resolution + 1):
160
+ for j in range(self.resolution + 1):
161
+ for k in range(1, self.resolution + 1):
162
+ if isnan(out_view[i, j, k]):
163
+ out_view[i, j, k] = out_view[i, j, k-1]
164
+ assert(not isnan(out_view[i, j, k]))
165
+ return out_array
166
+
167
+ def get_points(self):
168
+ points_np = np.zeros((self.grid_points.size(), 3), dtype=np.int64)
169
+ values_np = np.zeros((self.grid_points.size()), dtype=np.float64)
170
+
171
+ cdef long[:, :] points_view = points_np
172
+ cdef double[:] values_view = values_np
173
+ cdef Vector3D loc
174
+ cdef int i
175
+
176
+ for i in range(self.grid_points.size()):
177
+ loc = self.grid_points[i].loc
178
+ points_view[i, 0] = loc.x
179
+ points_view[i, 1] = loc.y
180
+ points_view[i, 2] = loc.z
181
+ values_view[i] = self.grid_points[i].value
182
+
183
+ return points_np, values_np
184
+
185
+ cdef void subdivide_voxels(self) except +:
186
+ cdef vector[bint] next_to_positive
187
+ cdef vector[bint] next_to_negative
188
+ cdef int i, j, k
189
+ cdef long idx
190
+ cdef Vector3D loc, adj_loc
191
+
192
+ # Initialize vectors
193
+ next_to_positive.resize(self.voxels.size(), False)
194
+ next_to_negative.resize(self.voxels.size(), False)
195
+
196
+ # Iterate over grid points and mark voxels active
197
+ # TODO: can move this to update operation and add attibute to voxel
198
+ for grid_point in self.grid_points:
199
+ loc = grid_point.loc
200
+ if not grid_point.known:
201
+ continue
202
+
203
+ # Iterate over the 8 adjacent voxels
204
+ for i in range(-1, 1):
205
+ for j in range(-1, 1):
206
+ for k in range(-1, 1):
207
+ adj_loc = Vector3D(
208
+ x=loc.x + i,
209
+ y=loc.y + j,
210
+ z=loc.z + k,
211
+ )
212
+ idx = self.get_voxel_idx(adj_loc)
213
+ if idx == -1:
214
+ continue
215
+
216
+ if grid_point.value >= self.threshold:
217
+ next_to_positive[idx] = True
218
+ if grid_point.value <= self.threshold:
219
+ next_to_negative[idx] = True
220
+
221
+ cdef int n_subdivide = 0
222
+
223
+ for idx in range(self.voxels.size()):
224
+ if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth:
225
+ continue
226
+ if next_to_positive[idx] and next_to_negative[idx]:
227
+ n_subdivide += 1
228
+
229
+ self.voxels.reserve(self.voxels.size() + 8 * n_subdivide)
230
+ self.grid_points.reserve(self.voxels.size() + 19 * n_subdivide)
231
+
232
+ for idx in range(self.voxels.size()):
233
+ if not self.voxels[idx].is_leaf or self.voxels[idx].level == self.depth:
234
+ continue
235
+ if next_to_positive[idx] and next_to_negative[idx]:
236
+ self.subdivide_voxel(idx)
237
+
238
+ cdef void subdivide_voxel(self, long idx):
239
+ cdef Voxel voxel
240
+ cdef GridPoint point
241
+ cdef Vector3D loc0 = self.voxels[idx].loc
242
+ cdef Vector3D loc
243
+ cdef int new_level = self.voxels[idx].level + 1
244
+ cdef int new_size = 1 << (self.depth - new_level)
245
+ assert(new_level <= self.depth)
246
+ assert(1 <= new_size <= self.voxel_size_0)
247
+
248
+ # Current voxel is not leaf anymore
249
+ self.voxels[idx].is_leaf = False
250
+ # Add new voxels
251
+ cdef int i, j, k
252
+ for i in range(2):
253
+ for j in range(2):
254
+ for k in range(2):
255
+ loc = Vector3D(
256
+ x=loc0.x + i * new_size,
257
+ y=loc0.y + j * new_size,
258
+ z=loc0.z + k * new_size,
259
+ )
260
+ voxel = Voxel(
261
+ loc=loc,
262
+ level=new_level,
263
+ is_leaf=True
264
+ )
265
+
266
+ self.voxels[idx].children[i][j][k] = self.voxels.size()
267
+ self.voxels.push_back(voxel)
268
+
269
+ # Add new grid points
270
+ for i in range(3):
271
+ for j in range(3):
272
+ for k in range(3):
273
+ loc = Vector3D(
274
+ loc0.x + i * new_size,
275
+ loc0.y + j * new_size,
276
+ loc0.z + k * new_size,
277
+ )
278
+
279
+ # Only add new grid points
280
+ if self.get_grid_point_idx(loc) == -1:
281
+ self.add_grid_point(loc)
282
+
283
+
284
+ @cython.cdivision(True)
285
+ cdef long get_voxel_idx(self, Vector3D loc) except +:
286
+ """Utility function for getting voxel index corresponding to 3D coordinates."""
287
+ # Shorthands
288
+ cdef long resolution = self.resolution
289
+ cdef long resolution_0 = self.resolution_0
290
+ cdef long depth = self.depth
291
+ cdef long voxel_size_0 = self.voxel_size_0
292
+
293
+ # Return -1 if point lies outside bounds
294
+ if not (0 <= loc.x < resolution and 0<= loc.y < resolution and 0 <= loc.z < resolution):
295
+ return -1
296
+
297
+ # Coordinates in coarse voxel grid
298
+ cdef Vector3D loc0 = Vector3D(
299
+ x=loc.x >> depth,
300
+ y=loc.y >> depth,
301
+ z=loc.z >> depth,
302
+ )
303
+
304
+ # Initial voxels
305
+ cdef int idx = vec_to_idx(loc0, resolution_0)
306
+ cdef Voxel voxel = self.voxels[idx]
307
+ assert(voxel.loc.x == loc0.x * voxel_size_0)
308
+ assert(voxel.loc.y == loc0.y * voxel_size_0)
309
+ assert(voxel.loc.z == loc0.z * voxel_size_0)
310
+
311
+ # Relative coordinates
312
+ cdef Vector3D loc_rel = Vector3D(
313
+ x=loc.x - (loc0.x << depth),
314
+ y=loc.y - (loc0.y << depth),
315
+ z=loc.z - (loc0.z << depth),
316
+ )
317
+
318
+ cdef Vector3D loc_offset
319
+ cdef long voxel_size = voxel_size_0
320
+
321
+ while not voxel.is_leaf:
322
+ voxel_size = voxel_size >> 1
323
+ assert(voxel_size >= 1)
324
+
325
+ # Determine child
326
+ loc_offset = Vector3D(
327
+ x=1 if (loc_rel.x >= voxel_size) else 0,
328
+ y=1 if (loc_rel.y >= voxel_size) else 0,
329
+ z=1 if (loc_rel.z >= voxel_size) else 0,
330
+ )
331
+ # New voxel
332
+ idx = voxel.children[loc_offset.x][loc_offset.y][loc_offset.z]
333
+ voxel = self.voxels[idx]
334
+
335
+ # New relative coordinates
336
+ loc_rel = Vector3D(
337
+ x=loc_rel.x - loc_offset.x * voxel_size,
338
+ y=loc_rel.y - loc_offset.y * voxel_size,
339
+ z=loc_rel.z - loc_offset.z * voxel_size,
340
+ )
341
+
342
+ assert(0<= loc_rel.x < voxel_size)
343
+ assert(0<= loc_rel.y < voxel_size)
344
+ assert(0<= loc_rel.z < voxel_size)
345
+
346
+
347
+ # Return idx
348
+ return idx
349
+
350
+
351
+ cdef inline void add_grid_point(self, Vector3D loc):
352
+ cdef GridPoint point = GridPoint(
353
+ loc=loc,
354
+ value=0.,
355
+ known=False,
356
+ )
357
+ self.grid_point_hash[vec_to_idx(loc, self.resolution + 1)] = self.grid_points.size()
358
+ self.grid_points.push_back(point)
359
+
360
+ cdef inline int get_grid_point_idx(self, Vector3D loc):
361
+ p_idx = self.grid_point_hash.find(vec_to_idx(loc, self.resolution + 1))
362
+ if p_idx == self.grid_point_hash.end():
363
+ return -1
364
+
365
+ cdef int idx = dref(p_idx).second
366
+ assert(self.grid_points[idx].loc.x == loc.x)
367
+ assert(self.grid_points[idx].loc.y == loc.y)
368
+ assert(self.grid_points[idx].loc.z == loc.z)
369
+
370
+ return idx
code/lib/model/body_model_params.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class BodyModelParams(nn.Module):
5
+ def __init__(self, num_frames, model_type='smpl'):
6
+ super(BodyModelParams, self).__init__()
7
+ self.num_frames = num_frames
8
+ self.model_type = model_type
9
+ self.params_dim = {
10
+ 'betas': 10,
11
+ 'global_orient': 3,
12
+ 'transl': 3,
13
+ }
14
+ if model_type == 'smpl':
15
+ self.params_dim.update({
16
+ 'body_pose': 69,
17
+ })
18
+ else:
19
+ assert ValueError(f'Unknown model type {model_type}, exiting!')
20
+
21
+ self.param_names = self.params_dim.keys()
22
+
23
+ for param_name in self.param_names:
24
+ if param_name == 'betas':
25
+ param = nn.Embedding(1, self.params_dim[param_name])
26
+ param.weight.data.fill_(0)
27
+ param.weight.requires_grad = False
28
+ setattr(self, param_name, param)
29
+ else:
30
+ param = nn.Embedding(num_frames, self.params_dim[param_name])
31
+ param.weight.data.fill_(0)
32
+ param.weight.requires_grad = False
33
+ setattr(self, param_name, param)
34
+
35
+ def init_parameters(self, param_name, data, requires_grad=False):
36
+ getattr(self, param_name).weight.data = data[..., :self.params_dim[param_name]]
37
+ getattr(self, param_name).weight.requires_grad = requires_grad
38
+
39
+ def set_requires_grad(self, param_name, requires_grad=True):
40
+ getattr(self, param_name).weight.requires_grad = requires_grad
41
+
42
+ def forward(self, frame_ids):
43
+ params = {}
44
+ for param_name in self.param_names:
45
+ if param_name == 'betas':
46
+ params[param_name] = getattr(self, param_name)(torch.zeros_like(frame_ids))
47
+ else:
48
+ params[param_name] = getattr(self, param_name)(frame_ids)
49
+ return params
code/lib/model/deformer.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from .smpl import SMPLServer
4
+ from pytorch3d import ops
5
+
6
+ class SMPLDeformer():
7
+ def __init__(self, max_dist=0.1, K=1, gender='female', betas=None):
8
+ super().__init__()
9
+
10
+ self.max_dist = max_dist
11
+ self.K = K
12
+ self.smpl = SMPLServer(gender=gender)
13
+ smpl_params_canoical = self.smpl.param_canonical.clone()
14
+ smpl_params_canoical[:, 76:] = torch.tensor(betas).float().to(self.smpl.param_canonical.device)
15
+ cano_scale, cano_transl, cano_thetas, cano_betas = torch.split(smpl_params_canoical, [1, 3, 72, 10], dim=1)
16
+ smpl_output = self.smpl(cano_scale, cano_transl, cano_thetas, cano_betas)
17
+ self.smpl_verts = smpl_output['smpl_verts']
18
+ self.smpl_weights = smpl_output['smpl_weights']
19
+ def forward(self, x, smpl_tfs, return_weights=True, inverse=False, smpl_verts=None):
20
+ if x.shape[0] == 0: return x
21
+ if smpl_verts is None:
22
+ weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
23
+ else:
24
+ weights, outlier_mask = self.query_skinning_weights_smpl_multi(x[None], smpl_verts=smpl_verts[0], smpl_weights=self.smpl_weights)
25
+ if return_weights:
26
+ return weights
27
+
28
+ x_transformed = skinning(x.unsqueeze(0), weights, smpl_tfs, inverse=inverse)
29
+
30
+ return x_transformed.squeeze(0), outlier_mask
31
+ def forward_skinning(self, xc, cond, smpl_tfs):
32
+ weights, _ = self.query_skinning_weights_smpl_multi(xc, smpl_verts=self.smpl_verts[0], smpl_weights=self.smpl_weights)
33
+ x_transformed = skinning(xc, weights, smpl_tfs, inverse=False)
34
+
35
+ return x_transformed
36
+
37
+ def query_skinning_weights_smpl_multi(self, pts, smpl_verts, smpl_weights):
38
+
39
+ distance_batch, index_batch, neighbor_points = ops.knn_points(pts, smpl_verts.unsqueeze(0),
40
+ K=self.K, return_nn=True)
41
+ distance_batch = torch.clamp(distance_batch, max=4)
42
+ weights_conf = torch.exp(-distance_batch)
43
+ distance_batch = torch.sqrt(distance_batch)
44
+ weights_conf = weights_conf / weights_conf.sum(-1, keepdim=True)
45
+ index_batch = index_batch[0]
46
+ weights = smpl_weights[:, index_batch, :]
47
+ weights = torch.sum(weights * weights_conf.unsqueeze(-1), dim=-2).detach()
48
+
49
+ outlier_mask = (distance_batch[..., 0] > self.max_dist)[0]
50
+ return weights, outlier_mask
51
+
52
+ def query_weights(self, xc):
53
+ weights = self.forward(xc, None, return_weights=True, inverse=False)
54
+ return weights
55
+
56
+ def forward_skinning_normal(self, xc, normal, cond, tfs, inverse = False):
57
+ if normal.ndim == 2:
58
+ normal = normal.unsqueeze(0)
59
+ w = self.query_weights(xc[0], cond)
60
+
61
+ p_h = F.pad(normal, (0, 1), value=0)
62
+
63
+ if inverse:
64
+ # p:num_point, n:num_bone, i,j: num_dim+1
65
+ tf_w = torch.einsum('bpn,bnij->bpij', w.double(), tfs.double())
66
+ p_h = torch.einsum('bpij,bpj->bpi', tf_w.inverse(), p_h.double()).float()
67
+ else:
68
+ p_h = torch.einsum('bpn, bnij, bpj->bpi', w.double(), tfs.double(), p_h.double()).float()
69
+
70
+ return p_h[:, :, :3]
71
+
72
+ def skinning(x, w, tfs, inverse=False):
73
+ """Linear blend skinning
74
+ Args:
75
+ x (tensor): canonical points. shape: [B, N, D]
76
+ w (tensor): conditional input. [B, N, J]
77
+ tfs (tensor): bone transformation matrices. shape: [B, J, D+1, D+1]
78
+ Returns:
79
+ x (tensor): skinned points. shape: [B, N, D]
80
+ """
81
+ x_h = F.pad(x, (0, 1), value=1.0)
82
+
83
+ if inverse:
84
+ # p:n_point, n:n_bone, i,k: n_dim+1
85
+ w_tf = torch.einsum("bpn,bnij->bpij", w, tfs)
86
+ x_h = torch.einsum("bpij,bpj->bpi", w_tf.inverse(), x_h)
87
+ else:
88
+ x_h = torch.einsum("bpn,bnij,bpj->bpi", w, tfs, x_h)
89
+ return x_h[:, :, :3]
code/lib/model/density.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ class Density(nn.Module):
5
+ def __init__(self, params_init={}):
6
+ super().__init__()
7
+ for p in params_init:
8
+ param = nn.Parameter(torch.tensor(params_init[p]))
9
+ setattr(self, p, param)
10
+
11
+ def forward(self, sdf, beta=None):
12
+ return self.density_func(sdf, beta=beta)
13
+
14
+
15
+ class LaplaceDensity(Density): # alpha * Laplace(loc=0, scale=beta).cdf(-sdf)
16
+ def __init__(self, params_init={}, beta_min=0.0001):
17
+ super().__init__(params_init=params_init)
18
+ self.beta_min = torch.tensor(beta_min).cuda()
19
+
20
+ def density_func(self, sdf, beta=None):
21
+ if beta is None:
22
+ beta = self.get_beta()
23
+
24
+ alpha = 1 / beta
25
+ return alpha * (0.5 + 0.5 * sdf.sign() * torch.expm1(-sdf.abs() / beta))
26
+
27
+ def get_beta(self):
28
+ beta = self.beta.abs() + self.beta_min
29
+ return beta
30
+
31
+
32
+ class AbsDensity(Density): # like NeRF++
33
+ def density_func(self, sdf, beta=None):
34
+ return torch.abs(sdf)
35
+
36
+
37
+ class SimpleDensity(Density): # like NeRF
38
+ def __init__(self, params_init={}, noise_std=1.0):
39
+ super().__init__(params_init=params_init)
40
+ self.noise_std = noise_std
41
+
42
+ def density_func(self, sdf, beta=None):
43
+ if self.training and self.noise_std > 0.0:
44
+ noise = torch.randn(sdf.shape).cuda() * self.noise_std
45
+ sdf = sdf + noise
46
+ return torch.relu(sdf)
code/lib/model/embedders.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class Embedder:
4
+ def __init__(self, **kwargs):
5
+ self.kwargs = kwargs
6
+ self.create_embedding_fn()
7
+
8
+ def create_embedding_fn(self):
9
+ embed_fns = []
10
+ d = self.kwargs['input_dims']
11
+ out_dim = 0
12
+ if self.kwargs['include_input']:
13
+ embed_fns.append(lambda x: x)
14
+ out_dim += d
15
+
16
+ max_freq = self.kwargs['max_freq_log2']
17
+ N_freqs = self.kwargs['num_freqs']
18
+
19
+ if self.kwargs['log_sampling']:
20
+ freq_bands = 2. ** torch.linspace(0., max_freq, N_freqs)
21
+ else:
22
+ freq_bands = torch.linspace(2.**0., 2.**max_freq, N_freqs)
23
+
24
+ for freq in freq_bands:
25
+ for p_fn in self.kwargs['periodic_fns']:
26
+ embed_fns.append(lambda x, p_fn=p_fn,
27
+ freq=freq: p_fn(x * freq))
28
+ out_dim += d
29
+
30
+ self.embed_fns = embed_fns
31
+ self.out_dim = out_dim
32
+
33
+ def embed(self, inputs):
34
+ return torch.cat([fn(inputs) for fn in self.embed_fns], -1)
35
+
36
+ def get_embedder(multires, input_dims=3, mode='fourier'):
37
+ embed_kwargs = {
38
+ 'include_input': True,
39
+ 'input_dims': input_dims,
40
+ 'max_freq_log2': multires-1,
41
+ 'num_freqs': multires,
42
+ 'log_sampling': True,
43
+ 'periodic_fns': [torch.sin, torch.cos],
44
+ }
45
+ if mode == 'fourier':
46
+ embedder_obj = Embedder(**embed_kwargs)
47
+
48
+
49
+ def embed(x, eo=embedder_obj): return eo.embed(x)
50
+ return embed, embedder_obj.out_dim
code/lib/model/loss.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+
5
+ class Loss(nn.Module):
6
+ def __init__(self, opt):
7
+ super().__init__()
8
+ self.eikonal_weight = opt.eikonal_weight
9
+ self.bce_weight = opt.bce_weight
10
+ self.opacity_sparse_weight = opt.opacity_sparse_weight
11
+ self.in_shape_weight = opt.in_shape_weight
12
+ self.eps = 1e-6
13
+ self.milestone = 200
14
+ self.l1_loss = nn.L1Loss(reduction='mean')
15
+ self.l2_loss = nn.MSELoss(reduction='mean')
16
+
17
+ # L1 reconstruction loss for RGB values
18
+ def get_rgb_loss(self, rgb_values, rgb_gt):
19
+ rgb_loss = self.l1_loss(rgb_values, rgb_gt)
20
+ return rgb_loss
21
+
22
+ # Eikonal loss introduced in IGR
23
+ def get_eikonal_loss(self, grad_theta):
24
+ eikonal_loss = ((grad_theta.norm(2, dim=-1) - 1)**2).mean()
25
+ return eikonal_loss
26
+
27
+ # BCE loss for clear boundary
28
+ def get_bce_loss(self, acc_map):
29
+ binary_loss = -1 * (acc_map * (acc_map + self.eps).log() + (1-acc_map) * (1 - acc_map + self.eps).log()).mean() * 2
30
+ return binary_loss
31
+
32
+ # Global opacity sparseness regularization
33
+ def get_opacity_sparse_loss(self, acc_map, index_off_surface):
34
+ opacity_sparse_loss = self.l1_loss(acc_map[index_off_surface], torch.zeros_like(acc_map[index_off_surface]))
35
+ return opacity_sparse_loss
36
+
37
+ # Optional: This loss helps to stablize the training in the very beginning
38
+ def get_in_shape_loss(self, acc_map, index_in_surface):
39
+ in_shape_loss = self.l1_loss(acc_map[index_in_surface], torch.ones_like(acc_map[index_in_surface]))
40
+ return in_shape_loss
41
+
42
+ def forward(self, model_outputs, ground_truth):
43
+ nan_filter = ~torch.any(model_outputs['rgb_values'].isnan(), dim=1)
44
+ rgb_gt = ground_truth['rgb'][0].cuda()
45
+ rgb_loss = self.get_rgb_loss(model_outputs['rgb_values'][nan_filter], rgb_gt[nan_filter])
46
+ eikonal_loss = self.get_eikonal_loss(model_outputs['grad_theta'])
47
+ bce_loss = self.get_bce_loss(model_outputs['acc_map'])
48
+ opacity_sparse_loss = self.get_opacity_sparse_loss(model_outputs['acc_map'], model_outputs['index_off_surface'])
49
+ in_shape_loss = self.get_in_shape_loss(model_outputs['acc_map'], model_outputs['index_in_surface'])
50
+ curr_epoch_for_loss = min(self.milestone, model_outputs['epoch']) # will not increase after the milestone
51
+
52
+ loss = rgb_loss + \
53
+ self.eikonal_weight * eikonal_loss + \
54
+ self.bce_weight * bce_loss + \
55
+ self.opacity_sparse_weight * (1 + curr_epoch_for_loss ** 2 / 40) * opacity_sparse_loss + \
56
+ self.in_shape_weight * (1 - curr_epoch_for_loss / self.milestone) * in_shape_loss
57
+ return {
58
+ 'loss': loss,
59
+ 'rgb_loss': rgb_loss,
60
+ 'eikonal_loss': eikonal_loss,
61
+ 'bce_loss': bce_loss,
62
+ 'opacity_sparse_loss': opacity_sparse_loss,
63
+ 'in_shape_loss': in_shape_loss,
64
+ }
code/lib/model/networks.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch.nn as nn
2
+ import torch
3
+ import numpy as np
4
+ from .embedders import get_embedder
5
+
6
+ class ImplicitNet(nn.Module):
7
+ def __init__(self, opt):
8
+ super().__init__()
9
+
10
+ dims = [opt.d_in] + list(
11
+ opt.dims) + [opt.d_out + opt.feature_vector_size]
12
+ self.num_layers = len(dims)
13
+ self.skip_in = opt.skip_in
14
+ self.embed_fn = None
15
+ self.opt = opt
16
+
17
+ if opt.multires > 0:
18
+ embed_fn, input_ch = get_embedder(opt.multires, input_dims=opt.d_in, mode=opt.embedder_mode)
19
+ self.embed_fn = embed_fn
20
+ dims[0] = input_ch
21
+ self.cond = opt.cond
22
+ if self.cond == 'smpl':
23
+ self.cond_layer = [0]
24
+ self.cond_dim = 69
25
+ elif self.cond == 'frame':
26
+ self.cond_layer = [0]
27
+ self.cond_dim = opt.dim_frame_encoding
28
+ self.dim_pose_embed = 0
29
+ if self.dim_pose_embed > 0:
30
+ self.lin_p0 = nn.Linear(self.cond_dim, self.dim_pose_embed)
31
+ self.cond_dim = self.dim_pose_embed
32
+ for l in range(0, self.num_layers - 1):
33
+ if l + 1 in self.skip_in:
34
+ out_dim = dims[l + 1] - dims[0]
35
+ else:
36
+ out_dim = dims[l + 1]
37
+
38
+ if self.cond != 'none' and l in self.cond_layer:
39
+ lin = nn.Linear(dims[l] + self.cond_dim, out_dim)
40
+ else:
41
+ lin = nn.Linear(dims[l], out_dim)
42
+ if opt.init == 'geometry':
43
+ if l == self.num_layers - 2:
44
+ torch.nn.init.normal_(lin.weight,
45
+ mean=np.sqrt(np.pi) /
46
+ np.sqrt(dims[l]),
47
+ std=0.0001)
48
+ torch.nn.init.constant_(lin.bias, -opt.bias)
49
+ elif opt.multires > 0 and l == 0:
50
+ torch.nn.init.constant_(lin.bias, 0.0)
51
+ torch.nn.init.constant_(lin.weight[:, 3:], 0.0)
52
+ torch.nn.init.normal_(lin.weight[:, :3], 0.0,
53
+ np.sqrt(2) / np.sqrt(out_dim))
54
+ elif opt.multires > 0 and l in self.skip_in:
55
+ torch.nn.init.constant_(lin.bias, 0.0)
56
+ torch.nn.init.normal_(lin.weight, 0.0,
57
+ np.sqrt(2) / np.sqrt(out_dim))
58
+ torch.nn.init.constant_(lin.weight[:, -(dims[0] - 3):],
59
+ 0.0)
60
+ else:
61
+ torch.nn.init.constant_(lin.bias, 0.0)
62
+ torch.nn.init.normal_(lin.weight, 0.0,
63
+ np.sqrt(2) / np.sqrt(out_dim))
64
+ if opt.init == 'zero':
65
+ init_val = 1e-5
66
+ if l == self.num_layers - 2:
67
+ torch.nn.init.constant_(lin.bias, 0.0)
68
+ torch.nn.init.uniform_(lin.weight, -init_val, init_val)
69
+ if opt.weight_norm:
70
+ lin = nn.utils.weight_norm(lin)
71
+ setattr(self, "lin" + str(l), lin)
72
+ self.softplus = nn.Softplus(beta=100)
73
+
74
+ def forward(self, input, cond, current_epoch=None):
75
+ if input.ndim == 2: input = input.unsqueeze(0)
76
+
77
+ num_batch, num_point, num_dim = input.shape
78
+
79
+ if num_batch * num_point == 0: return input
80
+
81
+ input = input.reshape(num_batch * num_point, num_dim)
82
+
83
+ if self.cond != 'none':
84
+ num_batch, num_cond = cond[self.cond].shape
85
+
86
+ input_cond = cond[self.cond].unsqueeze(1).expand(num_batch, num_point, num_cond)
87
+
88
+ input_cond = input_cond.reshape(num_batch * num_point, num_cond)
89
+
90
+ if self.dim_pose_embed:
91
+ input_cond = self.lin_p0(input_cond)
92
+
93
+ if self.embed_fn is not None:
94
+ input = self.embed_fn(input)
95
+
96
+ x = input
97
+
98
+ for l in range(0, self.num_layers - 1):
99
+ lin = getattr(self, "lin" + str(l))
100
+ if self.cond != 'none' and l in self.cond_layer:
101
+ x = torch.cat([x, input_cond], dim=-1)
102
+ if l in self.skip_in:
103
+ x = torch.cat([x, input], 1) / np.sqrt(2)
104
+ x = lin(x)
105
+ if l < self.num_layers - 2:
106
+ x = self.softplus(x)
107
+
108
+ x = x.reshape(num_batch, num_point, -1)
109
+
110
+ return x
111
+
112
+ def gradient(self, x, cond):
113
+ x.requires_grad_(True)
114
+ y = self.forward(x, cond)[:, :1]
115
+ d_output = torch.ones_like(y, requires_grad=False, device=y.device)
116
+ gradients = torch.autograd.grad(outputs=y,
117
+ inputs=x,
118
+ grad_outputs=d_output,
119
+ create_graph=True,
120
+ retain_graph=True,
121
+ only_inputs=True)[0]
122
+ return gradients.unsqueeze(1)
123
+
124
+
125
+ class RenderingNet(nn.Module):
126
+ def __init__(self, opt):
127
+ super().__init__()
128
+
129
+ self.mode = opt.mode
130
+ dims = [opt.d_in + opt.feature_vector_size] + list(
131
+ opt.dims) + [opt.d_out]
132
+
133
+ self.embedview_fn = None
134
+ if opt.multires_view > 0:
135
+ embedview_fn, input_ch = get_embedder(opt.multires_view)
136
+ self.embedview_fn = embedview_fn
137
+ dims[0] += (input_ch - 3)
138
+ if self.mode == 'nerf_frame_encoding':
139
+ dims[0] += opt.dim_frame_encoding
140
+ if self.mode == 'pose':
141
+ self.dim_cond_embed = 8
142
+ self.cond_dim = 69 # dimension of the body pose, global orientation excluded.
143
+ # lower the condition dimension
144
+ self.lin_pose = torch.nn.Linear(self.cond_dim, self.dim_cond_embed)
145
+ self.num_layers = len(dims)
146
+ for l in range(0, self.num_layers - 1):
147
+ out_dim = dims[l + 1]
148
+ lin = nn.Linear(dims[l], out_dim)
149
+ if opt.weight_norm:
150
+ lin = nn.utils.weight_norm(lin)
151
+ setattr(self, "lin" + str(l), lin)
152
+ self.relu = nn.ReLU()
153
+ self.sigmoid = nn.Sigmoid()
154
+
155
+ def forward(self, points, normals, view_dirs, body_pose, feature_vectors, frame_latent_code=None):
156
+ if self.embedview_fn is not None:
157
+ if self.mode == 'nerf_frame_encoding':
158
+ view_dirs = self.embedview_fn(view_dirs)
159
+
160
+ if self.mode == 'nerf_frame_encoding':
161
+ frame_latent_code = frame_latent_code.expand(view_dirs.shape[0], -1)
162
+ rendering_input = torch.cat([view_dirs, frame_latent_code, feature_vectors], dim=-1)
163
+ elif self.mode == 'pose':
164
+ num_points = points.shape[0]
165
+ body_pose = body_pose.unsqueeze(1).expand(-1, num_points, -1).reshape(num_points, -1)
166
+ body_pose = self.lin_pose(body_pose)
167
+ rendering_input = torch.cat([points, normals, body_pose, feature_vectors], dim=-1)
168
+ else:
169
+ raise NotImplementedError
170
+
171
+ x = rendering_input
172
+ for l in range(0, self.num_layers - 1):
173
+ lin = getattr(self, "lin" + str(l))
174
+ x = lin(x)
175
+ if l < self.num_layers - 2:
176
+ x = self.relu(x)
177
+ x = self.sigmoid(x)
178
+ return x
code/lib/model/ray_sampler.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+ import torch
3
+ from lib.utils import utils
4
+
5
+ class RaySampler(metaclass=abc.ABCMeta):
6
+ def __init__(self,near, far):
7
+ self.near = near
8
+ self.far = far
9
+
10
+ @abc.abstractmethod
11
+ def get_z_vals(self, ray_dirs, cam_loc, model):
12
+ pass
13
+
14
+ class UniformSampler(RaySampler):
15
+ """Samples uniformly in the range [near, far]
16
+ """
17
+ def __init__(self, scene_bounding_sphere, near, N_samples, take_sphere_intersection=False, far=-1):
18
+ super().__init__(near, 2.0 * scene_bounding_sphere if far == -1 else far) # default far is 2*R
19
+ self.N_samples = N_samples
20
+ self.scene_bounding_sphere = scene_bounding_sphere
21
+ self.take_sphere_intersection = take_sphere_intersection
22
+
23
+ def get_z_vals(self, ray_dirs, cam_loc, model):
24
+ if not self.take_sphere_intersection:
25
+ near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0], 1).cuda()
26
+ else:
27
+ sphere_intersections = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)
28
+ near = self.near * torch.ones(ray_dirs.shape[0], 1).cuda()
29
+ far = sphere_intersections[:,1:]
30
+
31
+ t_vals = torch.linspace(0., 1., steps=self.N_samples).cuda()
32
+ z_vals = near * (1. - t_vals) + far * (t_vals)
33
+
34
+ if model.training:
35
+ # get intervals between samples
36
+ mids = .5 * (z_vals[..., 1:] + z_vals[..., :-1])
37
+ upper = torch.cat([mids, z_vals[..., -1:]], -1)
38
+ lower = torch.cat([z_vals[..., :1], mids], -1)
39
+ # stratified samples in those intervals
40
+ t_rand = torch.rand(z_vals.shape).cuda()
41
+
42
+ z_vals = lower + (upper - lower) * t_rand
43
+
44
+ return z_vals
45
+
46
+ class ErrorBoundSampler(RaySampler):
47
+ def __init__(self, scene_bounding_sphere, near, N_samples, N_samples_eval, N_samples_extra,
48
+ eps, beta_iters, max_total_iters,
49
+ inverse_sphere_bg=False, N_samples_inverse_sphere=0, add_tiny=0.0):
50
+ super().__init__(near, 2.0 * scene_bounding_sphere)
51
+ self.N_samples = N_samples
52
+ self.N_samples_eval = N_samples_eval
53
+ self.uniform_sampler = UniformSampler(scene_bounding_sphere, near, N_samples_eval, take_sphere_intersection=inverse_sphere_bg)
54
+
55
+ self.N_samples_extra = N_samples_extra
56
+
57
+ self.eps = eps
58
+ self.beta_iters = beta_iters
59
+ self.max_total_iters = max_total_iters
60
+ self.scene_bounding_sphere = scene_bounding_sphere
61
+ self.add_tiny = add_tiny
62
+
63
+ self.inverse_sphere_bg = inverse_sphere_bg
64
+ if inverse_sphere_bg:
65
+ N_samples_inverse_sphere = 32
66
+ self.inverse_sphere_sampler = UniformSampler(1.0, 0.0, N_samples_inverse_sphere, False, far=1.0)
67
+
68
+ def get_z_vals(self, ray_dirs, cam_loc, model, cond, smpl_tfs, eval_mode, smpl_verts):
69
+ beta0 = model.density.get_beta().detach()
70
+
71
+ # Start with uniform sampling
72
+ z_vals = self.uniform_sampler.get_z_vals(ray_dirs, cam_loc, model)
73
+ samples, samples_idx = z_vals, None
74
+
75
+ # Get maximum beta from the upper bound (Lemma 2)
76
+ dists = z_vals[:, 1:] - z_vals[:, :-1]
77
+ bound = (1.0 / (4.0 * torch.log(torch.tensor(self.eps + 1.0)))) * (dists ** 2.).sum(-1)
78
+ beta = torch.sqrt(bound)
79
+
80
+ total_iters, not_converge = 0, True
81
+
82
+ # VolSDF Algorithm 1
83
+ while not_converge and total_iters < self.max_total_iters:
84
+ points = cam_loc.unsqueeze(1) + samples.unsqueeze(2) * ray_dirs.unsqueeze(1)
85
+ points_flat = points.reshape(-1, 3)
86
+ # Calculating the SDF only for the new sampled points
87
+ model.implicit_network.eval()
88
+ with torch.no_grad():
89
+ samples_sdf = model.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_verts=smpl_verts)[0]
90
+ model.implicit_network.train()
91
+ if samples_idx is not None:
92
+ sdf_merge = torch.cat([sdf.reshape(-1, z_vals.shape[1] - samples.shape[1]),
93
+ samples_sdf.reshape(-1, samples.shape[1])], -1)
94
+ sdf = torch.gather(sdf_merge, 1, samples_idx).reshape(-1, 1)
95
+ else:
96
+ sdf = samples_sdf
97
+
98
+
99
+ # Calculating the bound d* (Theorem 1)
100
+ d = sdf.reshape(z_vals.shape)
101
+ dists = z_vals[:, 1:] - z_vals[:, :-1]
102
+ a, b, c = dists, d[:, :-1].abs(), d[:, 1:].abs()
103
+ first_cond = a.pow(2) + b.pow(2) <= c.pow(2)
104
+ second_cond = a.pow(2) + c.pow(2) <= b.pow(2)
105
+ d_star = torch.zeros(z_vals.shape[0], z_vals.shape[1] - 1).cuda()
106
+ d_star[first_cond] = b[first_cond]
107
+ d_star[second_cond] = c[second_cond]
108
+ s = (a + b + c) / 2.0
109
+ area_before_sqrt = s * (s - a) * (s - b) * (s - c)
110
+ mask = ~first_cond & ~second_cond & (b + c - a > 0)
111
+ d_star[mask] = (2.0 * torch.sqrt(area_before_sqrt[mask])) / (a[mask])
112
+ d_star = (d[:, 1:].sign() * d[:, :-1].sign() == 1) * d_star # Fixing the sign
113
+
114
+
115
+ # Updating beta using line search
116
+ curr_error = self.get_error_bound(beta0, model, sdf, z_vals, dists, d_star)
117
+ beta[curr_error <= self.eps] = beta0
118
+ beta_min, beta_max = beta0.unsqueeze(0).repeat(z_vals.shape[0]), beta
119
+ for j in range(self.beta_iters):
120
+ beta_mid = (beta_min + beta_max) / 2.
121
+ curr_error = self.get_error_bound(beta_mid.unsqueeze(-1), model, sdf, z_vals, dists, d_star)
122
+ beta_max[curr_error <= self.eps] = beta_mid[curr_error <= self.eps]
123
+ beta_min[curr_error > self.eps] = beta_mid[curr_error > self.eps]
124
+ beta = beta_max
125
+
126
+
127
+ # Upsample more points
128
+ density = model.density(sdf.reshape(z_vals.shape), beta=beta.unsqueeze(-1))
129
+
130
+ dists = torch.cat([dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(dists.shape[0], 1)], -1)
131
+ free_energy = dists * density
132
+ shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy[:, :-1]], dim=-1)
133
+ alpha = 1 - torch.exp(-free_energy)
134
+ transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1))
135
+ weights = alpha * transmittance # probability of the ray hits something here
136
+
137
+ # Check if we are done and this is the last sampling
138
+ total_iters += 1
139
+ not_converge = beta.max() > beta0
140
+
141
+ if not_converge and total_iters < self.max_total_iters:
142
+ ''' Sample more points proportional to the current error bound'''
143
+
144
+ N = self.N_samples_eval
145
+
146
+ bins = z_vals
147
+ error_per_section = torch.exp(-d_star / beta.unsqueeze(-1)) * (dists[:,:-1] ** 2.) / (4 * beta.unsqueeze(-1) ** 2)
148
+ error_integral = torch.cumsum(error_per_section, dim=-1)
149
+ bound_opacity = (torch.clamp(torch.exp(error_integral),max=1.e6) - 1.0) * transmittance[:,:-1]
150
+
151
+ pdf = bound_opacity + self.add_tiny
152
+ pdf = pdf / torch.sum(pdf, -1, keepdim=True)
153
+ cdf = torch.cumsum(pdf, -1)
154
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1)
155
+
156
+ else:
157
+ ''' Sample the final sample set to be used in the volume rendering integral '''
158
+
159
+ N = self.N_samples
160
+
161
+ bins = z_vals
162
+ pdf = weights[..., :-1]
163
+ pdf = pdf + 1e-5 # prevent nans
164
+ pdf = pdf / torch.sum(pdf, -1, keepdim=True)
165
+ cdf = torch.cumsum(pdf, -1)
166
+ cdf = torch.cat([torch.zeros_like(cdf[..., :1]), cdf], -1) # (batch, len(bins))
167
+
168
+
169
+ # Invert CDF
170
+ if (not_converge and total_iters < self.max_total_iters) or (not model.training):
171
+ u = torch.linspace(0., 1., steps=N).cuda().unsqueeze(0).repeat(cdf.shape[0], 1)
172
+ else:
173
+ u = torch.rand(list(cdf.shape[:-1]) + [N]).cuda()
174
+ u = u.contiguous()
175
+
176
+ inds = torch.searchsorted(cdf, u, right=True)
177
+ below = torch.max(torch.zeros_like(inds - 1), inds - 1)
178
+ above = torch.min((cdf.shape[-1] - 1) * torch.ones_like(inds), inds)
179
+ inds_g = torch.stack([below, above], -1) # (batch, N_samples, 2)
180
+
181
+ matched_shape = [inds_g.shape[0], inds_g.shape[1], cdf.shape[-1]]
182
+ cdf_g = torch.gather(cdf.unsqueeze(1).expand(matched_shape), 2, inds_g)
183
+ bins_g = torch.gather(bins.unsqueeze(1).expand(matched_shape), 2, inds_g)
184
+
185
+ denom = (cdf_g[..., 1] - cdf_g[..., 0])
186
+ denom = torch.where(denom < 1e-5, torch.ones_like(denom), denom)
187
+ t = (u - cdf_g[..., 0]) / denom
188
+ samples = bins_g[..., 0] + t * (bins_g[..., 1] - bins_g[..., 0])
189
+
190
+
191
+ # Adding samples if we not converged
192
+ if not_converge and total_iters < self.max_total_iters:
193
+ z_vals, samples_idx = torch.sort(torch.cat([z_vals, samples], -1), -1)
194
+
195
+
196
+ z_samples = samples
197
+
198
+ near, far = self.near * torch.ones(ray_dirs.shape[0], 1).cuda(), self.far * torch.ones(ray_dirs.shape[0],1).cuda()
199
+ if self.inverse_sphere_bg: # if inverse sphere then need to add the far sphere intersection
200
+ far = utils.get_sphere_intersections(cam_loc, ray_dirs, r=self.scene_bounding_sphere)[:,1:]
201
+
202
+ if self.N_samples_extra > 0:
203
+ if model.training:
204
+ sampling_idx = torch.randperm(z_vals.shape[1])[:self.N_samples_extra]
205
+ else:
206
+ sampling_idx = torch.linspace(0, z_vals.shape[1]-1, self.N_samples_extra).long()
207
+ z_vals_extra = torch.cat([near, far, z_vals[:,sampling_idx]], -1)
208
+ else:
209
+ z_vals_extra = torch.cat([near, far], -1)
210
+
211
+ z_vals, _ = torch.sort(torch.cat([z_samples, z_vals_extra], -1), -1)
212
+
213
+ # add some of the near surface points
214
+ idx = torch.randint(z_vals.shape[-1], (z_vals.shape[0],)).cuda()
215
+ z_samples_eik = torch.gather(z_vals, 1, idx.unsqueeze(-1))
216
+
217
+ if self.inverse_sphere_bg:
218
+ z_vals_inverse_sphere = self.inverse_sphere_sampler.get_z_vals(ray_dirs, cam_loc, model)
219
+ z_vals_inverse_sphere = z_vals_inverse_sphere * (1./self.scene_bounding_sphere)
220
+ z_vals = (z_vals, z_vals_inverse_sphere)
221
+
222
+ return z_vals, z_samples_eik
223
+
224
+ def get_error_bound(self, beta, model, sdf, z_vals, dists, d_star):
225
+ density = model.density(sdf.reshape(z_vals.shape), beta=beta)
226
+ shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), dists * density[:, :-1]], dim=-1)
227
+ integral_estimation = torch.cumsum(shifted_free_energy, dim=-1)
228
+ error_per_section = torch.exp(-d_star / beta) * (dists ** 2.) / (4 * beta ** 2)
229
+ error_integral = torch.cumsum(error_per_section, dim=-1)
230
+ bound_opacity = (torch.clamp(torch.exp(error_integral), max=1.e6) - 1.0) * torch.exp(-integral_estimation[:, :-1])
231
+
232
+ return bound_opacity.max(-1)[0]
233
+
234
+
code/lib/model/sampler.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class PointInSpace:
5
+ def __init__(self, global_sigma=0.5, local_sigma=0.01):
6
+ self.global_sigma = global_sigma
7
+ self.local_sigma = local_sigma
8
+
9
+ def get_points(self, pc_input=None, local_sigma=None, global_ratio=0.125):
10
+ """Sample one point near each of the given point + 1/8 uniformly.
11
+ Args:
12
+ pc_input (tensor): sampling centers. shape: [B, N, D]
13
+ Returns:
14
+ samples (tensor): sampled points. shape: [B, N + N / 8, D]
15
+ """
16
+
17
+ batch_size, sample_size, dim = pc_input.shape
18
+ if local_sigma is None:
19
+ sample_local = pc_input + (torch.randn_like(pc_input) * self.local_sigma)
20
+ else:
21
+ sample_local = pc_input + (torch.randn_like(pc_input) * local_sigma)
22
+ sample_global = (
23
+ torch.rand(batch_size, int(sample_size * global_ratio), dim, device=pc_input.device)
24
+ * (self.global_sigma * 2)
25
+ ) - self.global_sigma
26
+
27
+ sample = torch.cat([sample_local, sample_global], dim=1)
28
+
29
+ return sample
code/lib/model/smpl.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import hydra
3
+ import numpy as np
4
+ from ..smpl.body_models import SMPL
5
+
6
+ class SMPLServer(torch.nn.Module):
7
+
8
+ def __init__(self, gender='neutral', betas=None, v_template=None):
9
+ super().__init__()
10
+
11
+
12
+ self.smpl = SMPL(model_path=hydra.utils.to_absolute_path('lib/smpl/smpl_model'),
13
+ gender=gender,
14
+ batch_size=1,
15
+ use_hands=False,
16
+ use_feet_keypoints=False,
17
+ dtype=torch.float32).cuda()
18
+
19
+ self.bone_parents = self.smpl.bone_parents.astype(int)
20
+ self.bone_parents[0] = -1
21
+ self.bone_ids = []
22
+ self.faces = self.smpl.faces
23
+ for i in range(24): self.bone_ids.append([self.bone_parents[i], i])
24
+
25
+ if v_template is not None:
26
+ self.v_template = torch.tensor(v_template).float().cuda()
27
+ else:
28
+ self.v_template = None
29
+
30
+ if betas is not None:
31
+ self.betas = torch.tensor(betas).float().cuda()
32
+ else:
33
+ self.betas = None
34
+
35
+ # define the canonical pose
36
+ param_canonical = torch.zeros((1, 86),dtype=torch.float32).cuda()
37
+ param_canonical[0, 0] = 1
38
+ param_canonical[0, 9] = np.pi / 6
39
+ param_canonical[0, 12] = -np.pi / 6
40
+ if self.betas is not None and self.v_template is None:
41
+ param_canonical[0,-10:] = self.betas
42
+ self.param_canonical = param_canonical
43
+
44
+ output = self.forward(*torch.split(self.param_canonical, [1, 3, 72, 10], dim=1), absolute=True)
45
+ self.verts_c = output['smpl_verts']
46
+ self.joints_c = output['smpl_jnts']
47
+ self.tfs_c_inv = output['smpl_tfs'].squeeze(0).inverse()
48
+
49
+
50
+ def forward(self, scale, transl, thetas, betas, absolute=False):
51
+ """return SMPL output from params
52
+ Args:
53
+ scale : scale factor. shape: [B, 1]
54
+ transl: translation. shape: [B, 3]
55
+ thetas: pose. shape: [B, 72]
56
+ betas: shape. shape: [B, 10]
57
+ absolute (bool): if true return smpl_tfs wrt thetas=0. else wrt thetas=thetas_canonical.
58
+ Returns:
59
+ smpl_verts: vertices. shape: [B, 6893. 3]
60
+ smpl_tfs: bone transformations. shape: [B, 24, 4, 4]
61
+ smpl_jnts: joint positions. shape: [B, 25, 3]
62
+ """
63
+
64
+ output = {}
65
+
66
+ # ignore betas if v_template is provided
67
+ if self.v_template is not None:
68
+ betas = torch.zeros_like(betas)
69
+
70
+
71
+ smpl_output = self.smpl.forward(betas=betas,
72
+ transl=torch.zeros_like(transl),
73
+ body_pose=thetas[:, 3:],
74
+ global_orient=thetas[:, :3],
75
+ return_verts=True,
76
+ return_full_pose=True,
77
+ v_template=self.v_template)
78
+
79
+ verts = smpl_output.vertices.clone()
80
+ output['smpl_verts'] = verts * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1)
81
+
82
+ joints = smpl_output.joints.clone()
83
+ output['smpl_jnts'] = joints * scale.unsqueeze(1) + transl.unsqueeze(1) * scale.unsqueeze(1)
84
+
85
+ tf_mats = smpl_output.T.clone()
86
+ tf_mats[:, :, :3, :] = tf_mats[:, :, :3, :] * scale.unsqueeze(1).unsqueeze(1)
87
+ tf_mats[:, :, :3, 3] = tf_mats[:, :, :3, 3] + transl.unsqueeze(1) * scale.unsqueeze(1)
88
+
89
+ if not absolute:
90
+ tf_mats = torch.einsum('bnij,njk->bnik', tf_mats, self.tfs_c_inv)
91
+
92
+ output['smpl_tfs'] = tf_mats
93
+ output['smpl_weights'] = smpl_output.weights
94
+ return output
code/lib/model/v2a.py ADDED
@@ -0,0 +1,368 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .networks import ImplicitNet, RenderingNet
2
+ from .density import LaplaceDensity, AbsDensity
3
+ from .ray_sampler import ErrorBoundSampler
4
+ from .deformer import SMPLDeformer
5
+ from .smpl import SMPLServer
6
+
7
+ from .sampler import PointInSpace
8
+
9
+ from ..utils import utils
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ from torch.autograd import grad
15
+ import hydra
16
+ import kaolin
17
+ from kaolin.ops.mesh import index_vertices_by_faces
18
+ class V2A(nn.Module):
19
+ def __init__(self, opt, betas_path, gender, num_training_frames):
20
+ super().__init__()
21
+
22
+ # Foreground networks
23
+ self.implicit_network = ImplicitNet(opt.implicit_network)
24
+ self.rendering_network = RenderingNet(opt.rendering_network)
25
+
26
+ # Background networks
27
+ self.bg_implicit_network = ImplicitNet(opt.bg_implicit_network)
28
+ self.bg_rendering_network = RenderingNet(opt.bg_rendering_network)
29
+
30
+ # Frame latent encoder
31
+ self.frame_latent_encoder = nn.Embedding(num_training_frames, opt.bg_rendering_network.dim_frame_encoding)
32
+ self.sampler = PointInSpace()
33
+
34
+ betas = np.load(betas_path)
35
+ self.use_smpl_deformer = opt.use_smpl_deformer
36
+ self.gender = gender
37
+ if self.use_smpl_deformer:
38
+ self.deformer = SMPLDeformer(betas=betas, gender=self.gender)
39
+
40
+ # pre-defined bounding sphere
41
+ self.sdf_bounding_sphere = 3.0
42
+
43
+ # threshold for the out-surface points
44
+ self.threshold = 0.05
45
+
46
+ self.density = LaplaceDensity(**opt.density)
47
+ self.bg_density = AbsDensity()
48
+
49
+ self.ray_sampler = ErrorBoundSampler(self.sdf_bounding_sphere, inverse_sphere_bg=True, **opt.ray_sampler)
50
+ self.smpl_server = SMPLServer(gender=self.gender, betas=betas)
51
+
52
+ if opt.smpl_init:
53
+ smpl_model_state = torch.load(hydra.utils.to_absolute_path('../assets/smpl_init.pth'))
54
+ self.implicit_network.load_state_dict(smpl_model_state["model_state_dict"])
55
+
56
+ self.smpl_v_cano = self.smpl_server.verts_c
57
+ self.smpl_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device)
58
+
59
+ self.mesh_v_cano = self.smpl_server.verts_c
60
+ self.mesh_f_cano = torch.tensor(self.smpl_server.smpl.faces.astype(np.int64), device=self.smpl_v_cano.device)
61
+ self.mesh_face_vertices = index_vertices_by_faces(self.mesh_v_cano, self.mesh_f_cano)
62
+
63
+ def sdf_func_with_smpl_deformer(self, x, cond, smpl_tfs, smpl_verts):
64
+ """ sdf_func_with_smpl_deformer method
65
+ Used to compute SDF values for input points using the SMPL deformer and the implicit network.
66
+ It handles the deforming of points, network inference, feature extraction, and handling of outlier points.
67
+ """
68
+ if hasattr(self, "deformer"):
69
+ x_c, outlier_mask = self.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts)
70
+ output = self.implicit_network(x_c, cond)[0]
71
+ sdf = output[:, 0:1]
72
+ feature = output[:, 1:]
73
+ if not self.training:
74
+ sdf[outlier_mask] = 4. # set a large SDF value for outlier points
75
+
76
+ return sdf, x_c, feature
77
+
78
+ def check_off_in_surface_points_cano_mesh(self, x_cano, N_samples, threshold=0.05):
79
+ """check_off_in_surface_points_cano_mesh method
80
+ Used to check whether points are off the surface or within the surface of a canonical mesh.
81
+ It calculates distances, signs, and signed distances to determine the position of points with respect to the mesh surface.
82
+ The method plays a role in identifying points that might be considered outliers or outside the reconstructed avatar's surface.
83
+ """
84
+
85
+ distance, _, _ = kaolin.metrics.trianglemesh.point_to_mesh_distance(x_cano.unsqueeze(0).contiguous(), self.mesh_face_vertices)
86
+
87
+ distance = torch.sqrt(distance) # kaolin outputs squared distance
88
+ sign = kaolin.ops.mesh.check_sign(self.mesh_v_cano, self.mesh_f_cano, x_cano.unsqueeze(0)).float()
89
+ sign = 1 - 2 * sign # -1 for off-surface, 1 for in-surface
90
+ signed_distance = sign * distance
91
+ batch_size = x_cano.shape[0] // N_samples
92
+ signed_distance = signed_distance.reshape(batch_size, N_samples, 1) # The distances are reshaped to match the batch size and the number of samples
93
+
94
+ minimum = torch.min(signed_distance, 1)[0]
95
+ index_off_surface = (minimum > threshold).squeeze(1)
96
+ index_in_surface = (minimum <= 0.).squeeze(1)
97
+ return index_off_surface, index_in_surface # Indexes of off-surface points and in-surface points
98
+
99
+ def forward(self, input):
100
+ # Parse model input, prepares the necessary input data and SMPL parameters for subsequent calculations
101
+ torch.set_grad_enabled(True)
102
+ intrinsics = input["intrinsics"]
103
+ pose = input["pose"]
104
+ uv = input["uv"]
105
+
106
+ scale = input['smpl_params'][:, 0]
107
+ smpl_pose = input["smpl_pose"]
108
+ smpl_shape = input["smpl_shape"]
109
+ smpl_trans = input["smpl_trans"]
110
+ smpl_output = self.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape) # invokes the SMPL model to obtain the transformations for pose and shape changes
111
+
112
+ smpl_tfs = smpl_output['smpl_tfs']
113
+
114
+ cond = {'smpl': smpl_pose[:, 3:]/np.pi}
115
+ if self.training:
116
+ if input['current_epoch'] < 20 or input['current_epoch'] % 20 == 0: # set the pose to zero for the first 20 epochs
117
+ cond = {'smpl': smpl_pose[:, 3:] * 0.}
118
+ ray_dirs, cam_loc = utils.get_camera_params(uv, pose, intrinsics) # get the ray directions and camera location
119
+ batch_size, num_pixels, _ = ray_dirs.shape
120
+
121
+ cam_loc = cam_loc.unsqueeze(1).repeat(1, num_pixels, 1).reshape(-1, 3) # reshape to match the batch size and the number of pixels
122
+ ray_dirs = ray_dirs.reshape(-1, 3) # reshape to match the batch size and the number of pixels
123
+
124
+ z_vals, _ = self.ray_sampler.get_z_vals(ray_dirs, cam_loc, self, cond, smpl_tfs, eval_mode=True, smpl_verts=smpl_output['smpl_verts']) # get the z values for each pixel
125
+
126
+ z_vals, z_vals_bg = z_vals # unpack the z values for the foreground and the background
127
+ z_max = z_vals[:,-1] # get the maximum z value
128
+ z_vals = z_vals[:,:-1] # get the z values for the foreground
129
+ N_samples = z_vals.shape[1] # get the number of samples
130
+
131
+ points = cam_loc.unsqueeze(1) + z_vals.unsqueeze(2) * ray_dirs.unsqueeze(1) # 3D points along the rays are calculated by adding z_vals scaled by ray directions to the camera location. The result is stored in the points tensor of shape (batch_size * num_pixels, N_samples, 3)
132
+ points_flat = points.reshape(-1, 3) # The points tensor is reshaped into a flattened tensor points_flat of shape (batch_size * num_pixels * N_samples, 3)
133
+
134
+ dirs = ray_dirs.unsqueeze(1).repeat(1,N_samples,1) # The dirs tensor is created by repeating ray_dirs for each sample along the rays. The resulting tensor has shape (batch_size * num_pixels, N_samples, 3)
135
+ sdf_output, canonical_points, feature_vectors = self.sdf_func_with_smpl_deformer(points_flat, cond, smpl_tfs, smpl_output['smpl_verts']) # The sdf_func_with_smpl_deformer method is called to compute the signed distance functions (SDF) for the points
136
+
137
+ sdf_output = sdf_output.unsqueeze(1) # The sdf_output tensor is reshaped by unsqueezing along the first dimension
138
+
139
+ if self.training:
140
+ index_off_surface, index_in_surface = self.check_off_in_surface_points_cano_mesh(canonical_points, N_samples, threshold=self.threshold)
141
+ canonical_points = canonical_points.reshape(num_pixels, N_samples, 3)
142
+
143
+ canonical_points = canonical_points.reshape(-1, 3) # The canonical points tensor flattened to shape (-1, 3)
144
+
145
+ # sample canonical SMPL surface pnts for the eikonal loss
146
+ smpl_verts_c = self.smpl_server.verts_c.repeat(batch_size, 1,1) # The canonical SMPL surface vertices are repeated across the batch dimension
147
+
148
+ indices = torch.randperm(smpl_verts_c.shape[1])[:num_pixels].cuda() # Random indices are generated to select a subset of vertices for sampling. The number of selected vertices is num_pixels
149
+ verts_c = torch.index_select(smpl_verts_c, 1, indices) # The selected vertices are gathered from smpl_verts_c, resulting in the tensor verts_c.
150
+ sample = self.sampler.get_points(verts_c, global_ratio=0.) # The get_points method of the sampler class is called to sample points around the canonical SMPL surface points. The global_ratio is set to 0.0, indicating local sampling
151
+
152
+ sample.requires_grad_() # The sampled points are marked as requiring gradients
153
+ local_pred = self.implicit_network(sample, cond)[..., 0:1] # The sampled points (sample) are passed through the implicit network along with the conditioning (cond). The local prediction (SDF) for each sampled point is extracted using [..., 0:1]
154
+ grad_theta = gradient(sample, local_pred) # compute gradients with respect to the sampled points and their local predictions (local_pred).
155
+
156
+ differentiable_points = canonical_points # The differentiable_points tensor is assigned the value of canonical_points
157
+
158
+ else:
159
+ differentiable_points = canonical_points.reshape(num_pixels, N_samples, 3).reshape(-1, 3)
160
+ grad_theta = None
161
+
162
+ sdf_output = sdf_output.reshape(num_pixels, N_samples, 1).reshape(-1, 1) # flattened to shape (num_pixels * N_samples, )
163
+ z_vals = z_vals
164
+ view = -dirs.reshape(-1, 3) # The view vector is calculated as the negation of the reshaped dirs, giving the view directions for points along the rays.
165
+
166
+ if differentiable_points.shape[0] > 0: # If there are differentiable points (indicating that gradient information is available)
167
+ fg_rgb_flat, others = self.get_rbg_value(points_flat, differentiable_points, view,
168
+ cond, smpl_tfs, feature_vectors=feature_vectors, is_training=self.training) # The returned values include fg_rgb_flat (foreground RGB values) and others (other calculated values, including normals)
169
+ normal_values = others['normals'] # The normal values are extracted from the others dictionary
170
+
171
+ if 'image_id' in input.keys():
172
+ frame_latent_code = self.frame_latent_encoder(input['image_id'])
173
+ else:
174
+ frame_latent_code = self.frame_latent_encoder(input['idx'])
175
+
176
+ fg_rgb = fg_rgb_flat.reshape(-1, N_samples, 3)
177
+ normal_values = normal_values.reshape(-1, N_samples, 3)
178
+ weights, bg_transmittance = self.volume_rendering(z_vals, z_max, sdf_output)
179
+
180
+ fg_rgb_values = torch.sum(weights.unsqueeze(-1) * fg_rgb, 1)
181
+
182
+ # Background rendering
183
+ if input['idx'] is not None:
184
+ N_bg_samples = z_vals_bg.shape[1]
185
+ z_vals_bg = torch.flip(z_vals_bg, dims=[-1, ]) # 1--->0
186
+
187
+ bg_dirs = ray_dirs.unsqueeze(1).repeat(1,N_bg_samples,1)
188
+ bg_locs = cam_loc.unsqueeze(1).repeat(1,N_bg_samples,1)
189
+
190
+ bg_points = self.depth2pts_outside(bg_locs, bg_dirs, z_vals_bg) # [..., N_samples, 4]
191
+ bg_points_flat = bg_points.reshape(-1, 4)
192
+ bg_dirs_flat = bg_dirs.reshape(-1, 3)
193
+ bg_output = self.bg_implicit_network(bg_points_flat, {'frame': frame_latent_code})[0]
194
+ bg_sdf = bg_output[:, :1]
195
+ bg_feature_vectors = bg_output[:, 1:]
196
+
197
+ bg_rendering_output = self.bg_rendering_network(None, None, bg_dirs_flat, None, bg_feature_vectors, frame_latent_code)
198
+ if bg_rendering_output.shape[-1] == 4:
199
+ bg_rgb_flat = bg_rendering_output[..., :-1]
200
+ shadow_r = bg_rendering_output[..., -1]
201
+ bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
202
+ shadow_r = shadow_r.reshape(-1, N_bg_samples, 1)
203
+ bg_rgb = (1 - shadow_r) * bg_rgb
204
+ else:
205
+ bg_rgb_flat = bg_rendering_output
206
+ bg_rgb = bg_rgb_flat.reshape(-1, N_bg_samples, 3)
207
+ bg_weights = self.bg_volume_rendering(z_vals_bg, bg_sdf)
208
+ bg_rgb_values = torch.sum(bg_weights.unsqueeze(-1) * bg_rgb, 1)
209
+ else:
210
+ bg_rgb_values = torch.ones_like(fg_rgb_values, device=fg_rgb_values.device)
211
+
212
+ # Composite foreground and background
213
+ bg_rgb_values = bg_transmittance.unsqueeze(-1) * bg_rgb_values
214
+ rgb_values = fg_rgb_values + bg_rgb_values
215
+
216
+ normal_values = torch.sum(weights.unsqueeze(-1) * normal_values, 1)
217
+
218
+ if self.training:
219
+ output = {
220
+ 'points': points,
221
+ 'rgb_values': rgb_values,
222
+ 'normal_values': normal_values,
223
+ 'index_outside': input['index_outside'],
224
+ 'index_off_surface': index_off_surface,
225
+ 'index_in_surface': index_in_surface,
226
+ 'acc_map': torch.sum(weights, -1),
227
+ 'sdf_output': sdf_output,
228
+ 'grad_theta': grad_theta,
229
+ 'epoch': input['current_epoch'],
230
+ }
231
+ else:
232
+ fg_output_rgb = fg_rgb_values + bg_transmittance.unsqueeze(-1) * torch.ones_like(fg_rgb_values, device=fg_rgb_values.device)
233
+ output = {
234
+ 'acc_map': torch.sum(weights, -1),
235
+ 'rgb_values': rgb_values,
236
+ 'fg_rgb_values': fg_output_rgb,
237
+ 'normal_values': normal_values,
238
+ 'sdf_output': sdf_output,
239
+ }
240
+ return output
241
+
242
+ def get_rbg_value(self, x, points, view_dirs, cond, tfs, feature_vectors, is_training=True):
243
+ pnts_c = points
244
+ others = {}
245
+
246
+ _, gradients, feature_vectors = self.forward_gradient(x, pnts_c, cond, tfs, create_graph=is_training, retain_graph=is_training)
247
+ # ensure the gradient is normalized
248
+ normals = nn.functional.normalize(gradients, dim=-1, eps=1e-6)
249
+ fg_rendering_output = self.rendering_network(pnts_c, normals, view_dirs, cond['smpl'],
250
+ feature_vectors)
251
+
252
+ rgb_vals = fg_rendering_output[:, :3]
253
+ others['normals'] = normals
254
+ return rgb_vals, others
255
+
256
+ def forward_gradient(self, x, pnts_c, cond, tfs, create_graph=True, retain_graph=True):
257
+ if pnts_c.shape[0] == 0:
258
+ return pnts_c.detach()
259
+ pnts_c.requires_grad_(True)
260
+
261
+ pnts_d = self.deformer.forward_skinning(pnts_c.unsqueeze(0), None, tfs).squeeze(0)
262
+ num_dim = pnts_d.shape[-1]
263
+ grads = []
264
+ for i in range(num_dim):
265
+ d_out = torch.zeros_like(pnts_d, requires_grad=False, device=pnts_d.device)
266
+ d_out[:, i] = 1
267
+ grad = torch.autograd.grad(
268
+ outputs=pnts_d,
269
+ inputs=pnts_c,
270
+ grad_outputs=d_out,
271
+ create_graph=create_graph,
272
+ retain_graph=True if i < num_dim - 1 else retain_graph,
273
+ only_inputs=True)[0]
274
+ grads.append(grad)
275
+ grads = torch.stack(grads, dim=-2)
276
+ grads_inv = grads.inverse()
277
+
278
+ output = self.implicit_network(pnts_c, cond)[0]
279
+ sdf = output[:, :1]
280
+
281
+ feature = output[:, 1:]
282
+ d_output = torch.ones_like(sdf, requires_grad=False, device=sdf.device)
283
+ gradients = torch.autograd.grad(
284
+ outputs=sdf,
285
+ inputs=pnts_c,
286
+ grad_outputs=d_output,
287
+ create_graph=create_graph,
288
+ retain_graph=retain_graph,
289
+ only_inputs=True)[0]
290
+
291
+ return grads.reshape(grads.shape[0], -1), torch.nn.functional.normalize(torch.einsum('bi,bij->bj', gradients, grads_inv), dim=1), feature
292
+
293
+ def volume_rendering(self, z_vals, z_max, sdf):
294
+ density_flat = self.density(sdf)
295
+ density = density_flat.reshape(-1, z_vals.shape[1]) # (batch_size * num_pixels) x N_samples
296
+
297
+ # included also the dist from the sphere intersection
298
+ dists = z_vals[:, 1:] - z_vals[:, :-1]
299
+ dists = torch.cat([dists, z_max.unsqueeze(-1) - z_vals[:, -1:]], -1)
300
+
301
+ # LOG SPACE
302
+ free_energy = dists * density
303
+ shifted_free_energy = torch.cat([torch.zeros(dists.shape[0], 1).cuda(), free_energy], dim=-1) # add 0 for transperancy 1 at t_0
304
+ alpha = 1 - torch.exp(-free_energy) # probability of it is not empty here
305
+ transmittance = torch.exp(-torch.cumsum(shifted_free_energy, dim=-1)) # probability of everything is empty up to now
306
+ fg_transmittance = transmittance[:, :-1]
307
+ weights = alpha * fg_transmittance # probability of the ray hits something here
308
+ bg_transmittance = transmittance[:, -1] # factor to be multiplied with the bg volume rendering
309
+
310
+ return weights, bg_transmittance
311
+
312
+ def bg_volume_rendering(self, z_vals_bg, bg_sdf):
313
+ bg_density_flat = self.bg_density(bg_sdf)
314
+ bg_density = bg_density_flat.reshape(-1, z_vals_bg.shape[1]) # (batch_size * num_pixels) x N_samples
315
+
316
+ bg_dists = z_vals_bg[:, :-1] - z_vals_bg[:, 1:]
317
+ bg_dists = torch.cat([bg_dists, torch.tensor([1e10]).cuda().unsqueeze(0).repeat(bg_dists.shape[0], 1)], -1)
318
+
319
+ # LOG SPACE
320
+ bg_free_energy = bg_dists * bg_density
321
+ bg_shifted_free_energy = torch.cat([torch.zeros(bg_dists.shape[0], 1).cuda(), bg_free_energy[:, :-1]], dim=-1) # shift one step
322
+ bg_alpha = 1 - torch.exp(-bg_free_energy) # probability of it is not empty here
323
+ bg_transmittance = torch.exp(-torch.cumsum(bg_shifted_free_energy, dim=-1)) # probability of everything is empty up to now
324
+ bg_weights = bg_alpha * bg_transmittance # probability of the ray hits something here
325
+
326
+ return bg_weights
327
+
328
+ def depth2pts_outside(self, ray_o, ray_d, depth):
329
+
330
+ '''
331
+ ray_o, ray_d: [..., 3]
332
+ depth: [...]; inverse of distance to sphere origin
333
+ '''
334
+
335
+ o_dot_d = torch.sum(ray_d * ray_o, dim=-1)
336
+ under_sqrt = o_dot_d ** 2 - ((ray_o ** 2).sum(-1) - self.sdf_bounding_sphere ** 2)
337
+ d_sphere = torch.sqrt(under_sqrt) - o_dot_d
338
+ p_sphere = ray_o + d_sphere.unsqueeze(-1) * ray_d
339
+ p_mid = ray_o - o_dot_d.unsqueeze(-1) * ray_d
340
+ p_mid_norm = torch.norm(p_mid, dim=-1)
341
+
342
+ rot_axis = torch.cross(ray_o, p_sphere, dim=-1)
343
+ rot_axis = rot_axis / torch.norm(rot_axis, dim=-1, keepdim=True)
344
+ phi = torch.asin(p_mid_norm / self.sdf_bounding_sphere)
345
+ theta = torch.asin(p_mid_norm * depth) # depth is inside [0, 1]
346
+ rot_angle = (phi - theta).unsqueeze(-1) # [..., 1]
347
+
348
+ # now rotate p_sphere
349
+ # Rodrigues formula: https://en.wikipedia.org/wiki/Rodrigues%27_rotation_formula
350
+ p_sphere_new = p_sphere * torch.cos(rot_angle) + \
351
+ torch.cross(rot_axis, p_sphere, dim=-1) * torch.sin(rot_angle) + \
352
+ rot_axis * torch.sum(rot_axis * p_sphere, dim=-1, keepdim=True) * (1. - torch.cos(rot_angle))
353
+ p_sphere_new = p_sphere_new / torch.norm(p_sphere_new, dim=-1, keepdim=True)
354
+ pts = torch.cat((p_sphere_new, depth.unsqueeze(-1)), dim=-1)
355
+
356
+ return pts
357
+
358
+ def gradient(inputs, outputs):
359
+
360
+ d_points = torch.ones_like(outputs, requires_grad=False, device=outputs.device)
361
+ points_grad = grad(
362
+ outputs=outputs,
363
+ inputs=inputs,
364
+ grad_outputs=d_points,
365
+ create_graph=True,
366
+ retain_graph=True,
367
+ only_inputs=True)[0][:, :, -3:]
368
+ return points_grad
code/lib/smpl/body_models.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import print_function
20
+ from __future__ import division
21
+
22
+ import os
23
+ import os.path as osp
24
+
25
+
26
+ import pickle
27
+
28
+ import numpy as np
29
+
30
+ from collections import namedtuple
31
+
32
+ import torch
33
+ import torch.nn as nn
34
+
35
+ from .lbs import (
36
+ lbs, vertices2joints, blend_shapes)
37
+
38
+ from .vertex_ids import vertex_ids as VERTEX_IDS
39
+ from .utils import Struct, to_np, to_tensor
40
+ from .vertex_joint_selector import VertexJointSelector
41
+
42
+
43
+ ModelOutput = namedtuple('ModelOutput',
44
+ ['vertices','faces', 'joints', 'full_pose', 'betas',
45
+ 'global_orient',
46
+ 'body_pose', 'expression',
47
+ 'left_hand_pose', 'right_hand_pose',
48
+ 'jaw_pose', 'T', 'T_weighted', 'weights'])
49
+ ModelOutput.__new__.__defaults__ = (None,) * len(ModelOutput._fields)
50
+
51
+ class SMPL(nn.Module):
52
+
53
+ NUM_JOINTS = 23
54
+ NUM_BODY_JOINTS = 23
55
+ NUM_BETAS = 10
56
+
57
+ def __init__(self, model_path, data_struct=None,
58
+ create_betas=True,
59
+ betas=None,
60
+ create_global_orient=True,
61
+ global_orient=None,
62
+ create_body_pose=True,
63
+ body_pose=None,
64
+ create_transl=True,
65
+ transl=None,
66
+ dtype=torch.float32,
67
+ batch_size=1,
68
+ joint_mapper=None, gender='neutral',
69
+ vertex_ids=None,
70
+ pose_blend=True,
71
+ **kwargs):
72
+ ''' SMPL model constructor
73
+
74
+ Parameters
75
+ ----------
76
+ model_path: str
77
+ The path to the folder or to the file where the model
78
+ parameters are stored
79
+ data_struct: Strct
80
+ A struct object. If given, then the parameters of the model are
81
+ read from the object. Otherwise, the model tries to read the
82
+ parameters from the given `model_path`. (default = None)
83
+ create_global_orient: bool, optional
84
+ Flag for creating a member variable for the global orientation
85
+ of the body. (default = True)
86
+ global_orient: torch.tensor, optional, Bx3
87
+ The default value for the global orientation variable.
88
+ (default = None)
89
+ create_body_pose: bool, optional
90
+ Flag for creating a member variable for the pose of the body.
91
+ (default = True)
92
+ body_pose: torch.tensor, optional, Bx(Body Joints * 3)
93
+ The default value for the body pose variable.
94
+ (default = None)
95
+ create_betas: bool, optional
96
+ Flag for creating a member variable for the shape space
97
+ (default = True).
98
+ betas: torch.tensor, optional, Bx10
99
+ The default value for the shape member variable.
100
+ (default = None)
101
+ create_transl: bool, optional
102
+ Flag for creating a member variable for the translation
103
+ of the body. (default = True)
104
+ transl: torch.tensor, optional, Bx3
105
+ The default value for the transl variable.
106
+ (default = None)
107
+ dtype: torch.dtype, optional
108
+ The data type for the created variables
109
+ batch_size: int, optional
110
+ The batch size used for creating the member variables
111
+ joint_mapper: object, optional
112
+ An object that re-maps the joints. Useful if one wants to
113
+ re-order the SMPL joints to some other convention (e.g. MSCOCO)
114
+ (default = None)
115
+ gender: str, optional
116
+ Which gender to load
117
+ vertex_ids: dict, optional
118
+ A dictionary containing the indices of the extra vertices that
119
+ will be selected
120
+ '''
121
+
122
+ self.gender = gender
123
+ self.pose_blend = pose_blend
124
+
125
+ if data_struct is None:
126
+ if osp.isdir(model_path):
127
+ model_fn = 'SMPL_{}.{ext}'.format(gender.upper(), ext='pkl')
128
+ smpl_path = os.path.join(model_path, model_fn)
129
+ else:
130
+ smpl_path = model_path
131
+ assert osp.exists(smpl_path), 'Path {} does not exist!'.format(
132
+ smpl_path)
133
+
134
+ with open(smpl_path, 'rb') as smpl_file:
135
+ data_struct = Struct(**pickle.load(smpl_file,encoding='latin1'))
136
+ super(SMPL, self).__init__()
137
+ self.batch_size = batch_size
138
+
139
+ if vertex_ids is None:
140
+ # SMPL and SMPL-H share the same topology, so any extra joints can
141
+ # be drawn from the same place
142
+ vertex_ids = VERTEX_IDS['smplh']
143
+
144
+ self.dtype = dtype
145
+
146
+ self.joint_mapper = joint_mapper
147
+
148
+ self.vertex_joint_selector = VertexJointSelector(
149
+ vertex_ids=vertex_ids, **kwargs)
150
+
151
+ self.faces = data_struct.f
152
+ self.register_buffer('faces_tensor',
153
+ to_tensor(to_np(self.faces, dtype=np.int64),
154
+ dtype=torch.long))
155
+
156
+ if create_betas:
157
+ if betas is None:
158
+ default_betas = torch.zeros([batch_size, self.NUM_BETAS],
159
+ dtype=dtype)
160
+ else:
161
+ if 'torch.Tensor' in str(type(betas)):
162
+ default_betas = betas.clone().detach()
163
+ else:
164
+ default_betas = torch.tensor(betas,
165
+ dtype=dtype)
166
+
167
+ self.register_parameter('betas', nn.Parameter(default_betas,
168
+ requires_grad=True))
169
+
170
+ # The tensor that contains the global rotation of the model
171
+ # It is separated from the pose of the joints in case we wish to
172
+ # optimize only over one of them
173
+ if create_global_orient:
174
+ if global_orient is None:
175
+ default_global_orient = torch.zeros([batch_size, 3],
176
+ dtype=dtype)
177
+ else:
178
+ if 'torch.Tensor' in str(type(global_orient)):
179
+ default_global_orient = global_orient.clone().detach()
180
+ else:
181
+ default_global_orient = torch.tensor(global_orient,
182
+ dtype=dtype)
183
+
184
+ global_orient = nn.Parameter(default_global_orient,
185
+ requires_grad=True)
186
+ self.register_parameter('global_orient', global_orient)
187
+
188
+ if create_body_pose:
189
+ if body_pose is None:
190
+ default_body_pose = torch.zeros(
191
+ [batch_size, self.NUM_BODY_JOINTS * 3], dtype=dtype)
192
+ else:
193
+ if 'torch.Tensor' in str(type(body_pose)):
194
+ default_body_pose = body_pose.clone().detach()
195
+ else:
196
+ default_body_pose = torch.tensor(body_pose,
197
+ dtype=dtype)
198
+ self.register_parameter(
199
+ 'body_pose',
200
+ nn.Parameter(default_body_pose, requires_grad=True))
201
+
202
+ if create_transl:
203
+ if transl is None:
204
+ default_transl = torch.zeros([batch_size, 3],
205
+ dtype=dtype,
206
+ requires_grad=True)
207
+ else:
208
+ default_transl = torch.tensor(transl, dtype=dtype)
209
+ self.register_parameter(
210
+ 'transl',
211
+ nn.Parameter(default_transl, requires_grad=True))
212
+
213
+ # The vertices of the template model
214
+ self.register_buffer('v_template',
215
+ to_tensor(to_np(data_struct.v_template),
216
+ dtype=dtype))
217
+
218
+ # The shape components
219
+ shapedirs = data_struct.shapedirs[:, :, :self.NUM_BETAS]
220
+ # The shape components
221
+ self.register_buffer(
222
+ 'shapedirs',
223
+ to_tensor(to_np(shapedirs), dtype=dtype))
224
+
225
+
226
+ j_regressor = to_tensor(to_np(
227
+ data_struct.J_regressor), dtype=dtype)
228
+ self.register_buffer('J_regressor', j_regressor)
229
+
230
+ # if self.gender == 'neutral':
231
+ # joint_regressor = to_tensor(to_np(
232
+ # data_struct.cocoplus_regressor), dtype=dtype).permute(1,0)
233
+ # self.register_buffer('joint_regressor', joint_regressor)
234
+
235
+ # Pose blend shape basis: 6890 x 3 x 207, reshaped to 6890*3 x 207
236
+ num_pose_basis = data_struct.posedirs.shape[-1]
237
+ # 207 x 20670
238
+ posedirs = np.reshape(data_struct.posedirs, [-1, num_pose_basis]).T
239
+ self.register_buffer('posedirs',
240
+ to_tensor(to_np(posedirs), dtype=dtype))
241
+
242
+ # indices of parents for each joints
243
+ parents = to_tensor(to_np(data_struct.kintree_table[0])).long()
244
+ parents[0] = -1
245
+ self.register_buffer('parents', parents)
246
+
247
+ self.bone_parents = to_np(data_struct.kintree_table[0])
248
+
249
+ self.register_buffer('lbs_weights',
250
+ to_tensor(to_np(data_struct.weights), dtype=dtype))
251
+
252
+ def create_mean_pose(self, data_struct):
253
+ pass
254
+
255
+ @torch.no_grad()
256
+ def reset_params(self, **params_dict):
257
+ for param_name, param in self.named_parameters():
258
+ if param_name in params_dict:
259
+ param[:] = torch.tensor(params_dict[param_name])
260
+ else:
261
+ param.fill_(0)
262
+
263
+ def get_T_hip(self, betas=None):
264
+ v_shaped = self.v_template + blend_shapes(betas, self.shapedirs)
265
+ J = vertices2joints(self.J_regressor, v_shaped)
266
+ T_hip = J[0,0]
267
+ return T_hip
268
+
269
+ def get_num_verts(self):
270
+ return self.v_template.shape[0]
271
+
272
+ def get_num_faces(self):
273
+ return self.faces.shape[0]
274
+
275
+ def extra_repr(self):
276
+ return 'Number of betas: {}'.format(self.NUM_BETAS)
277
+
278
+ def forward(self, betas=None, body_pose=None, global_orient=None,
279
+ transl=None, return_verts=True, return_full_pose=False,displacement=None,v_template=None,
280
+ **kwargs):
281
+ ''' Forward pass for the SMPL model
282
+
283
+ Parameters
284
+ ----------
285
+ global_orient: torch.tensor, optional, shape Bx3
286
+ If given, ignore the member variable and use it as the global
287
+ rotation of the body. Useful if someone wishes to predicts this
288
+ with an external model. (default=None)
289
+ betas: torch.tensor, optional, shape Bx10
290
+ If given, ignore the member variable `betas` and use it
291
+ instead. For example, it can used if shape parameters
292
+ `betas` are predicted from some external model.
293
+ (default=None)
294
+ body_pose: torch.tensor, optional, shape Bx(J*3)
295
+ If given, ignore the member variable `body_pose` and use it
296
+ instead. For example, it can used if someone predicts the
297
+ pose of the body joints are predicted from some external model.
298
+ It should be a tensor that contains joint rotations in
299
+ axis-angle format. (default=None)
300
+ transl: torch.tensor, optional, shape Bx3
301
+ If given, ignore the member variable `transl` and use it
302
+ instead. For example, it can used if the translation
303
+ `transl` is predicted from some external model.
304
+ (default=None)
305
+ return_verts: bool, optional
306
+ Return the vertices. (default=True)
307
+ return_full_pose: bool, optional
308
+ Returns the full axis-angle pose vector (default=False)
309
+
310
+ Returns
311
+ -------
312
+ '''
313
+ # If no shape and pose parameters are passed along, then use the
314
+ # ones from the module
315
+ global_orient = (global_orient if global_orient is not None else
316
+ self.global_orient)
317
+ body_pose = body_pose if body_pose is not None else self.body_pose
318
+ betas = betas if betas is not None else self.betas
319
+
320
+ apply_trans = transl is not None or hasattr(self, 'transl')
321
+ if transl is None and hasattr(self, 'transl'):
322
+ transl = self.transl
323
+
324
+ full_pose = torch.cat([global_orient, body_pose], dim=1)
325
+
326
+ # if betas.shape[0] != self.batch_size:
327
+ # num_repeats = int(self.batch_size / betas.shape[0])
328
+ # betas = betas.expand(num_repeats, -1)
329
+
330
+ if v_template is None:
331
+ v_template = self.v_template
332
+
333
+ if displacement is not None:
334
+ vertices, joints_smpl, T_weighted, W, T = lbs(betas, full_pose, v_template+displacement,
335
+ self.shapedirs, self.posedirs,
336
+ self.J_regressor, self.parents,
337
+ self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend)
338
+ else:
339
+ vertices, joints_smpl,T_weighted, W, T = lbs(betas, full_pose, v_template,
340
+ self.shapedirs, self.posedirs,
341
+ self.J_regressor, self.parents,
342
+ self.lbs_weights, dtype=self.dtype,pose_blend=self.pose_blend)
343
+
344
+ # if self.gender is not 'neutral':
345
+ joints = self.vertex_joint_selector(vertices, joints_smpl)
346
+ # else:
347
+ # joints = torch.matmul(vertices.permute(0,2,1),self.joint_regressor).permute(0,2,1)
348
+ # Map the joints to the current dataset
349
+ if self.joint_mapper is not None:
350
+ joints = self.joint_mapper(joints)
351
+
352
+ if apply_trans:
353
+ joints_smpl = joints_smpl + transl.unsqueeze(dim=1)
354
+ joints = joints + transl.unsqueeze(dim=1)
355
+ vertices = vertices + transl.unsqueeze(dim=1)
356
+
357
+ output = ModelOutput(vertices=vertices if return_verts else None,
358
+ faces=self.faces,
359
+ global_orient=global_orient,
360
+ body_pose=body_pose,
361
+ joints=joints_smpl,
362
+ betas=self.betas,
363
+ full_pose=full_pose if return_full_pose else None,
364
+ T=T, T_weighted=T_weighted, weights=W)
365
+ return output
code/lib/smpl/lbs.py ADDED
@@ -0,0 +1,377 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import print_function
20
+ from __future__ import division
21
+
22
+ import numpy as np
23
+
24
+ import torch
25
+ import torch.nn.functional as F
26
+
27
+ from .utils import rot_mat_to_euler
28
+
29
+
30
+ def find_dynamic_lmk_idx_and_bcoords(vertices, pose, dynamic_lmk_faces_idx,
31
+ dynamic_lmk_b_coords,
32
+ neck_kin_chain, dtype=torch.float32):
33
+ ''' Compute the faces, barycentric coordinates for the dynamic landmarks
34
+
35
+
36
+ To do so, we first compute the rotation of the neck around the y-axis
37
+ and then use a pre-computed look-up table to find the faces and the
38
+ barycentric coordinates that will be used.
39
+
40
+ Special thanks to Soubhik Sanyal (soubhik.sanyal@tuebingen.mpg.de)
41
+ for providing the original TensorFlow implementation and for the LUT.
42
+
43
+ Parameters
44
+ ----------
45
+ vertices: torch.tensor BxVx3, dtype = torch.float32
46
+ The tensor of input vertices
47
+ pose: torch.tensor Bx(Jx3), dtype = torch.float32
48
+ The current pose of the body model
49
+ dynamic_lmk_faces_idx: torch.tensor L, dtype = torch.long
50
+ The look-up table from neck rotation to faces
51
+ dynamic_lmk_b_coords: torch.tensor Lx3, dtype = torch.float32
52
+ The look-up table from neck rotation to barycentric coordinates
53
+ neck_kin_chain: list
54
+ A python list that contains the indices of the joints that form the
55
+ kinematic chain of the neck.
56
+ dtype: torch.dtype, optional
57
+
58
+ Returns
59
+ -------
60
+ dyn_lmk_faces_idx: torch.tensor, dtype = torch.long
61
+ A tensor of size BxL that contains the indices of the faces that
62
+ will be used to compute the current dynamic landmarks.
63
+ dyn_lmk_b_coords: torch.tensor, dtype = torch.float32
64
+ A tensor of size BxL that contains the indices of the faces that
65
+ will be used to compute the current dynamic landmarks.
66
+ '''
67
+
68
+ batch_size = vertices.shape[0]
69
+
70
+ aa_pose = torch.index_select(pose.view(batch_size, -1, 3), 1,
71
+ neck_kin_chain)
72
+ rot_mats = batch_rodrigues(
73
+ aa_pose.view(-1, 3), dtype=dtype).view(batch_size, -1, 3, 3)
74
+
75
+ rel_rot_mat = torch.eye(3, device=vertices.device,
76
+ dtype=dtype).unsqueeze_(dim=0)
77
+ for idx in range(len(neck_kin_chain)):
78
+ rel_rot_mat = torch.bmm(rot_mats[:, idx], rel_rot_mat)
79
+
80
+ y_rot_angle = torch.round(
81
+ torch.clamp(-rot_mat_to_euler(rel_rot_mat) * 180.0 / np.pi,
82
+ max=39)).to(dtype=torch.long)
83
+ neg_mask = y_rot_angle.lt(0).to(dtype=torch.long)
84
+ mask = y_rot_angle.lt(-39).to(dtype=torch.long)
85
+ neg_vals = mask * 78 + (1 - mask) * (39 - y_rot_angle)
86
+ y_rot_angle = (neg_mask * neg_vals +
87
+ (1 - neg_mask) * y_rot_angle)
88
+
89
+ dyn_lmk_faces_idx = torch.index_select(dynamic_lmk_faces_idx,
90
+ 0, y_rot_angle)
91
+ dyn_lmk_b_coords = torch.index_select(dynamic_lmk_b_coords,
92
+ 0, y_rot_angle)
93
+
94
+ return dyn_lmk_faces_idx, dyn_lmk_b_coords
95
+
96
+
97
+ def vertices2landmarks(vertices, faces, lmk_faces_idx, lmk_bary_coords):
98
+ ''' Calculates landmarks by barycentric interpolation
99
+
100
+ Parameters
101
+ ----------
102
+ vertices: torch.tensor BxVx3, dtype = torch.float32
103
+ The tensor of input vertices
104
+ faces: torch.tensor Fx3, dtype = torch.long
105
+ The faces of the mesh
106
+ lmk_faces_idx: torch.tensor L, dtype = torch.long
107
+ The tensor with the indices of the faces used to calculate the
108
+ landmarks.
109
+ lmk_bary_coords: torch.tensor Lx3, dtype = torch.float32
110
+ The tensor of barycentric coordinates that are used to interpolate
111
+ the landmarks
112
+
113
+ Returns
114
+ -------
115
+ landmarks: torch.tensor BxLx3, dtype = torch.float32
116
+ The coordinates of the landmarks for each mesh in the batch
117
+ '''
118
+ # Extract the indices of the vertices for each face
119
+ # BxLx3
120
+ batch_size, num_verts = vertices.shape[:2]
121
+ device = vertices.device
122
+
123
+ lmk_faces = torch.index_select(faces, 0, lmk_faces_idx.view(-1)).expand(
124
+ batch_size, -1, -1).long()
125
+
126
+ lmk_faces = lmk_faces + torch.arange(
127
+ batch_size, dtype=torch.long, device=device).view(-1, 1, 1) * num_verts
128
+
129
+ lmk_vertices = vertices.view(-1, 3)[lmk_faces].view(
130
+ batch_size, -1, 3, 3)
131
+
132
+ landmarks = torch.einsum('blfi,blf->bli', [lmk_vertices, lmk_bary_coords])
133
+ return landmarks
134
+
135
+
136
+ def lbs(betas, pose, v_template, shapedirs, posedirs, J_regressor, parents,
137
+ lbs_weights, pose2rot=True, dtype=torch.float32, pose_blend=True):
138
+ ''' Performs Linear Blend Skinning with the given shape and pose parameters
139
+
140
+ Parameters
141
+ ----------
142
+ betas : torch.tensor BxNB
143
+ The tensor of shape parameters
144
+ pose : torch.tensor Bx(J + 1) * 3
145
+ The pose parameters in axis-angle format
146
+ v_template torch.tensor BxVx3
147
+ The template mesh that will be deformed
148
+ shapedirs : torch.tensor 1xNB
149
+ The tensor of PCA shape displacements
150
+ posedirs : torch.tensor Px(V * 3)
151
+ The pose PCA coefficients
152
+ J_regressor : torch.tensor JxV
153
+ The regressor array that is used to calculate the joints from
154
+ the position of the vertices
155
+ parents: torch.tensor J
156
+ The array that describes the kinematic tree for the model
157
+ lbs_weights: torch.tensor N x V x (J + 1)
158
+ The linear blend skinning weights that represent how much the
159
+ rotation matrix of each part affects each vertex
160
+ pose2rot: bool, optional
161
+ Flag on whether to convert the input pose tensor to rotation
162
+ matrices. The default value is True. If False, then the pose tensor
163
+ should already contain rotation matrices and have a size of
164
+ Bx(J + 1)x9
165
+ dtype: torch.dtype, optional
166
+
167
+ Returns
168
+ -------
169
+ verts: torch.tensor BxVx3
170
+ The vertices of the mesh after applying the shape and pose
171
+ displacements.
172
+ joints: torch.tensor BxJx3
173
+ The joints of the model
174
+ '''
175
+
176
+ batch_size = max(betas.shape[0], pose.shape[0])
177
+ device = betas.device
178
+
179
+ # Add shape contribution
180
+ v_shaped = v_template + blend_shapes(betas, shapedirs)
181
+
182
+ # Get the joints
183
+ # NxJx3 array
184
+ J = vertices2joints(J_regressor, v_shaped)
185
+
186
+ # 3. Add pose blend shapes
187
+ # N x J x 3 x 3
188
+ ident = torch.eye(3, dtype=dtype, device=device)
189
+
190
+
191
+ if pose2rot:
192
+ rot_mats = batch_rodrigues(
193
+ pose.view(-1, 3), dtype=dtype).view([batch_size, -1, 3, 3])
194
+
195
+ pose_feature = (rot_mats[:, 1:, :, :] - ident).view([batch_size, -1])
196
+ # (N x P) x (P, V * 3) -> N x V x 3
197
+ pose_offsets = torch.matmul(pose_feature, posedirs) \
198
+ .view(batch_size, -1, 3)
199
+ else:
200
+ pose_feature = pose[:, 1:].view(batch_size, -1, 3, 3) - ident
201
+ rot_mats = pose.view(batch_size, -1, 3, 3)
202
+
203
+ pose_offsets = torch.matmul(pose_feature.view(batch_size, -1),
204
+ posedirs).view(batch_size, -1, 3)
205
+
206
+ if pose_blend:
207
+ v_posed = pose_offsets + v_shaped
208
+ else:
209
+ v_posed = v_shaped
210
+
211
+ # 4. Get the global joint location
212
+ J_transformed, A = batch_rigid_transform(rot_mats, J, parents, dtype=dtype)
213
+
214
+ # 5. Do skinning:
215
+ # W is N x V x (J + 1)
216
+ W = lbs_weights.unsqueeze(dim=0).expand([batch_size, -1, -1])
217
+ # (N x V x (J + 1)) x (N x (J + 1) x 16)
218
+ num_joints = J_regressor.shape[0]
219
+ T = torch.matmul(W, A.view(batch_size, num_joints, 16)) \
220
+ .view(batch_size, -1, 4, 4)
221
+
222
+ homogen_coord = torch.ones([batch_size, v_posed.shape[1], 1],
223
+ dtype=dtype, device=device)
224
+ v_posed_homo = torch.cat([v_posed, homogen_coord], dim=2)
225
+ v_homo = torch.matmul(T, torch.unsqueeze(v_posed_homo, dim=-1))
226
+
227
+ verts = v_homo[:, :, :3, 0]
228
+
229
+ return verts, J_transformed, T, W, A.view(batch_size, num_joints, 4,4)
230
+
231
+
232
+ def vertices2joints(J_regressor, vertices):
233
+ ''' Calculates the 3D joint locations from the vertices
234
+
235
+ Parameters
236
+ ----------
237
+ J_regressor : torch.tensor JxV
238
+ The regressor array that is used to calculate the joints from the
239
+ position of the vertices
240
+ vertices : torch.tensor BxVx3
241
+ The tensor of mesh vertices
242
+
243
+ Returns
244
+ -------
245
+ torch.tensor BxJx3
246
+ The location of the joints
247
+ '''
248
+
249
+ return torch.einsum('bik,ji->bjk', [vertices, J_regressor])
250
+
251
+
252
+ def blend_shapes(betas, shape_disps):
253
+ ''' Calculates the per vertex displacement due to the blend shapes
254
+
255
+
256
+ Parameters
257
+ ----------
258
+ betas : torch.tensor Bx(num_betas)
259
+ Blend shape coefficients
260
+ shape_disps: torch.tensor Vx3x(num_betas)
261
+ Blend shapes
262
+
263
+ Returns
264
+ -------
265
+ torch.tensor BxVx3
266
+ The per-vertex displacement due to shape deformation
267
+ '''
268
+
269
+ # Displacement[b, m, k] = sum_{l} betas[b, l] * shape_disps[m, k, l]
270
+ # i.e. Multiply each shape displacement by its corresponding beta and
271
+ # then sum them.
272
+ blend_shape = torch.einsum('bl,mkl->bmk', [betas, shape_disps])
273
+ return blend_shape
274
+
275
+
276
+ def batch_rodrigues(rot_vecs, epsilon=1e-8, dtype=torch.float32):
277
+ ''' Calculates the rotation matrices for a batch of rotation vectors
278
+ Parameters
279
+ ----------
280
+ rot_vecs: torch.tensor Nx3
281
+ array of N axis-angle vectors
282
+ Returns
283
+ -------
284
+ R: torch.tensor Nx3x3
285
+ The rotation matrices for the given axis-angle parameters
286
+ '''
287
+
288
+ batch_size = rot_vecs.shape[0]
289
+ device = rot_vecs.device
290
+
291
+ angle = torch.norm(rot_vecs + 1e-8, dim=1, keepdim=True)
292
+ rot_dir = rot_vecs / angle
293
+
294
+ cos = torch.unsqueeze(torch.cos(angle), dim=1)
295
+ sin = torch.unsqueeze(torch.sin(angle), dim=1)
296
+
297
+ # Bx1 arrays
298
+ rx, ry, rz = torch.split(rot_dir, 1, dim=1)
299
+ K = torch.zeros((batch_size, 3, 3), dtype=dtype, device=device)
300
+
301
+ zeros = torch.zeros((batch_size, 1), dtype=dtype, device=device)
302
+ K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=1) \
303
+ .view((batch_size, 3, 3))
304
+
305
+ ident = torch.eye(3, dtype=dtype, device=device).unsqueeze(dim=0)
306
+ rot_mat = ident + sin * K + (1 - cos) * torch.bmm(K, K)
307
+ return rot_mat
308
+
309
+
310
+ def transform_mat(R, t):
311
+ ''' Creates a batch of transformation matrices
312
+ Args:
313
+ - R: Bx3x3 array of a batch of rotation matrices
314
+ - t: Bx3x1 array of a batch of translation vectors
315
+ Returns:
316
+ - T: Bx4x4 Transformation matrix
317
+ '''
318
+ # No padding left or right, only add an extra row
319
+ return torch.cat([F.pad(R, [0, 0, 0, 1]),
320
+ F.pad(t, [0, 0, 0, 1], value=1)], dim=2)
321
+
322
+
323
+ def batch_rigid_transform(rot_mats, joints, parents, dtype=torch.float32):
324
+ """
325
+ Applies a batch of rigid transformations to the joints
326
+
327
+ Parameters
328
+ ----------
329
+ rot_mats : torch.tensor BxNx3x3
330
+ Tensor of rotation matrices
331
+ joints : torch.tensor BxNx3
332
+ Locations of joints
333
+ parents : torch.tensor BxN
334
+ The kinematic tree of each object
335
+ dtype : torch.dtype, optional:
336
+ The data type of the created tensors, the default is torch.float32
337
+
338
+ Returns
339
+ -------
340
+ posed_joints : torch.tensor BxNx3
341
+ The locations of the joints after applying the pose rotations
342
+ rel_transforms : torch.tensor BxNx4x4
343
+ The relative (with respect to the root joint) rigid transformations
344
+ for all the joints
345
+ """
346
+
347
+ joints = torch.unsqueeze(joints, dim=-1)
348
+
349
+ rel_joints = joints.clone()
350
+ rel_joints[:, 1:] = rel_joints[:, 1:] - joints[:, parents[1:]]
351
+
352
+ transforms_mat = transform_mat(
353
+ rot_mats.reshape(-1, 3, 3),
354
+ rel_joints.reshape(-1, 3, 1)).reshape(-1, joints.shape[1], 4, 4)
355
+
356
+ transform_chain = [transforms_mat[:, 0]]
357
+ for i in range(1, parents.shape[0]):
358
+ # Subtract the joint location at the rest pose
359
+ # No need for rotation, since it's identity when at rest
360
+ curr_res = torch.matmul(transform_chain[parents[i]],
361
+ transforms_mat[:, i])
362
+ transform_chain.append(curr_res)
363
+
364
+ transforms = torch.stack(transform_chain, dim=1)
365
+
366
+ # The last column of the transformations contains the posed joints
367
+ posed_joints = transforms[:, :, :3, 3]
368
+
369
+ # The last column of the transformations contains the posed joints
370
+ posed_joints = transforms[:, :, :3, 3]
371
+
372
+ joints_homogen = F.pad(joints, [0, 0, 0, 1])
373
+
374
+ rel_transforms = transforms - F.pad(
375
+ torch.matmul(transforms, joints_homogen), [3, 0, 0, 0, 0, 0, 0, 0])
376
+
377
+ return posed_joints, rel_transforms
code/lib/smpl/smpl_model/SMPL_FEMALE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a583c1b98e4afc19042641f1bae5cd8a1f712a6724886291a7627ec07acd408d
3
+ size 39056454
code/lib/smpl/smpl_model/SMPL_MALE.pkl ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0e8c0bbbbc635dcb166ed29c303fb4bef16ea5f623e5a89263495a9e403575bd
3
+ size 39056404
code/lib/smpl/utils.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import print_function
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+
22
+ import numpy as np
23
+ import torch
24
+
25
+
26
+ def to_tensor(array, dtype=torch.float32):
27
+ if 'torch.tensor' not in str(type(array)):
28
+ return torch.tensor(array, dtype=dtype)
29
+
30
+
31
+ class Struct(object):
32
+ def __init__(self, **kwargs):
33
+ for key, val in kwargs.items():
34
+ setattr(self, key, val)
35
+
36
+
37
+ def to_np(array, dtype=np.float32):
38
+ if 'scipy.sparse' in str(type(array)):
39
+ array = array.todense()
40
+ return np.array(array, dtype=dtype)
41
+
42
+
43
+ def rot_mat_to_euler(rot_mats):
44
+ # Calculates rotation matrix to euler angles
45
+ # Careful for extreme cases of eular angles like [0.0, pi, 0.0]
46
+
47
+ sy = torch.sqrt(rot_mats[:, 0, 0] * rot_mats[:, 0, 0] +
48
+ rot_mats[:, 1, 0] * rot_mats[:, 1, 0])
49
+ return torch.atan2(-rot_mats[:, 2, 0], sy)
code/lib/smpl/vertex_ids.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import print_function
19
+ from __future__ import absolute_import
20
+ from __future__ import division
21
+
22
+ # Joint name to vertex mapping. SMPL/SMPL-H/SMPL-X vertices that correspond to
23
+ # MSCOCO and OpenPose joints
24
+ vertex_ids = {
25
+ 'smplh': {
26
+ 'nose': 332,
27
+ 'reye': 6260,
28
+ 'leye': 2800,
29
+ 'rear': 4071,
30
+ 'lear': 583,
31
+ 'rthumb': 6191,
32
+ 'rindex': 5782,
33
+ 'rmiddle': 5905,
34
+ 'rring': 6016,
35
+ 'rpinky': 6133,
36
+ 'lthumb': 2746,
37
+ 'lindex': 2319,
38
+ 'lmiddle': 2445,
39
+ 'lring': 2556,
40
+ 'lpinky': 2673,
41
+ 'LBigToe': 3216,
42
+ 'LSmallToe': 3226,
43
+ 'LHeel': 3387,
44
+ 'RBigToe': 6617,
45
+ 'RSmallToe': 6624,
46
+ 'RHeel': 6787
47
+ },
48
+ 'smplx': {
49
+ 'nose': 9120,
50
+ 'reye': 9929,
51
+ 'leye': 9448,
52
+ 'rear': 616,
53
+ 'lear': 6,
54
+ 'rthumb': 8079,
55
+ 'rindex': 7669,
56
+ 'rmiddle': 7794,
57
+ 'rring': 7905,
58
+ 'rpinky': 8022,
59
+ 'lthumb': 5361,
60
+ 'lindex': 4933,
61
+ 'lmiddle': 5058,
62
+ 'lring': 5169,
63
+ 'lpinky': 5286,
64
+ 'LBigToe': 5770,
65
+ 'LSmallToe': 5780,
66
+ 'LHeel': 8846,
67
+ 'RBigToe': 8463,
68
+ 'RSmallToe': 8474,
69
+ 'RHeel': 8635
70
+ }
71
+ }
code/lib/smpl/vertex_joint_selector.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+
3
+ # Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
4
+ # holder of all proprietary rights on this computer program.
5
+ # You can only use this computer program if you have closed
6
+ # a license agreement with MPG or you get the right to use the computer
7
+ # program from someone who is authorized to grant you that right.
8
+ # Any use of the computer program without a valid license is prohibited and
9
+ # liable to prosecution.
10
+ #
11
+ # Copyright©2019 Max-Planck-Gesellschaft zur Förderung
12
+ # der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
13
+ # for Intelligent Systems and the Max Planck Institute for Biological
14
+ # Cybernetics. All rights reserved.
15
+ #
16
+ # Contact: ps-license@tuebingen.mpg.de
17
+
18
+ from __future__ import absolute_import
19
+ from __future__ import print_function
20
+ from __future__ import division
21
+
22
+ import numpy as np
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+
27
+ from .utils import to_tensor
28
+
29
+
30
+ class VertexJointSelector(nn.Module):
31
+
32
+ def __init__(self, vertex_ids=None,
33
+ use_hands=True,
34
+ use_feet_keypoints=True, **kwargs):
35
+ super(VertexJointSelector, self).__init__()
36
+
37
+ extra_joints_idxs = []
38
+
39
+ face_keyp_idxs = np.array([
40
+ vertex_ids['nose'],
41
+ vertex_ids['reye'],
42
+ vertex_ids['leye'],
43
+ vertex_ids['rear'],
44
+ vertex_ids['lear']], dtype=np.int64)
45
+
46
+ extra_joints_idxs = np.concatenate([extra_joints_idxs,
47
+ face_keyp_idxs])
48
+
49
+ if use_feet_keypoints:
50
+ feet_keyp_idxs = np.array([vertex_ids['LBigToe'],
51
+ vertex_ids['LSmallToe'],
52
+ vertex_ids['LHeel'],
53
+ vertex_ids['RBigToe'],
54
+ vertex_ids['RSmallToe'],
55
+ vertex_ids['RHeel']], dtype=np.int32)
56
+
57
+ extra_joints_idxs = np.concatenate(
58
+ [extra_joints_idxs, feet_keyp_idxs])
59
+
60
+ if use_hands:
61
+ self.tip_names = ['thumb', 'index', 'middle', 'ring', 'pinky']
62
+
63
+ tips_idxs = []
64
+ for hand_id in ['l', 'r']:
65
+ for tip_name in self.tip_names:
66
+ tips_idxs.append(vertex_ids[hand_id + tip_name])
67
+
68
+ extra_joints_idxs = np.concatenate(
69
+ [extra_joints_idxs, tips_idxs])
70
+
71
+ self.register_buffer('extra_joints_idxs',
72
+ to_tensor(extra_joints_idxs, dtype=torch.long))
73
+
74
+ def forward(self, vertices, joints):
75
+ extra_joints = torch.index_select(vertices, 1, self.extra_joints_idxs)
76
+ joints = torch.cat([joints, extra_joints], dim=1)
77
+ return joints
code/lib/utils/meshing.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from skimage import measure
4
+ from lib.libmise import mise
5
+ import trimesh
6
+
7
+ def generate_mesh(func, verts, level_set=0, res_init=32, res_up=3, point_batch=5000):
8
+
9
+ scale = 1.1 # Scale of the padded bbox regarding the tight one.
10
+ verts = verts.data.cpu().numpy()
11
+
12
+ gt_bbox = np.stack([verts.min(axis=0), verts.max(axis=0)], axis=0)
13
+ gt_center = (gt_bbox[0] + gt_bbox[1]) * 0.5
14
+ gt_scale = (gt_bbox[1] - gt_bbox[0]).max()
15
+
16
+ mesh_extractor = mise.MISE(res_init, res_up, level_set)
17
+ points = mesh_extractor.query()
18
+
19
+ # query occupancy grid
20
+ while points.shape[0] != 0:
21
+
22
+ orig_points = points
23
+ points = points.astype(np.float32)
24
+ points = (points / mesh_extractor.resolution - 0.5) * scale
25
+ points = points * gt_scale + gt_center
26
+ points = torch.tensor(points).float().cuda()
27
+
28
+ values = []
29
+ for _, pnts in enumerate((torch.split(points,point_batch,dim=0))):
30
+ out = func(pnts)
31
+ values.append(out['sdf'].data.cpu().numpy())
32
+ values = np.concatenate(values, axis=0).astype(np.float64)[:,0]
33
+
34
+ mesh_extractor.update(orig_points, values)
35
+
36
+ points = mesh_extractor.query()
37
+
38
+ value_grid = mesh_extractor.to_dense()
39
+
40
+ # marching cube
41
+ verts, faces, normals, values = measure.marching_cubes_lewiner(
42
+ volume=value_grid,
43
+ gradient_direction='ascent',
44
+ level=level_set)
45
+
46
+ verts = (verts / mesh_extractor.resolution - 0.5) * scale
47
+ verts = verts * gt_scale + gt_center
48
+ faces = faces[:, [0,2,1]]
49
+ meshexport = trimesh.Trimesh(verts, faces, normals, vertex_colors=values)
50
+
51
+ #remove disconnect part
52
+ connected_comp = meshexport.split(only_watertight=False)
53
+ max_area = 0
54
+ max_comp = None
55
+ for comp in connected_comp:
56
+ if comp.area > max_area:
57
+ max_area = comp.area
58
+ max_comp = comp
59
+ meshexport = max_comp
60
+
61
+ return meshexport
62
+
63
+
code/lib/utils/utils.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+ import torch
4
+ from torch.nn import functional as F
5
+
6
+
7
+ def split_input(model_input, total_pixels, n_pixels = 10000):
8
+ '''
9
+ Split the input to fit Cuda memory for large resolution.
10
+ Can decrease the value of n_pixels in case of cuda out of memory error.
11
+ '''
12
+
13
+ split = []
14
+
15
+ for i, indx in enumerate(torch.split(torch.arange(total_pixels).cuda(), n_pixels, dim=0)):
16
+ data = model_input.copy()
17
+ data['uv'] = torch.index_select(model_input['uv'], 1, indx)
18
+ split.append(data)
19
+ return split
20
+
21
+
22
+ def merge_output(res, total_pixels, batch_size):
23
+ ''' Merge the split output. '''
24
+
25
+ model_outputs = {}
26
+ for entry in res[0]:
27
+ if res[0][entry] is None:
28
+ continue
29
+ if len(res[0][entry].shape) == 1:
30
+ model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, 1) for r in res],
31
+ 1).reshape(batch_size * total_pixels)
32
+ else:
33
+ model_outputs[entry] = torch.cat([r[entry].reshape(batch_size, -1, r[entry].shape[-1]) for r in res],
34
+ 1).reshape(batch_size * total_pixels, -1)
35
+ return model_outputs
36
+
37
+
38
+ def get_psnr(img1, img2, normalize_rgb=False):
39
+ if normalize_rgb: # [-1,1] --> [0,1]
40
+ img1 = (img1 + 1.) / 2.
41
+ img2 = (img2 + 1. ) / 2.
42
+
43
+ mse = torch.mean((img1 - img2) ** 2)
44
+ psnr = -10. * torch.log(mse) / torch.log(torch.Tensor([10.]).cuda())
45
+
46
+ return psnr
47
+
48
+
49
+ def load_K_Rt_from_P(filename, P=None):
50
+ if P is None:
51
+ lines = open(filename).read().splitlines()
52
+ if len(lines) == 4:
53
+ lines = lines[1:]
54
+ lines = [[x[0], x[1], x[2], x[3]] for x in (x.split(" ") for x in lines)]
55
+ P = np.asarray(lines).astype(np.float32).squeeze()
56
+
57
+ out = cv2.decomposeProjectionMatrix(P)
58
+ K = out[0]
59
+ R = out[1]
60
+ t = out[2]
61
+
62
+ K = K/K[2,2]
63
+ intrinsics = np.eye(4)
64
+ intrinsics[:3, :3] = K
65
+
66
+ pose = np.eye(4, dtype=np.float32)
67
+ pose[:3, :3] = R.transpose()
68
+ pose[:3,3] = (t[:3] / t[3])[:,0]
69
+
70
+ return intrinsics, pose
71
+
72
+
73
+ def get_camera_params(uv, pose, intrinsics):
74
+ if pose.shape[1] == 7: #In case of quaternion vector representation
75
+ cam_loc = pose[:, 4:]
76
+ R = quat_to_rot(pose[:,:4])
77
+ p = torch.eye(4).repeat(pose.shape[0],1,1).cuda().float()
78
+ p[:, :3, :3] = R
79
+ p[:, :3, 3] = cam_loc
80
+ else: # In case of pose matrix representation
81
+ cam_loc = pose[:, :3, 3]
82
+ p = pose
83
+
84
+ batch_size, num_samples, _ = uv.shape
85
+
86
+ depth = torch.ones((batch_size, num_samples)).cuda()
87
+ x_cam = uv[:, :, 0].view(batch_size, -1)
88
+ y_cam = uv[:, :, 1].view(batch_size, -1)
89
+ z_cam = depth.view(batch_size, -1)
90
+
91
+ pixel_points_cam = lift(x_cam, y_cam, z_cam, intrinsics=intrinsics)
92
+
93
+ # permute for batch matrix product
94
+ pixel_points_cam = pixel_points_cam.permute(0, 2, 1)
95
+
96
+ world_coords = torch.bmm(p, pixel_points_cam).permute(0, 2, 1)[:, :, :3]
97
+ ray_dirs = world_coords - cam_loc[:, None, :]
98
+ ray_dirs = F.normalize(ray_dirs, dim=2)
99
+
100
+ return ray_dirs, cam_loc
101
+
102
+ def lift(x, y, z, intrinsics):
103
+ # parse intrinsics
104
+ intrinsics = intrinsics.cuda()
105
+ fx = intrinsics[:, 0, 0]
106
+ fy = intrinsics[:, 1, 1]
107
+ cx = intrinsics[:, 0, 2]
108
+ cy = intrinsics[:, 1, 2]
109
+ sk = intrinsics[:, 0, 1]
110
+
111
+ x_lift = (x - cx.unsqueeze(-1) + cy.unsqueeze(-1)*sk.unsqueeze(-1)/fy.unsqueeze(-1) - sk.unsqueeze(-1)*y/fy.unsqueeze(-1)) / fx.unsqueeze(-1) * z
112
+ y_lift = (y - cy.unsqueeze(-1)) / fy.unsqueeze(-1) * z
113
+
114
+ # homogeneous
115
+ return torch.stack((x_lift, y_lift, z, torch.ones_like(z).cuda()), dim=-1)
116
+
117
+
118
+ def quat_to_rot(q):
119
+ batch_size, _ = q.shape
120
+ q = F.normalize(q, dim=1)
121
+ R = torch.ones((batch_size, 3,3)).cuda()
122
+ qr=q[:,0]
123
+ qi = q[:, 1]
124
+ qj = q[:, 2]
125
+ qk = q[:, 3]
126
+ R[:, 0, 0]=1-2 * (qj**2 + qk**2)
127
+ R[:, 0, 1] = 2 * (qj *qi -qk*qr)
128
+ R[:, 0, 2] = 2 * (qi * qk + qr * qj)
129
+ R[:, 1, 0] = 2 * (qj * qi + qk * qr)
130
+ R[:, 1, 1] = 1-2 * (qi**2 + qk**2)
131
+ R[:, 1, 2] = 2*(qj*qk - qi*qr)
132
+ R[:, 2, 0] = 2 * (qk * qi-qj * qr)
133
+ R[:, 2, 1] = 2 * (qj*qk + qi*qr)
134
+ R[:, 2, 2] = 1-2 * (qi**2 + qj**2)
135
+ return R
136
+
137
+
138
+ def rot_to_quat(R):
139
+ batch_size, _,_ = R.shape
140
+ q = torch.ones((batch_size, 4)).cuda()
141
+
142
+ R00 = R[:, 0,0]
143
+ R01 = R[:, 0, 1]
144
+ R02 = R[:, 0, 2]
145
+ R10 = R[:, 1, 0]
146
+ R11 = R[:, 1, 1]
147
+ R12 = R[:, 1, 2]
148
+ R20 = R[:, 2, 0]
149
+ R21 = R[:, 2, 1]
150
+ R22 = R[:, 2, 2]
151
+
152
+ q[:,0]=torch.sqrt(1.0+R00+R11+R22)/2
153
+ q[:, 1]=(R21-R12)/(4*q[:,0])
154
+ q[:, 2] = (R02 - R20) / (4 * q[:, 0])
155
+ q[:, 3] = (R10 - R01) / (4 * q[:, 0])
156
+ return q
157
+
158
+
159
+ def get_sphere_intersections(cam_loc, ray_directions, r = 1.0):
160
+ # Input: n_rays x 3 ; n_rays x 3
161
+ # Output: n_rays x 1, n_rays x 1 (close and far)
162
+
163
+ ray_cam_dot = torch.bmm(ray_directions.view(-1, 1, 3),
164
+ cam_loc.view(-1, 3, 1)).squeeze(-1)
165
+ under_sqrt = ray_cam_dot ** 2 - (cam_loc.norm(2, 1, keepdim=True) ** 2 - r ** 2)
166
+
167
+ # sanity check
168
+ if (under_sqrt <= 0).sum() > 0:
169
+ print('BOUNDING SPHERE PROBLEM!')
170
+ exit()
171
+
172
+ sphere_intersections = torch.sqrt(under_sqrt) * torch.Tensor([-1, 1]).cuda().float() - ray_cam_dot
173
+ sphere_intersections = sphere_intersections.clamp_min(0.0)
174
+
175
+ return sphere_intersections
176
+
177
+ def bilinear_interpolation(xs, ys, dist_map):
178
+ x1 = np.floor(xs).astype(np.int32)
179
+ y1 = np.floor(ys).astype(np.int32)
180
+ x2 = x1 + 1
181
+ y2 = y1 + 1
182
+
183
+ dx = np.expand_dims(np.stack([x2 - xs, xs - x1], axis=1), axis=1)
184
+ dy = np.expand_dims(np.stack([y2 - ys, ys - y1], axis=1), axis=2)
185
+ Q = np.stack([
186
+ dist_map[x1, y1], dist_map[x1, y2], dist_map[x2, y1], dist_map[x2, y2]
187
+ ], axis=1).reshape(-1, 2, 2)
188
+ return np.squeeze(dx @ Q @ dy) # ((x2 - x1) * (y2 - y1)) = 1
189
+
190
+ def get_index_outside_of_bbox(samples_uniform, bbox_min, bbox_max):
191
+ samples_uniform_row = samples_uniform[:, 0]
192
+ samples_uniform_col = samples_uniform[:, 1]
193
+ index_outside = np.where((samples_uniform_row < bbox_min[0]) | (samples_uniform_row > bbox_max[0]) | (samples_uniform_col < bbox_min[1]) | (samples_uniform_col > bbox_max[1]))[0]
194
+ return index_outside
195
+
196
+
197
+ def weighted_sampling(data, img_size, num_sample, bbox_ratio=0.9):
198
+ """
199
+ More sampling within the bounding box
200
+ """
201
+
202
+ # calculate bounding box
203
+ mask = data["object_mask"]
204
+ where = np.asarray(np.where(mask))
205
+ bbox_min = where.min(axis=1)
206
+ bbox_max = where.max(axis=1)
207
+
208
+ num_sample_bbox = int(num_sample * bbox_ratio)
209
+ samples_bbox = np.random.rand(num_sample_bbox, 2)
210
+ samples_bbox = samples_bbox * (bbox_max - bbox_min) + bbox_min
211
+
212
+ num_sample_uniform = num_sample - num_sample_bbox
213
+ samples_uniform = np.random.rand(num_sample_uniform, 2)
214
+ samples_uniform *= (img_size[0] - 1, img_size[1] - 1)
215
+
216
+ # get indices for uniform samples outside of bbox
217
+ index_outside = get_index_outside_of_bbox(samples_uniform, bbox_min, bbox_max) + num_sample_bbox
218
+
219
+ indices = np.concatenate([samples_bbox, samples_uniform], axis=0)
220
+ output = {}
221
+ for key, val in data.items():
222
+ if len(val.shape) == 3:
223
+ new_val = np.stack([
224
+ bilinear_interpolation(indices[:, 0], indices[:, 1], val[:, :, i])
225
+ for i in range(val.shape[2])
226
+ ], axis=-1)
227
+ else:
228
+ new_val = bilinear_interpolation(indices[:, 0], indices[:, 1], val)
229
+ new_val = new_val.reshape(-1, *val.shape[2:])
230
+ output[key] = new_val
231
+
232
+ return output, index_outside
code/setup.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2020 The TensorFlow Authors
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # https://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ """Set-up script for installing extension modules."""
15
+ from Cython.Build import cythonize
16
+ import numpy
17
+ from setuptools import Extension
18
+ from setuptools import setup
19
+
20
+ # Get the numpy include directory.
21
+ numpy_include_dir = numpy.get_include()
22
+
23
+ # mise (efficient mesh extraction)
24
+ mise_module = Extension(
25
+ "lib.libmise.mise",
26
+ sources=["lib/libmise/mise.pyx"],
27
+ )
28
+
29
+ # Gather all extension modules
30
+ ext_modules = [
31
+ mise_module,
32
+ ]
33
+
34
+ setup(ext_modules=cythonize(ext_modules),)
code/test.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from v2a_model import V2AModel
2
+ from lib.datasets import create_dataset
3
+ import hydra
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.loggers import WandbLogger
6
+ import os
7
+ import glob
8
+
9
+ @hydra.main(config_path="confs", config_name="base")
10
+ def main(opt):
11
+ pl.seed_everything(42)
12
+ print("Working dir:", os.getcwd())
13
+
14
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
15
+ dirpath="checkpoints/",
16
+ filename="{epoch:04d}-{loss}",
17
+ save_on_train_epoch_end=True,
18
+ save_last=True)
19
+ logger = WandbLogger(project=opt.project_name, name=f"{opt.exp}/{opt.run}")
20
+
21
+ trainer = pl.Trainer(
22
+ gpus=1,
23
+ accelerator="gpu",
24
+ callbacks=[checkpoint_callback],
25
+ max_epochs=8000,
26
+ check_val_every_n_epoch=50,
27
+ logger=logger,
28
+ log_every_n_steps=1,
29
+ num_sanity_val_steps=0
30
+ )
31
+
32
+ model = V2AModel(opt)
33
+ checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1]
34
+ testset = create_dataset(opt.dataset.metainfo, opt.dataset.test)
35
+
36
+ trainer.test(model, testset, ckpt_path=checkpoint)
37
+
38
+ if __name__ == '__main__':
39
+ main()
code/train.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from v2a_model import V2AModel
2
+ from lib.datasets import create_dataset
3
+ import hydra
4
+ import pytorch_lightning as pl
5
+ from pytorch_lightning.loggers import WandbLogger
6
+ import os
7
+ import glob
8
+
9
+ @hydra.main(config_path="confs", config_name="base")
10
+ def main(opt):
11
+ pl.seed_everything(42)
12
+ print("Working dir:", os.getcwd())
13
+
14
+ checkpoint_callback = pl.callbacks.ModelCheckpoint(
15
+ dirpath="checkpoints/",
16
+ filename="{epoch:04d}-{loss}",
17
+ save_on_train_epoch_end=True,
18
+ save_last=True)
19
+ logger = WandbLogger(project=opt.project_name, name=f"{opt.exp}/{opt.run}")
20
+
21
+ trainer = pl.Trainer(
22
+ gpus=1,
23
+ accelerator="gpu",
24
+ callbacks=[checkpoint_callback],
25
+ max_epochs=8000,
26
+ check_val_every_n_epoch=50,
27
+ logger=logger,
28
+ log_every_n_steps=1,
29
+ num_sanity_val_steps=0
30
+ )
31
+
32
+
33
+ model = V2AModel(opt)
34
+ trainset = create_dataset(opt.dataset.metainfo, opt.dataset.train)
35
+ validset = create_dataset(opt.dataset.metainfo, opt.dataset.valid)
36
+
37
+ if opt.model.is_continue == True:
38
+ checkpoint = sorted(glob.glob("checkpoints/*.ckpt"))[-1]
39
+ trainer.fit(model, trainset, validset, ckpt_path=checkpoint)
40
+ else:
41
+ trainer.fit(model, trainset, validset)
42
+
43
+
44
+ if __name__ == '__main__':
45
+ main()
code/v2a_model.py ADDED
@@ -0,0 +1,311 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pytorch_lightning as pl
2
+ import torch.optim as optim
3
+ from lib.model.v2a import V2A
4
+ from lib.model.body_model_params import BodyModelParams
5
+ from lib.model.deformer import SMPLDeformer
6
+ import cv2
7
+ import torch
8
+ from lib.model.loss import Loss
9
+ import hydra
10
+ import os
11
+ import numpy as np
12
+ from lib.utils.meshing import generate_mesh
13
+ from kaolin.ops.mesh import index_vertices_by_faces
14
+ import trimesh
15
+ from lib.model.deformer import skinning
16
+ from lib.utils import utils
17
+ class V2AModel(pl.LightningModule):
18
+ def __init__(self, opt) -> None:
19
+ super().__init__()
20
+
21
+ self.opt = opt
22
+ num_training_frames = opt.dataset.metainfo.end_frame - opt.dataset.metainfo.start_frame
23
+ self.betas_path = os.path.join(hydra.utils.to_absolute_path('..'), 'data', opt.dataset.metainfo.data_dir, 'mean_shape.npy')
24
+ self.gender = opt.dataset.metainfo.gender
25
+ self.model = V2A(opt.model, self.betas_path, self.gender, num_training_frames)
26
+ self.start_frame = opt.dataset.metainfo.start_frame
27
+ self.end_frame = opt.dataset.metainfo.end_frame
28
+ self.training_modules = ["model"]
29
+
30
+ self.training_indices = list(range(self.start_frame, self.end_frame))
31
+ self.body_model_params = BodyModelParams(num_training_frames, model_type='smpl')
32
+ self.load_body_model_params()
33
+ optim_params = self.body_model_params.param_names
34
+ for param_name in optim_params:
35
+ self.body_model_params.set_requires_grad(param_name, requires_grad=True)
36
+ self.training_modules += ['body_model_params']
37
+
38
+ self.loss = Loss(opt.model.loss)
39
+
40
+ def load_body_model_params(self):
41
+ body_model_params = {param_name: [] for param_name in self.body_model_params.param_names}
42
+ data_root = os.path.join('../data', self.opt.dataset.metainfo.data_dir)
43
+ data_root = hydra.utils.to_absolute_path(data_root)
44
+
45
+ body_model_params['betas'] = torch.tensor(np.load(os.path.join(data_root, 'mean_shape.npy'))[None], dtype=torch.float32)
46
+ body_model_params['global_orient'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices][:, :3], dtype=torch.float32)
47
+ body_model_params['body_pose'] = torch.tensor(np.load(os.path.join(data_root, 'poses.npy'))[self.training_indices] [:, 3:], dtype=torch.float32)
48
+ body_model_params['transl'] = torch.tensor(np.load(os.path.join(data_root, 'normalize_trans.npy'))[self.training_indices], dtype=torch.float32)
49
+
50
+ for param_name in body_model_params.keys():
51
+ self.body_model_params.init_parameters(param_name, body_model_params[param_name], requires_grad=False)
52
+
53
+ def configure_optimizers(self):
54
+ params = [{'params': self.model.parameters(), 'lr':self.opt.model.learning_rate}]
55
+ params.append({'params': self.body_model_params.parameters(), 'lr':self.opt.model.learning_rate*0.1})
56
+ self.optimizer = optim.Adam(params, lr=self.opt.model.learning_rate, eps=1e-8)
57
+ self.scheduler = optim.lr_scheduler.MultiStepLR(
58
+ self.optimizer, milestones=self.opt.model.sched_milestones, gamma=self.opt.model.sched_factor)
59
+ return [self.optimizer], [self.scheduler]
60
+
61
+ def training_step(self, batch):
62
+ inputs, targets = batch
63
+
64
+ batch_idx = inputs["idx"]
65
+
66
+ body_model_params = self.body_model_params(batch_idx)
67
+ inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)
68
+ inputs['smpl_shape'] = body_model_params['betas']
69
+ inputs['smpl_trans'] = body_model_params['transl']
70
+
71
+ inputs['current_epoch'] = self.current_epoch
72
+ model_outputs = self.model(inputs)
73
+
74
+ loss_output = self.loss(model_outputs, targets)
75
+ for k, v in loss_output.items():
76
+ if k in ["loss"]:
77
+ self.log(k, v.item(), prog_bar=True, on_step=True)
78
+ else:
79
+ self.log(k, v.item(), prog_bar=True, on_step=True)
80
+ return loss_output["loss"]
81
+
82
+ def training_epoch_end(self, outputs) -> None:
83
+ # Canonical mesh update every 20 epochs
84
+ if self.current_epoch != 0 and self.current_epoch % 20 == 0:
85
+ cond = {'smpl': torch.zeros(1, 69).float().cuda()}
86
+ mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=2)
87
+ self.model.mesh_v_cano = torch.tensor(mesh_canonical.vertices[None], device = self.model.smpl_v_cano.device).float()
88
+ self.model.mesh_f_cano = torch.tensor(mesh_canonical.faces.astype(np.int64), device=self.model.smpl_v_cano.device)
89
+ self.model.mesh_face_vertices = index_vertices_by_faces(self.model.mesh_v_cano, self.model.mesh_f_cano)
90
+ return super().training_epoch_end(outputs)
91
+
92
+ def query_oc(self, x, cond):
93
+
94
+ x = x.reshape(-1, 3)
95
+ mnfld_pred = self.model.implicit_network(x, cond)[:,:,0].reshape(-1,1)
96
+ return {'sdf':mnfld_pred}
97
+
98
+ def query_wc(self, x):
99
+
100
+ x = x.reshape(-1, 3)
101
+ w = self.model.deformer.query_weights(x)
102
+
103
+ return w
104
+
105
+ def query_od(self, x, cond, smpl_tfs, smpl_verts):
106
+
107
+ x = x.reshape(-1, 3)
108
+ x_c, _ = self.model.deformer.forward(x, smpl_tfs, return_weights=False, inverse=True, smpl_verts=smpl_verts)
109
+ output = self.model.implicit_network(x_c, cond)[0]
110
+ sdf = output[:, 0:1]
111
+
112
+ return {'sdf': sdf}
113
+
114
+ def get_deformed_mesh_fast_mode(self, verts, smpl_tfs):
115
+ verts = torch.tensor(verts).cuda().float()
116
+ weights = self.model.deformer.query_weights(verts)
117
+ verts_deformed = skinning(verts.unsqueeze(0), weights, smpl_tfs).data.cpu().numpy()[0]
118
+ return verts_deformed
119
+
120
+ def validation_step(self, batch, *args, **kwargs):
121
+
122
+ output = {}
123
+ inputs, targets = batch
124
+ inputs['current_epoch'] = self.current_epoch
125
+ self.model.eval()
126
+
127
+ body_model_params = self.body_model_params(inputs['image_id'])
128
+ inputs['smpl_pose'] = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)
129
+ inputs['smpl_shape'] = body_model_params['betas']
130
+ inputs['smpl_trans'] = body_model_params['transl']
131
+
132
+ cond = {'smpl': inputs["smpl_pose"][:, 3:]/np.pi}
133
+ mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=3)
134
+
135
+ mesh_canonical = trimesh.Trimesh(mesh_canonical.vertices, mesh_canonical.faces)
136
+
137
+ output.update({
138
+ 'canonical_mesh':mesh_canonical
139
+ })
140
+
141
+ split = utils.split_input(inputs, targets["total_pixels"][0], n_pixels=min(targets['pixel_per_batch'], targets["img_size"][0] * targets["img_size"][1]))
142
+
143
+ res = []
144
+ for s in split:
145
+
146
+ out = self.model(s)
147
+
148
+ for k, v in out.items():
149
+ try:
150
+ out[k] = v.detach()
151
+ except:
152
+ out[k] = v
153
+
154
+ res.append({
155
+ 'rgb_values': out['rgb_values'].detach(),
156
+ 'normal_values': out['normal_values'].detach(),
157
+ 'fg_rgb_values': out['fg_rgb_values'].detach(),
158
+ })
159
+ batch_size = targets['rgb'].shape[0]
160
+
161
+ model_outputs = utils.merge_output(res, targets["total_pixels"][0], batch_size)
162
+
163
+ output.update({
164
+ "rgb_values": model_outputs["rgb_values"].detach().clone(),
165
+ "normal_values": model_outputs["normal_values"].detach().clone(),
166
+ "fg_rgb_values": model_outputs["fg_rgb_values"].detach().clone(),
167
+ **targets,
168
+ })
169
+
170
+ return output
171
+
172
+ def validation_step_end(self, batch_parts):
173
+ return batch_parts
174
+
175
+ def validation_epoch_end(self, outputs) -> None:
176
+ img_size = outputs[0]["img_size"]
177
+
178
+ rgb_pred = torch.cat([output["rgb_values"] for output in outputs], dim=0)
179
+ rgb_pred = rgb_pred.reshape(*img_size, -1)
180
+
181
+ fg_rgb_pred = torch.cat([output["fg_rgb_values"] for output in outputs], dim=0)
182
+ fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1)
183
+
184
+ normal_pred = torch.cat([output["normal_values"] for output in outputs], dim=0)
185
+ normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2
186
+
187
+ rgb_gt = torch.cat([output["rgb"] for output in outputs], dim=1).squeeze(0)
188
+ rgb_gt = rgb_gt.reshape(*img_size, -1)
189
+ if 'normal' in outputs[0].keys():
190
+ normal_gt = torch.cat([output["normal"] for output in outputs], dim=1).squeeze(0)
191
+ normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2
192
+ normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy()
193
+ else:
194
+ normal = torch.cat([normal_pred], dim=0).cpu().numpy()
195
+
196
+ rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy()
197
+ rgb = (rgb * 255).astype(np.uint8)
198
+
199
+ fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy()
200
+ fg_rgb = (fg_rgb * 255).astype(np.uint8)
201
+
202
+ normal = (normal * 255).astype(np.uint8)
203
+
204
+ os.makedirs("rendering", exist_ok=True)
205
+ os.makedirs("normal", exist_ok=True)
206
+ os.makedirs('fg_rendering', exist_ok=True)
207
+
208
+ canonical_mesh = outputs[0]['canonical_mesh']
209
+ canonical_mesh.export(f"rendering/{self.current_epoch}.ply")
210
+
211
+ cv2.imwrite(f"rendering/{self.current_epoch}.png", rgb[:, :, ::-1])
212
+ cv2.imwrite(f"normal/{self.current_epoch}.png", normal[:, :, ::-1])
213
+ cv2.imwrite(f"fg_rendering/{self.current_epoch}.png", fg_rgb[:, :, ::-1])
214
+
215
+ def test_step(self, batch, *args, **kwargs):
216
+ inputs, targets, pixel_per_batch, total_pixels, idx = batch
217
+ num_splits = (total_pixels + pixel_per_batch -
218
+ 1) // pixel_per_batch
219
+ results = []
220
+
221
+ scale, smpl_trans, smpl_pose, smpl_shape = torch.split(inputs["smpl_params"], [1, 3, 72, 10], dim=1)
222
+
223
+ body_model_params = self.body_model_params(inputs['idx'])
224
+ smpl_shape = body_model_params['betas'] if body_model_params['betas'].dim() == 2 else body_model_params['betas'].unsqueeze(0)
225
+ smpl_trans = body_model_params['transl']
226
+ smpl_pose = torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)
227
+
228
+ smpl_outputs = self.model.smpl_server(scale, smpl_trans, smpl_pose, smpl_shape)
229
+ smpl_tfs = smpl_outputs['smpl_tfs']
230
+ cond = {'smpl': smpl_pose[:, 3:]/np.pi}
231
+
232
+ mesh_canonical = generate_mesh(lambda x: self.query_oc(x, cond), self.model.smpl_server.verts_c[0], point_batch=10000, res_up=4)
233
+ self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender, K=7)
234
+ verts_deformed = self.get_deformed_mesh_fast_mode(mesh_canonical.vertices, smpl_tfs)
235
+ mesh_deformed = trimesh.Trimesh(vertices=verts_deformed, faces=mesh_canonical.faces, process=False)
236
+
237
+ os.makedirs("test_mask", exist_ok=True)
238
+ os.makedirs("test_rendering", exist_ok=True)
239
+ os.makedirs("test_fg_rendering", exist_ok=True)
240
+ os.makedirs("test_normal", exist_ok=True)
241
+ os.makedirs("test_mesh", exist_ok=True)
242
+
243
+ mesh_canonical.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_canonical.ply")
244
+ mesh_deformed.export(f"test_mesh/{int(idx.cpu().numpy()):04d}_deformed.ply")
245
+ self.model.deformer = SMPLDeformer(betas=np.load(self.betas_path), gender=self.gender)
246
+ for i in range(num_splits):
247
+ indices = list(range(i * pixel_per_batch,
248
+ min((i + 1) * pixel_per_batch, total_pixels)))
249
+ batch_inputs = {"uv": inputs["uv"][:, indices],
250
+ "intrinsics": inputs['intrinsics'],
251
+ "pose": inputs['pose'],
252
+ "smpl_params": inputs["smpl_params"],
253
+ "smpl_pose": inputs["smpl_params"][:, 4:76],
254
+ "smpl_shape": inputs["smpl_params"][:, 76:],
255
+ "smpl_trans": inputs["smpl_params"][:, 1:4],
256
+ "idx": inputs["idx"] if 'idx' in inputs.keys() else None}
257
+
258
+ body_model_params = self.body_model_params(inputs['idx'])
259
+
260
+ batch_inputs.update({'smpl_pose': torch.cat((body_model_params['global_orient'], body_model_params['body_pose']), dim=1)})
261
+ batch_inputs.update({'smpl_shape': body_model_params['betas']})
262
+ batch_inputs.update({'smpl_trans': body_model_params['transl']})
263
+
264
+ batch_targets = {"rgb": targets["rgb"][:, indices].detach().clone() if 'rgb' in targets.keys() else None,
265
+ "img_size": targets["img_size"]}
266
+
267
+ with torch.no_grad():
268
+ model_outputs = self.model(batch_inputs)
269
+ results.append({"rgb_values":model_outputs["rgb_values"].detach().clone(),
270
+ "fg_rgb_values":model_outputs["fg_rgb_values"].detach().clone(),
271
+ "normal_values": model_outputs["normal_values"].detach().clone(),
272
+ "acc_map": model_outputs["acc_map"].detach().clone(),
273
+ **batch_targets})
274
+
275
+ img_size = results[0]["img_size"]
276
+ rgb_pred = torch.cat([result["rgb_values"] for result in results], dim=0)
277
+ rgb_pred = rgb_pred.reshape(*img_size, -1)
278
+
279
+ fg_rgb_pred = torch.cat([result["fg_rgb_values"] for result in results], dim=0)
280
+ fg_rgb_pred = fg_rgb_pred.reshape(*img_size, -1)
281
+
282
+ normal_pred = torch.cat([result["normal_values"] for result in results], dim=0)
283
+ normal_pred = (normal_pred.reshape(*img_size, -1) + 1) / 2
284
+
285
+ pred_mask = torch.cat([result["acc_map"] for result in results], dim=0)
286
+ pred_mask = pred_mask.reshape(*img_size, -1)
287
+
288
+ if results[0]['rgb'] is not None:
289
+ rgb_gt = torch.cat([result["rgb"] for result in results], dim=1).squeeze(0)
290
+ rgb_gt = rgb_gt.reshape(*img_size, -1)
291
+ rgb = torch.cat([rgb_gt, rgb_pred], dim=0).cpu().numpy()
292
+ else:
293
+ rgb = torch.cat([rgb_pred], dim=0).cpu().numpy()
294
+ if 'normal' in results[0].keys():
295
+ normal_gt = torch.cat([result["normal"] for result in results], dim=1).squeeze(0)
296
+ normal_gt = (normal_gt.reshape(*img_size, -1) + 1) / 2
297
+ normal = torch.cat([normal_gt, normal_pred], dim=0).cpu().numpy()
298
+ else:
299
+ normal = torch.cat([normal_pred], dim=0).cpu().numpy()
300
+
301
+ rgb = (rgb * 255).astype(np.uint8)
302
+
303
+ fg_rgb = torch.cat([fg_rgb_pred], dim=0).cpu().numpy()
304
+ fg_rgb = (fg_rgb * 255).astype(np.uint8)
305
+
306
+ normal = (normal * 255).astype(np.uint8)
307
+
308
+ cv2.imwrite(f"test_mask/{int(idx.cpu().numpy()):04d}.png", pred_mask.cpu().numpy() * 255)
309
+ cv2.imwrite(f"test_rendering/{int(idx.cpu().numpy()):04d}.png", rgb[:, :, ::-1])
310
+ cv2.imwrite(f"test_normal/{int(idx.cpu().numpy()):04d}.png", normal[:, :, ::-1])
311
+ cv2.imwrite(f"test_fg_rendering/{int(idx.cpu().numpy()):04d}.png", fg_rgb[:, :, ::-1])
data/parkinglot/cameras.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef3e45c7a142677332eda018e986d156f8a496e90a84308d395cbb44c9c0f686
3
+ size 15626
data/parkinglot/cameras_normalize.npz ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cc2622c844d57ee385a6f2e85c89d70a27c2181d397f7556e562f4633267e935
3
+ size 29550
data/parkinglot/checkpoints/epoch=6299-loss=0.01887552998960018.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b82917e0738e2a3a297870275480b77d324fb08e2eacfec35b7121ba5d4e8f28
3
+ size 36504439
data/parkinglot/image/0000.png ADDED

Git LFS Details

  • SHA256: d35c55dcecdba64f390dbd4d293102982b00b0523084830dd71a584ba139218f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
data/parkinglot/image/0001.png ADDED

Git LFS Details

  • SHA256: bb4671faecbd05c7155754253b4ecc79645968601632bd35694ebeffa5fce9e4
  • Pointer size: 132 Bytes
  • Size of remote file: 1.37 MB
data/parkinglot/image/0002.png ADDED

Git LFS Details

  • SHA256: a6a06640174ea8a557be42e3fe1af535556fd1085f5b78bfc1d7a215f49bfb6b
  • Pointer size: 132 Bytes
  • Size of remote file: 1.38 MB
data/parkinglot/image/0003.png ADDED

Git LFS Details

  • SHA256: 760bcc5aa1b295464b56eb532bed40051ca8df36239093fa4da4569ba5242ad1
  • Pointer size: 132 Bytes
  • Size of remote file: 1.35 MB
data/parkinglot/image/0004.png ADDED

Git LFS Details

  • SHA256: 46f71622950d7707a1490b614d2e8bc3010e343a7dce2e6bed8f681b441c6410
  • Pointer size: 132 Bytes
  • Size of remote file: 1.34 MB