Luigi Piccinelli commited on
Commit
1ea89dd
·
1 Parent(s): 6b96309
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +16 -0
  2. LICENSE +407 -0
  3. README.md +1 -0
  4. app.py +800 -0
  5. assets/demo/bears.jpg +0 -0
  6. assets/demo/berzirk.jpg +0 -0
  7. assets/demo/dl3dv.json +4 -0
  8. assets/demo/dl3dv.png +0 -0
  9. assets/demo/equirectangular.jpg +0 -0
  10. assets/demo/kitti360.json +14 -0
  11. assets/demo/kitti360.png +0 -0
  12. assets/demo/luke.webp +0 -0
  13. assets/demo/naruto.jpg +0 -0
  14. assets/demo/poorthings.jpg +0 -0
  15. assets/demo/scannet.jpg +0 -0
  16. assets/demo/scannet.json +21 -0
  17. assets/demo/venice.jpg +0 -0
  18. assets/docs/unik3d-banner.png +0 -0
  19. assets/docs/unik3d-teaser.png +0 -0
  20. configs/config_vitb.json +159 -0
  21. configs/config_vitl.json +159 -0
  22. configs/config_vits.json +159 -0
  23. gradio_demo.py +796 -0
  24. hubconf.py +29 -0
  25. pyproject.toml +25 -0
  26. requirements.txt +84 -0
  27. requirements_demo.txt +84 -0
  28. scripts/README.md +55 -0
  29. scripts/demo.py +150 -0
  30. scripts/train.py +630 -0
  31. unik3d/__init__.py +1 -0
  32. unik3d/datasets/_2d3ds.py +67 -0
  33. unik3d/datasets/_4dor.py +52 -0
  34. unik3d/datasets/__init__.py +161 -0
  35. unik3d/datasets/a2d2.py +78 -0
  36. unik3d/datasets/adt.py +68 -0
  37. unik3d/datasets/aimotive.py +51 -0
  38. unik3d/datasets/argoverse.py +73 -0
  39. unik3d/datasets/argoverse2.py +49 -0
  40. unik3d/datasets/arkit.py +49 -0
  41. unik3d/datasets/ase.py +66 -0
  42. unik3d/datasets/base_dataset.py +344 -0
  43. unik3d/datasets/bdd.py +82 -0
  44. unik3d/datasets/bedlam.py +50 -0
  45. unik3d/datasets/behave.py +52 -0
  46. unik3d/datasets/blendedmvg.py +50 -0
  47. unik3d/datasets/cityscape.py +78 -0
  48. unik3d/datasets/ddad.py +84 -0
  49. unik3d/datasets/deep360.py +56 -0
  50. unik3d/datasets/dense.py +91 -0
.gitignore ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ **/__pycache__/
2
+ **/build/
3
+ **/dist/
4
+ **/*egg-info
5
+ .gradio/
6
+
7
+ # ignore scripts
8
+ _*.sh
9
+ __*.png
10
+ __*.jpg
11
+ __*.webp
12
+ ___*.py
13
+ **/___*.py
14
+
15
+ # ignore pcds
16
+ *.ply
LICENSE ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial 4.0 International Public
58
+ License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial 4.0 International Public License ("Public
63
+ License"). To the extent this Public License may be interpreted as a
64
+ contract, You are granted the Licensed Rights in consideration of Your
65
+ acceptance of these terms and conditions, and the Licensor grants You
66
+ such rights in consideration of benefits the Licensor receives from
67
+ making the Licensed Material available under these terms and
68
+ conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. Copyright and Similar Rights means copyright and/or similar rights
88
+ closely related to copyright including, without limitation,
89
+ performance, broadcast, sound recording, and Sui Generis Database
90
+ Rights, without regard to how the rights are labeled or
91
+ categorized. For purposes of this Public License, the rights
92
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
93
+ Rights.
94
+ d. Effective Technological Measures means those measures that, in the
95
+ absence of proper authority, may not be circumvented under laws
96
+ fulfilling obligations under Article 11 of the WIPO Copyright
97
+ Treaty adopted on December 20, 1996, and/or similar international
98
+ agreements.
99
+
100
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
101
+ any other exception or limitation to Copyright and Similar Rights
102
+ that applies to Your use of the Licensed Material.
103
+
104
+ f. Licensed Material means the artistic or literary work, database,
105
+ or other material to which the Licensor applied this Public
106
+ License.
107
+
108
+ g. Licensed Rights means the rights granted to You subject to the
109
+ terms and conditions of this Public License, which are limited to
110
+ all Copyright and Similar Rights that apply to Your use of the
111
+ Licensed Material and that the Licensor has authority to license.
112
+
113
+ h. Licensor means the individual(s) or entity(ies) granting rights
114
+ under this Public License.
115
+
116
+ i. NonCommercial means not primarily intended for or directed towards
117
+ commercial advantage or monetary compensation. For purposes of
118
+ this Public License, the exchange of the Licensed Material for
119
+ other material subject to Copyright and Similar Rights by digital
120
+ file-sharing or similar means is NonCommercial provided there is
121
+ no payment of monetary compensation in connection with the
122
+ exchange.
123
+
124
+ j. Share means to provide material to the public by any means or
125
+ process that requires permission under the Licensed Rights, such
126
+ as reproduction, public display, public performance, distribution,
127
+ dissemination, communication, or importation, and to make material
128
+ available to the public including in ways that members of the
129
+ public may access the material from a place and at a time
130
+ individually chosen by them.
131
+
132
+ k. Sui Generis Database Rights means rights other than copyright
133
+ resulting from Directive 96/9/EC of the European Parliament and of
134
+ the Council of 11 March 1996 on the legal protection of databases,
135
+ as amended and/or succeeded, as well as other essentially
136
+ equivalent rights anywhere in the world.
137
+
138
+ l. You means the individual or entity exercising the Licensed Rights
139
+ under this Public License. Your has a corresponding meaning.
140
+
141
+
142
+ Section 2 -- Scope.
143
+
144
+ a. License grant.
145
+
146
+ 1. Subject to the terms and conditions of this Public License,
147
+ the Licensor hereby grants You a worldwide, royalty-free,
148
+ non-sublicensable, non-exclusive, irrevocable license to
149
+ exercise the Licensed Rights in the Licensed Material to:
150
+
151
+ a. reproduce and Share the Licensed Material, in whole or
152
+ in part, for NonCommercial purposes only; and
153
+
154
+ b. produce, reproduce, and Share Adapted Material for
155
+ NonCommercial purposes only.
156
+
157
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
158
+ Exceptions and Limitations apply to Your use, this Public
159
+ License does not apply, and You do not need to comply with
160
+ its terms and conditions.
161
+
162
+ 3. Term. The term of this Public License is specified in Section
163
+ 6(a).
164
+
165
+ 4. Media and formats; technical modifications allowed. The
166
+ Licensor authorizes You to exercise the Licensed Rights in
167
+ all media and formats whether now known or hereafter created,
168
+ and to make technical modifications necessary to do so. The
169
+ Licensor waives and/or agrees not to assert any right or
170
+ authority to forbid You from making technical modifications
171
+ necessary to exercise the Licensed Rights, including
172
+ technical modifications necessary to circumvent Effective
173
+ Technological Measures. For purposes of this Public License,
174
+ simply making modifications authorized by this Section 2(a)
175
+ (4) never produces Adapted Material.
176
+
177
+ 5. Downstream recipients.
178
+
179
+ a. Offer from the Licensor -- Licensed Material. Every
180
+ recipient of the Licensed Material automatically
181
+ receives an offer from the Licensor to exercise the
182
+ Licensed Rights under the terms and conditions of this
183
+ Public License.
184
+
185
+ b. No downstream restrictions. You may not offer or impose
186
+ any additional or different terms or conditions on, or
187
+ apply any Effective Technological Measures to, the
188
+ Licensed Material if doing so restricts exercise of the
189
+ Licensed Rights by any recipient of the Licensed
190
+ Material.
191
+
192
+ 6. No endorsement. Nothing in this Public License constitutes or
193
+ may be construed as permission to assert or imply that You
194
+ are, or that Your use of the Licensed Material is, connected
195
+ with, or sponsored, endorsed, or granted official status by,
196
+ the Licensor or others designated to receive attribution as
197
+ provided in Section 3(a)(1)(A)(i).
198
+
199
+ b. Other rights.
200
+
201
+ 1. Moral rights, such as the right of integrity, are not
202
+ licensed under this Public License, nor are publicity,
203
+ privacy, and/or other similar personality rights; however, to
204
+ the extent possible, the Licensor waives and/or agrees not to
205
+ assert any such rights held by the Licensor to the limited
206
+ extent necessary to allow You to exercise the Licensed
207
+ Rights, but not otherwise.
208
+
209
+ 2. Patent and trademark rights are not licensed under this
210
+ Public License.
211
+
212
+ 3. To the extent possible, the Licensor waives any right to
213
+ collect royalties from You for the exercise of the Licensed
214
+ Rights, whether directly or through a collecting society
215
+ under any voluntary or waivable statutory or compulsory
216
+ licensing scheme. In all other cases the Licensor expressly
217
+ reserves any right to collect such royalties, including when
218
+ the Licensed Material is used other than for NonCommercial
219
+ purposes.
220
+
221
+
222
+ Section 3 -- License Conditions.
223
+
224
+ Your exercise of the Licensed Rights is expressly made subject to the
225
+ following conditions.
226
+
227
+ a. Attribution.
228
+
229
+ 1. If You Share the Licensed Material (including in modified
230
+ form), You must:
231
+
232
+ a. retain the following if it is supplied by the Licensor
233
+ with the Licensed Material:
234
+
235
+ i. identification of the creator(s) of the Licensed
236
+ Material and any others designated to receive
237
+ attribution, in any reasonable manner requested by
238
+ the Licensor (including by pseudonym if
239
+ designated);
240
+
241
+ ii. a copyright notice;
242
+
243
+ iii. a notice that refers to this Public License;
244
+
245
+ iv. a notice that refers to the disclaimer of
246
+ warranties;
247
+
248
+ v. a URI or hyperlink to the Licensed Material to the
249
+ extent reasonably practicable;
250
+
251
+ b. indicate if You modified the Licensed Material and
252
+ retain an indication of any previous modifications; and
253
+
254
+ c. indicate the Licensed Material is licensed under this
255
+ Public License, and include the text of, or the URI or
256
+ hyperlink to, this Public License.
257
+
258
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
259
+ reasonable manner based on the medium, means, and context in
260
+ which You Share the Licensed Material. For example, it may be
261
+ reasonable to satisfy the conditions by providing a URI or
262
+ hyperlink to a resource that includes the required
263
+ information.
264
+
265
+ 3. If requested by the Licensor, You must remove any of the
266
+ information required by Section 3(a)(1)(A) to the extent
267
+ reasonably practicable.
268
+
269
+ 4. If You Share Adapted Material You produce, the Adapter's
270
+ License You apply must not prevent recipients of the Adapted
271
+ Material from complying with this Public License.
272
+
273
+
274
+ Section 4 -- Sui Generis Database Rights.
275
+
276
+ Where the Licensed Rights include Sui Generis Database Rights that
277
+ apply to Your use of the Licensed Material:
278
+
279
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
280
+ to extract, reuse, reproduce, and Share all or a substantial
281
+ portion of the contents of the database for NonCommercial purposes
282
+ only;
283
+
284
+ b. if You include all or a substantial portion of the database
285
+ contents in a database in which You have Sui Generis Database
286
+ Rights, then the database in which You have Sui Generis Database
287
+ Rights (but not its individual contents) is Adapted Material; and
288
+
289
+ c. You must comply with the conditions in Section 3(a) if You Share
290
+ all or a substantial portion of the contents of the database.
291
+
292
+ For the avoidance of doubt, this Section 4 supplements and does not
293
+ replace Your obligations under this Public License where the Licensed
294
+ Rights include other Copyright and Similar Rights.
295
+
296
+
297
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
298
+
299
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
300
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
301
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
302
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
303
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
304
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
305
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
306
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
307
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
308
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
309
+
310
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
311
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
312
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
313
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
314
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
315
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
316
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
317
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
318
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
319
+
320
+ c. The disclaimer of warranties and limitation of liability provided
321
+ above shall be interpreted in a manner that, to the extent
322
+ possible, most closely approximates an absolute disclaimer and
323
+ waiver of all liability.
324
+
325
+
326
+ Section 6 -- Term and Termination.
327
+
328
+ a. This Public License applies for the term of the Copyright and
329
+ Similar Rights licensed here. However, if You fail to comply with
330
+ this Public License, then Your rights under this Public License
331
+ terminate automatically.
332
+
333
+ b. Where Your right to use the Licensed Material has terminated under
334
+ Section 6(a), it reinstates:
335
+
336
+ 1. automatically as of the date the violation is cured, provided
337
+ it is cured within 30 days of Your discovery of the
338
+ violation; or
339
+
340
+ 2. upon express reinstatement by the Licensor.
341
+
342
+ For the avoidance of doubt, this Section 6(b) does not affect any
343
+ right the Licensor may have to seek remedies for Your violations
344
+ of this Public License.
345
+
346
+ c. For the avoidance of doubt, the Licensor may also offer the
347
+ Licensed Material under separate terms or conditions or stop
348
+ distributing the Licensed Material at any time; however, doing so
349
+ will not terminate this Public License.
350
+
351
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
352
+ License.
353
+
354
+
355
+ Section 7 -- Other Terms and Conditions.
356
+
357
+ a. The Licensor shall not be bound by any additional or different
358
+ terms or conditions communicated by You unless expressly agreed.
359
+
360
+ b. Any arrangements, understandings, or agreements regarding the
361
+ Licensed Material not stated herein are separate from and
362
+ independent of the terms and conditions of this Public License.
363
+
364
+
365
+ Section 8 -- Interpretation.
366
+
367
+ a. For the avoidance of doubt, this Public License does not, and
368
+ shall not be interpreted to, reduce, limit, restrict, or impose
369
+ conditions on any use of the Licensed Material that could lawfully
370
+ be made without permission under this Public License.
371
+
372
+ b. To the extent possible, if any provision of this Public License is
373
+ deemed unenforceable, it shall be automatically reformed to the
374
+ minimum extent necessary to make it enforceable. If the provision
375
+ cannot be reformed, it shall be severed from this Public License
376
+ without affecting the enforceability of the remaining terms and
377
+ conditions.
378
+
379
+ c. No term or condition of this Public License will be waived and no
380
+ failure to comply consented to unless expressly agreed to by the
381
+ Licensor.
382
+
383
+ d. Nothing in this Public License constitutes or may be interpreted
384
+ as a limitation upon, or waiver of, any privileges and immunities
385
+ that apply to the Licensor or You, including from the legal
386
+ processes of any jurisdiction or authority.
387
+
388
+ =======================================================================
389
+
390
+ Creative Commons is not a party to its public
391
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
392
+ its public licenses to material it publishes and in those instances
393
+ will be considered the “Licensor.” The text of the Creative Commons
394
+ public licenses is dedicated to the public domain under the CC0 Public
395
+ Domain Dedication. Except for the limited purpose of indicating that
396
+ material is shared under a Creative Commons public license or as
397
+ otherwise permitted by the Creative Commons policies published at
398
+ creativecommons.org/policies, Creative Commons does not authorize the
399
+ use of the trademark "Creative Commons" or any other trademark or logo
400
+ of Creative Commons without its prior written consent including,
401
+ without limitation, in connection with any unauthorized modifications
402
+ to any of its public licenses or any other arrangements,
403
+ understandings, or agreements concerning use of licensed material. For
404
+ the avoidance of doubt, this paragraph does not form part of the
405
+ public licenses.
406
+
407
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -8,6 +8,7 @@ sdk_version: 5.22.0
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
 
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
8
  app_file: app.py
9
  pinned: false
10
  license: cc-by-nc-4.0
11
+ short_description: UniK3D (CVPR 2025)
12
  ---
13
 
14
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,800 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import shutil
4
+ import time
5
+ from datetime import datetime
6
+ from math import pi
7
+ import sys
8
+
9
+ import gradio as gr
10
+ import numpy as np
11
+ import torch
12
+ import trimesh
13
+ from PIL import Image
14
+
15
+
16
+ sys.path.append("unik3d/")
17
+
18
+ from unik3d.models import UniK3D
19
+ from unik3d.utils.camera import OPENCV, Fisheye624, Pinhole, Spherical
20
+ from unik3d.utils.visualization import colorize
21
+
22
+
23
+ def predictions_to_glb(
24
+ predictions,
25
+ mask_black_bg=False,
26
+ mask_far_points=False,
27
+ ) -> trimesh.Scene:
28
+ print("Building GLB scene")
29
+ images = predictions["image"].squeeze().permute(1, 2, 0).cpu().numpy()
30
+ world_points = predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy()
31
+
32
+ vertices_3d = world_points.reshape(-1, 3)
33
+ # flip x and y
34
+ vertices_3d[:, 1] *= -1
35
+ vertices_3d[:, 0] *= -1
36
+ colors_rgb = (images.reshape(-1, 3)).astype(np.uint8)
37
+
38
+ if mask_black_bg:
39
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
40
+ vertices_3d = vertices_3d[black_bg_mask]
41
+ colors_rgb = colors_rgb[black_bg_mask]
42
+
43
+ if mask_far_points:
44
+ far_points_mask = np.linalg.norm(vertices_3d, axis=-1) < 100.0
45
+ vertices_3d = vertices_3d[far_points_mask]
46
+ colors_rgb = colors_rgb[far_points_mask]
47
+
48
+ scene_3d = trimesh.Scene()
49
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
50
+ scene_3d.add_geometry(point_cloud_data)
51
+
52
+ return scene_3d
53
+
54
+
55
+ def instantiate_model(model_name):
56
+ type_ = model_name[0].lower()
57
+
58
+ name = f"unik3d-vit{type_}"
59
+ model = UniK3D.from_pretrained(f"lpiccinelli/{name}")
60
+
61
+ # Set resolution level and interpolation mode as specified.
62
+ model.resolution_level = 9
63
+ model.interpolation_mode = "bilinear"
64
+
65
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
66
+ model = model.to(device).eval()
67
+ return model
68
+
69
+
70
+ def instantiate_camera(camera_name, params, device):
71
+ if camera_name == "Predicted":
72
+ return None
73
+ fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov, H, W = params
74
+ if camera_name == "Pinhole":
75
+ params = [fx, fy, cx, cy]
76
+ elif camera_name == "Fisheye624":
77
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
78
+ elif camera_name == "OPENCV":
79
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
80
+ elif camera_name == "Equirectangular":
81
+ # dummy intrinsics for spherical camera, assume hfov -> vfov based on input shapes
82
+ hfov2 = hfov * pi / 180.0 / 2
83
+ params = [fx, fy, cx, cy, W, H, hfov2, H / W * hfov2]
84
+ camera_name = "Spherical"
85
+
86
+ return eval(camera_name)(params=torch.tensor(params).float()).to(device)
87
+
88
+
89
+ def run_model(target_dir, model_name, camera_name, params):
90
+
91
+ print("Instantiating model and camera...")
92
+ model = instantiate_model(model_name)
93
+
94
+ image_names = [x for x in os.listdir(target_dir) if x.endswith(".png")]
95
+ input_image = np.array(Image.open(os.path.join(target_dir, image_names[-1])))
96
+ image_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float()
97
+ device = next(model.parameters()).device
98
+ image_tensor = image_tensor.to(device)
99
+ H, W = image_tensor.shape[-2:]
100
+ params = params + [H, W]
101
+ camera = instantiate_camera(camera_name, params=params, device=device)
102
+
103
+ # Perform inference with the model.
104
+ print("Running inference...")
105
+ outputs = model.infer(image_tensor, camera=camera, normalize=True)
106
+ outputs["image"] = image_tensor
107
+
108
+ return outputs
109
+
110
+
111
+ def gradio_demo(
112
+ target_dir,
113
+ model_name,
114
+ camera_name,
115
+ fx,
116
+ fy,
117
+ cx,
118
+ cy,
119
+ k1,
120
+ k2,
121
+ k3,
122
+ k4,
123
+ k5,
124
+ k6,
125
+ t1,
126
+ t2,
127
+ hfov,
128
+ mask_black_bg,
129
+ mask_far_points,
130
+ ):
131
+ print(target_dir)
132
+ if not os.path.isdir(target_dir) or target_dir == "None":
133
+ return None, "No valid target directory found. Please upload first.", None
134
+
135
+ start_time = time.time()
136
+ gc.collect()
137
+
138
+ print("Running run_model...")
139
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov]
140
+ with torch.no_grad():
141
+ outputs = run_model(target_dir, model_name, camera_name, params)
142
+
143
+ # Save predictions
144
+ points = outputs["points"].squeeze().permute(1, 2, 0).cpu().numpy()
145
+ rgb = outputs["image"].squeeze().permute(1, 2, 0).cpu().numpy()
146
+
147
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
148
+ np.savez(prediction_save_path, {"points": points, "image": rgb})
149
+
150
+ # Build a GLB file name
151
+ glbfile = os.path.join(
152
+ target_dir,
153
+ f"glbscene.glb",
154
+ )
155
+
156
+ # Convert predictions to GLB
157
+ glbscene = predictions_to_glb(
158
+ outputs,
159
+ mask_black_bg=mask_black_bg,
160
+ mask_far_points=mask_far_points,
161
+ )
162
+ glbscene.export(file_obj=glbfile)
163
+
164
+ # Cleanup
165
+ del outputs
166
+ gc.collect()
167
+
168
+ end_time = time.time()
169
+ print(f"Total time: {end_time - start_time:.2f} seconds")
170
+ log_msg = f"Success. Waiting for visualization."
171
+
172
+ return glbfile, log_msg, prediction_save_path
173
+
174
+
175
+ def handle_uploads(input_image):
176
+ gc.collect()
177
+
178
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
179
+ tmpdir = os.environ.get("TMPDIR", "/tmp")
180
+ target_dir = os.path.join(tmpdir, f"input_images_{timestamp}")
181
+
182
+ if os.path.exists(target_dir):
183
+ shutil.rmtree(target_dir)
184
+ os.makedirs(target_dir)
185
+
186
+ dst_path = os.path.join(target_dir, "image.png")
187
+ Image.fromarray(input_image).save(dst_path)
188
+ image_paths = [dst_path]
189
+
190
+ print(f"Files uploaded.")
191
+ return target_dir, image_paths
192
+
193
+
194
+ def update_gallery_on_upload(input_images):
195
+ if input_images is None:
196
+ return None, None
197
+ target_dir, image_path = handle_uploads(input_images)
198
+ return target_dir, "Upload complete. Click 'Run UniK3D' to get 3D pointcloud."
199
+
200
+
201
+ def update_parameters(camera):
202
+ if camera == "Pinhole":
203
+ return (
204
+ gr.update(visible=True), # fx
205
+ gr.update(visible=True), # fy
206
+ gr.update(visible=True), # cx
207
+ gr.update(visible=True), # cy
208
+ gr.update(visible=False), # k1
209
+ gr.update(visible=False), # k2
210
+ gr.update(visible=False), # k3
211
+ gr.update(visible=False), # k4
212
+ gr.update(visible=False), # k5
213
+ gr.update(visible=False), # k6
214
+ gr.update(visible=False), # t1
215
+ gr.update(visible=False), # t2
216
+ gr.update(visible=False), # hfov
217
+ )
218
+ elif camera == "OPENCV":
219
+ return (
220
+ gr.update(visible=True), # fx
221
+ gr.update(visible=True), # fy
222
+ gr.update(visible=True), # cx
223
+ gr.update(visible=True), # cy
224
+ gr.update(visible=True), # k1
225
+ gr.update(visible=True), # k2
226
+ gr.update(visible=True), # k3
227
+ gr.update(visible=False), # k4
228
+ gr.update(visible=False), # k5
229
+ gr.update(visible=False), # k6
230
+ gr.update(visible=True), # t1
231
+ gr.update(visible=True), # t2
232
+ gr.update(visible=False), # hfov
233
+ )
234
+ elif camera == "Fisheye624":
235
+ return (
236
+ gr.update(visible=True), # fx
237
+ gr.update(visible=True), # fy
238
+ gr.update(visible=True), # cx
239
+ gr.update(visible=True), # cy
240
+ gr.update(visible=True), # k1
241
+ gr.update(visible=True), # k2
242
+ gr.update(visible=True), # k3
243
+ gr.update(visible=True), # k4
244
+ gr.update(visible=True), # k5
245
+ gr.update(visible=True), # k6
246
+ gr.update(visible=True), # t1
247
+ gr.update(visible=True), # t2
248
+ gr.update(visible=False), # hfov
249
+ )
250
+ elif camera == "Equirectangular":
251
+ return (
252
+ gr.update(visible=False), # fx
253
+ gr.update(visible=False), # fy
254
+ gr.update(visible=False), # cx
255
+ gr.update(visible=False), # cy
256
+ gr.update(visible=False), # k1
257
+ gr.update(visible=False), # k2
258
+ gr.update(visible=False), # k3
259
+ gr.update(visible=False), # k4
260
+ gr.update(visible=False), # k5
261
+ gr.update(visible=False), # k6
262
+ gr.update(visible=False), # t1
263
+ gr.update(visible=False), # t2
264
+ gr.update(visible=True), # hfov
265
+ )
266
+ elif camera == "Predicted":
267
+ return (
268
+ gr.update(visible=False), # fx
269
+ gr.update(visible=False), # fy
270
+ gr.update(visible=False), # cx
271
+ gr.update(visible=False), # cy
272
+ gr.update(visible=False), # k1
273
+ gr.update(visible=False), # k2
274
+ gr.update(visible=False), # k3
275
+ gr.update(visible=False), # k4
276
+ gr.update(visible=False), # k5
277
+ gr.update(visible=False), # k6
278
+ gr.update(visible=False), # t1
279
+ gr.update(visible=False), # t2
280
+ gr.update(visible=False), # hfov
281
+ )
282
+ else:
283
+ raise ValueError(f"Invalid camera type: {camera}")
284
+
285
+
286
+ def clear_fields():
287
+ return None
288
+
289
+
290
+ def update_log():
291
+ return "Loading Model and Running Inference..."
292
+
293
+
294
+ def update_visualization(target_dir, mask_black_bg, mask_far_points, is_example):
295
+
296
+ if is_example == "True":
297
+ return (
298
+ None,
299
+ "No reconstruction available. Please click the Reconstruct button first.",
300
+ )
301
+
302
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
303
+ return (
304
+ None,
305
+ "No reconstruction available. Please click the Reconstruct button first.",
306
+ )
307
+
308
+ predictions_path = os.path.join(target_dir, "predictions.npz")
309
+ if not os.path.exists(predictions_path):
310
+ return (
311
+ None,
312
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
313
+ )
314
+
315
+ loaded = np.load(predictions_path, allow_pickle=True)
316
+ predictions = {key: loaded[key] for key in loaded.keys()}
317
+
318
+ glbfile = os.path.join(
319
+ target_dir,
320
+ f"glbscene.glb",
321
+ )
322
+
323
+ if not os.path.exists(glbfile):
324
+ glbscene = predictions_to_glb(
325
+ predictions,
326
+ mask_black_bg=mask_black_bg,
327
+ mask_far_points=mask_far_points,
328
+ )
329
+ glbscene.export(file_obj=glbfile)
330
+
331
+ return glbfile, "Updating Visualization"
332
+
333
+
334
+ if __name__ == "__main__":
335
+ theme = gr.themes.Citrus()
336
+ theme.set(
337
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
338
+ checkbox_label_text_color_selected="*button_primary_text_color",
339
+ )
340
+
341
+ with gr.Blocks(
342
+ theme=theme,
343
+ css="""
344
+ .custom-log * {
345
+ font-style: italic;
346
+ font-size: 22px !important;
347
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
348
+ -webkit-background-clip: text;
349
+ background-clip: text;
350
+ font-weight: bold !important;
351
+ color: transparent !important;
352
+ text-align: center !important;
353
+ }
354
+
355
+ .example-log * {
356
+ font-style: italic;
357
+ font-size: 16px !important;
358
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
359
+ -webkit-background-clip: text;
360
+ background-clip: text;
361
+ color: transparent !important;
362
+ }
363
+
364
+ #my_radio .wrap {
365
+ display: flex;
366
+ flex-wrap: nowrap;
367
+ justify-content: center;
368
+ align-items: center;
369
+ }
370
+
371
+ #my_radio .wrap label {
372
+ display: flex;
373
+ width: 50%;
374
+ justify-content: center;
375
+ align-items: center;
376
+ margin: 0;
377
+ padding: 10px 0;
378
+ box-sizing: border-box;
379
+ }
380
+ """,
381
+ ) as demo:
382
+
383
+ # Instead of gr.State, we use a hidden Textbox:
384
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
385
+
386
+ gr.HTML(
387
+ """
388
+ <h1>UniK3D: Universal Camera Monocular 3D Estimation</h1>
389
+ <p>
390
+ <a href="https://github.com/lpiccinelli-eth/UniK3D">🌟 GitHub Repository</a> |
391
+ <a href="">🚀 Project Page</a>
392
+ </p>
393
+
394
+ <div style="font-size: 16px; line-height: 1.5;">
395
+ <p>Upload one image to create a 3D estimation of a scene or object. UniK3D allows to predict directly 3D of any camera and scene.</p>
396
+
397
+ <h3>Getting Started:</h3>
398
+ <ol>
399
+ <li><strong>Upload Your Image:</strong> Use the "Upload Images" panel to provide your input.</li>
400
+ <li><strong>Run:</strong> Click the "Run UniK3D" button to start the 3D estimation process.</li>
401
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.</li>
402
+ </ol>
403
+ <p><strong style="color: #ff7e26;">Please note:</strong> <span style="color: #ff7e26; font-weight: bold;">Our model runs on CPU on HuggingFace Space. Actual inference is less than 100ms second per image on consumer-level GPUs. Web-based 3D pointcloud visualization may be slow due to Gradio's rendering. For faster visualization, use a local machine to run our demo from our <a href="https://github.com/lpiccinelli-eth/UniK3D">GitHub repository</a>. </span></p>
404
+ </div>
405
+ """
406
+ )
407
+
408
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
409
+
410
+ with gr.Row():
411
+ with gr.Column():
412
+ camera_dropdown = gr.Dropdown(
413
+ choices=[
414
+ "Predicted",
415
+ "Pinhole",
416
+ "Fisheye624",
417
+ "OPENCV",
418
+ "Equirectangular",
419
+ ],
420
+ label="Input Camera",
421
+ )
422
+ model_dropdown = gr.Dropdown(
423
+ choices=["Large", "Base", "Small"], label="Utilized Model"
424
+ )
425
+ mask_black_bg = gr.Checkbox(
426
+ label="Filter Black Background", value=False
427
+ )
428
+ mask_far_points = gr.Checkbox(label="Filter Far Points", value=False)
429
+
430
+ with gr.Column():
431
+ fx = gr.Number(label="Focal length x", value=500.0, visible=False)
432
+ fy = gr.Number(label="Focal length y", value=500.0, visible=False)
433
+ cx = gr.Number(label="Center projection x", value=320.0, visible=False)
434
+ cy = gr.Number(label="Center projection y", value=240.0, visible=False)
435
+ hfov = gr.Number(
436
+ label="Horizontal FoV (degree)", value=0.0, visible=False
437
+ )
438
+
439
+ with gr.Column():
440
+ k1 = gr.Number(label="Radial 1", value=0.0, visible=False)
441
+ k2 = gr.Number(label="Radial 2", value=0.0, visible=False)
442
+ k3 = gr.Number(label="Radial 3", value=0.0, visible=False)
443
+ k4 = gr.Number(label="Radial 4", value=0.0, visible=False)
444
+
445
+ with gr.Column():
446
+ k5 = gr.Number(label="Radial 5", value=0.0, visible=False)
447
+ k6 = gr.Number(label="Radial 6", value=0.0, visible=False)
448
+ t1 = gr.Number(label="Tangential 1", value=0.0, visible=False)
449
+ t2 = gr.Number(label="Tangential 2", value=0.0, visible=False)
450
+
451
+ with gr.Row():
452
+ with gr.Column(scale=1):
453
+ input_image = gr.Image(label="Upload Images")
454
+ gr.Markdown("**3D Estimation**")
455
+ with gr.Row():
456
+ log_output = gr.Markdown(
457
+ "Please upload one image at a time, then click `Run UniK3D`.",
458
+ elem_classes=["custom-log"],
459
+ )
460
+ reconstruction_npy = gr.File(
461
+ label="Download 3D Pointcloud", type="filepath"
462
+ )
463
+
464
+ with gr.Column(scale=2):
465
+ reconstruction_output = gr.Model3D(
466
+ height=520, zoom_speed=0.5, pan_speed=0.5
467
+ )
468
+ with gr.Row():
469
+ submit_btn = gr.Button("Run UniK3D", scale=1, variant="primary")
470
+ clear_btn = gr.ClearButton(
471
+ [
472
+ input_image,
473
+ reconstruction_output,
474
+ log_output,
475
+ target_dir_output,
476
+ reconstruction_npy,
477
+ ],
478
+ scale=1,
479
+ )
480
+
481
+ examples = [
482
+ [
483
+ "assets/demo/poorthings.jpg",
484
+ "Large",
485
+ "Predicted",
486
+ 0.0,
487
+ 0.0,
488
+ 0.0,
489
+ 0.0,
490
+ 0.0,
491
+ 0.0,
492
+ 0.0,
493
+ 0.0,
494
+ 0.0,
495
+ 0.0,
496
+ 0.0,
497
+ 0.0,
498
+ 0.0,
499
+ True,
500
+ False,
501
+ ],
502
+ [
503
+ "assets/demo/naruto.jpg",
504
+ "Large",
505
+ "Predicted",
506
+ 0.0,
507
+ 0.0,
508
+ 0.0,
509
+ 0.0,
510
+ 0.0,
511
+ 0.0,
512
+ 0.0,
513
+ 0.0,
514
+ 0.0,
515
+ 0.0,
516
+ 0.0,
517
+ 0.0,
518
+ 0.0,
519
+ False,
520
+ False,
521
+ ],
522
+ [
523
+ "assets/demo/bears.jpg",
524
+ "Large",
525
+ "Predicted",
526
+ 0.0,
527
+ 0.0,
528
+ 0.0,
529
+ 0.0,
530
+ 0.0,
531
+ 0.0,
532
+ 0.0,
533
+ 0.0,
534
+ 0.0,
535
+ 0.0,
536
+ 0.0,
537
+ 0.0,
538
+ 0.0,
539
+ True,
540
+ False,
541
+ ],
542
+ [
543
+ "assets/demo/berzirk.jpg",
544
+ "Large",
545
+ "Predicted",
546
+ 0.0,
547
+ 0.0,
548
+ 0.0,
549
+ 0.0,
550
+ 0.0,
551
+ 0.0,
552
+ 0.0,
553
+ 0.0,
554
+ 0.0,
555
+ 0.0,
556
+ 0.0,
557
+ 0.0,
558
+ 0.0,
559
+ True,
560
+ False,
561
+ ],
562
+ [
563
+ "assets/demo/luke.webp",
564
+ "Large",
565
+ "Predicted",
566
+ 0.0,
567
+ 0.0,
568
+ 0.0,
569
+ 0.0,
570
+ 0.0,
571
+ 0.0,
572
+ 0.0,
573
+ 0.0,
574
+ 0.0,
575
+ 0.0,
576
+ 0.0,
577
+ 0.0,
578
+ 0.0,
579
+ False,
580
+ False,
581
+ ],
582
+ [
583
+ "assets/demo/equirectangular.jpg",
584
+ "Large",
585
+ "Equirectangular",
586
+ 0.0,
587
+ 0.0,
588
+ 0.0,
589
+ 0.0,
590
+ 0.0,
591
+ 0.0,
592
+ 0.0,
593
+ 0.0,
594
+ 0.0,
595
+ 0.0,
596
+ 0.0,
597
+ 0.0,
598
+ 360.0,
599
+ False,
600
+ False,
601
+ ],
602
+ [
603
+ "assets/demo/venice.jpg",
604
+ "Large",
605
+ "Equirectangular",
606
+ 0.0,
607
+ 0.0,
608
+ 0.0,
609
+ 0.0,
610
+ 0.0,
611
+ 0.0,
612
+ 0.0,
613
+ 0.0,
614
+ 0.0,
615
+ 0.0,
616
+ 0.0,
617
+ 0.0,
618
+ 360.0,
619
+ False,
620
+ True,
621
+ ],
622
+ [
623
+ "assets/demo/dl3dv.png",
624
+ "Large",
625
+ "OPENCV",
626
+ 429.57611083984375,
627
+ 429.6898193359375,
628
+ 479.5,
629
+ 269.5,
630
+ -0.0014844092074781656,
631
+ 0.0007422995404340327,
632
+ 0.0,
633
+ 0.0,
634
+ 0.0,
635
+ 0.0,
636
+ 0.00012013866944471374,
637
+ 0.001125041046179831,
638
+ 0.0,
639
+ False,
640
+ False,
641
+ ],
642
+ [
643
+ "assets/demo/scannet.jpg",
644
+ "Large",
645
+ "Fisheye624",
646
+ 791.90869140625,
647
+ 792.7230834960938,
648
+ 878.16796875,
649
+ 585.045166015625,
650
+ -0.029167557135224342,
651
+ -0.006803446915000677,
652
+ -0.0012682401575148106,
653
+ -4.6094228309812024e-05,
654
+ 0.0,
655
+ 0.0,
656
+ 0.0,
657
+ 0.0,
658
+ 0.0,
659
+ False,
660
+ False,
661
+ ],
662
+ ]
663
+
664
+ def example_pipeline(
665
+ input_image,
666
+ model_name,
667
+ camera_name,
668
+ fx,
669
+ fy,
670
+ cx,
671
+ cy,
672
+ k1,
673
+ k2,
674
+ k3,
675
+ k4,
676
+ k5,
677
+ k6,
678
+ t1,
679
+ t2,
680
+ hfov,
681
+ mask_black_bg,
682
+ mask_far_points,
683
+ ):
684
+ target_dir, image_path = handle_uploads(input_image)
685
+ glbfile, log_msg, prediction_save_path = gradio_demo(
686
+ target_dir,
687
+ model_name,
688
+ camera_name,
689
+ fx,
690
+ fy,
691
+ cx,
692
+ cy,
693
+ k1,
694
+ k2,
695
+ k3,
696
+ k4,
697
+ k5,
698
+ k6,
699
+ t1,
700
+ t2,
701
+ hfov,
702
+ mask_black_bg,
703
+ mask_far_points,
704
+ )
705
+ return (
706
+ glbfile,
707
+ log_msg,
708
+ prediction_save_path,
709
+ target_dir,
710
+ image_path,
711
+ )
712
+
713
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
714
+
715
+ gr.Examples(
716
+ examples=examples,
717
+ inputs=[
718
+ input_image,
719
+ model_dropdown,
720
+ camera_dropdown,
721
+ fx,
722
+ fy,
723
+ cx,
724
+ cy,
725
+ k1,
726
+ k2,
727
+ k3,
728
+ k4,
729
+ k5,
730
+ k6,
731
+ t1,
732
+ t2,
733
+ hfov,
734
+ mask_black_bg,
735
+ mask_far_points,
736
+ ],
737
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
738
+ fn=example_pipeline,
739
+ cache_examples=False,
740
+ examples_per_page=50,
741
+ )
742
+
743
+ submit_btn.click(
744
+ fn=clear_fields, inputs=[], outputs=[reconstruction_output]
745
+ ).then(fn=update_log, inputs=[], outputs=[log_output]).then(
746
+ fn=gradio_demo,
747
+ inputs=[
748
+ target_dir_output,
749
+ model_dropdown,
750
+ camera_dropdown,
751
+ fx,
752
+ fy,
753
+ cx,
754
+ cy,
755
+ k1,
756
+ k2,
757
+ k3,
758
+ k4,
759
+ k5,
760
+ k6,
761
+ t1,
762
+ t2,
763
+ hfov,
764
+ mask_black_bg,
765
+ mask_far_points,
766
+ ],
767
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
768
+ ).then(
769
+ fn=lambda: "False", inputs=[], outputs=[is_example]
770
+ )
771
+
772
+ mask_black_bg.change(
773
+ update_visualization,
774
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
775
+ [reconstruction_output, log_output],
776
+ )
777
+
778
+ mask_far_points.change(
779
+ update_visualization,
780
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
781
+ [reconstruction_output, log_output],
782
+ )
783
+
784
+ input_image.change(
785
+ fn=update_gallery_on_upload,
786
+ inputs=[input_image],
787
+ outputs=[target_dir_output, log_output],
788
+ )
789
+
790
+ # Dynamically update intrinsic parameter visibility when camera selection changes.
791
+ camera_dropdown.change(
792
+ fn=update_parameters,
793
+ inputs=camera_dropdown,
794
+ outputs=[fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov],
795
+ )
796
+
797
+ # demo.queue(max_size=20).launch(show_error=True, share=False, ssr_mode=False)
798
+ demo.launch(
799
+ show_error=True,
800
+ )
assets/demo/bears.jpg ADDED
assets/demo/berzirk.jpg ADDED
assets/demo/dl3dv.json ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ {
2
+ "name": "OPENCV",
3
+ "params": [429.57611083984375, 429.6898193359375, 479.5, 269.5, -0.0014844092074781656, 0.0007422995404340327, 0.0, 0.0, 0.0, 0.0, 0.00012013866944471374, 0.001125041046179831, 0.0, 0.0, 0.0, 0.0]
4
+ }
assets/demo/dl3dv.png ADDED
assets/demo/equirectangular.jpg ADDED
assets/demo/kitti360.json ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "params": [
3
+ 890.8814086914062,
4
+ 890.5255737304688,
5
+ 477.7955017089844,
6
+ 470.34332275390625,
7
+ 0.016798235476017,
8
+ 1.6548773050308228,
9
+ 0.000422239420004189,
10
+ 0.000424621335696429,
11
+ 2.213404655456543
12
+ ],
13
+ "name": "MEI"
14
+ }
assets/demo/kitti360.png ADDED
assets/demo/luke.webp ADDED
assets/demo/naruto.jpg ADDED
assets/demo/poorthings.jpg ADDED
assets/demo/scannet.jpg ADDED
assets/demo/scannet.json ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "params": [
3
+ 791.90869140625,
4
+ 792.7230834960938,
5
+ 878.16796875,
6
+ 585.045166015625,
7
+ -0.029167557135224342,
8
+ -0.006803446915000677,
9
+ -0.0012682401575148106,
10
+ -4.6094228309812024e-05,
11
+ 0.0,
12
+ 0.0,
13
+ 0.0,
14
+ 0.0,
15
+ 0.0,
16
+ 0.0,
17
+ 0.0,
18
+ 0.0
19
+ ],
20
+ "name": "Fisheye624"
21
+ }
assets/demo/venice.jpg ADDED
assets/docs/unik3d-banner.png ADDED
assets/docs/unik3d-teaser.png ADDED
configs/config_vitb.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 42,
4
+ "deterministic": true,
5
+ "name_page": "ufish"
6
+ },
7
+ "training": {
8
+ "n_iters": 250000,
9
+ "batch_size": 8,
10
+ "validation_interval": 2500,
11
+ "nsteps_accumulation_gradient": 4,
12
+ "lr": 5e-05,
13
+ "lr_final": 1e-06,
14
+ "lr_warmup": 1.0,
15
+ "cycle_beta": true,
16
+ "wd": 0.1,
17
+ "wd_final": 0.1,
18
+ "warmup_iters": 75000,
19
+ "ld": 1.0,
20
+ "drop_path": 0.0,
21
+ "ema": 0.9995,
22
+ "f16": "f16",
23
+ "clipping": 1.0,
24
+ "losses": {
25
+ "depth": {
26
+ "name": "Scale",
27
+ "weight": 1.0,
28
+ "fn": "l1",
29
+ "gamma": 1.0,
30
+ "alpha": 1.0,
31
+ "output_fn": "sqrt",
32
+ "input_fn": "log"
33
+ },
34
+ "camera": {
35
+ "name": "PolarRegression",
36
+ "weight": 1.0,
37
+ "gamma": 1.0,
38
+ "alpha": 1.0,
39
+ "fn": "l1",
40
+ "output_fn": "sqrt",
41
+ "input_fn": "linear",
42
+ "dims": [
43
+ 1,
44
+ 2
45
+ ],
46
+ "polar_weight": 3.0,
47
+ "polar_asym": 0.7
48
+ },
49
+ "confidence": {
50
+ "name": "Confidence",
51
+ "weight": 0.1,
52
+ "input_fn": "log",
53
+ "output_fn": "sqrt"
54
+ }
55
+ }
56
+ },
57
+ "data": {
58
+ "image_shape": [
59
+ 518,
60
+ 518
61
+ ],
62
+ "resize_method": "contextcrop",
63
+ "normalization": "imagenet",
64
+ "pair": 1,
65
+ "mini": 1.0,
66
+ "num_frames": 1,
67
+ "sampling": {
68
+ "KITTI": 1.0
69
+ },
70
+ "train_datasets": [
71
+ "KITTI"
72
+ ],
73
+ "val_datasets": [
74
+ "KITTI"
75
+ ],
76
+ "data_root": "datasets",
77
+ "crop": "garg",
78
+ "augmentations": {
79
+ "random_scale": 4.0,
80
+ "random_translate_x": 0.04,
81
+ "random_translate_y": 0.01,
82
+ "scale_p": 0.0,
83
+ "translate_p": 0.0,
84
+ "random_rotation": 0.0,
85
+ "rotation_p": 0.0,
86
+ "random_shear": 0.0,
87
+ "affine_p": 0.0,
88
+ "random_jitter": 0.5,
89
+ "jitter_p": 1.0,
90
+ "random_blur": 2.0,
91
+ "blur_p": 0.5,
92
+ "random_gamma": 0.5,
93
+ "gamma_p": 1.0,
94
+ "grayscale_p": 0.2,
95
+ "flip_p": 0.5,
96
+ "cut_p": 0.0,
97
+ "invert_p": 0.0,
98
+ "shape_mult": 14,
99
+ "noise_pad": 1.0,
100
+ "test_context": 1.0
101
+ },
102
+ "shape_constraints": {
103
+ "ratio_bounds": [
104
+ 0.5,
105
+ 2.5
106
+ ],
107
+ "pixels_max": 600000.0,
108
+ "pixels_min": 200000.0,
109
+ "height_min": 15,
110
+ "width_min": 15,
111
+ "shape_mult": 14,
112
+ "sample": true
113
+ }
114
+ },
115
+ "model": {
116
+ "name": "UniK3D",
117
+ "num_heads": 8,
118
+ "expansion": 4,
119
+ "num_steps": 100000,
120
+ "layer_scale": 1e-4,
121
+ "camera": {
122
+ "augment": true,
123
+ "weak_ratio": 0.9,
124
+ "tau": 50000
125
+ },
126
+ "pixel_decoder": {
127
+ "name": "Decoder",
128
+ "hidden_dim": 384,
129
+ "dropout": 0.0,
130
+ "depths": [
131
+ 2,
132
+ 2,
133
+ 2
134
+ ],
135
+ "detach": 0.1,
136
+ "out_dim": 48,
137
+ "kernel_size": 3,
138
+ "num_prompt_blocks": 1,
139
+ "use_norm": false
140
+ },
141
+ "pixel_encoder": {
142
+ "lr": 3e-06,
143
+ "wd": 0.1,
144
+ "name": "dinov2_vitb14",
145
+ "frozen_stages": 0,
146
+ "num_register_tokens": 0,
147
+ "use_norm": true,
148
+ "freeze_norm": true,
149
+ "pretrained": null,
150
+ "stacking_fn": "last",
151
+ "output_idx": [
152
+ 3,
153
+ 6,
154
+ 9,
155
+ 12
156
+ ]
157
+ }
158
+ }
159
+ }
configs/config_vitl.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 42,
4
+ "deterministic": true,
5
+ "name_page": "ufish"
6
+ },
7
+ "training": {
8
+ "n_iters": 250000,
9
+ "batch_size": 8,
10
+ "validation_interval": 2500,
11
+ "nsteps_accumulation_gradient": 4,
12
+ "lr": 5e-05,
13
+ "lr_final": 1e-06,
14
+ "lr_warmup": 1.0,
15
+ "cycle_beta": true,
16
+ "wd": 0.1,
17
+ "wd_final": 0.1,
18
+ "warmup_iters": 75000,
19
+ "ld": 1.0,
20
+ "drop_path": 0.0,
21
+ "ema": 0.9995,
22
+ "f16": "f16",
23
+ "clipping": 1.0,
24
+ "losses": {
25
+ "depth": {
26
+ "name": "Scale",
27
+ "weight": 1.0,
28
+ "fn": "l1",
29
+ "gamma": 1.0,
30
+ "alpha": 1.0,
31
+ "output_fn": "sqrt",
32
+ "input_fn": "log"
33
+ },
34
+ "camera": {
35
+ "name": "PolarRegression",
36
+ "weight": 1.0,
37
+ "gamma": 1.0,
38
+ "alpha": 1.0,
39
+ "fn": "l1",
40
+ "output_fn": "sqrt",
41
+ "input_fn": "linear",
42
+ "dims": [
43
+ 1,
44
+ 2
45
+ ],
46
+ "polar_weight": 3.0,
47
+ "polar_asym": 0.7
48
+ },
49
+ "confidence": {
50
+ "name": "Confidence",
51
+ "weight": 0.1,
52
+ "input_fn": "log",
53
+ "output_fn": "sqrt"
54
+ }
55
+ }
56
+ },
57
+ "data": {
58
+ "image_shape": [
59
+ 518,
60
+ 518
61
+ ],
62
+ "resize_method": "contextcrop",
63
+ "normalization": "imagenet",
64
+ "pair": 1,
65
+ "mini": 1.0,
66
+ "num_frames": 1,
67
+ "sampling": {
68
+ "KITTI": 1.0
69
+ },
70
+ "train_datasets": [
71
+ "KITTI"
72
+ ],
73
+ "val_datasets": [
74
+ "KITTI"
75
+ ],
76
+ "data_root": "datasets",
77
+ "crop": "garg",
78
+ "augmentations": {
79
+ "random_scale": 4.0,
80
+ "random_translate_x": 0.04,
81
+ "random_translate_y": 0.01,
82
+ "scale_p": 0.0,
83
+ "translate_p": 0.0,
84
+ "random_rotation": 0.0,
85
+ "rotation_p": 0.0,
86
+ "random_shear": 0.0,
87
+ "affine_p": 0.0,
88
+ "random_jitter": 0.5,
89
+ "jitter_p": 1.0,
90
+ "random_blur": 2.0,
91
+ "blur_p": 0.5,
92
+ "random_gamma": 0.5,
93
+ "gamma_p": 1.0,
94
+ "grayscale_p": 0.2,
95
+ "flip_p": 0.5,
96
+ "cut_p": 0.0,
97
+ "invert_p": 0.0,
98
+ "shape_mult": 14,
99
+ "noise_pad": 1.0,
100
+ "test_context": 1.0
101
+ },
102
+ "shape_constraints": {
103
+ "ratio_bounds": [
104
+ 0.5,
105
+ 2.5
106
+ ],
107
+ "pixels_max": 600000.0,
108
+ "pixels_min": 200000.0,
109
+ "height_min": 15,
110
+ "width_min": 15,
111
+ "shape_mult": 14,
112
+ "sample": true
113
+ }
114
+ },
115
+ "model": {
116
+ "name": "UniK3D",
117
+ "num_heads": 8,
118
+ "expansion": 4,
119
+ "num_steps": 100000,
120
+ "layer_scale": 1e-4,
121
+ "camera": {
122
+ "augment": true,
123
+ "weak_ratio": 0.9,
124
+ "tau": 50000
125
+ },
126
+ "pixel_decoder": {
127
+ "name": "Decoder",
128
+ "hidden_dim": 512,
129
+ "dropout": 0.0,
130
+ "depths": [
131
+ 2,
132
+ 2,
133
+ 2
134
+ ],
135
+ "detach": 0.1,
136
+ "out_dim": 64,
137
+ "kernel_size": 3,
138
+ "num_prompt_blocks": 1,
139
+ "use_norm": false
140
+ },
141
+ "pixel_encoder": {
142
+ "lr": 3e-06,
143
+ "wd": 0.1,
144
+ "name": "dinov2_vitl14",
145
+ "frozen_stages": 0,
146
+ "num_register_tokens": 0,
147
+ "use_norm": true,
148
+ "freeze_norm": true,
149
+ "pretrained": null,
150
+ "stacking_fn": "last",
151
+ "output_idx": [
152
+ 6,
153
+ 12,
154
+ 18,
155
+ 24
156
+ ]
157
+ }
158
+ }
159
+ }
configs/config_vits.json ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "generic": {
3
+ "seed": 42,
4
+ "deterministic": true,
5
+ "name_page": "ufish"
6
+ },
7
+ "training": {
8
+ "n_iters": 250000,
9
+ "batch_size": 8,
10
+ "validation_interval": 2500,
11
+ "nsteps_accumulation_gradient": 4,
12
+ "lr": 5e-05,
13
+ "lr_final": 1e-06,
14
+ "lr_warmup": 1.0,
15
+ "cycle_beta": true,
16
+ "wd": 0.1,
17
+ "wd_final": 0.1,
18
+ "warmup_iters": 75000,
19
+ "ld": 1.0,
20
+ "drop_path": 0.0,
21
+ "ema": 0.9995,
22
+ "f16": "f16",
23
+ "clipping": 1.0,
24
+ "losses": {
25
+ "depth": {
26
+ "name": "Scale",
27
+ "weight": 1.0,
28
+ "fn": "l1",
29
+ "gamma": 1.0,
30
+ "alpha": 1.0,
31
+ "output_fn": "sqrt",
32
+ "input_fn": "log"
33
+ },
34
+ "camera": {
35
+ "name": "PolarRegression",
36
+ "weight": 1.0,
37
+ "gamma": 1.0,
38
+ "alpha": 1.0,
39
+ "fn": "l1",
40
+ "output_fn": "sqrt",
41
+ "input_fn": "linear",
42
+ "dims": [
43
+ 1,
44
+ 2
45
+ ],
46
+ "polar_weight": 3.0,
47
+ "polar_asym": 0.7
48
+ },
49
+ "confidence": {
50
+ "name": "Confidence",
51
+ "weight": 0.1,
52
+ "input_fn": "log",
53
+ "output_fn": "sqrt"
54
+ }
55
+ }
56
+ },
57
+ "data": {
58
+ "image_shape": [
59
+ 518,
60
+ 518
61
+ ],
62
+ "resize_method": "contextcrop",
63
+ "normalization": "imagenet",
64
+ "pair": 1,
65
+ "mini": 1.0,
66
+ "num_frames": 1,
67
+ "sampling": {
68
+ "KITTI": 1.0
69
+ },
70
+ "train_datasets": [
71
+ "KITTI"
72
+ ],
73
+ "val_datasets": [
74
+ "KITTI"
75
+ ],
76
+ "data_root": "datasets",
77
+ "crop": "garg",
78
+ "augmentations": {
79
+ "random_scale": 4.0,
80
+ "random_translate_x": 0.04,
81
+ "random_translate_y": 0.01,
82
+ "scale_p": 0.0,
83
+ "translate_p": 0.0,
84
+ "random_rotation": 0.0,
85
+ "rotation_p": 0.0,
86
+ "random_shear": 0.0,
87
+ "affine_p": 0.0,
88
+ "random_jitter": 0.5,
89
+ "jitter_p": 1.0,
90
+ "random_blur": 2.0,
91
+ "blur_p": 0.5,
92
+ "random_gamma": 0.5,
93
+ "gamma_p": 1.0,
94
+ "grayscale_p": 0.2,
95
+ "flip_p": 0.5,
96
+ "cut_p": 0.0,
97
+ "invert_p": 0.0,
98
+ "shape_mult": 14,
99
+ "noise_pad": 1.0,
100
+ "test_context": 1.0
101
+ },
102
+ "shape_constraints": {
103
+ "ratio_bounds": [
104
+ 0.5,
105
+ 2.5
106
+ ],
107
+ "pixels_max": 600000.0,
108
+ "pixels_min": 200000.0,
109
+ "height_min": 15,
110
+ "width_min": 15,
111
+ "shape_mult": 14,
112
+ "sample": true
113
+ }
114
+ },
115
+ "model": {
116
+ "name": "UniK3D",
117
+ "num_heads": 8,
118
+ "expansion": 4,
119
+ "num_steps": 100000,
120
+ "layer_scale": 1e-4,
121
+ "camera": {
122
+ "augment": true,
123
+ "weak_ratio": 0.9,
124
+ "tau": 50000
125
+ },
126
+ "pixel_decoder": {
127
+ "name": "Decoder",
128
+ "hidden_dim": 256,
129
+ "dropout": 0.0,
130
+ "depths": [
131
+ 2,
132
+ 2,
133
+ 2
134
+ ],
135
+ "detach": 0.1,
136
+ "out_dim": 32,
137
+ "kernel_size": 3,
138
+ "num_prompt_blocks": 1,
139
+ "use_norm": false
140
+ },
141
+ "pixel_encoder": {
142
+ "lr": 3e-06,
143
+ "wd": 0.1,
144
+ "name": "dinov2_vits14",
145
+ "frozen_stages": 0,
146
+ "num_register_tokens": 0,
147
+ "use_norm": true,
148
+ "freeze_norm": true,
149
+ "pretrained": null,
150
+ "stacking_fn": "last",
151
+ "output_idx": [
152
+ 3,
153
+ 6,
154
+ 9,
155
+ 12
156
+ ]
157
+ }
158
+ }
159
+ }
gradio_demo.py ADDED
@@ -0,0 +1,796 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gc
2
+ import os
3
+ import shutil
4
+ import time
5
+ from datetime import datetime
6
+ from math import pi
7
+
8
+ import gradio as gr
9
+ import numpy as np
10
+ import torch
11
+ import trimesh
12
+ from PIL import Image
13
+
14
+ from unik3d.models import UniK3D
15
+ from unik3d.utils.camera import OPENCV, Fisheye624, Pinhole, Spherical
16
+ from unik3d.utils.visualization import colorize
17
+
18
+
19
+ def predictions_to_glb(
20
+ predictions,
21
+ mask_black_bg=False,
22
+ mask_far_points=False,
23
+ ) -> trimesh.Scene:
24
+ print("Building GLB scene")
25
+ images = predictions["image"].squeeze().permute(1, 2, 0).cpu().numpy()
26
+ world_points = predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy()
27
+
28
+ vertices_3d = world_points.reshape(-1, 3)
29
+ # flip x and y
30
+ vertices_3d[:, 1] *= -1
31
+ vertices_3d[:, 0] *= -1
32
+ colors_rgb = (images.reshape(-1, 3)).astype(np.uint8)
33
+
34
+ if mask_black_bg:
35
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
36
+ vertices_3d = vertices_3d[black_bg_mask]
37
+ colors_rgb = colors_rgb[black_bg_mask]
38
+
39
+ if mask_far_points:
40
+ far_points_mask = np.linalg.norm(vertices_3d, axis=-1) < 100.0
41
+ vertices_3d = vertices_3d[far_points_mask]
42
+ colors_rgb = colors_rgb[far_points_mask]
43
+
44
+ scene_3d = trimesh.Scene()
45
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
46
+ scene_3d.add_geometry(point_cloud_data)
47
+
48
+ return scene_3d
49
+
50
+
51
+ def instantiate_model(model_name):
52
+ type_ = model_name[0].lower()
53
+
54
+ name = f"unik3d-vit{type_}"
55
+ model = UniK3D.from_pretrained(f"lpiccinelli/{name}")
56
+
57
+ # Set resolution level and interpolation mode as specified.
58
+ model.resolution_level = 9
59
+ model.interpolation_mode = "bilinear"
60
+
61
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
62
+ model = model.to(device).eval()
63
+ return model
64
+
65
+
66
+ def instantiate_camera(camera_name, params, device):
67
+ if camera_name == "Predicted":
68
+ return None
69
+ fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov, H, W = params
70
+ if camera_name == "Pinhole":
71
+ params = [fx, fy, cx, cy]
72
+ elif camera_name == "Fisheye624":
73
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
74
+ elif camera_name == "OPENCV":
75
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
76
+ elif camera_name == "Equirectangular":
77
+ # dummy intrinsics for spherical camera, assume hfov -> vfov based on input shapes
78
+ hfov2 = hfov * pi / 180.0 / 2
79
+ params = [fx, fy, cx, cy, W, H, hfov2, H / W * hfov2]
80
+ camera_name = "Spherical"
81
+
82
+ return eval(camera_name)(params=torch.tensor(params).float()).to(device)
83
+
84
+
85
+ def run_model(target_dir, model_name, camera_name, params):
86
+
87
+ print("Instantiating model and camera...")
88
+ model = instantiate_model(model_name)
89
+
90
+ image_names = [x for x in os.listdir(target_dir) if x.endswith(".png")]
91
+ input_image = np.array(Image.open(os.path.join(target_dir, image_names[-1])))
92
+ image_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float()
93
+ device = next(model.parameters()).device
94
+ image_tensor = image_tensor.to(device)
95
+ H, W = image_tensor.shape[-2:]
96
+ params = params + [H, W]
97
+ camera = instantiate_camera(camera_name, params=params, device=device)
98
+
99
+ # Perform inference with the model.
100
+ print("Running inference...")
101
+ outputs = model.infer(image_tensor, camera=camera, normalize=True)
102
+ outputs["image"] = image_tensor
103
+
104
+ return outputs
105
+
106
+
107
+ def gradio_demo(
108
+ target_dir,
109
+ model_name,
110
+ camera_name,
111
+ fx,
112
+ fy,
113
+ cx,
114
+ cy,
115
+ k1,
116
+ k2,
117
+ k3,
118
+ k4,
119
+ k5,
120
+ k6,
121
+ t1,
122
+ t2,
123
+ hfov,
124
+ mask_black_bg,
125
+ mask_far_points,
126
+ ):
127
+ print(target_dir)
128
+ if not os.path.isdir(target_dir) or target_dir == "None":
129
+ return None, "No valid target directory found. Please upload first.", None
130
+
131
+ start_time = time.time()
132
+ gc.collect()
133
+
134
+ print("Running run_model...")
135
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov]
136
+ with torch.no_grad():
137
+ outputs = run_model(target_dir, model_name, camera_name, params)
138
+
139
+ # Save predictions
140
+ points = outputs["points"].squeeze().permute(1, 2, 0).cpu().numpy()
141
+ rgb = outputs["image"].squeeze().permute(1, 2, 0).cpu().numpy()
142
+
143
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
144
+ np.savez(prediction_save_path, {"points": points, "image": rgb})
145
+
146
+ # Build a GLB file name
147
+ glbfile = os.path.join(
148
+ target_dir,
149
+ f"glbscene.glb",
150
+ )
151
+
152
+ # Convert predictions to GLB
153
+ glbscene = predictions_to_glb(
154
+ outputs,
155
+ mask_black_bg=mask_black_bg,
156
+ mask_far_points=mask_far_points,
157
+ )
158
+ glbscene.export(file_obj=glbfile)
159
+
160
+ # Cleanup
161
+ del outputs
162
+ gc.collect()
163
+
164
+ end_time = time.time()
165
+ print(f"Total time: {end_time - start_time:.2f} seconds")
166
+ log_msg = f"Success. Waiting for visualization."
167
+
168
+ return glbfile, log_msg, prediction_save_path
169
+
170
+
171
+ def handle_uploads(input_image):
172
+ gc.collect()
173
+
174
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
175
+ tmpdir = os.environ.get("TMPDIR", "/tmp")
176
+ target_dir = os.path.join(tmpdir, f"input_images_{timestamp}")
177
+
178
+ if os.path.exists(target_dir):
179
+ shutil.rmtree(target_dir)
180
+ os.makedirs(target_dir)
181
+
182
+ dst_path = os.path.join(target_dir, "image.png")
183
+ Image.fromarray(input_image).save(dst_path)
184
+ image_paths = [dst_path]
185
+
186
+ print(f"Files uploaded.")
187
+ return target_dir, image_paths
188
+
189
+
190
+ def update_gallery_on_upload(input_images):
191
+ if input_images is None:
192
+ return None, None
193
+ target_dir, image_path = handle_uploads(input_images)
194
+ return target_dir, "Upload complete. Click 'Run UniK3D' to get 3D pointcloud."
195
+
196
+
197
+ def update_parameters(camera):
198
+ if camera == "Pinhole":
199
+ return (
200
+ gr.update(visible=True), # fx
201
+ gr.update(visible=True), # fy
202
+ gr.update(visible=True), # cx
203
+ gr.update(visible=True), # cy
204
+ gr.update(visible=False), # k1
205
+ gr.update(visible=False), # k2
206
+ gr.update(visible=False), # k3
207
+ gr.update(visible=False), # k4
208
+ gr.update(visible=False), # k5
209
+ gr.update(visible=False), # k6
210
+ gr.update(visible=False), # t1
211
+ gr.update(visible=False), # t2
212
+ gr.update(visible=False), # hfov
213
+ )
214
+ elif camera == "OPENCV":
215
+ return (
216
+ gr.update(visible=True), # fx
217
+ gr.update(visible=True), # fy
218
+ gr.update(visible=True), # cx
219
+ gr.update(visible=True), # cy
220
+ gr.update(visible=True), # k1
221
+ gr.update(visible=True), # k2
222
+ gr.update(visible=True), # k3
223
+ gr.update(visible=False), # k4
224
+ gr.update(visible=False), # k5
225
+ gr.update(visible=False), # k6
226
+ gr.update(visible=True), # t1
227
+ gr.update(visible=True), # t2
228
+ gr.update(visible=False), # hfov
229
+ )
230
+ elif camera == "Fisheye624":
231
+ return (
232
+ gr.update(visible=True), # fx
233
+ gr.update(visible=True), # fy
234
+ gr.update(visible=True), # cx
235
+ gr.update(visible=True), # cy
236
+ gr.update(visible=True), # k1
237
+ gr.update(visible=True), # k2
238
+ gr.update(visible=True), # k3
239
+ gr.update(visible=True), # k4
240
+ gr.update(visible=True), # k5
241
+ gr.update(visible=True), # k6
242
+ gr.update(visible=True), # t1
243
+ gr.update(visible=True), # t2
244
+ gr.update(visible=False), # hfov
245
+ )
246
+ elif camera == "Equirectangular":
247
+ return (
248
+ gr.update(visible=False), # fx
249
+ gr.update(visible=False), # fy
250
+ gr.update(visible=False), # cx
251
+ gr.update(visible=False), # cy
252
+ gr.update(visible=False), # k1
253
+ gr.update(visible=False), # k2
254
+ gr.update(visible=False), # k3
255
+ gr.update(visible=False), # k4
256
+ gr.update(visible=False), # k5
257
+ gr.update(visible=False), # k6
258
+ gr.update(visible=False), # t1
259
+ gr.update(visible=False), # t2
260
+ gr.update(visible=True), # hfov
261
+ )
262
+ elif camera == "Predicted":
263
+ return (
264
+ gr.update(visible=False), # fx
265
+ gr.update(visible=False), # fy
266
+ gr.update(visible=False), # cx
267
+ gr.update(visible=False), # cy
268
+ gr.update(visible=False), # k1
269
+ gr.update(visible=False), # k2
270
+ gr.update(visible=False), # k3
271
+ gr.update(visible=False), # k4
272
+ gr.update(visible=False), # k5
273
+ gr.update(visible=False), # k6
274
+ gr.update(visible=False), # t1
275
+ gr.update(visible=False), # t2
276
+ gr.update(visible=False), # hfov
277
+ )
278
+ else:
279
+ raise ValueError(f"Invalid camera type: {camera}")
280
+
281
+
282
+ def clear_fields():
283
+ return None
284
+
285
+
286
+ def update_log():
287
+ return "Loading Model and Running Inference..."
288
+
289
+
290
+ def update_visualization(target_dir, mask_black_bg, mask_far_points, is_example):
291
+
292
+ if is_example == "True":
293
+ return (
294
+ None,
295
+ "No reconstruction available. Please click the Reconstruct button first.",
296
+ )
297
+
298
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
299
+ return (
300
+ None,
301
+ "No reconstruction available. Please click the Reconstruct button first.",
302
+ )
303
+
304
+ predictions_path = os.path.join(target_dir, "predictions.npz")
305
+ if not os.path.exists(predictions_path):
306
+ return (
307
+ None,
308
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
309
+ )
310
+
311
+ loaded = np.load(predictions_path, allow_pickle=True)
312
+ predictions = {key: loaded[key] for key in loaded.keys()}
313
+
314
+ glbfile = os.path.join(
315
+ target_dir,
316
+ f"glbscene.glb",
317
+ )
318
+
319
+ if not os.path.exists(glbfile):
320
+ glbscene = predictions_to_glb(
321
+ predictions,
322
+ mask_black_bg=mask_black_bg,
323
+ mask_far_points=mask_far_points,
324
+ )
325
+ glbscene.export(file_obj=glbfile)
326
+
327
+ return glbfile, "Updating Visualization"
328
+
329
+
330
+ if __name__ == "__main__":
331
+ theme = gr.themes.Citrus()
332
+ theme.set(
333
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
334
+ checkbox_label_text_color_selected="*button_primary_text_color",
335
+ )
336
+
337
+ with gr.Blocks(
338
+ theme=theme,
339
+ css="""
340
+ .custom-log * {
341
+ font-style: italic;
342
+ font-size: 22px !important;
343
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
344
+ -webkit-background-clip: text;
345
+ background-clip: text;
346
+ font-weight: bold !important;
347
+ color: transparent !important;
348
+ text-align: center !important;
349
+ }
350
+
351
+ .example-log * {
352
+ font-style: italic;
353
+ font-size: 16px !important;
354
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
355
+ -webkit-background-clip: text;
356
+ background-clip: text;
357
+ color: transparent !important;
358
+ }
359
+
360
+ #my_radio .wrap {
361
+ display: flex;
362
+ flex-wrap: nowrap;
363
+ justify-content: center;
364
+ align-items: center;
365
+ }
366
+
367
+ #my_radio .wrap label {
368
+ display: flex;
369
+ width: 50%;
370
+ justify-content: center;
371
+ align-items: center;
372
+ margin: 0;
373
+ padding: 10px 0;
374
+ box-sizing: border-box;
375
+ }
376
+ """,
377
+ ) as demo:
378
+
379
+ # Instead of gr.State, we use a hidden Textbox:
380
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
381
+
382
+ gr.HTML(
383
+ """
384
+ <h1>UniK3D: Universal Camera Monocular 3D Estimation</h1>
385
+ <p>
386
+ <a href="https://github.com/lpiccinelli-eth/UniK3D">🌟 GitHub Repository</a> |
387
+ <a href="">🚀 Project Page</a>
388
+ </p>
389
+
390
+ <div style="font-size: 16px; line-height: 1.5;">
391
+ <p>Upload one image to create a 3D estimation of a scene or object. UniK3D allows to predict directly 3D of any camera and scene.</p>
392
+
393
+ <h3>Getting Started:</h3>
394
+ <ol>
395
+ <li><strong>Upload Your Image:</strong> Use the "Upload Images" panel to provide your input.</li>
396
+ <li><strong>Run:</strong> Click the "Run UniK3D" button to start the 3D estimation process.</li>
397
+ <li><strong>Visualize:</strong> The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.</li>
398
+ </ol>
399
+ <p><strong style="color: #ff7e26;">Please note:</strong> <span style="color: #ff7e26; font-weight: bold;">Our model runs on CPU on HuggingFace Space. Actual inference is less than 100ms second per image on consumer-level GPUs. Web-based 3D pointcloud visualization may be slow due to Gradio's rendering. For faster visualization, use a local machine to run our demo from our <a href="https://github.com/lpiccinelli-eth/UniK3D">GitHub repository</a>. </span></p>
400
+ </div>
401
+ """
402
+ )
403
+
404
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
405
+
406
+ with gr.Row():
407
+ with gr.Column():
408
+ camera_dropdown = gr.Dropdown(
409
+ choices=[
410
+ "Predicted",
411
+ "Pinhole",
412
+ "Fisheye624",
413
+ "OPENCV",
414
+ "Equirectangular",
415
+ ],
416
+ label="Input Camera",
417
+ )
418
+ model_dropdown = gr.Dropdown(
419
+ choices=["Large", "Base", "Small"], label="Utilized Model"
420
+ )
421
+ mask_black_bg = gr.Checkbox(
422
+ label="Filter Black Background", value=False
423
+ )
424
+ mask_far_points = gr.Checkbox(label="Filter Far Points", value=False)
425
+
426
+ with gr.Column():
427
+ fx = gr.Number(label="Focal length x", value=500.0, visible=False)
428
+ fy = gr.Number(label="Focal length y", value=500.0, visible=False)
429
+ cx = gr.Number(label="Center projection x", value=320.0, visible=False)
430
+ cy = gr.Number(label="Center projection y", value=240.0, visible=False)
431
+ hfov = gr.Number(
432
+ label="Horizontal FoV (degree)", value=0.0, visible=False
433
+ )
434
+
435
+ with gr.Column():
436
+ k1 = gr.Number(label="Radial 1", value=0.0, visible=False)
437
+ k2 = gr.Number(label="Radial 2", value=0.0, visible=False)
438
+ k3 = gr.Number(label="Radial 3", value=0.0, visible=False)
439
+ k4 = gr.Number(label="Radial 4", value=0.0, visible=False)
440
+
441
+ with gr.Column():
442
+ k5 = gr.Number(label="Radial 5", value=0.0, visible=False)
443
+ k6 = gr.Number(label="Radial 6", value=0.0, visible=False)
444
+ t1 = gr.Number(label="Tangential 1", value=0.0, visible=False)
445
+ t2 = gr.Number(label="Tangential 2", value=0.0, visible=False)
446
+
447
+ with gr.Row():
448
+ with gr.Column(scale=1):
449
+ input_image = gr.Image(label="Upload Images")
450
+ gr.Markdown("**3D Estimation**")
451
+ with gr.Row():
452
+ log_output = gr.Markdown(
453
+ "Please upload one image at a time, then click `Run UniK3D`.",
454
+ elem_classes=["custom-log"],
455
+ )
456
+ reconstruction_npy = gr.File(
457
+ label="Download 3D Pointcloud", type="filepath"
458
+ )
459
+
460
+ with gr.Column(scale=2):
461
+ reconstruction_output = gr.Model3D(
462
+ height=520, zoom_speed=0.5, pan_speed=0.5
463
+ )
464
+ with gr.Row():
465
+ submit_btn = gr.Button("Run UniK3D", scale=1, variant="primary")
466
+ clear_btn = gr.ClearButton(
467
+ [
468
+ input_image,
469
+ reconstruction_output,
470
+ log_output,
471
+ target_dir_output,
472
+ reconstruction_npy,
473
+ ],
474
+ scale=1,
475
+ )
476
+
477
+ examples = [
478
+ [
479
+ "assets/demo/poorthings.jpg",
480
+ "Large",
481
+ "Predicted",
482
+ 0.0,
483
+ 0.0,
484
+ 0.0,
485
+ 0.0,
486
+ 0.0,
487
+ 0.0,
488
+ 0.0,
489
+ 0.0,
490
+ 0.0,
491
+ 0.0,
492
+ 0.0,
493
+ 0.0,
494
+ 0.0,
495
+ True,
496
+ False,
497
+ ],
498
+ [
499
+ "assets/demo/naruto.jpg",
500
+ "Large",
501
+ "Predicted",
502
+ 0.0,
503
+ 0.0,
504
+ 0.0,
505
+ 0.0,
506
+ 0.0,
507
+ 0.0,
508
+ 0.0,
509
+ 0.0,
510
+ 0.0,
511
+ 0.0,
512
+ 0.0,
513
+ 0.0,
514
+ 0.0,
515
+ False,
516
+ False,
517
+ ],
518
+ [
519
+ "assets/demo/bears.png",
520
+ "Large",
521
+ "Predicted",
522
+ 0.0,
523
+ 0.0,
524
+ 0.0,
525
+ 0.0,
526
+ 0.0,
527
+ 0.0,
528
+ 0.0,
529
+ 0.0,
530
+ 0.0,
531
+ 0.0,
532
+ 0.0,
533
+ 0.0,
534
+ 0.0,
535
+ True,
536
+ False,
537
+ ],
538
+ [
539
+ "assets/demo/berzirk.jpg",
540
+ "Large",
541
+ "Predicted",
542
+ 0.0,
543
+ 0.0,
544
+ 0.0,
545
+ 0.0,
546
+ 0.0,
547
+ 0.0,
548
+ 0.0,
549
+ 0.0,
550
+ 0.0,
551
+ 0.0,
552
+ 0.0,
553
+ 0.0,
554
+ 0.0,
555
+ True,
556
+ False,
557
+ ],
558
+ [
559
+ "assets/demo/luke.webp",
560
+ "Large",
561
+ "Predicted",
562
+ 0.0,
563
+ 0.0,
564
+ 0.0,
565
+ 0.0,
566
+ 0.0,
567
+ 0.0,
568
+ 0.0,
569
+ 0.0,
570
+ 0.0,
571
+ 0.0,
572
+ 0.0,
573
+ 0.0,
574
+ 0.0,
575
+ False,
576
+ False,
577
+ ],
578
+ [
579
+ "assets/demo/equirectangular.jpg",
580
+ "Large",
581
+ "Equirectangular",
582
+ 0.0,
583
+ 0.0,
584
+ 0.0,
585
+ 0.0,
586
+ 0.0,
587
+ 0.0,
588
+ 0.0,
589
+ 0.0,
590
+ 0.0,
591
+ 0.0,
592
+ 0.0,
593
+ 0.0,
594
+ 360.0,
595
+ False,
596
+ False,
597
+ ],
598
+ [
599
+ "assets/demo/venice.jpg",
600
+ "Large",
601
+ "Equirectangular",
602
+ 0.0,
603
+ 0.0,
604
+ 0.0,
605
+ 0.0,
606
+ 0.0,
607
+ 0.0,
608
+ 0.0,
609
+ 0.0,
610
+ 0.0,
611
+ 0.0,
612
+ 0.0,
613
+ 0.0,
614
+ 360.0,
615
+ False,
616
+ True,
617
+ ],
618
+ [
619
+ "assets/demo/dl3dv.png",
620
+ "Large",
621
+ "OPENCV",
622
+ 429.57611083984375,
623
+ 429.6898193359375,
624
+ 479.5,
625
+ 269.5,
626
+ -0.0014844092074781656,
627
+ 0.0007422995404340327,
628
+ 0.0,
629
+ 0.0,
630
+ 0.0,
631
+ 0.0,
632
+ 0.00012013866944471374,
633
+ 0.001125041046179831,
634
+ 0.0,
635
+ False,
636
+ False,
637
+ ],
638
+ [
639
+ "assets/demo/scannet.png",
640
+ "Large",
641
+ "Fisheye624",
642
+ 791.90869140625,
643
+ 792.7230834960938,
644
+ 878.16796875,
645
+ 585.045166015625,
646
+ -0.029167557135224342,
647
+ -0.006803446915000677,
648
+ -0.0012682401575148106,
649
+ -4.6094228309812024e-05,
650
+ 0.0,
651
+ 0.0,
652
+ 0.0,
653
+ 0.0,
654
+ 0.0,
655
+ False,
656
+ False,
657
+ ],
658
+ ]
659
+
660
+ def example_pipeline(
661
+ input_image,
662
+ model_name,
663
+ camera_name,
664
+ fx,
665
+ fy,
666
+ cx,
667
+ cy,
668
+ k1,
669
+ k2,
670
+ k3,
671
+ k4,
672
+ k5,
673
+ k6,
674
+ t1,
675
+ t2,
676
+ hfov,
677
+ mask_black_bg,
678
+ mask_far_points,
679
+ ):
680
+ target_dir, image_path = handle_uploads(input_image)
681
+ glbfile, log_msg, prediction_save_path = gradio_demo(
682
+ target_dir,
683
+ model_name,
684
+ camera_name,
685
+ fx,
686
+ fy,
687
+ cx,
688
+ cy,
689
+ k1,
690
+ k2,
691
+ k3,
692
+ k4,
693
+ k5,
694
+ k6,
695
+ t1,
696
+ t2,
697
+ hfov,
698
+ mask_black_bg,
699
+ mask_far_points,
700
+ )
701
+ return (
702
+ glbfile,
703
+ log_msg,
704
+ prediction_save_path,
705
+ target_dir,
706
+ image_path,
707
+ )
708
+
709
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
710
+
711
+ gr.Examples(
712
+ examples=examples,
713
+ inputs=[
714
+ input_image,
715
+ model_dropdown,
716
+ camera_dropdown,
717
+ fx,
718
+ fy,
719
+ cx,
720
+ cy,
721
+ k1,
722
+ k2,
723
+ k3,
724
+ k4,
725
+ k5,
726
+ k6,
727
+ t1,
728
+ t2,
729
+ hfov,
730
+ mask_black_bg,
731
+ mask_far_points,
732
+ ],
733
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
734
+ fn=example_pipeline,
735
+ cache_examples=False,
736
+ examples_per_page=50,
737
+ )
738
+
739
+ submit_btn.click(
740
+ fn=clear_fields, inputs=[], outputs=[reconstruction_output]
741
+ ).then(fn=update_log, inputs=[], outputs=[log_output]).then(
742
+ fn=gradio_demo,
743
+ inputs=[
744
+ target_dir_output,
745
+ model_dropdown,
746
+ camera_dropdown,
747
+ fx,
748
+ fy,
749
+ cx,
750
+ cy,
751
+ k1,
752
+ k2,
753
+ k3,
754
+ k4,
755
+ k5,
756
+ k6,
757
+ t1,
758
+ t2,
759
+ hfov,
760
+ mask_black_bg,
761
+ mask_far_points,
762
+ ],
763
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
764
+ ).then(
765
+ fn=lambda: "False", inputs=[], outputs=[is_example]
766
+ )
767
+
768
+ mask_black_bg.change(
769
+ update_visualization,
770
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
771
+ [reconstruction_output, log_output],
772
+ )
773
+
774
+ mask_far_points.change(
775
+ update_visualization,
776
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
777
+ [reconstruction_output, log_output],
778
+ )
779
+
780
+ input_image.change(
781
+ fn=update_gallery_on_upload,
782
+ inputs=[input_image],
783
+ outputs=[target_dir_output, log_output],
784
+ )
785
+
786
+ # Dynamically update intrinsic parameter visibility when camera selection changes.
787
+ camera_dropdown.change(
788
+ fn=update_parameters,
789
+ inputs=camera_dropdown,
790
+ outputs=[fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov],
791
+ )
792
+
793
+ # demo.queue(max_size=20).launch(show_error=True, share=False, ssr_mode=False)
794
+ demo.launch(
795
+ show_error=True,
796
+ )
hubconf.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dependencies = ["torch", "huggingface_hub"]
2
+
3
+ import os
4
+ import json
5
+
6
+ import torch
7
+ import huggingface_hub
8
+
9
+ from unik3d.models import UniK3D as UniK3D_
10
+
11
+ BACKBONES = ["vitl", "vitb", "vits"]
12
+
13
+
14
+ def UniK3D(backbone="vitl", pretrained=True):
15
+ assert backbone in BACKBONES, f"backbone must be one of {BACKBONES}"
16
+ repo_dir = os.path.dirname(os.path.realpath(__file__))
17
+ with open(os.path.join(repo_dir, "configs", f"config_{backbone}.json")) as f:
18
+ config = json.load(f)
19
+
20
+ model = UniK3D_(config)
21
+ if pretrained:
22
+ path = huggingface_hub.hf_hub_download(repo_id=f"lpiccinelli/unik3d-{backbone}", filename=f"pytorch_model.bin", repo_type="model")
23
+ info = model.load_state_dict(torch.load(path), strict=False)
24
+ print(f"UniK3D-{backbone} is loaded with:")
25
+ print(f"\t missing keys: {info.missing_keys}")
26
+ print(f"\t additional keys: {info.unexpected_keys}")
27
+
28
+ return model
29
+
pyproject.toml ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [build-system]
2
+ requires = ["setuptools"]
3
+ build-backend = "setuptools.build_meta"
4
+
5
+ [tool.pyright]
6
+ include = ["unik3d"]
7
+
8
+ [project]
9
+ name = "unik3d"
10
+ version = "0.1"
11
+ authors = [{name = "Luigi Piccinelli", email = "lpiccinelli@ethz.ch"}]
12
+ description = "UniK3D: Universal Monocular Metric Depth Estimation"
13
+ readme = "README.md"
14
+ license = { text="Creatives Common BY-NC 4.0 license"}
15
+ requires-python = ">=3.11.0"
16
+ dynamic = ["dependencies"]
17
+
18
+ [tool.setuptools.dynamic]
19
+ dependencies = {file = ["requirements.txt"]}
20
+
21
+ [tool.setuptools.package-data]
22
+ "*" = ["py.typed"]
23
+
24
+ [tool.setuptools.packages.find]
25
+ include = ["unik3d*"]
requirements.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ appdirs
2
+ attrs
3
+ black
4
+ blosc2
5
+ botocore>=1.34.54
6
+ certifi>=2022.12.7
7
+ charset-normalizer
8
+ click
9
+ contourpy
10
+ cycler
11
+ docker-pycreds
12
+ einops>=0.7.0
13
+ filelock
14
+ flake8>=7.0.0
15
+ flake8-bugbear>=24.2.6
16
+ flake8-comprehensions>=3.14.0
17
+ fonttools
18
+ fsspec
19
+ fvcore>=0.1.5.post20221221
20
+ gitdb
21
+ GitPython
22
+ gradio
23
+ h5py>=3.10.0
24
+ huggingface-hub>=0.22.0
25
+ idna
26
+ imageio
27
+ imath
28
+ iopath
29
+ isort
30
+ Jinja2
31
+ jmespath
32
+ kiwisolver
33
+ MarkupSafe
34
+ matplotlib
35
+ mccabe
36
+ mpmath
37
+ msgpack
38
+ mypy-extensions
39
+ ndindex
40
+ networkx
41
+ ninja
42
+ numexpr
43
+ numpy<2.0.0
44
+ opencv-python
45
+ OpenEXR
46
+ packaging
47
+ pandas
48
+ pathspec
49
+ pillow>=10.2.0
50
+ platformdirs
51
+ portalocker
52
+ protobuf>=4.25.3
53
+ psutil
54
+ py-cpuinfo
55
+ pycodestyle
56
+ pyflakes
57
+ pyparsing
58
+ python-dateutil
59
+ pytz
60
+ PyYAML
61
+ requests
62
+ safetensors
63
+ scipy
64
+ sentry-sdk
65
+ setproctitle
66
+ six
67
+ smmap
68
+ sympy
69
+ tables
70
+ tabulate
71
+ termcolor
72
+ timm
73
+ tqdm
74
+ trimesh
75
+ triton>=2.4.0
76
+ typing_extensions
77
+ tzdata==2024.1
78
+ urllib3==1.26.13
79
+ wandb
80
+ yacs
81
+ torch>=2.4.0
82
+ torchvision>=0.19.0
83
+ torchaudio>=2.4.0
84
+ xformers>=0.0.26
requirements_demo.txt ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ appdirs
2
+ attrs
3
+ black
4
+ blosc2
5
+ botocore>=1.34.54
6
+ certifi>=2022.12.7
7
+ charset-normalizer
8
+ click
9
+ contourpy
10
+ cycler
11
+ docker-pycreds
12
+ einops>=0.7.0
13
+ filelock
14
+ flake8>=7.0.0
15
+ flake8-bugbear>=24.2.6
16
+ flake8-comprehensions>=3.14.0
17
+ fonttools
18
+ fsspec
19
+ fvcore>=0.1.5.post20221221
20
+ gitdb
21
+ GitPython
22
+ gradio
23
+ h5py>=3.10.0
24
+ huggingface-hub>=0.22.0
25
+ idna
26
+ imageio
27
+ imath
28
+ iopath
29
+ isort
30
+ Jinja2
31
+ jmespath
32
+ kiwisolver
33
+ MarkupSafe
34
+ matplotlib
35
+ mccabe
36
+ mpmath
37
+ msgpack
38
+ mypy-extensions
39
+ ndindex
40
+ networkx
41
+ ninja
42
+ numexpr
43
+ numpy<2.0.0
44
+ opencv-python
45
+ OpenEXR
46
+ packaging
47
+ pandas
48
+ pathspec
49
+ pillow>=10.2.0
50
+ platformdirs
51
+ portalocker
52
+ protobuf>=4.25.3
53
+ psutil
54
+ py-cpuinfo
55
+ pycodestyle
56
+ pyflakes
57
+ pyparsing
58
+ python-dateutil
59
+ pytz
60
+ PyYAML
61
+ requests
62
+ safetensors
63
+ scipy
64
+ sentry-sdk
65
+ setproctitle
66
+ six
67
+ smmap
68
+ sympy
69
+ tables
70
+ tabulate
71
+ termcolor
72
+ timm
73
+ tqdm
74
+ trimesh
75
+ triton>=2.4.0
76
+ typing_extensions
77
+ tzdata==2024.1
78
+ urllib3==1.26.13
79
+ wandb
80
+ yacs
81
+ torch>=2.4.0
82
+ torchvision>=0.19.0
83
+ torchaudio>=2.4.0
84
+ xformers>=0.0.26
scripts/README.md ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ## Training
2
+
3
+ We provide the `train.py` script that allows to load the dataset, initialize and start the training. From the root of the repo:
4
+
5
+ ```bash
6
+ export REPO=`pwd`
7
+ export PYTHONPATH=${REPO}:${PYTHONPATH}
8
+
9
+ # Adapt all this to your setup
10
+ export TMPDIR="/tmp"
11
+ export TORCH_HOME=${TMPDIR}
12
+ export HUGGINGFACE_HUB_CACHE=${TMPDIR}
13
+ export WANDB_HOME=${TMPDIR}
14
+ export DATAROOT=<where-you-stored-the-hdf5>
15
+
16
+
17
+ export MASTER_PORT=$((( RANDOM % 600 ) + 29400 ))
18
+ if [ $NNODES -gt 1 ]; then
19
+ export MASTER_PORT=29400
20
+ fi
21
+
22
+ # this is the config will be used
23
+ export CFG="config_vitl.json"
24
+ ```
25
+
26
+ If you are on a machine without SLURM you can run the following:
27
+ ```bash
28
+ # make the following input-dependent for multi-node
29
+ export NNODES=1
30
+ export RANK=0
31
+ export MASTER_ADDR=127.0.0.1
32
+ export CUDA_VISIBLE_DEVICES="0" # set yours
33
+
34
+ export GPUS=$(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l)
35
+ echo "Start script with python from: `which python`"
36
+ torchrun --rdzv-backend=c10d --nnodes=${NNODES} --nproc_per_node=${GPUS} --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --distributed
37
+ ```
38
+
39
+ If you system has SLURM, all the information will be set by the scheduler and you have to run just:
40
+ ```bash
41
+ srun -c ${SLURM_CPUS_PER_TASK} --kill-on-bad-exit=1 python -u ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --master-port ${MASTER_PORT} --distributed
42
+ ```
43
+
44
+
45
+ ### Datasets
46
+
47
+ We used both image-based and sequence-based dataset. The `ImageDataset` class is actually for legacy only as we moved image-based dataset to be "dummy" single-frame sequences.<br>
48
+ We [provide two example dataset to get familiar to the pipeline and structure, namely iBims-1 and Sintel](https://drive.google.com/drive/folders/1FKsa5-b3EX0ukZq7bxord5fC5OfUiy16?usp=sharing), image- and sequence-based, respectively.<br>
49
+ You can adapt the data loading and processing to your example; however, you will need to keep the same interface for the model to be consisten and train "out-of-the-box" the model.<br>
50
+
51
+
52
+ ### Additional dependencies
53
+
54
+ We require chamfer distance for the evaluation, you can compile the knn operation under `ops/knn`: `bash compile.sh` from the directory `$REPO/unik3d/ops/knn`. Set the correct `export TORCH_CUDA_ARCH_LIST`, according to the hardware you are working on.
55
+ For training and to perform augmentation, you can use `camera_augmenter.py`; however the splatting requires you to install operations by cloning and installing from `github.com/hperrot/splatting`.
scripts/demo.py ADDED
@@ -0,0 +1,150 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ from PIL import Image
9
+
10
+ from unik3d.models import UniK3D
11
+ from unik3d.utils.camera import (MEI, OPENCV, BatchCamera, Fisheye624, Pinhole,
12
+ Spherical)
13
+ from unik3d.utils.visualization import colorize, save_file_ply
14
+
15
+ SAVE = False
16
+ BASE_PATH = os.path.join(
17
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "assets", "demo"
18
+ )
19
+
20
+
21
+ def infer(model, rgb_path, camera_path, rays=None):
22
+ rgb = np.array(Image.open(rgb_path))
23
+ rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
24
+
25
+ camera = None
26
+ if camera_path is not None:
27
+ with open(camera_path, "r") as f:
28
+ camera_dict = json.load(f)
29
+
30
+ params = torch.tensor(camera_dict["params"])
31
+ name = camera_dict["name"]
32
+ assert name in ["Fisheye624", "Spherical", "OPENCV", "Pinhole", "MEI"]
33
+ camera = eval(name)(params=params)
34
+
35
+ outputs = model.infer(rgb=rgb_torch, camera=camera, normalize=True, rays=rays)
36
+
37
+ return rgb_torch, outputs
38
+
39
+
40
+ def infer_equirectangular(model, rgb_path):
41
+ rgb = np.array(Image.open(rgb_path))
42
+ rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
43
+
44
+ # assuming full equirectangular image horizontally
45
+ H, W = rgb.shape[:2]
46
+ hfov_half = np.pi
47
+ vfov_half = np.pi * H / W
48
+ assert vfov_half <= np.pi / 2
49
+
50
+ params = [W, H, hfov_half, vfov_half]
51
+ camera = Spherical(params=torch.tensor([1.0] * 4 + params))
52
+
53
+ outputs = model.infer(rgb=rgb_torch, camera=camera, normalize=True)
54
+ return rgb_torch, outputs
55
+
56
+
57
+ def save(rgb, outputs, name, base_path, save_pointcloud=False):
58
+ depth = outputs["depth"]
59
+ rays = outputs["rays"]
60
+ points = outputs["points"]
61
+
62
+ depth = depth.cpu().numpy()
63
+ rays = ((rays + 1) * 127.5).clip(0, 255)
64
+
65
+ Image.fromarray(colorize(depth.squeeze())).save(
66
+ os.path.join(base_path, f"{name}_depth.png")
67
+ )
68
+ Image.fromarray(rgb.squeeze().permute(1, 2, 0).cpu().numpy()).save(
69
+ os.path.join(base_path, f"{name}_rgb.png")
70
+ )
71
+ Image.fromarray(rays.squeeze().permute(1, 2, 0).byte().cpu().numpy()).save(
72
+ os.path.join(base_path, f"{name}_rays.png")
73
+ )
74
+
75
+ if save_pointcloud:
76
+ predictions_3d = points.permute(0, 2, 3, 1).reshape(-1, 3).cpu().numpy()
77
+ rgb = rgb.permute(1, 2, 0).reshape(-1, 3).cpu().numpy()
78
+ save_file_ply(predictions_3d, rgb, os.path.join(base_path, f"{name}.ply"))
79
+
80
+
81
+ def demo(model):
82
+ # RGB + CAMERA
83
+ rgb, outputs = infer(
84
+ model,
85
+ os.path.join(BASE_PATH, f"scannet.png"),
86
+ os.path.join(BASE_PATH, "scannet.json"),
87
+ )
88
+ if SAVE:
89
+ save(rgb, outputs, name="scannet", base_path=BASE_PATH)
90
+
91
+ # get GT and pred
92
+ pts_pred = outputs["points"].squeeze().cpu().permute(1, 2, 0).numpy()
93
+ pts_gt = np.load("./assets/demo/scannet.npy").astype(float)
94
+ mask = np.linalg.norm(pts_gt, axis=-1) > 0
95
+ error = np.linalg.norm(pts_pred - pts_gt, axis=-1)
96
+ error = np.mean(error[mask] ** 2) ** 0.5
97
+
98
+ # Trade-off between speed and resolution
99
+ model.resolution_level = 1
100
+ rgb, outputs = infer(
101
+ model,
102
+ os.path.join(BASE_PATH, f"scannet.png"),
103
+ os.path.join(BASE_PATH, "scannet.json"),
104
+ )
105
+ if SAVE:
106
+ save(rgb, outputs, name="scannet_lowres", base_path=BASE_PATH)
107
+
108
+ # RGB
109
+ rgb, outputs = infer(model, os.path.join(BASE_PATH, f"poorthings.jpg"), None)
110
+ if SAVE:
111
+ save(rgb, outputs, name="poorthings", base_path=BASE_PATH)
112
+
113
+ # RGB + CAMERA
114
+ rgb, outputs = infer(
115
+ model,
116
+ os.path.join(BASE_PATH, f"dl3dv.png"),
117
+ os.path.join(BASE_PATH, "dl3dv.json"),
118
+ )
119
+ if SAVE:
120
+ save(rgb, outputs, name="dl3dv", base_path=BASE_PATH)
121
+
122
+ # EQUIRECTANGULAR
123
+ rgb, outputs = infer_equirectangular(
124
+ model, os.path.join(BASE_PATH, f"equirectangular.jpg")
125
+ )
126
+ if SAVE:
127
+ save(rgb, outputs, name="equirectangular", base_path=BASE_PATH)
128
+
129
+ print("Output keys are", outputs.keys())
130
+
131
+ if SAVE:
132
+ print("Done! Results saved in", BASE_PATH)
133
+
134
+ print(f"RMSE on 3D clouds for ScanNet++ sample: {100*error:.1f}cm")
135
+
136
+
137
+ if __name__ == "__main__":
138
+ print("Torch version:", torch.__version__)
139
+ type_ = "l" # available types: s, b, l
140
+ name = f"unik3d-vit{type_}"
141
+ model = UniK3D.from_pretrained(f"lpiccinelli/{name}")
142
+
143
+ # set resolution level in [0,10) and output interpolation
144
+ model.resolution_level = 9
145
+ model.interpolation_mode = "bilinear"
146
+
147
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
148
+ model = model.to(device).eval()
149
+
150
+ demo(model)
scripts/train.py ADDED
@@ -0,0 +1,630 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ import os
4
+ import random
5
+ import uuid
6
+ from contextlib import nullcontext
7
+ from copy import deepcopy
8
+ from datetime import datetime as dt
9
+ from functools import partial
10
+ from math import log2
11
+ from time import sleep, time
12
+ from typing import Any, Dict
13
+
14
+ import git
15
+ import numpy as np
16
+ import psutil
17
+ import torch
18
+ import torch.nn as nn
19
+ import torch.utils.data.distributed
20
+ import wandb
21
+ from PIL import Image
22
+ from torch import distributed as dist
23
+ from torch import optim
24
+ from torch.nn.parallel.distributed import DistributedDataParallel
25
+ from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
26
+ from tqdm import tqdm
27
+
28
+ import unik3d.datasets as datasets
29
+ from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate,
30
+ collate_fn, get_weights)
31
+ from unik3d.models import UniK3D
32
+ from unik3d.ops.scheduler import CosineScheduler
33
+ from unik3d.utils import (barrier, format_seconds, is_main_process,
34
+ log_train_artifacts, validate)
35
+ from unik3d.utils.distributed import (create_local_process_group,
36
+ local_broadcast_process_authkey,
37
+ setup_multi_processes, setup_slurm,
38
+ sync_string_across_gpus,
39
+ sync_tensor_across_gpus)
40
+ from unik3d.utils.ema_torch import (DummyExponentialMovingAverage,
41
+ ExponentialMovingAverage)
42
+ from unik3d.utils.misc import calculate_mean_values
43
+
44
+ EMA_INTERVAL = 10
45
+ EMA_TAU = 10000
46
+ EMA_START = 50000
47
+
48
+
49
+ MAP_DTYPE = {
50
+ "f16": torch.float16,
51
+ "bf16": torch.bfloat16,
52
+ "f32": torch.float32,
53
+ }
54
+
55
+
56
+ def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device):
57
+ keys = list(dict_.keys())
58
+ values = torch.tensor(list(dict_.values()), device=device)
59
+ keys = sync_string_across_gpus(keys, device)
60
+ values = sync_tensor_across_gpus(values, dim=0).cpu().tolist()
61
+ dict_ = calculate_mean_values(keys, values)
62
+ return dict_
63
+
64
+
65
+ def main_worker(config: Dict[str, Any], args: argparse.Namespace):
66
+
67
+ current_process = psutil.Process(os.getpid())
68
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
69
+ seed = config["generic"]["seed"]
70
+
71
+ if not args.distributed:
72
+ args.rank = 0
73
+ args.local_rank = 0
74
+ args.world_size = 1
75
+ else:
76
+ # initializes the distributed backend which will take care of synchronizing nodes/GPUs
77
+ setup_multi_processes(config)
78
+ is_slurm = "SLURM_PROCID" in os.environ
79
+ if is_slurm:
80
+ setup_slurm("nccl", port=args.master_port)
81
+ args.rank = int(os.environ["RANK"])
82
+ args.world_size = int(os.environ["WORLD_SIZE"])
83
+ args.local_rank = device = int(os.environ["LOCAL_RANK"])
84
+ if not is_slurm:
85
+ import datetime
86
+
87
+ dist.init_process_group(
88
+ "nccl",
89
+ rank=args.rank,
90
+ world_size=args.world_size,
91
+ timeout=datetime.timedelta(seconds=30 * 60),
92
+ )
93
+ torch.cuda.set_device(device)
94
+ create_local_process_group()
95
+ local_broadcast_process_authkey()
96
+ print(
97
+ f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}."
98
+ )
99
+ config["training"]["batch_size"] = int(
100
+ config["training"]["batch_size"] / args.world_size
101
+ )
102
+ dist.barrier()
103
+
104
+ # Fix seed
105
+ # Different for every machine to avoid sampling
106
+ # the same element across machines
107
+ seed = seed + args.rank
108
+ random.seed(seed)
109
+ np.random.seed(seed)
110
+ torch.manual_seed(seed)
111
+ torch.cuda.manual_seed(seed)
112
+ torch.cuda.manual_seed_all(seed)
113
+ os.environ["PYTHONHASHSEED"] = str(seed)
114
+
115
+ batch_size = config["training"]["batch_size"]
116
+ if is_main_process():
117
+ print("Config: ", args.config_file)
118
+ print(
119
+ f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}"
120
+ )
121
+ print("BatchSize per GPU: ", batch_size)
122
+ print(
123
+ f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step"
124
+ )
125
+
126
+ ##############################
127
+ ########### MODEL ############
128
+ ##############################
129
+ # Build model
130
+ model = UniK3D(config).to(device)
131
+ model.eval()
132
+ print(f"MODEL: {model.__class__.__name__} at {model.device}")
133
+ torch.cuda.empty_cache()
134
+
135
+ if args.distributed:
136
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
137
+ model = DistributedDataParallel(
138
+ model,
139
+ find_unused_parameters=False,
140
+ device_ids=[device],
141
+ output_device=device,
142
+ )
143
+
144
+ ##############################
145
+ ######### OPTIMIZER ##########
146
+ ##############################
147
+ dtype_16bit = config["training"]["f16"]
148
+ is_16bit = dtype_16bit != "f32"
149
+ clipping = config["training"].get("clipping", None)
150
+
151
+ # Optimize
152
+ ddp_model = model.module if args.distributed else model
153
+ params = ddp_model.get_params(config)
154
+ optimizer = optim.AdamW(
155
+ params,
156
+ eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8
157
+ # amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions?
158
+ )
159
+
160
+ # Load Model:
161
+ step = 0
162
+ if config["training"].get("pretrained", None) is not None:
163
+ ddp_model.load_pretrained(config["training"]["pretrained"])
164
+ pretrained = torch.load(
165
+ config["training"]["pretrained"], map_location="cpu", weights_only=False
166
+ )
167
+ try:
168
+ optimizer.load_state_dict(pretrained["optimizer"])
169
+ except Exception as e:
170
+ if is_main_process():
171
+ print("Could not load optimizer state dict:", e)
172
+ step = pretrained.get("step", 0)
173
+ ddp_model.pixel_decoder.steps = step
174
+
175
+ # EMA
176
+ ema_class = (
177
+ ExponentialMovingAverage
178
+ if config["training"]["ema"] > 0.0
179
+ else DummyExponentialMovingAverage
180
+ )
181
+ ema_handle = ema_class(
182
+ ddp_model.parameters_grad(),
183
+ 1 - (1 - config["training"]["ema"]) * EMA_INTERVAL,
184
+ update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL,
185
+ switch=True,
186
+ tau=EMA_TAU // EMA_INTERVAL,
187
+ )
188
+ setattr(ema_handle, "num_updates", step // EMA_INTERVAL)
189
+
190
+ ##############################
191
+ ######### GENERICS ###########
192
+ ##############################
193
+ resize_method = config["data"].get("resize_method", "hard")
194
+ crop = config["data"].get("crop", "garg")
195
+ augmentations_db = config["data"].get("augmentations", {})
196
+ shape_constraints = config["data"].get("shape_constraints", {})
197
+ image_shape = config["data"]["image_shape"]
198
+ mini = config["data"]["mini"]
199
+ nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"]
200
+ batch_size = config["training"]["batch_size"]
201
+ clipping_fn = torch.nn.utils.clip_grad_norm_
202
+
203
+ is_shell = int(os.environ.get("SHELL_JOB", 0))
204
+ run_id = sync_string_across_gpus(
205
+ [f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device
206
+ )[0]
207
+
208
+ if not is_shell and is_main_process():
209
+ repo_folder = os.path.dirname(os.path.realpath(__file__))
210
+ try:
211
+ repo = git.Repo(repo_folder)
212
+ current_head = repo.head if repo.head.is_detached else repo.active_branch
213
+ notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}"
214
+ except:
215
+ print(f"problem with {repo_folder}, does it exist?")
216
+ notes = ""
217
+
218
+ # restore the original batchsize, not acquired by other calls from now on
219
+ if args.distributed:
220
+ config["training"]["batch_size"] = (
221
+ config["training"]["batch_size"] * args.world_size
222
+ )
223
+ wandb.init(
224
+ project="UniK3D",
225
+ name=run_id,
226
+ config=config,
227
+ tags=None,
228
+ notes=notes,
229
+ dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")),
230
+ )
231
+ wandb.watch(model)
232
+
233
+ ##############################
234
+ ########## DATASET ###########
235
+ ##############################
236
+ # Datasets loading
237
+ train_datasets, val_datasets = {}, {}
238
+ if is_main_process():
239
+ print("Loading training datasets...")
240
+ dims = 0
241
+
242
+ for dataset in config["data"]["train_datasets"]:
243
+ assert hasattr(datasets, dataset), f"{dataset} not a custom dataset"
244
+ train_dataset: datasets.BaseDataset = getattr(datasets, dataset)
245
+ train_datasets[dataset] = train_dataset(
246
+ image_shape=image_shape,
247
+ split_file=train_dataset.train_split,
248
+ test_mode=False,
249
+ crop=crop,
250
+ augmentations_db=augmentations_db,
251
+ shape_constraints=shape_constraints,
252
+ normalize=config["data"].get("normalization", "imagenet"),
253
+ resize_method=resize_method,
254
+ mini=mini,
255
+ num_frames=config["data"].get("num_frames", 1),
256
+ fps_range=[1, 5],
257
+ num_copies=config["data"]["pair"],
258
+ )
259
+ dim = (
260
+ train_datasets[dataset].dataset._addr.numel() * 8
261
+ + train_datasets[dataset].dataset._lst.numel()
262
+ ) / (2**20)
263
+ if hasattr(train_datasets[dataset], "sequences"):
264
+ dim += (
265
+ train_datasets[dataset].sequences._addr.numel() * 8
266
+ + train_datasets[dataset].sequences._lst.numel()
267
+ ) / (2**20)
268
+ dims = dims + dim
269
+ if is_main_process():
270
+ print(f"{dataset}: {dim:.1f}MB")
271
+
272
+ print(f"All training datasets loaded, with total size: {dims:.1f}MB")
273
+
274
+ barrier()
275
+
276
+ assert batch_size % config["data"]["pair"] == 0
277
+ batch_size = batch_size // config["data"]["pair"]
278
+ assert batch_size % nsteps_accumulation_gradient == 0
279
+ batch_chunk = batch_size // nsteps_accumulation_gradient
280
+
281
+ train_dataset = ConcatDataset(
282
+ list(train_datasets.values()),
283
+ shape_constraints=shape_constraints,
284
+ )
285
+
286
+ if is_main_process():
287
+ print("Loading validation datasets...")
288
+ for dataset in config["data"]["val_datasets"]:
289
+ val_dataset: datasets.BaseDataset = getattr(datasets, dataset)
290
+ val_datasets[dataset] = val_dataset(
291
+ image_shape=image_shape,
292
+ split_file=val_dataset.test_split,
293
+ test_mode=True,
294
+ crop=crop,
295
+ shape_constraints=shape_constraints,
296
+ augmentations_db=augmentations_db,
297
+ normalize=config["data"].get("normalization", "imagenet"),
298
+ resize_method=resize_method,
299
+ num_frames=1,
300
+ mini=1.0,
301
+ num_copies=1,
302
+ )
303
+
304
+ # Dataset samplers, create distributed sampler pinned to rank
305
+ if args.distributed:
306
+ sampling = deepcopy(config["data"]["sampling"])
307
+ weights, num_samples = get_weights(train_datasets, sampling)
308
+ train_sampler = torch.utils.data.WeightedRandomSampler(
309
+ weights, num_samples, replacement=True
310
+ )
311
+ valid_samplers = {
312
+ k: DistributedSamplerNoDuplicate(
313
+ v,
314
+ num_replicas=args.world_size,
315
+ rank=args.rank,
316
+ shuffle=False,
317
+ drop_last=False,
318
+ )
319
+ for k, v in val_datasets.items()
320
+ }
321
+ else:
322
+ train_sampler = RandomSampler(train_dataset)
323
+ valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()}
324
+
325
+ train_sampler = torch.utils.data.BatchSampler(
326
+ train_sampler, batch_size=batch_size, drop_last=True
327
+ )
328
+
329
+ # Dataset loader
330
+ val_batch_size = 1
331
+ num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4))
332
+ train_loader = DataLoader(
333
+ train_dataset,
334
+ num_workers=num_workers,
335
+ sampler=train_sampler,
336
+ pin_memory=True,
337
+ collate_fn=partial(collate_fn, is_batched=True),
338
+ persistent_workers=True if num_workers else None,
339
+ )
340
+ val_loaders = {
341
+ name_dataset: DataLoader(
342
+ dataset,
343
+ batch_size=val_batch_size,
344
+ shuffle=False,
345
+ num_workers=num_workers,
346
+ sampler=valid_samplers[name_dataset],
347
+ pin_memory=True,
348
+ drop_last=False,
349
+ collate_fn=partial(collate_fn, is_batched=False),
350
+ )
351
+ for name_dataset, dataset in val_datasets.items()
352
+ }
353
+
354
+ # SCHEDULERS!
355
+ scheduler_wd = CosineScheduler(
356
+ optimizer,
357
+ key="weight_decay",
358
+ init_value=config["training"]["wd"],
359
+ base_value=config["training"]["wd"],
360
+ final_value=config["training"]["wd_final"],
361
+ warmup_iters=0,
362
+ total_iters=config["training"]["n_iters"],
363
+ flat_iters=config["training"]["warmup_iters"],
364
+ step_init=step - 1,
365
+ )
366
+ scheduler_lr = CosineScheduler(
367
+ optimizer,
368
+ key="lr",
369
+ init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0),
370
+ final_value=config["training"]["lr_final"],
371
+ warmup_iters=5000,
372
+ flat_iters=config["training"]["warmup_iters"],
373
+ total_iters=config["training"]["n_iters"],
374
+ step_init=step - 1,
375
+ )
376
+ scheduler_betas = CosineScheduler(
377
+ optimizer,
378
+ key="betas",
379
+ init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9,
380
+ base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9,
381
+ final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9,
382
+ warmup_iters=config["training"]["warmup_iters"],
383
+ total_iters=config["training"]["n_iters"],
384
+ step_init=step - 1,
385
+ )
386
+
387
+ # Set loss scaler for half precision training + sanity zeroing grads
388
+ dtype = MAP_DTYPE[dtype_16bit]
389
+ if not torch.cuda.is_bf16_supported() and is_16bit:
390
+ dtype = torch.float16
391
+
392
+ context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit)
393
+ # use float16 to check for instability at inference an avoid bfloat16 for coarseness
394
+ context_val = torch.autocast(
395
+ device_type="cuda", dtype=torch.float16, enabled=is_16bit
396
+ )
397
+ optimizer.zero_grad(set_to_none=True)
398
+
399
+ ##############################
400
+ ########## TRAINING ##########
401
+ ##############################
402
+ # Remember that if i-th layer is frozen, this will break gradient checkpointing
403
+ # in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as
404
+ # without gradient, thus the i+1-th layer does not have grads (?). To solve it,
405
+ # just add requires_grad_() to the inputs coming from the frozen layer
406
+ ddp_model.train()
407
+
408
+ start = time()
409
+ n_steps = config["training"]["n_iters"]
410
+ init_steps = int(step)
411
+ track_pbar = is_shell
412
+
413
+ if is_main_process():
414
+ print("Is a shell job?", is_shell)
415
+ print("Use dtype:", dtype if is_16bit else torch.float32)
416
+ print(
417
+ f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps'
418
+ )
419
+ print(f"START with {num_workers} workers")
420
+ if track_pbar:
421
+ pbar = tqdm(total=n_steps - init_steps)
422
+
423
+ scaler = torch.amp.GradScaler(
424
+ "cuda",
425
+ init_scale=2**14 if dtype_16bit == "f16" else 2**40,
426
+ enabled=is_16bit,
427
+ growth_factor=1.2,
428
+ backoff_factor=0.8,
429
+ growth_interval=500,
430
+ )
431
+ track_losses, track_grad = {}, {}
432
+ system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30
433
+ cpid_memory = current_process.memory_info()[0] / 2.0**30
434
+ gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30
435
+ while True:
436
+ for j, batches in enumerate(train_loader):
437
+ system_memory = (
438
+ 0.99 * system_memory
439
+ + 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30
440
+ )
441
+ cpid_memory = (
442
+ 0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30
443
+ )
444
+ gpu_mem = (
445
+ 0.99 * gpu_mem
446
+ + 0.01
447
+ * (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0])
448
+ / 2**30
449
+ )
450
+ if j % 1000 == 0 and is_main_process():
451
+ print(f"System information at step {j}")
452
+ print(f"System-wide RAM available: {system_memory:.2f}GB")
453
+ print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%")
454
+ print(f"GPU memory utilized: {gpu_mem:.2f}GB")
455
+
456
+ batches["data"] = {
457
+ k: v.to(model.device, non_blocking=True)
458
+ for k, v in batches["data"].items()
459
+ }
460
+ for idx in range(nsteps_accumulation_gradient):
461
+ batch = {}
462
+ batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk)
463
+ batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()}
464
+ batch["img_metas"] = batches["img_metas"][batch_slice]
465
+ with (
466
+ model.no_sync()
467
+ if idx < nsteps_accumulation_gradient - 1
468
+ else nullcontext()
469
+ ):
470
+ with context:
471
+ preds, losses = model(batch["data"], batch["img_metas"])
472
+ loss = sum(losses["opt"].values())
473
+ scaler.scale(loss).backward()
474
+
475
+ losses_dict = {
476
+ k: v.detach() for loss in losses.values() for k, v in loss.items()
477
+ }
478
+ track_losses.update(
479
+ {
480
+ k: track_losses.get(k, 0.0)
481
+ + torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5)
482
+ for k, v in losses_dict.items()
483
+ }
484
+ )
485
+ ddp_model.loss_history = track_losses
486
+
487
+ if clipping is not None:
488
+ scaler.unscale_(optimizer)
489
+ grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping)
490
+ if torch.isfinite(grad_norm):
491
+ track_losses.update(
492
+ {"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm}
493
+ )
494
+
495
+ # there is a deeper issue, either log/sqrt of negative loss
496
+ # or the inputs create large values and destroy model weights
497
+ if is_16bit and scaler.get_scale() < 1:
498
+ raise ValueError("Scale went less than 1, ISSUE!!!")
499
+
500
+ scaler.step(optimizer)
501
+ scaler.update()
502
+
503
+ scheduler_wd.step()
504
+ scheduler_lr.step()
505
+ scheduler_betas.step()
506
+ model.module.step()
507
+ optimizer.zero_grad(set_to_none=True)
508
+ if step % EMA_INTERVAL == 0:
509
+ ema_handle.update()
510
+
511
+ if is_main_process() and track_pbar:
512
+ pbar.update(1)
513
+
514
+ step += 1
515
+
516
+ # LOGGING
517
+ if step % 100 == 0 and is_main_process():
518
+ log_num = min(10, preds["depth"].shape[0])
519
+ log_train_artifacts(
520
+ batch["data"]["image"][-log_num:, 0].float(),
521
+ (
522
+ batch["data"]["depth"][-log_num:, 0].float()
523
+ if "depth" in batch["data"]
524
+ else []
525
+ ),
526
+ preds["depth"][-log_num:, 0].detach().float(),
527
+ infos={
528
+ k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items()
529
+ },
530
+ step=step,
531
+ )
532
+
533
+ if step % 50 == 0:
534
+ track_losses = {
535
+ k: v / (50 * nsteps_accumulation_gradient)
536
+ for k, v in track_losses.items()
537
+ }
538
+ # grad norm is for every step!
539
+ track_losses["Grad_Norm"] = (
540
+ track_losses["Grad_Norm"] * nsteps_accumulation_gradient
541
+ )
542
+ track_losses = aggregate_sync_losses(track_losses, device=model.device)
543
+ if is_main_process():
544
+ elapsed = int(time() - start)
545
+ eta = int(elapsed * (n_steps - step) / max(1, step - init_steps))
546
+ print(
547
+ f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]"
548
+ )
549
+ try:
550
+ wandb.log(
551
+ {
552
+ **{f"Train/{k}": v for k, v in track_losses.items()},
553
+ **{f"Train/lr": scheduler_lr.get()[-1]},
554
+ **{f"Train/wd": scheduler_wd.get()[-2]},
555
+ **{f"Train/scale_f16": log2(scaler.get_scale())},
556
+ },
557
+ step=step,
558
+ )
559
+ except Exception as e:
560
+ print("Not logging loss because of:", e)
561
+ if step % 100 == 0:
562
+ log_loss_dict = {
563
+ f"Train/{k}": v for k, v in track_losses.items()
564
+ }
565
+ print(
566
+ ", ".join(
567
+ [f"{k}: {v:.5f}" for k, v in log_loss_dict.items()]
568
+ )
569
+ )
570
+ track_losses = {} # reinit every 50 steps, average the current 50 steps
571
+
572
+ # Validation
573
+ is_last_step = step >= config["training"]["n_iters"]
574
+ is_validation = step % config["training"]["validation_interval"] == 0
575
+ if is_last_step or is_validation:
576
+ torch.cuda.empty_cache()
577
+ barrier()
578
+ if is_main_process():
579
+ print(f"Validation at {step}th step...")
580
+ ddp_model.eval()
581
+ start_validation = time()
582
+ with torch.no_grad(), ema_handle.average_parameters():
583
+ validate(
584
+ model,
585
+ test_loaders=val_loaders,
586
+ step=step,
587
+ run_id=run_id,
588
+ idxs=(64, 96, 224, 256), # random
589
+ context=context_val,
590
+ )
591
+
592
+ if is_main_process():
593
+ print(f"Elapsed: {format_seconds(int(time() - start_validation))}")
594
+ ddp_model.train()
595
+ torch.cuda.empty_cache()
596
+
597
+ if step >= config["training"]["n_iters"]:
598
+ if is_main_process() and track_pbar:
599
+ pbar.close()
600
+ wandb.finish(0)
601
+ dist.destroy_process_group()
602
+ return 0
603
+
604
+
605
+ if __name__ == "__main__":
606
+ if "SLURM_PROCID" in os.environ:
607
+ os.environ["TRITON_CACHE_DIR"] = "/tmp"
608
+ # Arguments
609
+ parser = argparse.ArgumentParser(
610
+ description="Training script", conflict_handler="resolve"
611
+ )
612
+ parser.add_argument("--config-file", type=str, required=True)
613
+ parser.add_argument("--master-port", type=str)
614
+ parser.add_argument("--distributed", action="store_true")
615
+ parser.add_argument("--local_rank", type=int, default=0)
616
+
617
+ args = parser.parse_args()
618
+ with open(args.config_file, "r") as f:
619
+ config = json.load(f)
620
+
621
+ deterministic = config["generic"].get("deterministic", True)
622
+ torch.backends.cudnn.deterministic = deterministic
623
+ torch.backends.cudnn.benchmark = not deterministic
624
+
625
+ torch.backends.cudnn.allow_tf32 = True
626
+ torch.backends.cuda.matmul.allow_tf32 = True
627
+ torch.set_float32_matmul_precision("high")
628
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
629
+ torch.set_num_threads(1)
630
+ main_worker(config, args)
unik3d/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .models import UniK3D
unik3d/datasets/_2d3ds.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
6
+ from unik3d.datasets.sequence_dataset import SequenceDataset
7
+
8
+
9
+ class d2D3DS(SequenceDataset):
10
+ min_depth = 0.01
11
+ max_depth = 10.0
12
+ depth_scale = 512.0
13
+ test_split = "train.txt"
14
+ train_split = "train.txt"
15
+ sequences_file = "sequences.json"
16
+ hdf5_paths = [f"2D3DS.hdf5"]
17
+
18
+ def __init__(
19
+ self,
20
+ image_shape: tuple[int, int],
21
+ split_file: str,
22
+ test_mode: bool,
23
+ normalize: bool,
24
+ augmentations_db: dict[str, Any],
25
+ resize_method: str,
26
+ mini: float = 1.0,
27
+ num_frames: int = 1,
28
+ benchmark: bool = False,
29
+ decode_fields: list[str] = ["image", "depth"],
30
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
31
+ **kwargs,
32
+ ) -> None:
33
+ super().__init__(
34
+ image_shape=image_shape,
35
+ split_file=split_file,
36
+ test_mode=test_mode,
37
+ benchmark=benchmark,
38
+ normalize=normalize,
39
+ augmentations_db=augmentations_db,
40
+ resize_method=resize_method,
41
+ mini=mini,
42
+ num_frames=num_frames,
43
+ decode_fields=decode_fields,
44
+ inplace_fields=inplace_fields,
45
+ **kwargs,
46
+ )
47
+ self.resizer = Compose(
48
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
49
+ )
50
+
51
+ def preprocess(self, results):
52
+ self.resizer.ctx = None
53
+ if self.test_mode:
54
+ for i, seq in enumerate(results["sequence_fields"]):
55
+ results[seq]["points"] = results[seq]["camera"].reconstruct(
56
+ results[seq]["depth"]
57
+ )
58
+ results[seq]["depth"] = results[seq]["points"][:, -1:]
59
+ results[seq]["gt_fields"].add("points")
60
+ return super().preprocess(results)
61
+
62
+ def pre_pipeline(self, results):
63
+ results = super().pre_pipeline(results)
64
+ results["dense"] = [True] * self.num_frames * self.num_copies
65
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
66
+ results["quality"] = [1] * self.num_frames * self.num_copies
67
+ return results
unik3d/datasets/_4dor.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class d4DOR(SequenceDataset):
7
+ min_depth = 0.01
8
+ max_depth = 10.0
9
+ depth_scale = 1000.0
10
+ default_fps = 10
11
+ test_split = "train.txt"
12
+ train_split = "train.txt"
13
+ sequences_file = "sequences.json"
14
+ hdf5_paths = ["4DOR.hdf5"]
15
+
16
+ def __init__(
17
+ self,
18
+ image_shape: tuple[int, int],
19
+ split_file: str,
20
+ test_mode: bool,
21
+ normalize: bool,
22
+ augmentations_db: dict[str, Any],
23
+ resize_method: str,
24
+ mini: float = 1.0,
25
+ num_frames: int = 1,
26
+ benchmark: bool = False,
27
+ decode_fields: list[str] = ["image", "depth"],
28
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(
32
+ image_shape=image_shape,
33
+ split_file=split_file,
34
+ test_mode=test_mode,
35
+ benchmark=benchmark,
36
+ normalize=normalize,
37
+ augmentations_db=augmentations_db,
38
+ resize_method=resize_method,
39
+ mini=mini,
40
+ num_frames=num_frames,
41
+ decode_fields=decode_fields,
42
+ inplace_fields=inplace_fields,
43
+ **kwargs,
44
+ )
45
+
46
+ def pre_pipeline(self, results):
47
+ results = super().pre_pipeline(results)
48
+ results["dense"] = [True] * self.num_frames * self.num_copies
49
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
50
+ results["si"] = [False] * self.num_frames * self.num_copies
51
+ results["quality"] = [2] * self.num_frames * self.num_copies
52
+ return results
unik3d/datasets/__init__.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ._2d3ds import d2D3DS
2
+ from ._4dor import d4DOR
3
+ from .a2d2 import A2D2
4
+ from .adt import ADT
5
+ from .aimotive import aiMotive
6
+ from .argoverse import Argoverse
7
+ from .argoverse2 import Argoverse2
8
+ from .arkit import ARKit
9
+ from .ase import ASE
10
+ from .base_dataset import BaseDataset
11
+ from .bdd import BDD
12
+ from .bedlam import BEDLAM
13
+ from .behave import Behave
14
+ from .blendedmvg import BlendedMVG
15
+ from .cityscape import Cityscape
16
+ from .ddad import DDAD
17
+ from .deep360 import Deep360
18
+ from .dense import DENSE
19
+ from .diml import DIML
20
+ from .diode import DiodeIndoor, DiodeIndoor_F
21
+ from .dl3dv import DL3DV
22
+ from .driving_stereo import DrivingStereo
23
+ from .dtu_rmvd import DTURMVD
24
+ from .dummy import Dummy
25
+ from .dynamic_replica import DynReplica
26
+ from .eden import EDEN
27
+ from .eth3d import ETH3D, ETH3D_F, ETH3DRMVD
28
+ from .facedepth import FaceDepth
29
+ from .flsea import FLSea
30
+ from .futurehouse import FutureHouse
31
+ from .gibson import Gibson
32
+ from .hammer import HAMMER
33
+ from .hm3d import HM3D
34
+ from .hoi4d import HOI4D
35
+ from .hypersim import HyperSim
36
+ from .ibims import IBims, IBims_F
37
+ from .ken_burns import KenBurns
38
+ from .kitti import KITTI, KITTIRMVD, KITTIBenchmark
39
+ from .kitti360 import KITTI360
40
+ from .lyft import Lyft
41
+ from .mapillary import Mapillary
42
+ from .matrix_city import MatrixCity
43
+ from .matterport3d import Matterport3D
44
+ from .megadepth import MegaDepth
45
+ from .megadepth_s import MegaDepthS
46
+ from .midair import MidAir
47
+ from .mip import MIP
48
+ from .ms2 import MS2
49
+ from .mvimgnet import MVImgNet
50
+ from .mvsynth import MVSynth
51
+ from .nerds360 import NeRDS360
52
+ from .niantic_mapfree import NianticMapFree
53
+ from .nuscenes import Nuscenes
54
+ from .nyuv2 import NYUv2Depth
55
+ from .point_odyssey import PointOdyssey
56
+ from .proteus import Proteus
57
+ from .samplers import (DistributedSamplerNoDuplicate,
58
+ DistributedSamplerWrapper, ShardedInfiniteSampler)
59
+ from .scannet import ScanNet
60
+ from .scannetpp import ScanNetpp, ScanNetpp_F
61
+ from .sintel import Sintel
62
+ from .sunrgbd import SUNRGBD
63
+ from .synscapes import Synscapes
64
+ from .tartanair import TartanAir
65
+ from .taskonomy import Taskonomy
66
+ from .tat_rmvd import TATRMVD
67
+ from .theo import Theo
68
+ from .unrealstereo4k import UnrealStereo4K
69
+ from .urbansyn import UrbanSyn
70
+ from .utils import ConcatDataset, collate_fn, get_weights
71
+ from .vkitti import VKITTI
72
+ from .void import VOID
73
+ from .waymo import Waymo
74
+ from .wildrgbd import WildRGBD
75
+
76
+ __all__ = [
77
+ "Dummy",
78
+ "BaseDataset",
79
+ "get_weights" "DistributedSamplerNoDuplicate",
80
+ "ShardedInfiniteSampler",
81
+ "DistributedSamplerWrapper",
82
+ "ConcatDataset",
83
+ "PairDataset",
84
+ "collate_fn",
85
+ # additional, do not count
86
+ "WaymoImage",
87
+ "MegaDepth",
88
+ "COCO2017",
89
+ "ImageNet",
90
+ "OASISv2",
91
+ # image based
92
+ "Argoverse",
93
+ "DDAD",
94
+ "IBims",
95
+ "NYUv2Depth",
96
+ "DrivingStereo",
97
+ "VOID",
98
+ "Mapillary",
99
+ "ScanNet",
100
+ "Taskonomy",
101
+ "BDD",
102
+ "A2D2",
103
+ "Nuscenes",
104
+ "SUNRGBD",
105
+ "ETH3D",
106
+ "HAMMER",
107
+ "Cityscape",
108
+ "KITTI",
109
+ "DENSE",
110
+ "DIML",
111
+ "DiodeIndoor",
112
+ "FLSea",
113
+ "ARKitScenes",
114
+ "Lyft",
115
+ "HyperSim",
116
+ "KenBurns",
117
+ "HRWSI",
118
+ "UrbanSyn",
119
+ "Synscapes",
120
+ "Gibson",
121
+ "Matterport3D",
122
+ "_2D3DS",
123
+ # sequence based
124
+ "TartanAir",
125
+ "WildRGBD",
126
+ "ScanNetS",
127
+ "ScanNetpp",
128
+ "MVImgNet",
129
+ "NianticMapFree",
130
+ "DL3DV",
131
+ "PointOdyssey",
132
+ "KITTIMulti",
133
+ "Waymo",
134
+ "Argoverse2",
135
+ "UnrealStereo4K",
136
+ "MatrixCity",
137
+ "HM3D",
138
+ "MVSynth",
139
+ "EDEN",
140
+ # sequence based, but not usable for seq, only image
141
+ "BEDLAM",
142
+ "NeRDS360",
143
+ "BlendedMVG",
144
+ "DynReplica",
145
+ "ARKitS",
146
+ "Sintel",
147
+ "VKITTI",
148
+ "MegaDepthS",
149
+ # benchmarks
150
+ "KITTIBenchmark",
151
+ "ETH3DRMVD",
152
+ "DTURMVD",
153
+ "KITTIRMVD",
154
+ "TATRMVD",
155
+ "DiodeIndoor_F",
156
+ "IBims_F",
157
+ "ETH3D_F",
158
+ "KITTI360",
159
+ "ScanNetpp_F",
160
+ "ADT",
161
+ ]
unik3d/datasets/a2d2.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from unik3d.datasets.image_dataset import ImageDataset
9
+ from unik3d.datasets.utils import DatasetFromList
10
+
11
+
12
+ class A2D2(ImageDataset):
13
+ min_depth = 0.01
14
+ max_depth = 120.0
15
+ depth_scale = 256.0
16
+ train_split = "train_clean.txt"
17
+ intrisics_file = "intrinsics.json"
18
+ hdf5_paths = ["a2d2.hdf5"]
19
+
20
+ def __init__(
21
+ self,
22
+ image_shape,
23
+ split_file,
24
+ test_mode,
25
+ crop=None,
26
+ benchmark=False,
27
+ augmentations_db={},
28
+ normalize=True,
29
+ resize_method="hard",
30
+ mini=1.0,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(
34
+ image_shape=image_shape,
35
+ split_file=split_file,
36
+ test_mode=test_mode,
37
+ benchmark=benchmark,
38
+ normalize=normalize,
39
+ augmentations_db=augmentations_db,
40
+ resize_method=resize_method,
41
+ mini=mini,
42
+ **kwargs,
43
+ )
44
+ self.test_mode = test_mode
45
+ self.load_dataset()
46
+
47
+ def load_dataset(self):
48
+ h5file = h5py.File(
49
+ os.path.join(self.data_root, self.hdf5_paths[0]),
50
+ "r",
51
+ libver="latest",
52
+ swmr=True,
53
+ )
54
+ txt_file = np.array(h5file[self.split_file])
55
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
56
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
57
+ intrinsics = json.loads(intrinsics)
58
+ h5file.close()
59
+ dataset = []
60
+ for line in txt_string.split("\n"):
61
+ image_filename, depth_filename = line.strip().split(" ")
62
+ intrinsics_val = torch.tensor(
63
+ intrinsics[os.path.join(*image_filename.split("/")[:2])]
64
+ ).squeeze()[:, :3]
65
+ sample = [image_filename, depth_filename, intrinsics_val]
66
+ dataset.append(sample)
67
+
68
+ if not self.test_mode:
69
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
70
+
71
+ self.dataset = DatasetFromList(dataset)
72
+ self.log_load_dataset()
73
+
74
+ def pre_pipeline(self, results):
75
+ results = super().pre_pipeline(results)
76
+ results["dense"] = [False] * self.num_copies
77
+ results["quality"] = [1] * self.num_copies
78
+ return results
unik3d/datasets/adt.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from unik3d.datasets.sequence_dataset import SequenceDataset
6
+
7
+
8
+ class ADT(SequenceDataset):
9
+ min_depth = 0.01
10
+ max_depth = 20.0
11
+ depth_scale = 1000.0
12
+ test_split = "val.txt"
13
+ train_split = "train.txt"
14
+ sequences_file = "sequences.json"
15
+ hdf5_paths = [f"ADT.hdf5"]
16
+
17
+ def __init__(
18
+ self,
19
+ image_shape: tuple[int, int],
20
+ split_file: str,
21
+ test_mode: bool,
22
+ normalize: bool,
23
+ augmentations_db: dict[str, Any],
24
+ resize_method: str,
25
+ mini: float = 1.0,
26
+ num_frames: int = 1,
27
+ benchmark: bool = False,
28
+ decode_fields: list[str] = ["image", "depth"],
29
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
30
+ **kwargs,
31
+ ) -> None:
32
+ super().__init__(
33
+ image_shape=image_shape,
34
+ split_file=split_file,
35
+ test_mode=test_mode,
36
+ benchmark=benchmark,
37
+ normalize=normalize,
38
+ augmentations_db=augmentations_db,
39
+ resize_method=resize_method,
40
+ mini=mini,
41
+ num_frames=num_frames,
42
+ decode_fields=decode_fields, # if not test_mode else [*decode_fields, "points"],
43
+ inplace_fields=inplace_fields,
44
+ **kwargs,
45
+ )
46
+
47
+ def preprocess(self, results):
48
+ self.resizer.ctx = None
49
+ for i, seq in enumerate(results["sequence_fields"]):
50
+ # Create a mask where the distance from the center is less than H/2
51
+ H, W = results[seq]["image"].shape[-2:]
52
+ x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W)
53
+ y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H)
54
+ xv, yv = torch.meshgrid(x, y, indexing="xy")
55
+ distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
56
+ results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20
57
+ results[seq]["depth_mask"] = results[seq]["validity_mask"].clone()
58
+ results[seq]["mask_fields"].add("depth_mask")
59
+ results[seq]["mask_fields"].add("validity_mask")
60
+
61
+ return super().preprocess(results)
62
+
63
+ def pre_pipeline(self, results):
64
+ results = super().pre_pipeline(results)
65
+ results["dense"] = [True] * self.num_frames * self.num_copies
66
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
67
+ results["quality"] = [0] * self.num_frames * self.num_copies
68
+ return results
unik3d/datasets/aimotive.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class aiMotive(SequenceDataset):
7
+ min_depth = 0.01
8
+ max_depth = 100.0
9
+ depth_scale = 256.0
10
+ default_fps = 10
11
+ test_split = "train.txt"
12
+ train_split = "train.txt"
13
+ sequences_file = "sequences.json"
14
+ hdf5_paths = ["aiMotive.hdf5"]
15
+
16
+ def __init__(
17
+ self,
18
+ image_shape: tuple[int, int],
19
+ split_file: str,
20
+ test_mode: bool,
21
+ normalize: bool,
22
+ augmentations_db: dict[str, Any],
23
+ resize_method: str,
24
+ mini: float = 1.0,
25
+ num_frames: int = 1,
26
+ benchmark: bool = False,
27
+ decode_fields: list[str] = ["image", "depth"],
28
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(
32
+ image_shape=image_shape,
33
+ split_file=split_file,
34
+ test_mode=test_mode,
35
+ benchmark=benchmark,
36
+ normalize=normalize,
37
+ augmentations_db=augmentations_db,
38
+ resize_method=resize_method,
39
+ mini=mini,
40
+ num_frames=num_frames,
41
+ decode_fields=decode_fields,
42
+ inplace_fields=inplace_fields,
43
+ **kwargs,
44
+ )
45
+
46
+ def pre_pipeline(self, results):
47
+ results = super().pre_pipeline(results)
48
+ results["dense"] = [False] * self.num_frames * self.num_copies
49
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
50
+ results["quality"] = [2] * self.num_frames * self.num_copies
51
+ return results
unik3d/datasets/argoverse.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from unik3d.datasets.image_dataset import ImageDataset
9
+ from unik3d.datasets.utils import DatasetFromList
10
+
11
+
12
+ class Argoverse(ImageDataset):
13
+ min_depth = 0.05
14
+ max_depth = 120.0
15
+ depth_scale = 256.0
16
+ test_split = "argo_val.txt"
17
+ train_split = "argo_train.txt"
18
+ intrisics_file = "argo_intrinsics.json"
19
+ hdf5_paths = ["argoverse11.hdf5"]
20
+
21
+ def __init__(
22
+ self,
23
+ image_shape,
24
+ split_file,
25
+ test_mode,
26
+ crop=None,
27
+ benchmark=False,
28
+ augmentations_db={},
29
+ normalize=True,
30
+ resize_method="hard",
31
+ mini=1.0,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(
35
+ image_shape=image_shape,
36
+ split_file=split_file,
37
+ test_mode=test_mode,
38
+ benchmark=benchmark,
39
+ normalize=normalize,
40
+ augmentations_db=augmentations_db,
41
+ resize_method=resize_method,
42
+ mini=mini,
43
+ **kwargs,
44
+ )
45
+ self.test_mode = test_mode
46
+
47
+ self.crop = crop
48
+ self.load_dataset()
49
+
50
+ def load_dataset(self):
51
+ h5file = h5py.File(
52
+ os.path.join(self.data_root, self.hdf5_paths[0]),
53
+ "r",
54
+ libver="latest",
55
+ swmr=True,
56
+ )
57
+ txt_file = np.array(h5file[self.split_file])
58
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
59
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
60
+ intrinsics = json.loads(intrinsics)
61
+ h5file.close()
62
+ dataset = []
63
+ for line in txt_string.split("\n"):
64
+ image_filename, depth_filename = line.strip().split(" ")
65
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
66
+ sample = [image_filename, depth_filename, intrinsics_val]
67
+ dataset.append(sample)
68
+
69
+ if not self.test_mode:
70
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
71
+
72
+ self.dataset = DatasetFromList(dataset)
73
+ self.log_load_dataset()
unik3d/datasets/argoverse2.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class Argoverse2(SequenceDataset):
7
+ min_depth = 0.05
8
+ max_depth = 120.0
9
+ depth_scale = 256.0
10
+ test_split = "val.txt"
11
+ train_split = "train.txt"
12
+ sequences_file = "sequences_clean.json"
13
+ hdf5_paths = [f"AV2_viz.hdf5"]
14
+
15
+ def __init__(
16
+ self,
17
+ image_shape: tuple[int, int],
18
+ split_file: str,
19
+ test_mode: bool,
20
+ normalize: bool,
21
+ augmentations_db: dict[str, Any],
22
+ resize_method: str,
23
+ mini: float = 1.0,
24
+ num_frames: int = 1,
25
+ benchmark: bool = False,
26
+ decode_fields: list[str] = ["image", "depth"],
27
+ inplace_fields: list[str] = ["K", "cam2w"],
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(
31
+ image_shape=image_shape,
32
+ split_file=split_file,
33
+ test_mode=test_mode,
34
+ benchmark=benchmark,
35
+ normalize=normalize,
36
+ augmentations_db=augmentations_db,
37
+ resize_method=resize_method,
38
+ mini=mini,
39
+ num_frames=num_frames,
40
+ decode_fields=decode_fields,
41
+ inplace_fields=inplace_fields,
42
+ **kwargs,
43
+ )
44
+
45
+ def pre_pipeline(self, results):
46
+ results = super().pre_pipeline(results)
47
+ results["dense"] = [False] * self.num_frames * self.num_copies
48
+ results["quality"] = [1] * self.num_frames * self.num_copies
49
+ return results
unik3d/datasets/arkit.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class ARKit(SequenceDataset):
7
+ min_depth = 0.01
8
+ max_depth = 10.0
9
+ depth_scale = 1000.0
10
+ test_split = "Training.txt"
11
+ train_split = "Training.txt"
12
+ sequences_file = "sequences.json"
13
+ hdf5_paths = ["ARKitS.hdf5"]
14
+
15
+ def __init__(
16
+ self,
17
+ image_shape: tuple[int, int],
18
+ split_file: str,
19
+ test_mode: bool,
20
+ normalize: bool,
21
+ augmentations_db: dict[str, Any],
22
+ resize_method: str,
23
+ mini: float = 1.0,
24
+ num_frames: int = 1,
25
+ benchmark: bool = False,
26
+ decode_fields: list[str] = ["image", "depth"],
27
+ inplace_fields: list[str] = ["K", "cam2w"],
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(
31
+ image_shape=image_shape,
32
+ split_file=split_file,
33
+ test_mode=test_mode,
34
+ benchmark=benchmark,
35
+ normalize=normalize,
36
+ augmentations_db=augmentations_db,
37
+ resize_method=resize_method,
38
+ mini=mini,
39
+ num_frames=num_frames,
40
+ decode_fields=decode_fields,
41
+ inplace_fields=inplace_fields,
42
+ **kwargs,
43
+ )
44
+
45
+ def pre_pipeline(self, results):
46
+ results = super().pre_pipeline(results)
47
+ results["dense"] = [True] * self.num_frames * self.num_copies
48
+ results["quality"] = [2] * self.num_frames * self.num_copies
49
+ return results
unik3d/datasets/ase.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from unik3d.datasets.sequence_dataset import SequenceDataset
6
+
7
+
8
+ class ASE(SequenceDataset):
9
+ min_depth = 0.01
10
+ max_depth = 20.0
11
+ depth_scale = 1000.0
12
+ test_split = "val.txt"
13
+ train_split = "train.txt"
14
+ sequences_file = "sequences.json"
15
+ hdf5_paths = [f"ASE.hdf5"]
16
+
17
+ def __init__(
18
+ self,
19
+ image_shape: tuple[int, int],
20
+ split_file: str,
21
+ test_mode: bool,
22
+ normalize: bool,
23
+ augmentations_db: dict[str, Any],
24
+ resize_method: str,
25
+ mini: float = 1.0,
26
+ num_frames: int = 1,
27
+ benchmark: bool = False,
28
+ decode_fields: list[str] = ["image", "depth"],
29
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
30
+ **kwargs,
31
+ ) -> None:
32
+ super().__init__(
33
+ image_shape=image_shape,
34
+ split_file=split_file,
35
+ test_mode=test_mode,
36
+ benchmark=benchmark,
37
+ normalize=normalize,
38
+ augmentations_db=augmentations_db,
39
+ resize_method=resize_method,
40
+ mini=mini,
41
+ num_frames=num_frames,
42
+ decode_fields=decode_fields,
43
+ inplace_fields=inplace_fields,
44
+ **kwargs,
45
+ )
46
+
47
+ def preprocess(self, results):
48
+ self.resizer.ctx = None
49
+ for i, seq in enumerate(results["sequence_fields"]):
50
+ # Create a mask where the distance from the center is less than H/2
51
+ H, W = results[seq]["image"].shape[-2:]
52
+ x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W)
53
+ y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H)
54
+ xv, yv = torch.meshgrid(x, y, indexing="xy")
55
+ distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
56
+ results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20
57
+ results[seq]["mask_fields"].add("validity_mask")
58
+
59
+ return super().preprocess(results)
60
+
61
+ def pre_pipeline(self, results):
62
+ results = super().pre_pipeline(results)
63
+ results["dense"] = [True] * self.num_frames * self.num_copies
64
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
65
+ results["quality"] = [0] * self.num_frames * self.num_copies
66
+ return results
unik3d/datasets/base_dataset.py ADDED
@@ -0,0 +1,344 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from abc import abstractmethod
3
+ from copy import deepcopy
4
+ from math import ceil, log
5
+ from typing import Any, Dict, Tuple
6
+
7
+ import numpy as np
8
+ import torch
9
+ from torch.utils.data import Dataset
10
+
11
+ import unik3d.datasets.pipelines as pipelines
12
+ from unik3d.utils import (eval_3d, eval_depth, identity, is_main_process,
13
+ recursive_index, sync_tensor_across_gpus)
14
+ from unik3d.utils.constants import (IMAGENET_DATASET_MEAN,
15
+ IMAGENET_DATASET_STD, OPENAI_DATASET_MEAN,
16
+ OPENAI_DATASET_STD)
17
+
18
+
19
+ class BaseDataset(Dataset):
20
+ min_depth = 0.01
21
+ max_depth = 1000.0
22
+
23
+ def __init__(
24
+ self,
25
+ image_shape: Tuple[int, int],
26
+ split_file: str,
27
+ test_mode: bool,
28
+ normalize: bool,
29
+ augmentations_db: Dict[str, Any],
30
+ shape_constraints: Dict[str, Any],
31
+ resize_method: str,
32
+ mini: float,
33
+ num_copies: int = 1,
34
+ **kwargs,
35
+ ) -> None:
36
+ super().__init__()
37
+ assert normalize in [None, "imagenet", "openai"]
38
+
39
+ self.split_file = split_file
40
+ self.test_mode = test_mode
41
+ self.data_root = os.environ["DATAROOT"]
42
+ self.image_shape = image_shape
43
+ self.resize_method = resize_method
44
+ self.mini = mini
45
+ self.num_frames = 1
46
+ self.num_copies = num_copies
47
+ self.metrics_store = {}
48
+ self.metrics_count = {}
49
+
50
+ if normalize == "imagenet":
51
+ self.normalization_stats = {
52
+ "mean": torch.tensor(IMAGENET_DATASET_MEAN),
53
+ "std": torch.tensor(IMAGENET_DATASET_STD),
54
+ }
55
+ elif normalize == "openai":
56
+ self.normalization_stats = {
57
+ "mean": torch.tensor(OPENAI_DATASET_MEAN),
58
+ "std": torch.tensor(OPENAI_DATASET_STD),
59
+ }
60
+ else:
61
+ self.normalization_stats = {
62
+ "mean": torch.tensor([0.0, 0.0, 0.0]),
63
+ "std": torch.tensor([1.0, 1.0, 1.0]),
64
+ }
65
+
66
+ for k, v in augmentations_db.items():
67
+ setattr(self, k, v)
68
+ self.shape_constraints = shape_constraints
69
+ if not self.test_mode:
70
+ self._augmentation_space()
71
+
72
+ self.masker = pipelines.AnnotationMask(
73
+ min_value=0.0,
74
+ max_value=self.max_depth if test_mode else None,
75
+ custom_fn=identity,
76
+ )
77
+ self.filler = pipelines.RandomFiller(test_mode=test_mode)
78
+
79
+ shape_mult = self.shape_constraints["shape_mult"]
80
+ self.image_shape = [
81
+ ceil(self.image_shape[0] / shape_mult) * shape_mult,
82
+ ceil(self.image_shape[1] / shape_mult) * shape_mult,
83
+ ]
84
+ self.resizer = pipelines.ContextCrop(
85
+ image_shape=self.image_shape,
86
+ train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale),
87
+ test_min_ctx=self.test_context,
88
+ keep_original=test_mode,
89
+ shape_constraints=self.shape_constraints,
90
+ )
91
+
92
+ self.collecter = pipelines.Collect(
93
+ keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"]
94
+ )
95
+
96
+ def __len__(self):
97
+ return len(self.dataset)
98
+
99
+ def pack_batch(self, results):
100
+ results["paddings"] = [
101
+ results[x]["paddings"][0] for x in results["sequence_fields"]
102
+ ]
103
+ for fields_name in [
104
+ "image_fields",
105
+ "gt_fields",
106
+ "mask_fields",
107
+ "camera_fields",
108
+ ]:
109
+ fields = results.get(fields_name)
110
+ packed = {
111
+ field: torch.cat(
112
+ [results[seq][field] for seq in results["sequence_fields"]]
113
+ )
114
+ for field in fields
115
+ }
116
+ results.update(packed)
117
+ return results
118
+
119
+ def unpack_batch(self, results):
120
+ for fields_name in [
121
+ "image_fields",
122
+ "gt_fields",
123
+ "mask_fields",
124
+ "camera_fields",
125
+ ]:
126
+ fields = results.get(fields_name)
127
+ unpacked = {
128
+ field: {
129
+ seq: results[field][idx : idx + 1]
130
+ for idx, seq in enumerate(results["sequence_fields"])
131
+ }
132
+ for field in fields
133
+ }
134
+ results.update(unpacked)
135
+ return results
136
+
137
+ def _augmentation_space(self):
138
+ self.augmentations_dict = {
139
+ "Flip": pipelines.RandomFlip(prob=self.flip_p),
140
+ "Jitter": pipelines.RandomColorJitter(
141
+ (-self.random_jitter, self.random_jitter), prob=self.jitter_p
142
+ ),
143
+ "Gamma": pipelines.RandomGamma(
144
+ (-self.random_gamma, self.random_gamma), prob=self.gamma_p
145
+ ),
146
+ "Blur": pipelines.GaussianBlur(
147
+ kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p
148
+ ),
149
+ "Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p),
150
+ }
151
+
152
+ def augment(self, results):
153
+ for name, aug in self.augmentations_dict.items():
154
+ results = aug(results)
155
+ return results
156
+
157
+ def prepare_depth_eval(self, inputs, preds):
158
+ new_preds = {}
159
+ keyframe_idx = getattr(self, "keyframe_idx", None)
160
+ slice_idx = slice(
161
+ keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None
162
+ )
163
+ new_gts = inputs["depth"][slice_idx]
164
+ new_masks = inputs["depth_mask"][slice_idx].bool()
165
+ for key, val in preds.items():
166
+ if "depth" in key:
167
+ new_preds[key] = val[slice_idx]
168
+ return new_gts, new_preds, new_masks
169
+
170
+ def prepare_points_eval(self, inputs, preds):
171
+ new_preds = {}
172
+ new_gts = inputs["points"]
173
+ new_masks = inputs["depth_mask"].bool()
174
+ if "points_mask" in inputs:
175
+ new_masks = inputs["points_mask"].bool()
176
+ for key, val in preds.items():
177
+ if "points" in key:
178
+ new_preds[key] = val
179
+ return new_gts, new_preds, new_masks
180
+
181
+ def add_points(self, inputs):
182
+ inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct(
183
+ inputs["depth"]
184
+ )
185
+ return inputs
186
+
187
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
188
+ def accumulate_metrics(
189
+ self,
190
+ inputs,
191
+ preds,
192
+ keyframe_idx=None,
193
+ metrics=["depth", "points", "flow_fwd", "pairwise"],
194
+ ):
195
+ if "depth" in inputs and "points" not in inputs:
196
+ inputs = self.add_points(inputs)
197
+
198
+ available_metrics = []
199
+ for metric in metrics:
200
+ metric_in_gt = any((metric in k for k in inputs.keys()))
201
+ metric_in_pred = any((metric in k for k in preds.keys()))
202
+ if metric_in_gt and metric_in_pred:
203
+ available_metrics.append(metric)
204
+
205
+ if keyframe_idx is not None:
206
+ inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1))
207
+ preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1))
208
+
209
+ if "depth" in available_metrics:
210
+ depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds)
211
+ self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks)
212
+
213
+ if "points" in available_metrics:
214
+ points_gt, points_pred, points_masks = self.prepare_points_eval(
215
+ inputs, preds
216
+ )
217
+ self.accumulate_metrics_3d(points_gt, points_pred, points_masks)
218
+
219
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
220
+ def accumulate_metrics_depth(self, gts, preds, masks):
221
+ for eval_type, pred in preds.items():
222
+ log_name = eval_type.replace("depth", "").strip("-").strip("_")
223
+ if log_name not in self.metrics_store:
224
+ self.metrics_store[log_name] = {}
225
+ current_count = self.metrics_count.get(
226
+ log_name, torch.tensor([], device=gts.device)
227
+ )
228
+ new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
229
+ self.metrics_count[log_name] = torch.cat([current_count, new_count])
230
+ for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items():
231
+ current_metric = self.metrics_store[log_name].get(
232
+ k, torch.tensor([], device=gts.device)
233
+ )
234
+ self.metrics_store[log_name][k] = torch.cat([current_metric, v])
235
+
236
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
237
+ def accumulate_metrics_3d(self, gts, preds, masks):
238
+ thresholds = torch.linspace(
239
+ log(self.min_depth),
240
+ log(self.max_depth / 20),
241
+ steps=100,
242
+ device=gts.device,
243
+ ).exp()
244
+ for eval_type, pred in preds.items():
245
+ log_name = eval_type.replace("points", "").strip("-").strip("_")
246
+ if log_name not in self.metrics_store:
247
+ self.metrics_store[log_name] = {}
248
+ current_count = self.metrics_count.get(
249
+ log_name, torch.tensor([], device=gts.device)
250
+ )
251
+ new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
252
+ self.metrics_count[log_name] = torch.cat([current_count, new_count])
253
+ for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items():
254
+ current_metric = self.metrics_store[log_name].get(
255
+ k, torch.tensor([], device=gts.device)
256
+ )
257
+ self.metrics_store[log_name][k] = torch.cat([current_metric, v])
258
+
259
+ def get_evaluation(self, metrics=None):
260
+ metric_vals = {}
261
+ for eval_type in metrics if metrics is not None else self.metrics_store.keys():
262
+ assert self.metrics_store[eval_type]
263
+ cnts = sync_tensor_across_gpus(self.metrics_count[eval_type])
264
+ for name, val in self.metrics_store[eval_type].items():
265
+ # vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum()
266
+ vals_r = sync_tensor_across_gpus(val).mean()
267
+ metric_vals[f"{eval_type}_{name}".strip("_")] = np.round(
268
+ vals_r.cpu().item(), 5
269
+ )
270
+ self.metrics_store[eval_type] = {}
271
+ self.metrics_count = {}
272
+ return metric_vals
273
+
274
+ def replicate(self, results):
275
+ for i in range(1, self.num_copies):
276
+ results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()}
277
+ results["sequence_fields"].append((0, i))
278
+ return results
279
+
280
+ def log_load_dataset(self):
281
+ if is_main_process():
282
+ info = f"Loaded {self.__class__.__name__} with {len(self)} images."
283
+ print(info)
284
+
285
+ def pre_pipeline(self, results):
286
+ results["image_fields"] = results.get("image_fields", set())
287
+ results["gt_fields"] = results.get("gt_fields", set())
288
+ results["mask_fields"] = results.get("mask_fields", set())
289
+ results["sequence_fields"] = results.get("sequence_fields", set())
290
+ results["camera_fields"] = results.get("camera_fields", set())
291
+ results["dataset_name"] = (
292
+ [self.__class__.__name__] * self.num_frames * self.num_copies
293
+ )
294
+ results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies
295
+ results["si"] = [False] * self.num_frames * self.num_copies
296
+ results["dense"] = [False] * self.num_frames * self.num_copies
297
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
298
+ results["quality"] = [0] * self.num_frames * self.num_copies
299
+ results["valid_camera"] = [True] * self.num_frames * self.num_copies
300
+ results["valid_pose"] = [True] * self.num_frames * self.num_copies
301
+ return results
302
+
303
+ def eval_mask(self, valid_mask):
304
+ return valid_mask
305
+
306
+ def chunk(self, dataset, chunk_dim=1, pct=1.0):
307
+ subsampled_datasets = [
308
+ x
309
+ for i in range(0, len(dataset), int(1 / pct * chunk_dim))
310
+ for x in dataset[i : i + chunk_dim]
311
+ ]
312
+ return subsampled_datasets
313
+
314
+ @abstractmethod
315
+ def preprocess(self, results):
316
+ raise NotImplementedError
317
+
318
+ @abstractmethod
319
+ def postprocess(self, results):
320
+ raise NotImplementedError
321
+
322
+ @abstractmethod
323
+ def get_mapper(self):
324
+ raise NotImplementedError
325
+
326
+ @abstractmethod
327
+ def get_intrinsics(self, idx, image_name):
328
+ raise NotImplementedError
329
+
330
+ @abstractmethod
331
+ def get_extrinsics(self, idx, image_name):
332
+ raise NotImplementedError
333
+
334
+ @abstractmethod
335
+ def load_dataset(self):
336
+ raise NotImplementedError
337
+
338
+ @abstractmethod
339
+ def get_single_item(self, idx, sample=None, mapper=None):
340
+ raise NotImplementedError
341
+
342
+ @abstractmethod
343
+ def __getitem__(self, idx):
344
+ raise NotImplementedError
unik3d/datasets/bdd.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from unik3d.datasets.image_dataset import ImageDataset
9
+ from unik3d.datasets.utils import DatasetFromList
10
+
11
+
12
+ class BDD(ImageDataset):
13
+ min_depth = 0.01
14
+ max_depth = 70.0
15
+ depth_scale = 256.0
16
+ test_split = "val.txt"
17
+ train_split = "train_clean.txt"
18
+ intrisics_file = "intrinsics.json"
19
+ hdf5_paths = ["BDD.hdf5"]
20
+
21
+ def __init__(
22
+ self,
23
+ image_shape,
24
+ split_file,
25
+ test_mode,
26
+ benchmark=False,
27
+ augmentations_db={},
28
+ normalize=True,
29
+ resize_method="hard",
30
+ mini=1.0,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(
34
+ image_shape=image_shape,
35
+ split_file=split_file,
36
+ test_mode=test_mode,
37
+ benchmark=benchmark,
38
+ normalize=normalize,
39
+ augmentations_db=augmentations_db,
40
+ resize_method=resize_method,
41
+ mini=mini,
42
+ **kwargs,
43
+ )
44
+ self.test_mode = test_mode
45
+ self.load_dataset()
46
+
47
+ def load_dataset(self):
48
+ h5file = h5py.File(
49
+ os.path.join(self.data_root, self.hdf5_paths[0]),
50
+ "r",
51
+ libver="latest",
52
+ swmr=True,
53
+ )
54
+ txt_file = np.array(h5file[self.split_file])
55
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
56
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
57
+ intrinsics = json.loads(intrinsics)
58
+
59
+ dataset = []
60
+ for line in txt_string.split("\n"):
61
+ image_filename, depth_filename = line.strip().split(" ")
62
+ intrinsics_val = torch.tensor(
63
+ intrinsics[os.path.join(*image_filename.split("/")[:2])]
64
+ ).squeeze()[:, :3]
65
+ sample = [image_filename, depth_filename, intrinsics_val]
66
+ dataset.append(sample)
67
+ h5file.close()
68
+ if not self.test_mode:
69
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
70
+ if self.test_mode:
71
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.1)
72
+
73
+ self.dataset = DatasetFromList(dataset)
74
+ self.log_load_dataset()
75
+
76
+ def pre_pipeline(self, results):
77
+ results = super().pre_pipeline(results)
78
+ results["si"] = [True] * self.num_copies
79
+ results["valid_camera"] = [False] * self.num_copies
80
+ results["dense"] = [False] * self.num_copies
81
+ results["quality"] = [2] * self.num_copies
82
+ return results
unik3d/datasets/bedlam.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class BEDLAM(SequenceDataset):
7
+ min_depth = 0.01
8
+ max_depth = 256.0
9
+ depth_scale = 1000.0
10
+ test_split = "train.txt"
11
+ train_split = "val.txt"
12
+ sequences_file = "sequences.json"
13
+ hdf5_paths = ["BEDLAM.hdf5"]
14
+
15
+ def __init__(
16
+ self,
17
+ image_shape: tuple[int, int],
18
+ split_file: str,
19
+ test_mode: bool,
20
+ normalize: bool,
21
+ augmentations_db: dict[str, Any],
22
+ resize_method: str,
23
+ mini: float = 1.0,
24
+ num_frames: int = 1,
25
+ benchmark: bool = False,
26
+ decode_fields: list[str] = ["image", "depth"],
27
+ inplace_fields: list[str] = ["K", "cam2w"],
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(
31
+ image_shape=image_shape,
32
+ split_file=split_file,
33
+ test_mode=test_mode,
34
+ benchmark=benchmark,
35
+ normalize=normalize,
36
+ augmentations_db=augmentations_db,
37
+ resize_method=resize_method,
38
+ mini=mini,
39
+ num_frames=num_frames,
40
+ decode_fields=decode_fields,
41
+ inplace_fields=inplace_fields,
42
+ **kwargs,
43
+ )
44
+
45
+ def pre_pipeline(self, results):
46
+ results = super().pre_pipeline(results)
47
+ results["dense"] = [True] * self.num_frames * self.num_copies
48
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
49
+ results["quality"] = [0] * self.num_frames * self.num_copies
50
+ return results
unik3d/datasets/behave.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class Behave(SequenceDataset):
7
+ min_depth = 0.01
8
+ max_depth = 10.0
9
+ depth_scale = 1000.0
10
+ default_fps = 10
11
+ test_split = "train.txt"
12
+ train_split = "train.txt"
13
+ sequences_file = "sequences.json"
14
+ hdf5_paths = ["Behave.hdf5"]
15
+
16
+ def __init__(
17
+ self,
18
+ image_shape: tuple[int, int],
19
+ split_file: str,
20
+ test_mode: bool,
21
+ normalize: bool,
22
+ augmentations_db: dict[str, Any],
23
+ resize_method: str,
24
+ mini: float = 1.0,
25
+ num_frames: int = 1,
26
+ benchmark: bool = False,
27
+ decode_fields: list[str] = ["image", "depth"],
28
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
29
+ **kwargs,
30
+ ) -> None:
31
+ super().__init__(
32
+ image_shape=image_shape,
33
+ split_file=split_file,
34
+ test_mode=test_mode,
35
+ benchmark=benchmark,
36
+ normalize=normalize,
37
+ augmentations_db=augmentations_db,
38
+ resize_method=resize_method,
39
+ mini=mini,
40
+ num_frames=num_frames,
41
+ decode_fields=decode_fields,
42
+ inplace_fields=inplace_fields,
43
+ **kwargs,
44
+ )
45
+
46
+ def pre_pipeline(self, results):
47
+ results = super().pre_pipeline(results)
48
+ results["dense"] = [True] * self.num_frames * self.num_copies
49
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
50
+ results["si"] = [False] * self.num_frames * self.num_copies
51
+ results["quality"] = [1] * self.num_frames * self.num_copies
52
+ return results
unik3d/datasets/blendedmvg.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ from unik3d.datasets.sequence_dataset import SequenceDataset
4
+
5
+
6
+ class BlendedMVG(SequenceDataset):
7
+ min_depth = 0.01
8
+ max_depth = 5000.0
9
+ depth_scale = 1000.0
10
+ test_split = "train.txt"
11
+ train_split = "train.txt"
12
+ sequences_file = "sequences_clean.json"
13
+ hdf5_paths = ["BlendedMVG_.hdf5"]
14
+
15
+ def __init__(
16
+ self,
17
+ image_shape: tuple[int, int],
18
+ split_file: str,
19
+ test_mode: bool,
20
+ normalize: bool,
21
+ augmentations_db: dict[str, Any],
22
+ resize_method: str,
23
+ mini: float = 1.0,
24
+ num_frames: int = 1,
25
+ benchmark: bool = False,
26
+ decode_fields: list[str] = ["image", "depth"],
27
+ inplace_fields: list[str] = ["K", "cam2w"],
28
+ **kwargs,
29
+ ) -> None:
30
+ super().__init__(
31
+ image_shape=image_shape,
32
+ split_file=split_file,
33
+ test_mode=test_mode,
34
+ benchmark=benchmark,
35
+ normalize=normalize,
36
+ augmentations_db=augmentations_db,
37
+ resize_method=resize_method,
38
+ mini=mini,
39
+ num_frames=num_frames,
40
+ decode_fields=decode_fields,
41
+ inplace_fields=inplace_fields,
42
+ **kwargs,
43
+ )
44
+
45
+ def pre_pipeline(self, results):
46
+ results = super().pre_pipeline(results)
47
+ results["dense"] = [True] * self.num_frames * self.num_copies
48
+ results["si"] = [False] * self.num_frames * self.num_copies
49
+ results["quality"] = [2] * self.num_frames * self.num_copies
50
+ return results
unik3d/datasets/cityscape.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from unik3d.datasets.image_dataset import ImageDataset
9
+ from unik3d.datasets.utils import DatasetFromList
10
+
11
+
12
+ class Cityscape(ImageDataset):
13
+ min_depth = 0.05
14
+ max_depth = 80.0
15
+ depth_scale = 256.0
16
+ test_split = "val.txt"
17
+ train_split = "train.txt"
18
+ intrisics_file = "intrinsics.json"
19
+ hdf5_paths = ["cityscape.hdf5"]
20
+
21
+ def __init__(
22
+ self,
23
+ image_shape,
24
+ split_file,
25
+ test_mode,
26
+ crop=None,
27
+ benchmark=False,
28
+ augmentations_db={},
29
+ normalize=True,
30
+ resize_method="hard",
31
+ mini=1.0,
32
+ **kwargs,
33
+ ):
34
+ super().__init__(
35
+ image_shape=image_shape,
36
+ split_file=split_file,
37
+ test_mode=test_mode,
38
+ benchmark=benchmark,
39
+ normalize=normalize,
40
+ augmentations_db=augmentations_db,
41
+ resize_method=resize_method,
42
+ mini=mini,
43
+ **kwargs,
44
+ )
45
+ self.test_mode = test_mode
46
+
47
+ self.crop = crop
48
+ self.load_dataset()
49
+
50
+ def load_dataset(self):
51
+ h5file = h5py.File(
52
+ os.path.join(self.data_root, self.hdf5_paths[0]),
53
+ "r",
54
+ libver="latest",
55
+ swmr=True,
56
+ )
57
+ txt_file = np.array(h5file[self.split_file])
58
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
59
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
60
+ intrinsics = json.loads(intrinsics)
61
+ h5file.close()
62
+ dataset = []
63
+ for line in txt_string.split("\n"):
64
+ image_filename, depth_filename = line.strip().split(" ")
65
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
66
+ sample = [image_filename, depth_filename, intrinsics_val]
67
+ dataset.append(sample)
68
+
69
+ if not self.test_mode:
70
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
71
+
72
+ self.dataset = DatasetFromList(dataset)
73
+ self.log_load_dataset()
74
+
75
+ def pre_pipeline(self, results):
76
+ results = super().pre_pipeline(results)
77
+ results["quality"] = [2] * self.num_copies
78
+ return results
unik3d/datasets/ddad.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+
4
+ import h5py
5
+ import numpy as np
6
+ import torch
7
+
8
+ from unik3d.datasets.image_dataset import ImageDataset
9
+ from unik3d.datasets.utils import DatasetFromList
10
+
11
+
12
+ class DDAD(ImageDataset):
13
+ min_depth = 0.05
14
+ max_depth = 120.0
15
+ depth_scale = 256.0
16
+ test_split = "val.txt"
17
+ train_split = "train.txt"
18
+ intrisics_file = "intrinsics.json"
19
+ hdf5_paths = [f"ddad/ddad_{i}.hdf5" for i in range(8)]
20
+
21
+ def __init__(
22
+ self,
23
+ image_shape,
24
+ split_file,
25
+ test_mode,
26
+ benchmark=False,
27
+ augmentations_db={},
28
+ normalize=True,
29
+ resize_method="hard",
30
+ mini=1.0,
31
+ **kwargs,
32
+ ):
33
+ super().__init__(
34
+ image_shape=image_shape,
35
+ split_file=split_file,
36
+ test_mode=test_mode,
37
+ benchmark=benchmark,
38
+ normalize=normalize,
39
+ augmentations_db=augmentations_db,
40
+ resize_method=resize_method,
41
+ mini=mini,
42
+ **kwargs,
43
+ )
44
+ self.test_mode = test_mode
45
+ self.load_dataset()
46
+
47
+ def load_dataset(self):
48
+ h5file = h5py.File(
49
+ os.path.join(self.data_root, self.hdf5_paths[0]),
50
+ "r",
51
+ libver="latest",
52
+ swmr=True,
53
+ )
54
+ txt_file = np.array(h5file[self.split_file])
55
+ txt_string = txt_file.tostring().decode("ascii").strip("\n")
56
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
57
+ intrinsics = json.loads(intrinsics)
58
+ h5file.close()
59
+ dataset = []
60
+ for line in txt_string.split("\n"):
61
+ image_filename, depth_filename, chunk_idx = line.strip().split(" ")
62
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
63
+ sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
64
+ dataset.append(sample)
65
+
66
+ if not self.test_mode:
67
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
68
+
69
+ self.dataset = DatasetFromList(dataset)
70
+ self.log_load_dataset()
71
+
72
+ def get_mapper(self):
73
+ return {
74
+ "image_filename": 0,
75
+ "depth_filename": 1,
76
+ "K": 2,
77
+ "chunk_idx": 3,
78
+ }
79
+
80
+ def pre_pipeline(self, results):
81
+ results = super().pre_pipeline(results)
82
+ results["dense"] = [False] * self.num_copies
83
+ results["quality"] = [1] * self.num_copies
84
+ return results
unik3d/datasets/deep360.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any
2
+
3
+ import torch
4
+
5
+ from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
6
+ from unik3d.datasets.sequence_dataset import SequenceDataset
7
+
8
+
9
+ class Deep360(SequenceDataset):
10
+ min_depth = 0.1
11
+ max_depth = 1000.0
12
+ depth_scale = 1000.0
13
+ test_split = "train.txt"
14
+ train_split = "train.txt"
15
+ sequences_file = "sequences.json"
16
+ hdf5_paths = [f"Deep360.hdf5"]
17
+
18
+ def __init__(
19
+ self,
20
+ image_shape: tuple[int, int],
21
+ split_file: str,
22
+ test_mode: bool,
23
+ normalize: bool,
24
+ augmentations_db: dict[str, Any],
25
+ resize_method: str,
26
+ mini: float = 1.0,
27
+ num_frames: int = 1,
28
+ benchmark: bool = False,
29
+ decode_fields: list[str] = ["image", "depth"],
30
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
31
+ **kwargs,
32
+ ) -> None:
33
+ super().__init__(
34
+ image_shape=image_shape,
35
+ split_file=split_file,
36
+ test_mode=test_mode,
37
+ benchmark=benchmark,
38
+ normalize=normalize,
39
+ augmentations_db=augmentations_db,
40
+ resize_method=resize_method,
41
+ mini=mini,
42
+ num_frames=num_frames,
43
+ decode_fields=decode_fields,
44
+ inplace_fields=inplace_fields,
45
+ **kwargs,
46
+ )
47
+ self.resizer = Compose(
48
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
49
+ )
50
+
51
+ def pre_pipeline(self, results):
52
+ results = super().pre_pipeline(results)
53
+ results["dense"] = [True] * self.num_frames * self.num_copies
54
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
55
+ results["quality"] = [0] * self.num_frames * self.num_copies
56
+ return results
unik3d/datasets/dense.py ADDED
@@ -0,0 +1,91 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import h5py
4
+ import numpy as np
5
+ import torch
6
+
7
+ from unik3d.datasets.image_dataset import ImageDataset
8
+ from unik3d.datasets.utils import DatasetFromList
9
+
10
+
11
+ class DENSE(ImageDataset):
12
+ CAM_INTRINSIC = {
13
+ "ALL": torch.tensor(
14
+ [
15
+ [1177.8614, 0.0, 474.319027],
16
+ [0.0, 1177.8614, 224.275919],
17
+ [0.0, 0.0, 1.0],
18
+ ]
19
+ )
20
+ }
21
+ min_depth = 0.05
22
+ max_depth = 80.0
23
+ depth_scale = 255.0
24
+ test_split = "train.txt"
25
+ train_split = "train.txt"
26
+ hdf5_paths = ["DENSE.hdf5"]
27
+
28
+ def __init__(
29
+ self,
30
+ image_shape,
31
+ split_file,
32
+ test_mode,
33
+ benchmark=False,
34
+ augmentations_db={},
35
+ normalize=True,
36
+ resize_method="hard",
37
+ mini=1.0,
38
+ **kwargs,
39
+ ):
40
+ super().__init__(
41
+ image_shape=image_shape,
42
+ split_file=split_file,
43
+ test_mode=test_mode,
44
+ benchmark=benchmark,
45
+ normalize=normalize,
46
+ augmentations_db=augmentations_db,
47
+ resize_method=resize_method,
48
+ mini=mini,
49
+ **kwargs,
50
+ )
51
+ self.test_mode = test_mode
52
+
53
+ self.intrisics = {}
54
+ self.load_dataset()
55
+
56
+ def load_dataset(self):
57
+ h5file = h5py.File(
58
+ os.path.join(self.data_root, self.hdf5_paths[0]),
59
+ "r",
60
+ libver="latest",
61
+ swmr=True,
62
+ )
63
+ txt_file = np.array(h5file[self.split_file])
64
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
65
+ h5file.close()
66
+ dataset = []
67
+ for line in txt_string.split("\n"):
68
+ image_filename, depth_filename = line.strip().split(" ")
69
+ sample = [image_filename, depth_filename]
70
+ dataset.append(sample)
71
+
72
+ if not self.test_mode:
73
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
74
+
75
+ self.dataset = DatasetFromList(dataset)
76
+ self.log_load_dataset()
77
+
78
+ def get_intrinsics(self, idx, image_name):
79
+ return self.CAM_INTRINSIC["ALL"].clone()
80
+
81
+ def get_mapper(self):
82
+ return {
83
+ "image_filename": 0,
84
+ "depth_filename": 1,
85
+ }
86
+
87
+ def pre_pipeline(self, results):
88
+ results = super().pre_pipeline(results)
89
+ results["dense"] = [False] * self.num_copies
90
+ results["quality"] = [1] * self.num_copies
91
+ return results