diff --git a/.gitignore b/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..a1ca0df48306fa3795b32723d9f6d6d76e2fc4d3
--- /dev/null
+++ b/.gitignore
@@ -0,0 +1,16 @@
+**/__pycache__/
+**/build/
+**/dist/
+**/*egg-info
+.gradio/
+
+# ignore scripts
+_*.sh
+__*.png
+__*.jpg
+__*.webp
+___*.py
+**/___*.py
+
+# ignore pcds
+*.ply
diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..12d4ca74d98cfb36269464dced75e9441b7bf645
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,407 @@
+Attribution-NonCommercial 4.0 International
+
+=======================================================================
+
+Creative Commons Corporation ("Creative Commons") is not a law firm and
+does not provide legal services or legal advice. Distribution of
+Creative Commons public licenses does not create a lawyer-client or
+other relationship. Creative Commons makes its licenses and related
+information available on an "as-is" basis. Creative Commons gives no
+warranties regarding its licenses, any material licensed under their
+terms and conditions, or any related information. Creative Commons
+disclaims all liability for damages resulting from their use to the
+fullest extent possible.
+
+Using Creative Commons Public Licenses
+
+Creative Commons public licenses provide a standard set of terms and
+conditions that creators and other rights holders may use to share
+original works of authorship and other material subject to copyright
+and certain other rights specified in the public license below. The
+following considerations are for informational purposes only, are not
+exhaustive, and do not form part of our licenses.
+
+ Considerations for licensors: Our public licenses are
+ intended for use by those authorized to give the public
+ permission to use material in ways otherwise restricted by
+ copyright and certain other rights. Our licenses are
+ irrevocable. Licensors should read and understand the terms
+ and conditions of the license they choose before applying it.
+ Licensors should also secure all rights necessary before
+ applying our licenses so that the public can reuse the
+ material as expected. Licensors should clearly mark any
+ material not subject to the license. This includes other CC-
+ licensed material, or material used under an exception or
+ limitation to copyright. More considerations for licensors:
+ wiki.creativecommons.org/Considerations_for_licensors
+
+ Considerations for the public: By using one of our public
+ licenses, a licensor grants the public permission to use the
+ licensed material under specified terms and conditions. If
+ the licensor's permission is not necessary for any reason--for
+ example, because of any applicable exception or limitation to
+ copyright--then that use is not regulated by the license. Our
+ licenses grant only permissions under copyright and certain
+ other rights that a licensor has authority to grant. Use of
+ the licensed material may still be restricted for other
+ reasons, including because others have copyright or other
+ rights in the material. A licensor may make special requests,
+ such as asking that all changes be marked or described.
+ Although not required by our licenses, you are encouraged to
+ respect those requests where reasonable. More considerations
+ for the public:
+ wiki.creativecommons.org/Considerations_for_licensees
+
+=======================================================================
+
+Creative Commons Attribution-NonCommercial 4.0 International Public
+License
+
+By exercising the Licensed Rights (defined below), You accept and agree
+to be bound by the terms and conditions of this Creative Commons
+Attribution-NonCommercial 4.0 International Public License ("Public
+License"). To the extent this Public License may be interpreted as a
+contract, You are granted the Licensed Rights in consideration of Your
+acceptance of these terms and conditions, and the Licensor grants You
+such rights in consideration of benefits the Licensor receives from
+making the Licensed Material available under these terms and
+conditions.
+
+
+Section 1 -- Definitions.
+
+ a. Adapted Material means material subject to Copyright and Similar
+ Rights that is derived from or based upon the Licensed Material
+ and in which the Licensed Material is translated, altered,
+ arranged, transformed, or otherwise modified in a manner requiring
+ permission under the Copyright and Similar Rights held by the
+ Licensor. For purposes of this Public License, where the Licensed
+ Material is a musical work, performance, or sound recording,
+ Adapted Material is always produced where the Licensed Material is
+ synched in timed relation with a moving image.
+
+ b. Adapter's License means the license You apply to Your Copyright
+ and Similar Rights in Your contributions to Adapted Material in
+ accordance with the terms and conditions of this Public License.
+
+ c. Copyright and Similar Rights means copyright and/or similar rights
+ closely related to copyright including, without limitation,
+ performance, broadcast, sound recording, and Sui Generis Database
+ Rights, without regard to how the rights are labeled or
+ categorized. For purposes of this Public License, the rights
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
+ Rights.
+ d. Effective Technological Measures means those measures that, in the
+ absence of proper authority, may not be circumvented under laws
+ fulfilling obligations under Article 11 of the WIPO Copyright
+ Treaty adopted on December 20, 1996, and/or similar international
+ agreements.
+
+ e. Exceptions and Limitations means fair use, fair dealing, and/or
+ any other exception or limitation to Copyright and Similar Rights
+ that applies to Your use of the Licensed Material.
+
+ f. Licensed Material means the artistic or literary work, database,
+ or other material to which the Licensor applied this Public
+ License.
+
+ g. Licensed Rights means the rights granted to You subject to the
+ terms and conditions of this Public License, which are limited to
+ all Copyright and Similar Rights that apply to Your use of the
+ Licensed Material and that the Licensor has authority to license.
+
+ h. Licensor means the individual(s) or entity(ies) granting rights
+ under this Public License.
+
+ i. NonCommercial means not primarily intended for or directed towards
+ commercial advantage or monetary compensation. For purposes of
+ this Public License, the exchange of the Licensed Material for
+ other material subject to Copyright and Similar Rights by digital
+ file-sharing or similar means is NonCommercial provided there is
+ no payment of monetary compensation in connection with the
+ exchange.
+
+ j. Share means to provide material to the public by any means or
+ process that requires permission under the Licensed Rights, such
+ as reproduction, public display, public performance, distribution,
+ dissemination, communication, or importation, and to make material
+ available to the public including in ways that members of the
+ public may access the material from a place and at a time
+ individually chosen by them.
+
+ k. Sui Generis Database Rights means rights other than copyright
+ resulting from Directive 96/9/EC of the European Parliament and of
+ the Council of 11 March 1996 on the legal protection of databases,
+ as amended and/or succeeded, as well as other essentially
+ equivalent rights anywhere in the world.
+
+ l. You means the individual or entity exercising the Licensed Rights
+ under this Public License. Your has a corresponding meaning.
+
+
+Section 2 -- Scope.
+
+ a. License grant.
+
+ 1. Subject to the terms and conditions of this Public License,
+ the Licensor hereby grants You a worldwide, royalty-free,
+ non-sublicensable, non-exclusive, irrevocable license to
+ exercise the Licensed Rights in the Licensed Material to:
+
+ a. reproduce and Share the Licensed Material, in whole or
+ in part, for NonCommercial purposes only; and
+
+ b. produce, reproduce, and Share Adapted Material for
+ NonCommercial purposes only.
+
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
+ Exceptions and Limitations apply to Your use, this Public
+ License does not apply, and You do not need to comply with
+ its terms and conditions.
+
+ 3. Term. The term of this Public License is specified in Section
+ 6(a).
+
+ 4. Media and formats; technical modifications allowed. The
+ Licensor authorizes You to exercise the Licensed Rights in
+ all media and formats whether now known or hereafter created,
+ and to make technical modifications necessary to do so. The
+ Licensor waives and/or agrees not to assert any right or
+ authority to forbid You from making technical modifications
+ necessary to exercise the Licensed Rights, including
+ technical modifications necessary to circumvent Effective
+ Technological Measures. For purposes of this Public License,
+ simply making modifications authorized by this Section 2(a)
+ (4) never produces Adapted Material.
+
+ 5. Downstream recipients.
+
+ a. Offer from the Licensor -- Licensed Material. Every
+ recipient of the Licensed Material automatically
+ receives an offer from the Licensor to exercise the
+ Licensed Rights under the terms and conditions of this
+ Public License.
+
+ b. No downstream restrictions. You may not offer or impose
+ any additional or different terms or conditions on, or
+ apply any Effective Technological Measures to, the
+ Licensed Material if doing so restricts exercise of the
+ Licensed Rights by any recipient of the Licensed
+ Material.
+
+ 6. No endorsement. Nothing in this Public License constitutes or
+ may be construed as permission to assert or imply that You
+ are, or that Your use of the Licensed Material is, connected
+ with, or sponsored, endorsed, or granted official status by,
+ the Licensor or others designated to receive attribution as
+ provided in Section 3(a)(1)(A)(i).
+
+ b. Other rights.
+
+ 1. Moral rights, such as the right of integrity, are not
+ licensed under this Public License, nor are publicity,
+ privacy, and/or other similar personality rights; however, to
+ the extent possible, the Licensor waives and/or agrees not to
+ assert any such rights held by the Licensor to the limited
+ extent necessary to allow You to exercise the Licensed
+ Rights, but not otherwise.
+
+ 2. Patent and trademark rights are not licensed under this
+ Public License.
+
+ 3. To the extent possible, the Licensor waives any right to
+ collect royalties from You for the exercise of the Licensed
+ Rights, whether directly or through a collecting society
+ under any voluntary or waivable statutory or compulsory
+ licensing scheme. In all other cases the Licensor expressly
+ reserves any right to collect such royalties, including when
+ the Licensed Material is used other than for NonCommercial
+ purposes.
+
+
+Section 3 -- License Conditions.
+
+Your exercise of the Licensed Rights is expressly made subject to the
+following conditions.
+
+ a. Attribution.
+
+ 1. If You Share the Licensed Material (including in modified
+ form), You must:
+
+ a. retain the following if it is supplied by the Licensor
+ with the Licensed Material:
+
+ i. identification of the creator(s) of the Licensed
+ Material and any others designated to receive
+ attribution, in any reasonable manner requested by
+ the Licensor (including by pseudonym if
+ designated);
+
+ ii. a copyright notice;
+
+ iii. a notice that refers to this Public License;
+
+ iv. a notice that refers to the disclaimer of
+ warranties;
+
+ v. a URI or hyperlink to the Licensed Material to the
+ extent reasonably practicable;
+
+ b. indicate if You modified the Licensed Material and
+ retain an indication of any previous modifications; and
+
+ c. indicate the Licensed Material is licensed under this
+ Public License, and include the text of, or the URI or
+ hyperlink to, this Public License.
+
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
+ reasonable manner based on the medium, means, and context in
+ which You Share the Licensed Material. For example, it may be
+ reasonable to satisfy the conditions by providing a URI or
+ hyperlink to a resource that includes the required
+ information.
+
+ 3. If requested by the Licensor, You must remove any of the
+ information required by Section 3(a)(1)(A) to the extent
+ reasonably practicable.
+
+ 4. If You Share Adapted Material You produce, the Adapter's
+ License You apply must not prevent recipients of the Adapted
+ Material from complying with this Public License.
+
+
+Section 4 -- Sui Generis Database Rights.
+
+Where the Licensed Rights include Sui Generis Database Rights that
+apply to Your use of the Licensed Material:
+
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
+ to extract, reuse, reproduce, and Share all or a substantial
+ portion of the contents of the database for NonCommercial purposes
+ only;
+
+ b. if You include all or a substantial portion of the database
+ contents in a database in which You have Sui Generis Database
+ Rights, then the database in which You have Sui Generis Database
+ Rights (but not its individual contents) is Adapted Material; and
+
+ c. You must comply with the conditions in Section 3(a) if You Share
+ all or a substantial portion of the contents of the database.
+
+For the avoidance of doubt, this Section 4 supplements and does not
+replace Your obligations under this Public License where the Licensed
+Rights include other Copyright and Similar Rights.
+
+
+Section 5 -- Disclaimer of Warranties and Limitation of Liability.
+
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
+
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
+
+ c. The disclaimer of warranties and limitation of liability provided
+ above shall be interpreted in a manner that, to the extent
+ possible, most closely approximates an absolute disclaimer and
+ waiver of all liability.
+
+
+Section 6 -- Term and Termination.
+
+ a. This Public License applies for the term of the Copyright and
+ Similar Rights licensed here. However, if You fail to comply with
+ this Public License, then Your rights under this Public License
+ terminate automatically.
+
+ b. Where Your right to use the Licensed Material has terminated under
+ Section 6(a), it reinstates:
+
+ 1. automatically as of the date the violation is cured, provided
+ it is cured within 30 days of Your discovery of the
+ violation; or
+
+ 2. upon express reinstatement by the Licensor.
+
+ For the avoidance of doubt, this Section 6(b) does not affect any
+ right the Licensor may have to seek remedies for Your violations
+ of this Public License.
+
+ c. For the avoidance of doubt, the Licensor may also offer the
+ Licensed Material under separate terms or conditions or stop
+ distributing the Licensed Material at any time; however, doing so
+ will not terminate this Public License.
+
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
+ License.
+
+
+Section 7 -- Other Terms and Conditions.
+
+ a. The Licensor shall not be bound by any additional or different
+ terms or conditions communicated by You unless expressly agreed.
+
+ b. Any arrangements, understandings, or agreements regarding the
+ Licensed Material not stated herein are separate from and
+ independent of the terms and conditions of this Public License.
+
+
+Section 8 -- Interpretation.
+
+ a. For the avoidance of doubt, this Public License does not, and
+ shall not be interpreted to, reduce, limit, restrict, or impose
+ conditions on any use of the Licensed Material that could lawfully
+ be made without permission under this Public License.
+
+ b. To the extent possible, if any provision of this Public License is
+ deemed unenforceable, it shall be automatically reformed to the
+ minimum extent necessary to make it enforceable. If the provision
+ cannot be reformed, it shall be severed from this Public License
+ without affecting the enforceability of the remaining terms and
+ conditions.
+
+ c. No term or condition of this Public License will be waived and no
+ failure to comply consented to unless expressly agreed to by the
+ Licensor.
+
+ d. Nothing in this Public License constitutes or may be interpreted
+ as a limitation upon, or waiver of, any privileges and immunities
+ that apply to the Licensor or You, including from the legal
+ processes of any jurisdiction or authority.
+
+=======================================================================
+
+Creative Commons is not a party to its public
+licenses. Notwithstanding, Creative Commons may elect to apply one of
+its public licenses to material it publishes and in those instances
+will be considered the “Licensor.” The text of the Creative Commons
+public licenses is dedicated to the public domain under the CC0 Public
+Domain Dedication. Except for the limited purpose of indicating that
+material is shared under a Creative Commons public license or as
+otherwise permitted by the Creative Commons policies published at
+creativecommons.org/policies, Creative Commons does not authorize the
+use of the trademark "Creative Commons" or any other trademark or logo
+of Creative Commons without its prior written consent including,
+without limitation, in connection with any unauthorized modifications
+to any of its public licenses or any other arrangements,
+understandings, or agreements concerning use of licensed material. For
+the avoidance of doubt, this paragraph does not form part of the
+public licenses.
+
+Creative Commons may be contacted at creativecommons.org.
\ No newline at end of file
diff --git a/README.md b/README.md
index f0442905a8d4cb9e6d0020015731ee82a9fe7c3b..bf3e83af9af16301aaf295e155560c7916b0596c 100644
--- a/README.md
+++ b/README.md
@@ -8,6 +8,7 @@ sdk_version: 5.22.0
app_file: app.py
pinned: false
license: cc-by-nc-4.0
+short_description: UniK3D (CVPR 2025)
---
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
diff --git a/app.py b/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..950e0d92ffe20c33dbab3d03dd2a06f14fb86652
--- /dev/null
+++ b/app.py
@@ -0,0 +1,800 @@
+import gc
+import os
+import shutil
+import time
+from datetime import datetime
+from math import pi
+import sys
+
+import gradio as gr
+import numpy as np
+import torch
+import trimesh
+from PIL import Image
+
+
+sys.path.append("unik3d/")
+
+from unik3d.models import UniK3D
+from unik3d.utils.camera import OPENCV, Fisheye624, Pinhole, Spherical
+from unik3d.utils.visualization import colorize
+
+
+def predictions_to_glb(
+ predictions,
+ mask_black_bg=False,
+ mask_far_points=False,
+) -> trimesh.Scene:
+ print("Building GLB scene")
+ images = predictions["image"].squeeze().permute(1, 2, 0).cpu().numpy()
+ world_points = predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy()
+
+ vertices_3d = world_points.reshape(-1, 3)
+ # flip x and y
+ vertices_3d[:, 1] *= -1
+ vertices_3d[:, 0] *= -1
+ colors_rgb = (images.reshape(-1, 3)).astype(np.uint8)
+
+ if mask_black_bg:
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
+ vertices_3d = vertices_3d[black_bg_mask]
+ colors_rgb = colors_rgb[black_bg_mask]
+
+ if mask_far_points:
+ far_points_mask = np.linalg.norm(vertices_3d, axis=-1) < 100.0
+ vertices_3d = vertices_3d[far_points_mask]
+ colors_rgb = colors_rgb[far_points_mask]
+
+ scene_3d = trimesh.Scene()
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
+ scene_3d.add_geometry(point_cloud_data)
+
+ return scene_3d
+
+
+def instantiate_model(model_name):
+ type_ = model_name[0].lower()
+
+ name = f"unik3d-vit{type_}"
+ model = UniK3D.from_pretrained(f"lpiccinelli/{name}")
+
+ # Set resolution level and interpolation mode as specified.
+ model.resolution_level = 9
+ model.interpolation_mode = "bilinear"
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device).eval()
+ return model
+
+
+def instantiate_camera(camera_name, params, device):
+ if camera_name == "Predicted":
+ return None
+ fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov, H, W = params
+ if camera_name == "Pinhole":
+ params = [fx, fy, cx, cy]
+ elif camera_name == "Fisheye624":
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
+ elif camera_name == "OPENCV":
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
+ elif camera_name == "Equirectangular":
+ # dummy intrinsics for spherical camera, assume hfov -> vfov based on input shapes
+ hfov2 = hfov * pi / 180.0 / 2
+ params = [fx, fy, cx, cy, W, H, hfov2, H / W * hfov2]
+ camera_name = "Spherical"
+
+ return eval(camera_name)(params=torch.tensor(params).float()).to(device)
+
+
+def run_model(target_dir, model_name, camera_name, params):
+
+ print("Instantiating model and camera...")
+ model = instantiate_model(model_name)
+
+ image_names = [x for x in os.listdir(target_dir) if x.endswith(".png")]
+ input_image = np.array(Image.open(os.path.join(target_dir, image_names[-1])))
+ image_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float()
+ device = next(model.parameters()).device
+ image_tensor = image_tensor.to(device)
+ H, W = image_tensor.shape[-2:]
+ params = params + [H, W]
+ camera = instantiate_camera(camera_name, params=params, device=device)
+
+ # Perform inference with the model.
+ print("Running inference...")
+ outputs = model.infer(image_tensor, camera=camera, normalize=True)
+ outputs["image"] = image_tensor
+
+ return outputs
+
+
+def gradio_demo(
+ target_dir,
+ model_name,
+ camera_name,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+):
+ print(target_dir)
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ return None, "No valid target directory found. Please upload first.", None
+
+ start_time = time.time()
+ gc.collect()
+
+ print("Running run_model...")
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov]
+ with torch.no_grad():
+ outputs = run_model(target_dir, model_name, camera_name, params)
+
+ # Save predictions
+ points = outputs["points"].squeeze().permute(1, 2, 0).cpu().numpy()
+ rgb = outputs["image"].squeeze().permute(1, 2, 0).cpu().numpy()
+
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
+ np.savez(prediction_save_path, {"points": points, "image": rgb})
+
+ # Build a GLB file name
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene.glb",
+ )
+
+ # Convert predictions to GLB
+ glbscene = predictions_to_glb(
+ outputs,
+ mask_black_bg=mask_black_bg,
+ mask_far_points=mask_far_points,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ # Cleanup
+ del outputs
+ gc.collect()
+
+ end_time = time.time()
+ print(f"Total time: {end_time - start_time:.2f} seconds")
+ log_msg = f"Success. Waiting for visualization."
+
+ return glbfile, log_msg, prediction_save_path
+
+
+def handle_uploads(input_image):
+ gc.collect()
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ tmpdir = os.environ.get("TMPDIR", "/tmp")
+ target_dir = os.path.join(tmpdir, f"input_images_{timestamp}")
+
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(target_dir)
+
+ dst_path = os.path.join(target_dir, "image.png")
+ Image.fromarray(input_image).save(dst_path)
+ image_paths = [dst_path]
+
+ print(f"Files uploaded.")
+ return target_dir, image_paths
+
+
+def update_gallery_on_upload(input_images):
+ if input_images is None:
+ return None, None
+ target_dir, image_path = handle_uploads(input_images)
+ return target_dir, "Upload complete. Click 'Run UniK3D' to get 3D pointcloud."
+
+
+def update_parameters(camera):
+ if camera == "Pinhole":
+ return (
+ gr.update(visible=True), # fx
+ gr.update(visible=True), # fy
+ gr.update(visible=True), # cx
+ gr.update(visible=True), # cy
+ gr.update(visible=False), # k1
+ gr.update(visible=False), # k2
+ gr.update(visible=False), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=False), # t1
+ gr.update(visible=False), # t2
+ gr.update(visible=False), # hfov
+ )
+ elif camera == "OPENCV":
+ return (
+ gr.update(visible=True), # fx
+ gr.update(visible=True), # fy
+ gr.update(visible=True), # cx
+ gr.update(visible=True), # cy
+ gr.update(visible=True), # k1
+ gr.update(visible=True), # k2
+ gr.update(visible=True), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=True), # t1
+ gr.update(visible=True), # t2
+ gr.update(visible=False), # hfov
+ )
+ elif camera == "Fisheye624":
+ return (
+ gr.update(visible=True), # fx
+ gr.update(visible=True), # fy
+ gr.update(visible=True), # cx
+ gr.update(visible=True), # cy
+ gr.update(visible=True), # k1
+ gr.update(visible=True), # k2
+ gr.update(visible=True), # k3
+ gr.update(visible=True), # k4
+ gr.update(visible=True), # k5
+ gr.update(visible=True), # k6
+ gr.update(visible=True), # t1
+ gr.update(visible=True), # t2
+ gr.update(visible=False), # hfov
+ )
+ elif camera == "Equirectangular":
+ return (
+ gr.update(visible=False), # fx
+ gr.update(visible=False), # fy
+ gr.update(visible=False), # cx
+ gr.update(visible=False), # cy
+ gr.update(visible=False), # k1
+ gr.update(visible=False), # k2
+ gr.update(visible=False), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=False), # t1
+ gr.update(visible=False), # t2
+ gr.update(visible=True), # hfov
+ )
+ elif camera == "Predicted":
+ return (
+ gr.update(visible=False), # fx
+ gr.update(visible=False), # fy
+ gr.update(visible=False), # cx
+ gr.update(visible=False), # cy
+ gr.update(visible=False), # k1
+ gr.update(visible=False), # k2
+ gr.update(visible=False), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=False), # t1
+ gr.update(visible=False), # t2
+ gr.update(visible=False), # hfov
+ )
+ else:
+ raise ValueError(f"Invalid camera type: {camera}")
+
+
+def clear_fields():
+ return None
+
+
+def update_log():
+ return "Loading Model and Running Inference..."
+
+
+def update_visualization(target_dir, mask_black_bg, mask_far_points, is_example):
+
+ if is_example == "True":
+ return (
+ None,
+ "No reconstruction available. Please click the Reconstruct button first.",
+ )
+
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return (
+ None,
+ "No reconstruction available. Please click the Reconstruct button first.",
+ )
+
+ predictions_path = os.path.join(target_dir, "predictions.npz")
+ if not os.path.exists(predictions_path):
+ return (
+ None,
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
+ )
+
+ loaded = np.load(predictions_path, allow_pickle=True)
+ predictions = {key: loaded[key] for key in loaded.keys()}
+
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene.glb",
+ )
+
+ if not os.path.exists(glbfile):
+ glbscene = predictions_to_glb(
+ predictions,
+ mask_black_bg=mask_black_bg,
+ mask_far_points=mask_far_points,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ return glbfile, "Updating Visualization"
+
+
+if __name__ == "__main__":
+ theme = gr.themes.Citrus()
+ theme.set(
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
+ checkbox_label_text_color_selected="*button_primary_text_color",
+ )
+
+ with gr.Blocks(
+ theme=theme,
+ css="""
+ .custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+ }
+
+ .example-log * {
+ font-style: italic;
+ font-size: 16px !important;
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent !important;
+ }
+
+ #my_radio .wrap {
+ display: flex;
+ flex-wrap: nowrap;
+ justify-content: center;
+ align-items: center;
+ }
+
+ #my_radio .wrap label {
+ display: flex;
+ width: 50%;
+ justify-content: center;
+ align-items: center;
+ margin: 0;
+ padding: 10px 0;
+ box-sizing: border-box;
+ }
+ """,
+ ) as demo:
+
+ # Instead of gr.State, we use a hidden Textbox:
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+
+ gr.HTML(
+ """
+
UniK3D: Universal Camera Monocular 3D Estimation
+
+ 🌟 GitHub Repository |
+ 🚀 Project Page
+
+
+
+
Upload one image to create a 3D estimation of a scene or object. UniK3D allows to predict directly 3D of any camera and scene.
+
+
Getting Started:
+
+ - Upload Your Image: Use the "Upload Images" panel to provide your input.
+ - Run: Click the "Run UniK3D" button to start the 3D estimation process.
+ - Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.
+
+
Please note: Our model runs on CPU on HuggingFace Space. Actual inference is less than 100ms second per image on consumer-level GPUs. Web-based 3D pointcloud visualization may be slow due to Gradio's rendering. For faster visualization, use a local machine to run our demo from our GitHub repository.
+
+ """
+ )
+
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
+
+ with gr.Row():
+ with gr.Column():
+ camera_dropdown = gr.Dropdown(
+ choices=[
+ "Predicted",
+ "Pinhole",
+ "Fisheye624",
+ "OPENCV",
+ "Equirectangular",
+ ],
+ label="Input Camera",
+ )
+ model_dropdown = gr.Dropdown(
+ choices=["Large", "Base", "Small"], label="Utilized Model"
+ )
+ mask_black_bg = gr.Checkbox(
+ label="Filter Black Background", value=False
+ )
+ mask_far_points = gr.Checkbox(label="Filter Far Points", value=False)
+
+ with gr.Column():
+ fx = gr.Number(label="Focal length x", value=500.0, visible=False)
+ fy = gr.Number(label="Focal length y", value=500.0, visible=False)
+ cx = gr.Number(label="Center projection x", value=320.0, visible=False)
+ cy = gr.Number(label="Center projection y", value=240.0, visible=False)
+ hfov = gr.Number(
+ label="Horizontal FoV (degree)", value=0.0, visible=False
+ )
+
+ with gr.Column():
+ k1 = gr.Number(label="Radial 1", value=0.0, visible=False)
+ k2 = gr.Number(label="Radial 2", value=0.0, visible=False)
+ k3 = gr.Number(label="Radial 3", value=0.0, visible=False)
+ k4 = gr.Number(label="Radial 4", value=0.0, visible=False)
+
+ with gr.Column():
+ k5 = gr.Number(label="Radial 5", value=0.0, visible=False)
+ k6 = gr.Number(label="Radial 6", value=0.0, visible=False)
+ t1 = gr.Number(label="Tangential 1", value=0.0, visible=False)
+ t2 = gr.Number(label="Tangential 2", value=0.0, visible=False)
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ input_image = gr.Image(label="Upload Images")
+ gr.Markdown("**3D Estimation**")
+ with gr.Row():
+ log_output = gr.Markdown(
+ "Please upload one image at a time, then click `Run UniK3D`.",
+ elem_classes=["custom-log"],
+ )
+ reconstruction_npy = gr.File(
+ label="Download 3D Pointcloud", type="filepath"
+ )
+
+ with gr.Column(scale=2):
+ reconstruction_output = gr.Model3D(
+ height=520, zoom_speed=0.5, pan_speed=0.5
+ )
+ with gr.Row():
+ submit_btn = gr.Button("Run UniK3D", scale=1, variant="primary")
+ clear_btn = gr.ClearButton(
+ [
+ input_image,
+ reconstruction_output,
+ log_output,
+ target_dir_output,
+ reconstruction_npy,
+ ],
+ scale=1,
+ )
+
+ examples = [
+ [
+ "assets/demo/poorthings.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ True,
+ False,
+ ],
+ [
+ "assets/demo/naruto.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/bears.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ True,
+ False,
+ ],
+ [
+ "assets/demo/berzirk.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ True,
+ False,
+ ],
+ [
+ "assets/demo/luke.webp",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/equirectangular.jpg",
+ "Large",
+ "Equirectangular",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 360.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/venice.jpg",
+ "Large",
+ "Equirectangular",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 360.0,
+ False,
+ True,
+ ],
+ [
+ "assets/demo/dl3dv.png",
+ "Large",
+ "OPENCV",
+ 429.57611083984375,
+ 429.6898193359375,
+ 479.5,
+ 269.5,
+ -0.0014844092074781656,
+ 0.0007422995404340327,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.00012013866944471374,
+ 0.001125041046179831,
+ 0.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/scannet.jpg",
+ "Large",
+ "Fisheye624",
+ 791.90869140625,
+ 792.7230834960938,
+ 878.16796875,
+ 585.045166015625,
+ -0.029167557135224342,
+ -0.006803446915000677,
+ -0.0012682401575148106,
+ -4.6094228309812024e-05,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ False,
+ False,
+ ],
+ ]
+
+ def example_pipeline(
+ input_image,
+ model_name,
+ camera_name,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ ):
+ target_dir, image_path = handle_uploads(input_image)
+ glbfile, log_msg, prediction_save_path = gradio_demo(
+ target_dir,
+ model_name,
+ camera_name,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ )
+ return (
+ glbfile,
+ log_msg,
+ prediction_save_path,
+ target_dir,
+ image_path,
+ )
+
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
+
+ gr.Examples(
+ examples=examples,
+ inputs=[
+ input_image,
+ model_dropdown,
+ camera_dropdown,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ ],
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
+ fn=example_pipeline,
+ cache_examples=False,
+ examples_per_page=50,
+ )
+
+ submit_btn.click(
+ fn=clear_fields, inputs=[], outputs=[reconstruction_output]
+ ).then(fn=update_log, inputs=[], outputs=[log_output]).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output,
+ model_dropdown,
+ camera_dropdown,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ ],
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
+ ).then(
+ fn=lambda: "False", inputs=[], outputs=[is_example]
+ )
+
+ mask_black_bg.change(
+ update_visualization,
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
+ [reconstruction_output, log_output],
+ )
+
+ mask_far_points.change(
+ update_visualization,
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
+ [reconstruction_output, log_output],
+ )
+
+ input_image.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_image],
+ outputs=[target_dir_output, log_output],
+ )
+
+ # Dynamically update intrinsic parameter visibility when camera selection changes.
+ camera_dropdown.change(
+ fn=update_parameters,
+ inputs=camera_dropdown,
+ outputs=[fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov],
+ )
+
+ # demo.queue(max_size=20).launch(show_error=True, share=False, ssr_mode=False)
+ demo.launch(
+ show_error=True,
+ )
diff --git a/assets/demo/bears.jpg b/assets/demo/bears.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..99f6ec57a49b31f946b034a0a3831324abf91e25
Binary files /dev/null and b/assets/demo/bears.jpg differ
diff --git a/assets/demo/berzirk.jpg b/assets/demo/berzirk.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..e5fa22911064487993c6a70297a544655fb172a2
Binary files /dev/null and b/assets/demo/berzirk.jpg differ
diff --git a/assets/demo/dl3dv.json b/assets/demo/dl3dv.json
new file mode 100644
index 0000000000000000000000000000000000000000..6bf2bbd8a0a952e9a8462a975d1e1e63c5a0a0cc
--- /dev/null
+++ b/assets/demo/dl3dv.json
@@ -0,0 +1,4 @@
+{
+ "name": "OPENCV",
+ "params": [429.57611083984375, 429.6898193359375, 479.5, 269.5, -0.0014844092074781656, 0.0007422995404340327, 0.0, 0.0, 0.0, 0.0, 0.00012013866944471374, 0.001125041046179831, 0.0, 0.0, 0.0, 0.0]
+}
\ No newline at end of file
diff --git a/assets/demo/dl3dv.png b/assets/demo/dl3dv.png
new file mode 100644
index 0000000000000000000000000000000000000000..a3f2bb1302251b09c20f99a452bae29d6251daa6
Binary files /dev/null and b/assets/demo/dl3dv.png differ
diff --git a/assets/demo/equirectangular.jpg b/assets/demo/equirectangular.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ef669756640bf3b9f71775ae078c116502e9184d
Binary files /dev/null and b/assets/demo/equirectangular.jpg differ
diff --git a/assets/demo/kitti360.json b/assets/demo/kitti360.json
new file mode 100644
index 0000000000000000000000000000000000000000..cdacd45482f9052ca9a10047a6db44c5a7a670ce
--- /dev/null
+++ b/assets/demo/kitti360.json
@@ -0,0 +1,14 @@
+{
+ "params": [
+ 890.8814086914062,
+ 890.5255737304688,
+ 477.7955017089844,
+ 470.34332275390625,
+ 0.016798235476017,
+ 1.6548773050308228,
+ 0.000422239420004189,
+ 0.000424621335696429,
+ 2.213404655456543
+ ],
+ "name": "MEI"
+}
\ No newline at end of file
diff --git a/assets/demo/kitti360.png b/assets/demo/kitti360.png
new file mode 100644
index 0000000000000000000000000000000000000000..4a58026a71cfbc3787dc2e91c11d29d54ef90138
Binary files /dev/null and b/assets/demo/kitti360.png differ
diff --git a/assets/demo/luke.webp b/assets/demo/luke.webp
new file mode 100644
index 0000000000000000000000000000000000000000..b60384b4bb9bf67a626b3f84e226bfa39b2e561c
Binary files /dev/null and b/assets/demo/luke.webp differ
diff --git a/assets/demo/naruto.jpg b/assets/demo/naruto.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6bb0aad52c5aeaa9a1ac062fede8d1313017f062
Binary files /dev/null and b/assets/demo/naruto.jpg differ
diff --git a/assets/demo/poorthings.jpg b/assets/demo/poorthings.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..12b018df8c45097d3cb3c2ff2c6725327c0897e9
Binary files /dev/null and b/assets/demo/poorthings.jpg differ
diff --git a/assets/demo/scannet.jpg b/assets/demo/scannet.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..ad3b4a68bbf4e5a232de75a2bf19ac29d84687f8
Binary files /dev/null and b/assets/demo/scannet.jpg differ
diff --git a/assets/demo/scannet.json b/assets/demo/scannet.json
new file mode 100644
index 0000000000000000000000000000000000000000..4df47995b7af91b18360192537e2aa3aad80a527
--- /dev/null
+++ b/assets/demo/scannet.json
@@ -0,0 +1,21 @@
+{
+ "params": [
+ 791.90869140625,
+ 792.7230834960938,
+ 878.16796875,
+ 585.045166015625,
+ -0.029167557135224342,
+ -0.006803446915000677,
+ -0.0012682401575148106,
+ -4.6094228309812024e-05,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0
+ ],
+ "name": "Fisheye624"
+}
\ No newline at end of file
diff --git a/assets/demo/venice.jpg b/assets/demo/venice.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..2c6c7dd8629c9887b51786b92a1c8e110014070b
Binary files /dev/null and b/assets/demo/venice.jpg differ
diff --git a/assets/docs/unik3d-banner.png b/assets/docs/unik3d-banner.png
new file mode 100644
index 0000000000000000000000000000000000000000..598ac6713d629653e18850bf493967cb07800acb
Binary files /dev/null and b/assets/docs/unik3d-banner.png differ
diff --git a/assets/docs/unik3d-teaser.png b/assets/docs/unik3d-teaser.png
new file mode 100644
index 0000000000000000000000000000000000000000..a60145bf44be0cc2091ef94471608de1a98378d3
Binary files /dev/null and b/assets/docs/unik3d-teaser.png differ
diff --git a/configs/config_vitb.json b/configs/config_vitb.json
new file mode 100644
index 0000000000000000000000000000000000000000..ad6b813bb4f43744b45bd5ba7626f28b03e48302
--- /dev/null
+++ b/configs/config_vitb.json
@@ -0,0 +1,159 @@
+{
+ "generic": {
+ "seed": 42,
+ "deterministic": true,
+ "name_page": "ufish"
+ },
+ "training": {
+ "n_iters": 250000,
+ "batch_size": 8,
+ "validation_interval": 2500,
+ "nsteps_accumulation_gradient": 4,
+ "lr": 5e-05,
+ "lr_final": 1e-06,
+ "lr_warmup": 1.0,
+ "cycle_beta": true,
+ "wd": 0.1,
+ "wd_final": 0.1,
+ "warmup_iters": 75000,
+ "ld": 1.0,
+ "drop_path": 0.0,
+ "ema": 0.9995,
+ "f16": "f16",
+ "clipping": 1.0,
+ "losses": {
+ "depth": {
+ "name": "Scale",
+ "weight": 1.0,
+ "fn": "l1",
+ "gamma": 1.0,
+ "alpha": 1.0,
+ "output_fn": "sqrt",
+ "input_fn": "log"
+ },
+ "camera": {
+ "name": "PolarRegression",
+ "weight": 1.0,
+ "gamma": 1.0,
+ "alpha": 1.0,
+ "fn": "l1",
+ "output_fn": "sqrt",
+ "input_fn": "linear",
+ "dims": [
+ 1,
+ 2
+ ],
+ "polar_weight": 3.0,
+ "polar_asym": 0.7
+ },
+ "confidence": {
+ "name": "Confidence",
+ "weight": 0.1,
+ "input_fn": "log",
+ "output_fn": "sqrt"
+ }
+ }
+ },
+ "data": {
+ "image_shape": [
+ 518,
+ 518
+ ],
+ "resize_method": "contextcrop",
+ "normalization": "imagenet",
+ "pair": 1,
+ "mini": 1.0,
+ "num_frames": 1,
+ "sampling": {
+ "KITTI": 1.0
+ },
+ "train_datasets": [
+ "KITTI"
+ ],
+ "val_datasets": [
+ "KITTI"
+ ],
+ "data_root": "datasets",
+ "crop": "garg",
+ "augmentations": {
+ "random_scale": 4.0,
+ "random_translate_x": 0.04,
+ "random_translate_y": 0.01,
+ "scale_p": 0.0,
+ "translate_p": 0.0,
+ "random_rotation": 0.0,
+ "rotation_p": 0.0,
+ "random_shear": 0.0,
+ "affine_p": 0.0,
+ "random_jitter": 0.5,
+ "jitter_p": 1.0,
+ "random_blur": 2.0,
+ "blur_p": 0.5,
+ "random_gamma": 0.5,
+ "gamma_p": 1.0,
+ "grayscale_p": 0.2,
+ "flip_p": 0.5,
+ "cut_p": 0.0,
+ "invert_p": 0.0,
+ "shape_mult": 14,
+ "noise_pad": 1.0,
+ "test_context": 1.0
+ },
+ "shape_constraints": {
+ "ratio_bounds": [
+ 0.5,
+ 2.5
+ ],
+ "pixels_max": 600000.0,
+ "pixels_min": 200000.0,
+ "height_min": 15,
+ "width_min": 15,
+ "shape_mult": 14,
+ "sample": true
+ }
+ },
+ "model": {
+ "name": "UniK3D",
+ "num_heads": 8,
+ "expansion": 4,
+ "num_steps": 100000,
+ "layer_scale": 1e-4,
+ "camera": {
+ "augment": true,
+ "weak_ratio": 0.9,
+ "tau": 50000
+ },
+ "pixel_decoder": {
+ "name": "Decoder",
+ "hidden_dim": 384,
+ "dropout": 0.0,
+ "depths": [
+ 2,
+ 2,
+ 2
+ ],
+ "detach": 0.1,
+ "out_dim": 48,
+ "kernel_size": 3,
+ "num_prompt_blocks": 1,
+ "use_norm": false
+ },
+ "pixel_encoder": {
+ "lr": 3e-06,
+ "wd": 0.1,
+ "name": "dinov2_vitb14",
+ "frozen_stages": 0,
+ "num_register_tokens": 0,
+ "use_norm": true,
+ "freeze_norm": true,
+ "pretrained": null,
+ "stacking_fn": "last",
+ "output_idx": [
+ 3,
+ 6,
+ 9,
+ 12
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/configs/config_vitl.json b/configs/config_vitl.json
new file mode 100644
index 0000000000000000000000000000000000000000..ca6865b908e9c8a0319ed0e6f3e5bc3f7450f362
--- /dev/null
+++ b/configs/config_vitl.json
@@ -0,0 +1,159 @@
+{
+ "generic": {
+ "seed": 42,
+ "deterministic": true,
+ "name_page": "ufish"
+ },
+ "training": {
+ "n_iters": 250000,
+ "batch_size": 8,
+ "validation_interval": 2500,
+ "nsteps_accumulation_gradient": 4,
+ "lr": 5e-05,
+ "lr_final": 1e-06,
+ "lr_warmup": 1.0,
+ "cycle_beta": true,
+ "wd": 0.1,
+ "wd_final": 0.1,
+ "warmup_iters": 75000,
+ "ld": 1.0,
+ "drop_path": 0.0,
+ "ema": 0.9995,
+ "f16": "f16",
+ "clipping": 1.0,
+ "losses": {
+ "depth": {
+ "name": "Scale",
+ "weight": 1.0,
+ "fn": "l1",
+ "gamma": 1.0,
+ "alpha": 1.0,
+ "output_fn": "sqrt",
+ "input_fn": "log"
+ },
+ "camera": {
+ "name": "PolarRegression",
+ "weight": 1.0,
+ "gamma": 1.0,
+ "alpha": 1.0,
+ "fn": "l1",
+ "output_fn": "sqrt",
+ "input_fn": "linear",
+ "dims": [
+ 1,
+ 2
+ ],
+ "polar_weight": 3.0,
+ "polar_asym": 0.7
+ },
+ "confidence": {
+ "name": "Confidence",
+ "weight": 0.1,
+ "input_fn": "log",
+ "output_fn": "sqrt"
+ }
+ }
+ },
+ "data": {
+ "image_shape": [
+ 518,
+ 518
+ ],
+ "resize_method": "contextcrop",
+ "normalization": "imagenet",
+ "pair": 1,
+ "mini": 1.0,
+ "num_frames": 1,
+ "sampling": {
+ "KITTI": 1.0
+ },
+ "train_datasets": [
+ "KITTI"
+ ],
+ "val_datasets": [
+ "KITTI"
+ ],
+ "data_root": "datasets",
+ "crop": "garg",
+ "augmentations": {
+ "random_scale": 4.0,
+ "random_translate_x": 0.04,
+ "random_translate_y": 0.01,
+ "scale_p": 0.0,
+ "translate_p": 0.0,
+ "random_rotation": 0.0,
+ "rotation_p": 0.0,
+ "random_shear": 0.0,
+ "affine_p": 0.0,
+ "random_jitter": 0.5,
+ "jitter_p": 1.0,
+ "random_blur": 2.0,
+ "blur_p": 0.5,
+ "random_gamma": 0.5,
+ "gamma_p": 1.0,
+ "grayscale_p": 0.2,
+ "flip_p": 0.5,
+ "cut_p": 0.0,
+ "invert_p": 0.0,
+ "shape_mult": 14,
+ "noise_pad": 1.0,
+ "test_context": 1.0
+ },
+ "shape_constraints": {
+ "ratio_bounds": [
+ 0.5,
+ 2.5
+ ],
+ "pixels_max": 600000.0,
+ "pixels_min": 200000.0,
+ "height_min": 15,
+ "width_min": 15,
+ "shape_mult": 14,
+ "sample": true
+ }
+ },
+ "model": {
+ "name": "UniK3D",
+ "num_heads": 8,
+ "expansion": 4,
+ "num_steps": 100000,
+ "layer_scale": 1e-4,
+ "camera": {
+ "augment": true,
+ "weak_ratio": 0.9,
+ "tau": 50000
+ },
+ "pixel_decoder": {
+ "name": "Decoder",
+ "hidden_dim": 512,
+ "dropout": 0.0,
+ "depths": [
+ 2,
+ 2,
+ 2
+ ],
+ "detach": 0.1,
+ "out_dim": 64,
+ "kernel_size": 3,
+ "num_prompt_blocks": 1,
+ "use_norm": false
+ },
+ "pixel_encoder": {
+ "lr": 3e-06,
+ "wd": 0.1,
+ "name": "dinov2_vitl14",
+ "frozen_stages": 0,
+ "num_register_tokens": 0,
+ "use_norm": true,
+ "freeze_norm": true,
+ "pretrained": null,
+ "stacking_fn": "last",
+ "output_idx": [
+ 6,
+ 12,
+ 18,
+ 24
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/configs/config_vits.json b/configs/config_vits.json
new file mode 100644
index 0000000000000000000000000000000000000000..ca37e00100b51e681f233f9b638aa0c4beb7b734
--- /dev/null
+++ b/configs/config_vits.json
@@ -0,0 +1,159 @@
+{
+ "generic": {
+ "seed": 42,
+ "deterministic": true,
+ "name_page": "ufish"
+ },
+ "training": {
+ "n_iters": 250000,
+ "batch_size": 8,
+ "validation_interval": 2500,
+ "nsteps_accumulation_gradient": 4,
+ "lr": 5e-05,
+ "lr_final": 1e-06,
+ "lr_warmup": 1.0,
+ "cycle_beta": true,
+ "wd": 0.1,
+ "wd_final": 0.1,
+ "warmup_iters": 75000,
+ "ld": 1.0,
+ "drop_path": 0.0,
+ "ema": 0.9995,
+ "f16": "f16",
+ "clipping": 1.0,
+ "losses": {
+ "depth": {
+ "name": "Scale",
+ "weight": 1.0,
+ "fn": "l1",
+ "gamma": 1.0,
+ "alpha": 1.0,
+ "output_fn": "sqrt",
+ "input_fn": "log"
+ },
+ "camera": {
+ "name": "PolarRegression",
+ "weight": 1.0,
+ "gamma": 1.0,
+ "alpha": 1.0,
+ "fn": "l1",
+ "output_fn": "sqrt",
+ "input_fn": "linear",
+ "dims": [
+ 1,
+ 2
+ ],
+ "polar_weight": 3.0,
+ "polar_asym": 0.7
+ },
+ "confidence": {
+ "name": "Confidence",
+ "weight": 0.1,
+ "input_fn": "log",
+ "output_fn": "sqrt"
+ }
+ }
+ },
+ "data": {
+ "image_shape": [
+ 518,
+ 518
+ ],
+ "resize_method": "contextcrop",
+ "normalization": "imagenet",
+ "pair": 1,
+ "mini": 1.0,
+ "num_frames": 1,
+ "sampling": {
+ "KITTI": 1.0
+ },
+ "train_datasets": [
+ "KITTI"
+ ],
+ "val_datasets": [
+ "KITTI"
+ ],
+ "data_root": "datasets",
+ "crop": "garg",
+ "augmentations": {
+ "random_scale": 4.0,
+ "random_translate_x": 0.04,
+ "random_translate_y": 0.01,
+ "scale_p": 0.0,
+ "translate_p": 0.0,
+ "random_rotation": 0.0,
+ "rotation_p": 0.0,
+ "random_shear": 0.0,
+ "affine_p": 0.0,
+ "random_jitter": 0.5,
+ "jitter_p": 1.0,
+ "random_blur": 2.0,
+ "blur_p": 0.5,
+ "random_gamma": 0.5,
+ "gamma_p": 1.0,
+ "grayscale_p": 0.2,
+ "flip_p": 0.5,
+ "cut_p": 0.0,
+ "invert_p": 0.0,
+ "shape_mult": 14,
+ "noise_pad": 1.0,
+ "test_context": 1.0
+ },
+ "shape_constraints": {
+ "ratio_bounds": [
+ 0.5,
+ 2.5
+ ],
+ "pixels_max": 600000.0,
+ "pixels_min": 200000.0,
+ "height_min": 15,
+ "width_min": 15,
+ "shape_mult": 14,
+ "sample": true
+ }
+ },
+ "model": {
+ "name": "UniK3D",
+ "num_heads": 8,
+ "expansion": 4,
+ "num_steps": 100000,
+ "layer_scale": 1e-4,
+ "camera": {
+ "augment": true,
+ "weak_ratio": 0.9,
+ "tau": 50000
+ },
+ "pixel_decoder": {
+ "name": "Decoder",
+ "hidden_dim": 256,
+ "dropout": 0.0,
+ "depths": [
+ 2,
+ 2,
+ 2
+ ],
+ "detach": 0.1,
+ "out_dim": 32,
+ "kernel_size": 3,
+ "num_prompt_blocks": 1,
+ "use_norm": false
+ },
+ "pixel_encoder": {
+ "lr": 3e-06,
+ "wd": 0.1,
+ "name": "dinov2_vits14",
+ "frozen_stages": 0,
+ "num_register_tokens": 0,
+ "use_norm": true,
+ "freeze_norm": true,
+ "pretrained": null,
+ "stacking_fn": "last",
+ "output_idx": [
+ 3,
+ 6,
+ 9,
+ 12
+ ]
+ }
+ }
+}
\ No newline at end of file
diff --git a/gradio_demo.py b/gradio_demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f7c654cda3d22989cf54fa0a25c91d5f46341dd
--- /dev/null
+++ b/gradio_demo.py
@@ -0,0 +1,796 @@
+import gc
+import os
+import shutil
+import time
+from datetime import datetime
+from math import pi
+
+import gradio as gr
+import numpy as np
+import torch
+import trimesh
+from PIL import Image
+
+from unik3d.models import UniK3D
+from unik3d.utils.camera import OPENCV, Fisheye624, Pinhole, Spherical
+from unik3d.utils.visualization import colorize
+
+
+def predictions_to_glb(
+ predictions,
+ mask_black_bg=False,
+ mask_far_points=False,
+) -> trimesh.Scene:
+ print("Building GLB scene")
+ images = predictions["image"].squeeze().permute(1, 2, 0).cpu().numpy()
+ world_points = predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy()
+
+ vertices_3d = world_points.reshape(-1, 3)
+ # flip x and y
+ vertices_3d[:, 1] *= -1
+ vertices_3d[:, 0] *= -1
+ colors_rgb = (images.reshape(-1, 3)).astype(np.uint8)
+
+ if mask_black_bg:
+ black_bg_mask = colors_rgb.sum(axis=1) >= 16
+ vertices_3d = vertices_3d[black_bg_mask]
+ colors_rgb = colors_rgb[black_bg_mask]
+
+ if mask_far_points:
+ far_points_mask = np.linalg.norm(vertices_3d, axis=-1) < 100.0
+ vertices_3d = vertices_3d[far_points_mask]
+ colors_rgb = colors_rgb[far_points_mask]
+
+ scene_3d = trimesh.Scene()
+ point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
+ scene_3d.add_geometry(point_cloud_data)
+
+ return scene_3d
+
+
+def instantiate_model(model_name):
+ type_ = model_name[0].lower()
+
+ name = f"unik3d-vit{type_}"
+ model = UniK3D.from_pretrained(f"lpiccinelli/{name}")
+
+ # Set resolution level and interpolation mode as specified.
+ model.resolution_level = 9
+ model.interpolation_mode = "bilinear"
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device).eval()
+ return model
+
+
+def instantiate_camera(camera_name, params, device):
+ if camera_name == "Predicted":
+ return None
+ fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov, H, W = params
+ if camera_name == "Pinhole":
+ params = [fx, fy, cx, cy]
+ elif camera_name == "Fisheye624":
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
+ elif camera_name == "OPENCV":
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2]
+ elif camera_name == "Equirectangular":
+ # dummy intrinsics for spherical camera, assume hfov -> vfov based on input shapes
+ hfov2 = hfov * pi / 180.0 / 2
+ params = [fx, fy, cx, cy, W, H, hfov2, H / W * hfov2]
+ camera_name = "Spherical"
+
+ return eval(camera_name)(params=torch.tensor(params).float()).to(device)
+
+
+def run_model(target_dir, model_name, camera_name, params):
+
+ print("Instantiating model and camera...")
+ model = instantiate_model(model_name)
+
+ image_names = [x for x in os.listdir(target_dir) if x.endswith(".png")]
+ input_image = np.array(Image.open(os.path.join(target_dir, image_names[-1])))
+ image_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float()
+ device = next(model.parameters()).device
+ image_tensor = image_tensor.to(device)
+ H, W = image_tensor.shape[-2:]
+ params = params + [H, W]
+ camera = instantiate_camera(camera_name, params=params, device=device)
+
+ # Perform inference with the model.
+ print("Running inference...")
+ outputs = model.infer(image_tensor, camera=camera, normalize=True)
+ outputs["image"] = image_tensor
+
+ return outputs
+
+
+def gradio_demo(
+ target_dir,
+ model_name,
+ camera_name,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+):
+ print(target_dir)
+ if not os.path.isdir(target_dir) or target_dir == "None":
+ return None, "No valid target directory found. Please upload first.", None
+
+ start_time = time.time()
+ gc.collect()
+
+ print("Running run_model...")
+ params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov]
+ with torch.no_grad():
+ outputs = run_model(target_dir, model_name, camera_name, params)
+
+ # Save predictions
+ points = outputs["points"].squeeze().permute(1, 2, 0).cpu().numpy()
+ rgb = outputs["image"].squeeze().permute(1, 2, 0).cpu().numpy()
+
+ prediction_save_path = os.path.join(target_dir, "predictions.npz")
+ np.savez(prediction_save_path, {"points": points, "image": rgb})
+
+ # Build a GLB file name
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene.glb",
+ )
+
+ # Convert predictions to GLB
+ glbscene = predictions_to_glb(
+ outputs,
+ mask_black_bg=mask_black_bg,
+ mask_far_points=mask_far_points,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ # Cleanup
+ del outputs
+ gc.collect()
+
+ end_time = time.time()
+ print(f"Total time: {end_time - start_time:.2f} seconds")
+ log_msg = f"Success. Waiting for visualization."
+
+ return glbfile, log_msg, prediction_save_path
+
+
+def handle_uploads(input_image):
+ gc.collect()
+
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f")
+ tmpdir = os.environ.get("TMPDIR", "/tmp")
+ target_dir = os.path.join(tmpdir, f"input_images_{timestamp}")
+
+ if os.path.exists(target_dir):
+ shutil.rmtree(target_dir)
+ os.makedirs(target_dir)
+
+ dst_path = os.path.join(target_dir, "image.png")
+ Image.fromarray(input_image).save(dst_path)
+ image_paths = [dst_path]
+
+ print(f"Files uploaded.")
+ return target_dir, image_paths
+
+
+def update_gallery_on_upload(input_images):
+ if input_images is None:
+ return None, None
+ target_dir, image_path = handle_uploads(input_images)
+ return target_dir, "Upload complete. Click 'Run UniK3D' to get 3D pointcloud."
+
+
+def update_parameters(camera):
+ if camera == "Pinhole":
+ return (
+ gr.update(visible=True), # fx
+ gr.update(visible=True), # fy
+ gr.update(visible=True), # cx
+ gr.update(visible=True), # cy
+ gr.update(visible=False), # k1
+ gr.update(visible=False), # k2
+ gr.update(visible=False), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=False), # t1
+ gr.update(visible=False), # t2
+ gr.update(visible=False), # hfov
+ )
+ elif camera == "OPENCV":
+ return (
+ gr.update(visible=True), # fx
+ gr.update(visible=True), # fy
+ gr.update(visible=True), # cx
+ gr.update(visible=True), # cy
+ gr.update(visible=True), # k1
+ gr.update(visible=True), # k2
+ gr.update(visible=True), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=True), # t1
+ gr.update(visible=True), # t2
+ gr.update(visible=False), # hfov
+ )
+ elif camera == "Fisheye624":
+ return (
+ gr.update(visible=True), # fx
+ gr.update(visible=True), # fy
+ gr.update(visible=True), # cx
+ gr.update(visible=True), # cy
+ gr.update(visible=True), # k1
+ gr.update(visible=True), # k2
+ gr.update(visible=True), # k3
+ gr.update(visible=True), # k4
+ gr.update(visible=True), # k5
+ gr.update(visible=True), # k6
+ gr.update(visible=True), # t1
+ gr.update(visible=True), # t2
+ gr.update(visible=False), # hfov
+ )
+ elif camera == "Equirectangular":
+ return (
+ gr.update(visible=False), # fx
+ gr.update(visible=False), # fy
+ gr.update(visible=False), # cx
+ gr.update(visible=False), # cy
+ gr.update(visible=False), # k1
+ gr.update(visible=False), # k2
+ gr.update(visible=False), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=False), # t1
+ gr.update(visible=False), # t2
+ gr.update(visible=True), # hfov
+ )
+ elif camera == "Predicted":
+ return (
+ gr.update(visible=False), # fx
+ gr.update(visible=False), # fy
+ gr.update(visible=False), # cx
+ gr.update(visible=False), # cy
+ gr.update(visible=False), # k1
+ gr.update(visible=False), # k2
+ gr.update(visible=False), # k3
+ gr.update(visible=False), # k4
+ gr.update(visible=False), # k5
+ gr.update(visible=False), # k6
+ gr.update(visible=False), # t1
+ gr.update(visible=False), # t2
+ gr.update(visible=False), # hfov
+ )
+ else:
+ raise ValueError(f"Invalid camera type: {camera}")
+
+
+def clear_fields():
+ return None
+
+
+def update_log():
+ return "Loading Model and Running Inference..."
+
+
+def update_visualization(target_dir, mask_black_bg, mask_far_points, is_example):
+
+ if is_example == "True":
+ return (
+ None,
+ "No reconstruction available. Please click the Reconstruct button first.",
+ )
+
+ if not target_dir or target_dir == "None" or not os.path.isdir(target_dir):
+ return (
+ None,
+ "No reconstruction available. Please click the Reconstruct button first.",
+ )
+
+ predictions_path = os.path.join(target_dir, "predictions.npz")
+ if not os.path.exists(predictions_path):
+ return (
+ None,
+ f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.",
+ )
+
+ loaded = np.load(predictions_path, allow_pickle=True)
+ predictions = {key: loaded[key] for key in loaded.keys()}
+
+ glbfile = os.path.join(
+ target_dir,
+ f"glbscene.glb",
+ )
+
+ if not os.path.exists(glbfile):
+ glbscene = predictions_to_glb(
+ predictions,
+ mask_black_bg=mask_black_bg,
+ mask_far_points=mask_far_points,
+ )
+ glbscene.export(file_obj=glbfile)
+
+ return glbfile, "Updating Visualization"
+
+
+if __name__ == "__main__":
+ theme = gr.themes.Citrus()
+ theme.set(
+ checkbox_label_background_fill_selected="*button_primary_background_fill",
+ checkbox_label_text_color_selected="*button_primary_text_color",
+ )
+
+ with gr.Blocks(
+ theme=theme,
+ css="""
+ .custom-log * {
+ font-style: italic;
+ font-size: 22px !important;
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ font-weight: bold !important;
+ color: transparent !important;
+ text-align: center !important;
+ }
+
+ .example-log * {
+ font-style: italic;
+ font-size: 16px !important;
+ background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%);
+ -webkit-background-clip: text;
+ background-clip: text;
+ color: transparent !important;
+ }
+
+ #my_radio .wrap {
+ display: flex;
+ flex-wrap: nowrap;
+ justify-content: center;
+ align-items: center;
+ }
+
+ #my_radio .wrap label {
+ display: flex;
+ width: 50%;
+ justify-content: center;
+ align-items: center;
+ margin: 0;
+ padding: 10px 0;
+ box-sizing: border-box;
+ }
+ """,
+ ) as demo:
+
+ # Instead of gr.State, we use a hidden Textbox:
+ is_example = gr.Textbox(label="is_example", visible=False, value="None")
+
+ gr.HTML(
+ """
+ UniK3D: Universal Camera Monocular 3D Estimation
+
+ 🌟 GitHub Repository |
+ 🚀 Project Page
+
+
+
+
Upload one image to create a 3D estimation of a scene or object. UniK3D allows to predict directly 3D of any camera and scene.
+
+
Getting Started:
+
+ - Upload Your Image: Use the "Upload Images" panel to provide your input.
+ - Run: Click the "Run UniK3D" button to start the 3D estimation process.
+ - Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.
+
+
Please note: Our model runs on CPU on HuggingFace Space. Actual inference is less than 100ms second per image on consumer-level GPUs. Web-based 3D pointcloud visualization may be slow due to Gradio's rendering. For faster visualization, use a local machine to run our demo from our GitHub repository.
+
+ """
+ )
+
+ target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None")
+
+ with gr.Row():
+ with gr.Column():
+ camera_dropdown = gr.Dropdown(
+ choices=[
+ "Predicted",
+ "Pinhole",
+ "Fisheye624",
+ "OPENCV",
+ "Equirectangular",
+ ],
+ label="Input Camera",
+ )
+ model_dropdown = gr.Dropdown(
+ choices=["Large", "Base", "Small"], label="Utilized Model"
+ )
+ mask_black_bg = gr.Checkbox(
+ label="Filter Black Background", value=False
+ )
+ mask_far_points = gr.Checkbox(label="Filter Far Points", value=False)
+
+ with gr.Column():
+ fx = gr.Number(label="Focal length x", value=500.0, visible=False)
+ fy = gr.Number(label="Focal length y", value=500.0, visible=False)
+ cx = gr.Number(label="Center projection x", value=320.0, visible=False)
+ cy = gr.Number(label="Center projection y", value=240.0, visible=False)
+ hfov = gr.Number(
+ label="Horizontal FoV (degree)", value=0.0, visible=False
+ )
+
+ with gr.Column():
+ k1 = gr.Number(label="Radial 1", value=0.0, visible=False)
+ k2 = gr.Number(label="Radial 2", value=0.0, visible=False)
+ k3 = gr.Number(label="Radial 3", value=0.0, visible=False)
+ k4 = gr.Number(label="Radial 4", value=0.0, visible=False)
+
+ with gr.Column():
+ k5 = gr.Number(label="Radial 5", value=0.0, visible=False)
+ k6 = gr.Number(label="Radial 6", value=0.0, visible=False)
+ t1 = gr.Number(label="Tangential 1", value=0.0, visible=False)
+ t2 = gr.Number(label="Tangential 2", value=0.0, visible=False)
+
+ with gr.Row():
+ with gr.Column(scale=1):
+ input_image = gr.Image(label="Upload Images")
+ gr.Markdown("**3D Estimation**")
+ with gr.Row():
+ log_output = gr.Markdown(
+ "Please upload one image at a time, then click `Run UniK3D`.",
+ elem_classes=["custom-log"],
+ )
+ reconstruction_npy = gr.File(
+ label="Download 3D Pointcloud", type="filepath"
+ )
+
+ with gr.Column(scale=2):
+ reconstruction_output = gr.Model3D(
+ height=520, zoom_speed=0.5, pan_speed=0.5
+ )
+ with gr.Row():
+ submit_btn = gr.Button("Run UniK3D", scale=1, variant="primary")
+ clear_btn = gr.ClearButton(
+ [
+ input_image,
+ reconstruction_output,
+ log_output,
+ target_dir_output,
+ reconstruction_npy,
+ ],
+ scale=1,
+ )
+
+ examples = [
+ [
+ "assets/demo/poorthings.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ True,
+ False,
+ ],
+ [
+ "assets/demo/naruto.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/bears.png",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ True,
+ False,
+ ],
+ [
+ "assets/demo/berzirk.jpg",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ True,
+ False,
+ ],
+ [
+ "assets/demo/luke.webp",
+ "Large",
+ "Predicted",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/equirectangular.jpg",
+ "Large",
+ "Equirectangular",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 360.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/venice.jpg",
+ "Large",
+ "Equirectangular",
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 360.0,
+ False,
+ True,
+ ],
+ [
+ "assets/demo/dl3dv.png",
+ "Large",
+ "OPENCV",
+ 429.57611083984375,
+ 429.6898193359375,
+ 479.5,
+ 269.5,
+ -0.0014844092074781656,
+ 0.0007422995404340327,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.00012013866944471374,
+ 0.001125041046179831,
+ 0.0,
+ False,
+ False,
+ ],
+ [
+ "assets/demo/scannet.png",
+ "Large",
+ "Fisheye624",
+ 791.90869140625,
+ 792.7230834960938,
+ 878.16796875,
+ 585.045166015625,
+ -0.029167557135224342,
+ -0.006803446915000677,
+ -0.0012682401575148106,
+ -4.6094228309812024e-05,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ 0.0,
+ False,
+ False,
+ ],
+ ]
+
+ def example_pipeline(
+ input_image,
+ model_name,
+ camera_name,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ ):
+ target_dir, image_path = handle_uploads(input_image)
+ glbfile, log_msg, prediction_save_path = gradio_demo(
+ target_dir,
+ model_name,
+ camera_name,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ )
+ return (
+ glbfile,
+ log_msg,
+ prediction_save_path,
+ target_dir,
+ image_path,
+ )
+
+ gr.Markdown("Click any row to load an example.", elem_classes=["example-log"])
+
+ gr.Examples(
+ examples=examples,
+ inputs=[
+ input_image,
+ model_dropdown,
+ camera_dropdown,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ ],
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
+ fn=example_pipeline,
+ cache_examples=False,
+ examples_per_page=50,
+ )
+
+ submit_btn.click(
+ fn=clear_fields, inputs=[], outputs=[reconstruction_output]
+ ).then(fn=update_log, inputs=[], outputs=[log_output]).then(
+ fn=gradio_demo,
+ inputs=[
+ target_dir_output,
+ model_dropdown,
+ camera_dropdown,
+ fx,
+ fy,
+ cx,
+ cy,
+ k1,
+ k2,
+ k3,
+ k4,
+ k5,
+ k6,
+ t1,
+ t2,
+ hfov,
+ mask_black_bg,
+ mask_far_points,
+ ],
+ outputs=[reconstruction_output, log_output, reconstruction_npy],
+ ).then(
+ fn=lambda: "False", inputs=[], outputs=[is_example]
+ )
+
+ mask_black_bg.change(
+ update_visualization,
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
+ [reconstruction_output, log_output],
+ )
+
+ mask_far_points.change(
+ update_visualization,
+ [target_dir_output, mask_black_bg, mask_far_points, is_example],
+ [reconstruction_output, log_output],
+ )
+
+ input_image.change(
+ fn=update_gallery_on_upload,
+ inputs=[input_image],
+ outputs=[target_dir_output, log_output],
+ )
+
+ # Dynamically update intrinsic parameter visibility when camera selection changes.
+ camera_dropdown.change(
+ fn=update_parameters,
+ inputs=camera_dropdown,
+ outputs=[fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov],
+ )
+
+ # demo.queue(max_size=20).launch(show_error=True, share=False, ssr_mode=False)
+ demo.launch(
+ show_error=True,
+ )
diff --git a/hubconf.py b/hubconf.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f66ecb5b1cb48e8169d25578656f3137cf4498b
--- /dev/null
+++ b/hubconf.py
@@ -0,0 +1,29 @@
+dependencies = ["torch", "huggingface_hub"]
+
+import os
+import json
+
+import torch
+import huggingface_hub
+
+from unik3d.models import UniK3D as UniK3D_
+
+BACKBONES = ["vitl", "vitb", "vits"]
+
+
+def UniK3D(backbone="vitl", pretrained=True):
+ assert backbone in BACKBONES, f"backbone must be one of {BACKBONES}"
+ repo_dir = os.path.dirname(os.path.realpath(__file__))
+ with open(os.path.join(repo_dir, "configs", f"config_{backbone}.json")) as f:
+ config = json.load(f)
+
+ model = UniK3D_(config)
+ if pretrained:
+ path = huggingface_hub.hf_hub_download(repo_id=f"lpiccinelli/unik3d-{backbone}", filename=f"pytorch_model.bin", repo_type="model")
+ info = model.load_state_dict(torch.load(path), strict=False)
+ print(f"UniK3D-{backbone} is loaded with:")
+ print(f"\t missing keys: {info.missing_keys}")
+ print(f"\t additional keys: {info.unexpected_keys}")
+
+ return model
+
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 0000000000000000000000000000000000000000..fd34b34b16cc63a01cd357e6eea96433888acda0
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,25 @@
+[build-system]
+requires = ["setuptools"]
+build-backend = "setuptools.build_meta"
+
+[tool.pyright]
+include = ["unik3d"]
+
+[project]
+name = "unik3d"
+version = "0.1"
+authors = [{name = "Luigi Piccinelli", email = "lpiccinelli@ethz.ch"}]
+description = "UniK3D: Universal Monocular Metric Depth Estimation"
+readme = "README.md"
+license = { text="Creatives Common BY-NC 4.0 license"}
+requires-python = ">=3.11.0"
+dynamic = ["dependencies"]
+
+[tool.setuptools.dynamic]
+dependencies = {file = ["requirements.txt"]}
+
+[tool.setuptools.package-data]
+"*" = ["py.typed"]
+
+[tool.setuptools.packages.find]
+include = ["unik3d*"]
diff --git a/requirements.txt b/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..536fd8919111c6731776590ebd8e821b04f73bd7
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,84 @@
+appdirs
+attrs
+black
+blosc2
+botocore>=1.34.54
+certifi>=2022.12.7
+charset-normalizer
+click
+contourpy
+cycler
+docker-pycreds
+einops>=0.7.0
+filelock
+flake8>=7.0.0
+flake8-bugbear>=24.2.6
+flake8-comprehensions>=3.14.0
+fonttools
+fsspec
+fvcore>=0.1.5.post20221221
+gitdb
+GitPython
+gradio
+h5py>=3.10.0
+huggingface-hub>=0.22.0
+idna
+imageio
+imath
+iopath
+isort
+Jinja2
+jmespath
+kiwisolver
+MarkupSafe
+matplotlib
+mccabe
+mpmath
+msgpack
+mypy-extensions
+ndindex
+networkx
+ninja
+numexpr
+numpy<2.0.0
+opencv-python
+OpenEXR
+packaging
+pandas
+pathspec
+pillow>=10.2.0
+platformdirs
+portalocker
+protobuf>=4.25.3
+psutil
+py-cpuinfo
+pycodestyle
+pyflakes
+pyparsing
+python-dateutil
+pytz
+PyYAML
+requests
+safetensors
+scipy
+sentry-sdk
+setproctitle
+six
+smmap
+sympy
+tables
+tabulate
+termcolor
+timm
+tqdm
+trimesh
+triton>=2.4.0
+typing_extensions
+tzdata==2024.1
+urllib3==1.26.13
+wandb
+yacs
+torch>=2.4.0
+torchvision>=0.19.0
+torchaudio>=2.4.0
+xformers>=0.0.26
\ No newline at end of file
diff --git a/requirements_demo.txt b/requirements_demo.txt
new file mode 100644
index 0000000000000000000000000000000000000000..536fd8919111c6731776590ebd8e821b04f73bd7
--- /dev/null
+++ b/requirements_demo.txt
@@ -0,0 +1,84 @@
+appdirs
+attrs
+black
+blosc2
+botocore>=1.34.54
+certifi>=2022.12.7
+charset-normalizer
+click
+contourpy
+cycler
+docker-pycreds
+einops>=0.7.0
+filelock
+flake8>=7.0.0
+flake8-bugbear>=24.2.6
+flake8-comprehensions>=3.14.0
+fonttools
+fsspec
+fvcore>=0.1.5.post20221221
+gitdb
+GitPython
+gradio
+h5py>=3.10.0
+huggingface-hub>=0.22.0
+idna
+imageio
+imath
+iopath
+isort
+Jinja2
+jmespath
+kiwisolver
+MarkupSafe
+matplotlib
+mccabe
+mpmath
+msgpack
+mypy-extensions
+ndindex
+networkx
+ninja
+numexpr
+numpy<2.0.0
+opencv-python
+OpenEXR
+packaging
+pandas
+pathspec
+pillow>=10.2.0
+platformdirs
+portalocker
+protobuf>=4.25.3
+psutil
+py-cpuinfo
+pycodestyle
+pyflakes
+pyparsing
+python-dateutil
+pytz
+PyYAML
+requests
+safetensors
+scipy
+sentry-sdk
+setproctitle
+six
+smmap
+sympy
+tables
+tabulate
+termcolor
+timm
+tqdm
+trimesh
+triton>=2.4.0
+typing_extensions
+tzdata==2024.1
+urllib3==1.26.13
+wandb
+yacs
+torch>=2.4.0
+torchvision>=0.19.0
+torchaudio>=2.4.0
+xformers>=0.0.26
\ No newline at end of file
diff --git a/scripts/README.md b/scripts/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..f35f660e7ee343c9d7a962d4d366a81c610f786c
--- /dev/null
+++ b/scripts/README.md
@@ -0,0 +1,55 @@
+## Training
+
+We provide the `train.py` script that allows to load the dataset, initialize and start the training. From the root of the repo:
+
+```bash
+export REPO=`pwd`
+export PYTHONPATH=${REPO}:${PYTHONPATH}
+
+# Adapt all this to your setup
+export TMPDIR="/tmp"
+export TORCH_HOME=${TMPDIR}
+export HUGGINGFACE_HUB_CACHE=${TMPDIR}
+export WANDB_HOME=${TMPDIR}
+export DATAROOT=
+
+
+export MASTER_PORT=$((( RANDOM % 600 ) + 29400 ))
+if [ $NNODES -gt 1 ]; then
+ export MASTER_PORT=29400
+fi
+
+# this is the config will be used
+export CFG="config_vitl.json"
+```
+
+If you are on a machine without SLURM you can run the following:
+```bash
+# make the following input-dependent for multi-node
+export NNODES=1
+export RANK=0
+export MASTER_ADDR=127.0.0.1
+export CUDA_VISIBLE_DEVICES="0" # set yours
+
+export GPUS=$(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l)
+echo "Start script with python from: `which python`"
+torchrun --rdzv-backend=c10d --nnodes=${NNODES} --nproc_per_node=${GPUS} --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --distributed
+```
+
+If you system has SLURM, all the information will be set by the scheduler and you have to run just:
+```bash
+srun -c ${SLURM_CPUS_PER_TASK} --kill-on-bad-exit=1 python -u ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --master-port ${MASTER_PORT} --distributed
+```
+
+
+### Datasets
+
+We used both image-based and sequence-based dataset. The `ImageDataset` class is actually for legacy only as we moved image-based dataset to be "dummy" single-frame sequences.
+We [provide two example dataset to get familiar to the pipeline and structure, namely iBims-1 and Sintel](https://drive.google.com/drive/folders/1FKsa5-b3EX0ukZq7bxord5fC5OfUiy16?usp=sharing), image- and sequence-based, respectively.
+You can adapt the data loading and processing to your example; however, you will need to keep the same interface for the model to be consisten and train "out-of-the-box" the model.
+
+
+### Additional dependencies
+
+We require chamfer distance for the evaluation, you can compile the knn operation under `ops/knn`: `bash compile.sh` from the directory `$REPO/unik3d/ops/knn`. Set the correct `export TORCH_CUDA_ARCH_LIST`, according to the hardware you are working on.
+For training and to perform augmentation, you can use `camera_augmenter.py`; however the splatting requires you to install operations by cloning and installing from `github.com/hperrot/splatting`.
\ No newline at end of file
diff --git a/scripts/demo.py b/scripts/demo.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd6765b53e06b3d0d24f8a65561d0021a0d9b4bf
--- /dev/null
+++ b/scripts/demo.py
@@ -0,0 +1,150 @@
+import argparse
+import json
+import os
+
+import numpy as np
+import torch
+import torch.nn as nn
+from PIL import Image
+
+from unik3d.models import UniK3D
+from unik3d.utils.camera import (MEI, OPENCV, BatchCamera, Fisheye624, Pinhole,
+ Spherical)
+from unik3d.utils.visualization import colorize, save_file_ply
+
+SAVE = False
+BASE_PATH = os.path.join(
+ os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "assets", "demo"
+)
+
+
+def infer(model, rgb_path, camera_path, rays=None):
+ rgb = np.array(Image.open(rgb_path))
+ rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
+
+ camera = None
+ if camera_path is not None:
+ with open(camera_path, "r") as f:
+ camera_dict = json.load(f)
+
+ params = torch.tensor(camera_dict["params"])
+ name = camera_dict["name"]
+ assert name in ["Fisheye624", "Spherical", "OPENCV", "Pinhole", "MEI"]
+ camera = eval(name)(params=params)
+
+ outputs = model.infer(rgb=rgb_torch, camera=camera, normalize=True, rays=rays)
+
+ return rgb_torch, outputs
+
+
+def infer_equirectangular(model, rgb_path):
+ rgb = np.array(Image.open(rgb_path))
+ rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1)
+
+ # assuming full equirectangular image horizontally
+ H, W = rgb.shape[:2]
+ hfov_half = np.pi
+ vfov_half = np.pi * H / W
+ assert vfov_half <= np.pi / 2
+
+ params = [W, H, hfov_half, vfov_half]
+ camera = Spherical(params=torch.tensor([1.0] * 4 + params))
+
+ outputs = model.infer(rgb=rgb_torch, camera=camera, normalize=True)
+ return rgb_torch, outputs
+
+
+def save(rgb, outputs, name, base_path, save_pointcloud=False):
+ depth = outputs["depth"]
+ rays = outputs["rays"]
+ points = outputs["points"]
+
+ depth = depth.cpu().numpy()
+ rays = ((rays + 1) * 127.5).clip(0, 255)
+
+ Image.fromarray(colorize(depth.squeeze())).save(
+ os.path.join(base_path, f"{name}_depth.png")
+ )
+ Image.fromarray(rgb.squeeze().permute(1, 2, 0).cpu().numpy()).save(
+ os.path.join(base_path, f"{name}_rgb.png")
+ )
+ Image.fromarray(rays.squeeze().permute(1, 2, 0).byte().cpu().numpy()).save(
+ os.path.join(base_path, f"{name}_rays.png")
+ )
+
+ if save_pointcloud:
+ predictions_3d = points.permute(0, 2, 3, 1).reshape(-1, 3).cpu().numpy()
+ rgb = rgb.permute(1, 2, 0).reshape(-1, 3).cpu().numpy()
+ save_file_ply(predictions_3d, rgb, os.path.join(base_path, f"{name}.ply"))
+
+
+def demo(model):
+ # RGB + CAMERA
+ rgb, outputs = infer(
+ model,
+ os.path.join(BASE_PATH, f"scannet.png"),
+ os.path.join(BASE_PATH, "scannet.json"),
+ )
+ if SAVE:
+ save(rgb, outputs, name="scannet", base_path=BASE_PATH)
+
+ # get GT and pred
+ pts_pred = outputs["points"].squeeze().cpu().permute(1, 2, 0).numpy()
+ pts_gt = np.load("./assets/demo/scannet.npy").astype(float)
+ mask = np.linalg.norm(pts_gt, axis=-1) > 0
+ error = np.linalg.norm(pts_pred - pts_gt, axis=-1)
+ error = np.mean(error[mask] ** 2) ** 0.5
+
+ # Trade-off between speed and resolution
+ model.resolution_level = 1
+ rgb, outputs = infer(
+ model,
+ os.path.join(BASE_PATH, f"scannet.png"),
+ os.path.join(BASE_PATH, "scannet.json"),
+ )
+ if SAVE:
+ save(rgb, outputs, name="scannet_lowres", base_path=BASE_PATH)
+
+ # RGB
+ rgb, outputs = infer(model, os.path.join(BASE_PATH, f"poorthings.jpg"), None)
+ if SAVE:
+ save(rgb, outputs, name="poorthings", base_path=BASE_PATH)
+
+ # RGB + CAMERA
+ rgb, outputs = infer(
+ model,
+ os.path.join(BASE_PATH, f"dl3dv.png"),
+ os.path.join(BASE_PATH, "dl3dv.json"),
+ )
+ if SAVE:
+ save(rgb, outputs, name="dl3dv", base_path=BASE_PATH)
+
+ # EQUIRECTANGULAR
+ rgb, outputs = infer_equirectangular(
+ model, os.path.join(BASE_PATH, f"equirectangular.jpg")
+ )
+ if SAVE:
+ save(rgb, outputs, name="equirectangular", base_path=BASE_PATH)
+
+ print("Output keys are", outputs.keys())
+
+ if SAVE:
+ print("Done! Results saved in", BASE_PATH)
+
+ print(f"RMSE on 3D clouds for ScanNet++ sample: {100*error:.1f}cm")
+
+
+if __name__ == "__main__":
+ print("Torch version:", torch.__version__)
+ type_ = "l" # available types: s, b, l
+ name = f"unik3d-vit{type_}"
+ model = UniK3D.from_pretrained(f"lpiccinelli/{name}")
+
+ # set resolution level in [0,10) and output interpolation
+ model.resolution_level = 9
+ model.interpolation_mode = "bilinear"
+
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+ model = model.to(device).eval()
+
+ demo(model)
diff --git a/scripts/train.py b/scripts/train.py
new file mode 100644
index 0000000000000000000000000000000000000000..95d25552fa45af91d42e5e89c8462f0a4205b545
--- /dev/null
+++ b/scripts/train.py
@@ -0,0 +1,630 @@
+import argparse
+import json
+import os
+import random
+import uuid
+from contextlib import nullcontext
+from copy import deepcopy
+from datetime import datetime as dt
+from functools import partial
+from math import log2
+from time import sleep, time
+from typing import Any, Dict
+
+import git
+import numpy as np
+import psutil
+import torch
+import torch.nn as nn
+import torch.utils.data.distributed
+import wandb
+from PIL import Image
+from torch import distributed as dist
+from torch import optim
+from torch.nn.parallel.distributed import DistributedDataParallel
+from torch.utils.data import DataLoader, RandomSampler, SequentialSampler
+from tqdm import tqdm
+
+import unik3d.datasets as datasets
+from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate,
+ collate_fn, get_weights)
+from unik3d.models import UniK3D
+from unik3d.ops.scheduler import CosineScheduler
+from unik3d.utils import (barrier, format_seconds, is_main_process,
+ log_train_artifacts, validate)
+from unik3d.utils.distributed import (create_local_process_group,
+ local_broadcast_process_authkey,
+ setup_multi_processes, setup_slurm,
+ sync_string_across_gpus,
+ sync_tensor_across_gpus)
+from unik3d.utils.ema_torch import (DummyExponentialMovingAverage,
+ ExponentialMovingAverage)
+from unik3d.utils.misc import calculate_mean_values
+
+EMA_INTERVAL = 10
+EMA_TAU = 10000
+EMA_START = 50000
+
+
+MAP_DTYPE = {
+ "f16": torch.float16,
+ "bf16": torch.bfloat16,
+ "f32": torch.float32,
+}
+
+
+def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device):
+ keys = list(dict_.keys())
+ values = torch.tensor(list(dict_.values()), device=device)
+ keys = sync_string_across_gpus(keys, device)
+ values = sync_tensor_across_gpus(values, dim=0).cpu().tolist()
+ dict_ = calculate_mean_values(keys, values)
+ return dict_
+
+
+def main_worker(config: Dict[str, Any], args: argparse.Namespace):
+
+ current_process = psutil.Process(os.getpid())
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
+ seed = config["generic"]["seed"]
+
+ if not args.distributed:
+ args.rank = 0
+ args.local_rank = 0
+ args.world_size = 1
+ else:
+ # initializes the distributed backend which will take care of synchronizing nodes/GPUs
+ setup_multi_processes(config)
+ is_slurm = "SLURM_PROCID" in os.environ
+ if is_slurm:
+ setup_slurm("nccl", port=args.master_port)
+ args.rank = int(os.environ["RANK"])
+ args.world_size = int(os.environ["WORLD_SIZE"])
+ args.local_rank = device = int(os.environ["LOCAL_RANK"])
+ if not is_slurm:
+ import datetime
+
+ dist.init_process_group(
+ "nccl",
+ rank=args.rank,
+ world_size=args.world_size,
+ timeout=datetime.timedelta(seconds=30 * 60),
+ )
+ torch.cuda.set_device(device)
+ create_local_process_group()
+ local_broadcast_process_authkey()
+ print(
+ f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}."
+ )
+ config["training"]["batch_size"] = int(
+ config["training"]["batch_size"] / args.world_size
+ )
+ dist.barrier()
+
+ # Fix seed
+ # Different for every machine to avoid sampling
+ # the same element across machines
+ seed = seed + args.rank
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ os.environ["PYTHONHASHSEED"] = str(seed)
+
+ batch_size = config["training"]["batch_size"]
+ if is_main_process():
+ print("Config: ", args.config_file)
+ print(
+ f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}"
+ )
+ print("BatchSize per GPU: ", batch_size)
+ print(
+ f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step"
+ )
+
+ ##############################
+ ########### MODEL ############
+ ##############################
+ # Build model
+ model = UniK3D(config).to(device)
+ model.eval()
+ print(f"MODEL: {model.__class__.__name__} at {model.device}")
+ torch.cuda.empty_cache()
+
+ if args.distributed:
+ model = nn.SyncBatchNorm.convert_sync_batchnorm(model)
+ model = DistributedDataParallel(
+ model,
+ find_unused_parameters=False,
+ device_ids=[device],
+ output_device=device,
+ )
+
+ ##############################
+ ######### OPTIMIZER ##########
+ ##############################
+ dtype_16bit = config["training"]["f16"]
+ is_16bit = dtype_16bit != "f32"
+ clipping = config["training"].get("clipping", None)
+
+ # Optimize
+ ddp_model = model.module if args.distributed else model
+ params = ddp_model.get_params(config)
+ optimizer = optim.AdamW(
+ params,
+ eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8
+ # amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions?
+ )
+
+ # Load Model:
+ step = 0
+ if config["training"].get("pretrained", None) is not None:
+ ddp_model.load_pretrained(config["training"]["pretrained"])
+ pretrained = torch.load(
+ config["training"]["pretrained"], map_location="cpu", weights_only=False
+ )
+ try:
+ optimizer.load_state_dict(pretrained["optimizer"])
+ except Exception as e:
+ if is_main_process():
+ print("Could not load optimizer state dict:", e)
+ step = pretrained.get("step", 0)
+ ddp_model.pixel_decoder.steps = step
+
+ # EMA
+ ema_class = (
+ ExponentialMovingAverage
+ if config["training"]["ema"] > 0.0
+ else DummyExponentialMovingAverage
+ )
+ ema_handle = ema_class(
+ ddp_model.parameters_grad(),
+ 1 - (1 - config["training"]["ema"]) * EMA_INTERVAL,
+ update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL,
+ switch=True,
+ tau=EMA_TAU // EMA_INTERVAL,
+ )
+ setattr(ema_handle, "num_updates", step // EMA_INTERVAL)
+
+ ##############################
+ ######### GENERICS ###########
+ ##############################
+ resize_method = config["data"].get("resize_method", "hard")
+ crop = config["data"].get("crop", "garg")
+ augmentations_db = config["data"].get("augmentations", {})
+ shape_constraints = config["data"].get("shape_constraints", {})
+ image_shape = config["data"]["image_shape"]
+ mini = config["data"]["mini"]
+ nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"]
+ batch_size = config["training"]["batch_size"]
+ clipping_fn = torch.nn.utils.clip_grad_norm_
+
+ is_shell = int(os.environ.get("SHELL_JOB", 0))
+ run_id = sync_string_across_gpus(
+ [f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device
+ )[0]
+
+ if not is_shell and is_main_process():
+ repo_folder = os.path.dirname(os.path.realpath(__file__))
+ try:
+ repo = git.Repo(repo_folder)
+ current_head = repo.head if repo.head.is_detached else repo.active_branch
+ notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}"
+ except:
+ print(f"problem with {repo_folder}, does it exist?")
+ notes = ""
+
+ # restore the original batchsize, not acquired by other calls from now on
+ if args.distributed:
+ config["training"]["batch_size"] = (
+ config["training"]["batch_size"] * args.world_size
+ )
+ wandb.init(
+ project="UniK3D",
+ name=run_id,
+ config=config,
+ tags=None,
+ notes=notes,
+ dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")),
+ )
+ wandb.watch(model)
+
+ ##############################
+ ########## DATASET ###########
+ ##############################
+ # Datasets loading
+ train_datasets, val_datasets = {}, {}
+ if is_main_process():
+ print("Loading training datasets...")
+ dims = 0
+
+ for dataset in config["data"]["train_datasets"]:
+ assert hasattr(datasets, dataset), f"{dataset} not a custom dataset"
+ train_dataset: datasets.BaseDataset = getattr(datasets, dataset)
+ train_datasets[dataset] = train_dataset(
+ image_shape=image_shape,
+ split_file=train_dataset.train_split,
+ test_mode=False,
+ crop=crop,
+ augmentations_db=augmentations_db,
+ shape_constraints=shape_constraints,
+ normalize=config["data"].get("normalization", "imagenet"),
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=config["data"].get("num_frames", 1),
+ fps_range=[1, 5],
+ num_copies=config["data"]["pair"],
+ )
+ dim = (
+ train_datasets[dataset].dataset._addr.numel() * 8
+ + train_datasets[dataset].dataset._lst.numel()
+ ) / (2**20)
+ if hasattr(train_datasets[dataset], "sequences"):
+ dim += (
+ train_datasets[dataset].sequences._addr.numel() * 8
+ + train_datasets[dataset].sequences._lst.numel()
+ ) / (2**20)
+ dims = dims + dim
+ if is_main_process():
+ print(f"{dataset}: {dim:.1f}MB")
+
+ print(f"All training datasets loaded, with total size: {dims:.1f}MB")
+
+ barrier()
+
+ assert batch_size % config["data"]["pair"] == 0
+ batch_size = batch_size // config["data"]["pair"]
+ assert batch_size % nsteps_accumulation_gradient == 0
+ batch_chunk = batch_size // nsteps_accumulation_gradient
+
+ train_dataset = ConcatDataset(
+ list(train_datasets.values()),
+ shape_constraints=shape_constraints,
+ )
+
+ if is_main_process():
+ print("Loading validation datasets...")
+ for dataset in config["data"]["val_datasets"]:
+ val_dataset: datasets.BaseDataset = getattr(datasets, dataset)
+ val_datasets[dataset] = val_dataset(
+ image_shape=image_shape,
+ split_file=val_dataset.test_split,
+ test_mode=True,
+ crop=crop,
+ shape_constraints=shape_constraints,
+ augmentations_db=augmentations_db,
+ normalize=config["data"].get("normalization", "imagenet"),
+ resize_method=resize_method,
+ num_frames=1,
+ mini=1.0,
+ num_copies=1,
+ )
+
+ # Dataset samplers, create distributed sampler pinned to rank
+ if args.distributed:
+ sampling = deepcopy(config["data"]["sampling"])
+ weights, num_samples = get_weights(train_datasets, sampling)
+ train_sampler = torch.utils.data.WeightedRandomSampler(
+ weights, num_samples, replacement=True
+ )
+ valid_samplers = {
+ k: DistributedSamplerNoDuplicate(
+ v,
+ num_replicas=args.world_size,
+ rank=args.rank,
+ shuffle=False,
+ drop_last=False,
+ )
+ for k, v in val_datasets.items()
+ }
+ else:
+ train_sampler = RandomSampler(train_dataset)
+ valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()}
+
+ train_sampler = torch.utils.data.BatchSampler(
+ train_sampler, batch_size=batch_size, drop_last=True
+ )
+
+ # Dataset loader
+ val_batch_size = 1
+ num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4))
+ train_loader = DataLoader(
+ train_dataset,
+ num_workers=num_workers,
+ sampler=train_sampler,
+ pin_memory=True,
+ collate_fn=partial(collate_fn, is_batched=True),
+ persistent_workers=True if num_workers else None,
+ )
+ val_loaders = {
+ name_dataset: DataLoader(
+ dataset,
+ batch_size=val_batch_size,
+ shuffle=False,
+ num_workers=num_workers,
+ sampler=valid_samplers[name_dataset],
+ pin_memory=True,
+ drop_last=False,
+ collate_fn=partial(collate_fn, is_batched=False),
+ )
+ for name_dataset, dataset in val_datasets.items()
+ }
+
+ # SCHEDULERS!
+ scheduler_wd = CosineScheduler(
+ optimizer,
+ key="weight_decay",
+ init_value=config["training"]["wd"],
+ base_value=config["training"]["wd"],
+ final_value=config["training"]["wd_final"],
+ warmup_iters=0,
+ total_iters=config["training"]["n_iters"],
+ flat_iters=config["training"]["warmup_iters"],
+ step_init=step - 1,
+ )
+ scheduler_lr = CosineScheduler(
+ optimizer,
+ key="lr",
+ init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0),
+ final_value=config["training"]["lr_final"],
+ warmup_iters=5000,
+ flat_iters=config["training"]["warmup_iters"],
+ total_iters=config["training"]["n_iters"],
+ step_init=step - 1,
+ )
+ scheduler_betas = CosineScheduler(
+ optimizer,
+ key="betas",
+ init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9,
+ base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9,
+ final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9,
+ warmup_iters=config["training"]["warmup_iters"],
+ total_iters=config["training"]["n_iters"],
+ step_init=step - 1,
+ )
+
+ # Set loss scaler for half precision training + sanity zeroing grads
+ dtype = MAP_DTYPE[dtype_16bit]
+ if not torch.cuda.is_bf16_supported() and is_16bit:
+ dtype = torch.float16
+
+ context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit)
+ # use float16 to check for instability at inference an avoid bfloat16 for coarseness
+ context_val = torch.autocast(
+ device_type="cuda", dtype=torch.float16, enabled=is_16bit
+ )
+ optimizer.zero_grad(set_to_none=True)
+
+ ##############################
+ ########## TRAINING ##########
+ ##############################
+ # Remember that if i-th layer is frozen, this will break gradient checkpointing
+ # in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as
+ # without gradient, thus the i+1-th layer does not have grads (?). To solve it,
+ # just add requires_grad_() to the inputs coming from the frozen layer
+ ddp_model.train()
+
+ start = time()
+ n_steps = config["training"]["n_iters"]
+ init_steps = int(step)
+ track_pbar = is_shell
+
+ if is_main_process():
+ print("Is a shell job?", is_shell)
+ print("Use dtype:", dtype if is_16bit else torch.float32)
+ print(
+ f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps'
+ )
+ print(f"START with {num_workers} workers")
+ if track_pbar:
+ pbar = tqdm(total=n_steps - init_steps)
+
+ scaler = torch.amp.GradScaler(
+ "cuda",
+ init_scale=2**14 if dtype_16bit == "f16" else 2**40,
+ enabled=is_16bit,
+ growth_factor=1.2,
+ backoff_factor=0.8,
+ growth_interval=500,
+ )
+ track_losses, track_grad = {}, {}
+ system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30
+ cpid_memory = current_process.memory_info()[0] / 2.0**30
+ gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30
+ while True:
+ for j, batches in enumerate(train_loader):
+ system_memory = (
+ 0.99 * system_memory
+ + 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30
+ )
+ cpid_memory = (
+ 0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30
+ )
+ gpu_mem = (
+ 0.99 * gpu_mem
+ + 0.01
+ * (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0])
+ / 2**30
+ )
+ if j % 1000 == 0 and is_main_process():
+ print(f"System information at step {j}")
+ print(f"System-wide RAM available: {system_memory:.2f}GB")
+ print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%")
+ print(f"GPU memory utilized: {gpu_mem:.2f}GB")
+
+ batches["data"] = {
+ k: v.to(model.device, non_blocking=True)
+ for k, v in batches["data"].items()
+ }
+ for idx in range(nsteps_accumulation_gradient):
+ batch = {}
+ batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk)
+ batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()}
+ batch["img_metas"] = batches["img_metas"][batch_slice]
+ with (
+ model.no_sync()
+ if idx < nsteps_accumulation_gradient - 1
+ else nullcontext()
+ ):
+ with context:
+ preds, losses = model(batch["data"], batch["img_metas"])
+ loss = sum(losses["opt"].values())
+ scaler.scale(loss).backward()
+
+ losses_dict = {
+ k: v.detach() for loss in losses.values() for k, v in loss.items()
+ }
+ track_losses.update(
+ {
+ k: track_losses.get(k, 0.0)
+ + torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5)
+ for k, v in losses_dict.items()
+ }
+ )
+ ddp_model.loss_history = track_losses
+
+ if clipping is not None:
+ scaler.unscale_(optimizer)
+ grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping)
+ if torch.isfinite(grad_norm):
+ track_losses.update(
+ {"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm}
+ )
+
+ # there is a deeper issue, either log/sqrt of negative loss
+ # or the inputs create large values and destroy model weights
+ if is_16bit and scaler.get_scale() < 1:
+ raise ValueError("Scale went less than 1, ISSUE!!!")
+
+ scaler.step(optimizer)
+ scaler.update()
+
+ scheduler_wd.step()
+ scheduler_lr.step()
+ scheduler_betas.step()
+ model.module.step()
+ optimizer.zero_grad(set_to_none=True)
+ if step % EMA_INTERVAL == 0:
+ ema_handle.update()
+
+ if is_main_process() and track_pbar:
+ pbar.update(1)
+
+ step += 1
+
+ # LOGGING
+ if step % 100 == 0 and is_main_process():
+ log_num = min(10, preds["depth"].shape[0])
+ log_train_artifacts(
+ batch["data"]["image"][-log_num:, 0].float(),
+ (
+ batch["data"]["depth"][-log_num:, 0].float()
+ if "depth" in batch["data"]
+ else []
+ ),
+ preds["depth"][-log_num:, 0].detach().float(),
+ infos={
+ k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items()
+ },
+ step=step,
+ )
+
+ if step % 50 == 0:
+ track_losses = {
+ k: v / (50 * nsteps_accumulation_gradient)
+ for k, v in track_losses.items()
+ }
+ # grad norm is for every step!
+ track_losses["Grad_Norm"] = (
+ track_losses["Grad_Norm"] * nsteps_accumulation_gradient
+ )
+ track_losses = aggregate_sync_losses(track_losses, device=model.device)
+ if is_main_process():
+ elapsed = int(time() - start)
+ eta = int(elapsed * (n_steps - step) / max(1, step - init_steps))
+ print(
+ f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]"
+ )
+ try:
+ wandb.log(
+ {
+ **{f"Train/{k}": v for k, v in track_losses.items()},
+ **{f"Train/lr": scheduler_lr.get()[-1]},
+ **{f"Train/wd": scheduler_wd.get()[-2]},
+ **{f"Train/scale_f16": log2(scaler.get_scale())},
+ },
+ step=step,
+ )
+ except Exception as e:
+ print("Not logging loss because of:", e)
+ if step % 100 == 0:
+ log_loss_dict = {
+ f"Train/{k}": v for k, v in track_losses.items()
+ }
+ print(
+ ", ".join(
+ [f"{k}: {v:.5f}" for k, v in log_loss_dict.items()]
+ )
+ )
+ track_losses = {} # reinit every 50 steps, average the current 50 steps
+
+ # Validation
+ is_last_step = step >= config["training"]["n_iters"]
+ is_validation = step % config["training"]["validation_interval"] == 0
+ if is_last_step or is_validation:
+ torch.cuda.empty_cache()
+ barrier()
+ if is_main_process():
+ print(f"Validation at {step}th step...")
+ ddp_model.eval()
+ start_validation = time()
+ with torch.no_grad(), ema_handle.average_parameters():
+ validate(
+ model,
+ test_loaders=val_loaders,
+ step=step,
+ run_id=run_id,
+ idxs=(64, 96, 224, 256), # random
+ context=context_val,
+ )
+
+ if is_main_process():
+ print(f"Elapsed: {format_seconds(int(time() - start_validation))}")
+ ddp_model.train()
+ torch.cuda.empty_cache()
+
+ if step >= config["training"]["n_iters"]:
+ if is_main_process() and track_pbar:
+ pbar.close()
+ wandb.finish(0)
+ dist.destroy_process_group()
+ return 0
+
+
+if __name__ == "__main__":
+ if "SLURM_PROCID" in os.environ:
+ os.environ["TRITON_CACHE_DIR"] = "/tmp"
+ # Arguments
+ parser = argparse.ArgumentParser(
+ description="Training script", conflict_handler="resolve"
+ )
+ parser.add_argument("--config-file", type=str, required=True)
+ parser.add_argument("--master-port", type=str)
+ parser.add_argument("--distributed", action="store_true")
+ parser.add_argument("--local_rank", type=int, default=0)
+
+ args = parser.parse_args()
+ with open(args.config_file, "r") as f:
+ config = json.load(f)
+
+ deterministic = config["generic"].get("deterministic", True)
+ torch.backends.cudnn.deterministic = deterministic
+ torch.backends.cudnn.benchmark = not deterministic
+
+ torch.backends.cudnn.allow_tf32 = True
+ torch.backends.cuda.matmul.allow_tf32 = True
+ torch.set_float32_matmul_precision("high")
+ torch.backends.cuda.enable_mem_efficient_sdp(False)
+ torch.set_num_threads(1)
+ main_worker(config, args)
diff --git a/unik3d/__init__.py b/unik3d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4526bd9a1b4b0b508458724e3157d4d8dd8ebc30
--- /dev/null
+++ b/unik3d/__init__.py
@@ -0,0 +1 @@
+from .models import UniK3D
diff --git a/unik3d/datasets/_2d3ds.py b/unik3d/datasets/_2d3ds.py
new file mode 100644
index 0000000000000000000000000000000000000000..021e86c11a55eba9c9d5ff6754f6ab2d1db34b74
--- /dev/null
+++ b/unik3d/datasets/_2d3ds.py
@@ -0,0 +1,67 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class d2D3DS(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 512.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"2D3DS.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+ self.resizer = Compose(
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
+ )
+
+ def preprocess(self, results):
+ self.resizer.ctx = None
+ if self.test_mode:
+ for i, seq in enumerate(results["sequence_fields"]):
+ results[seq]["points"] = results[seq]["camera"].reconstruct(
+ results[seq]["depth"]
+ )
+ results[seq]["depth"] = results[seq]["points"][:, -1:]
+ results[seq]["gt_fields"].add("points")
+ return super().preprocess(results)
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/_4dor.py b/unik3d/datasets/_4dor.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9db5355112799eef5ae8bd13cb0fb52144ca064
--- /dev/null
+++ b/unik3d/datasets/_4dor.py
@@ -0,0 +1,52 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class d4DOR(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ default_fps = 10
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["4DOR.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["si"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/__init__.py b/unik3d/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e6a339b51c69a6f612615136dc2429ca36fac95
--- /dev/null
+++ b/unik3d/datasets/__init__.py
@@ -0,0 +1,161 @@
+from ._2d3ds import d2D3DS
+from ._4dor import d4DOR
+from .a2d2 import A2D2
+from .adt import ADT
+from .aimotive import aiMotive
+from .argoverse import Argoverse
+from .argoverse2 import Argoverse2
+from .arkit import ARKit
+from .ase import ASE
+from .base_dataset import BaseDataset
+from .bdd import BDD
+from .bedlam import BEDLAM
+from .behave import Behave
+from .blendedmvg import BlendedMVG
+from .cityscape import Cityscape
+from .ddad import DDAD
+from .deep360 import Deep360
+from .dense import DENSE
+from .diml import DIML
+from .diode import DiodeIndoor, DiodeIndoor_F
+from .dl3dv import DL3DV
+from .driving_stereo import DrivingStereo
+from .dtu_rmvd import DTURMVD
+from .dummy import Dummy
+from .dynamic_replica import DynReplica
+from .eden import EDEN
+from .eth3d import ETH3D, ETH3D_F, ETH3DRMVD
+from .facedepth import FaceDepth
+from .flsea import FLSea
+from .futurehouse import FutureHouse
+from .gibson import Gibson
+from .hammer import HAMMER
+from .hm3d import HM3D
+from .hoi4d import HOI4D
+from .hypersim import HyperSim
+from .ibims import IBims, IBims_F
+from .ken_burns import KenBurns
+from .kitti import KITTI, KITTIRMVD, KITTIBenchmark
+from .kitti360 import KITTI360
+from .lyft import Lyft
+from .mapillary import Mapillary
+from .matrix_city import MatrixCity
+from .matterport3d import Matterport3D
+from .megadepth import MegaDepth
+from .megadepth_s import MegaDepthS
+from .midair import MidAir
+from .mip import MIP
+from .ms2 import MS2
+from .mvimgnet import MVImgNet
+from .mvsynth import MVSynth
+from .nerds360 import NeRDS360
+from .niantic_mapfree import NianticMapFree
+from .nuscenes import Nuscenes
+from .nyuv2 import NYUv2Depth
+from .point_odyssey import PointOdyssey
+from .proteus import Proteus
+from .samplers import (DistributedSamplerNoDuplicate,
+ DistributedSamplerWrapper, ShardedInfiniteSampler)
+from .scannet import ScanNet
+from .scannetpp import ScanNetpp, ScanNetpp_F
+from .sintel import Sintel
+from .sunrgbd import SUNRGBD
+from .synscapes import Synscapes
+from .tartanair import TartanAir
+from .taskonomy import Taskonomy
+from .tat_rmvd import TATRMVD
+from .theo import Theo
+from .unrealstereo4k import UnrealStereo4K
+from .urbansyn import UrbanSyn
+from .utils import ConcatDataset, collate_fn, get_weights
+from .vkitti import VKITTI
+from .void import VOID
+from .waymo import Waymo
+from .wildrgbd import WildRGBD
+
+__all__ = [
+ "Dummy",
+ "BaseDataset",
+ "get_weights" "DistributedSamplerNoDuplicate",
+ "ShardedInfiniteSampler",
+ "DistributedSamplerWrapper",
+ "ConcatDataset",
+ "PairDataset",
+ "collate_fn",
+ # additional, do not count
+ "WaymoImage",
+ "MegaDepth",
+ "COCO2017",
+ "ImageNet",
+ "OASISv2",
+ # image based
+ "Argoverse",
+ "DDAD",
+ "IBims",
+ "NYUv2Depth",
+ "DrivingStereo",
+ "VOID",
+ "Mapillary",
+ "ScanNet",
+ "Taskonomy",
+ "BDD",
+ "A2D2",
+ "Nuscenes",
+ "SUNRGBD",
+ "ETH3D",
+ "HAMMER",
+ "Cityscape",
+ "KITTI",
+ "DENSE",
+ "DIML",
+ "DiodeIndoor",
+ "FLSea",
+ "ARKitScenes",
+ "Lyft",
+ "HyperSim",
+ "KenBurns",
+ "HRWSI",
+ "UrbanSyn",
+ "Synscapes",
+ "Gibson",
+ "Matterport3D",
+ "_2D3DS",
+ # sequence based
+ "TartanAir",
+ "WildRGBD",
+ "ScanNetS",
+ "ScanNetpp",
+ "MVImgNet",
+ "NianticMapFree",
+ "DL3DV",
+ "PointOdyssey",
+ "KITTIMulti",
+ "Waymo",
+ "Argoverse2",
+ "UnrealStereo4K",
+ "MatrixCity",
+ "HM3D",
+ "MVSynth",
+ "EDEN",
+ # sequence based, but not usable for seq, only image
+ "BEDLAM",
+ "NeRDS360",
+ "BlendedMVG",
+ "DynReplica",
+ "ARKitS",
+ "Sintel",
+ "VKITTI",
+ "MegaDepthS",
+ # benchmarks
+ "KITTIBenchmark",
+ "ETH3DRMVD",
+ "DTURMVD",
+ "KITTIRMVD",
+ "TATRMVD",
+ "DiodeIndoor_F",
+ "IBims_F",
+ "ETH3D_F",
+ "KITTI360",
+ "ScanNetpp_F",
+ "ADT",
+]
diff --git a/unik3d/datasets/a2d2.py b/unik3d/datasets/a2d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a01b2037dd27f88137161b4d1a850df69f71d90
--- /dev/null
+++ b/unik3d/datasets/a2d2.py
@@ -0,0 +1,78 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class A2D2(ImageDataset):
+ min_depth = 0.01
+ max_depth = 120.0
+ depth_scale = 256.0
+ train_split = "train_clean.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["a2d2.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(
+ intrinsics[os.path.join(*image_filename.split("/")[:2])]
+ ).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
diff --git a/unik3d/datasets/adt.py b/unik3d/datasets/adt.py
new file mode 100644
index 0000000000000000000000000000000000000000..8883efd6d8e2590467ea4314ab2f13d837da684e
--- /dev/null
+++ b/unik3d/datasets/adt.py
@@ -0,0 +1,68 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class ADT(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 20.0
+ depth_scale = 1000.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"ADT.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields, # if not test_mode else [*decode_fields, "points"],
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def preprocess(self, results):
+ self.resizer.ctx = None
+ for i, seq in enumerate(results["sequence_fields"]):
+ # Create a mask where the distance from the center is less than H/2
+ H, W = results[seq]["image"].shape[-2:]
+ x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W)
+ y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H)
+ xv, yv = torch.meshgrid(x, y, indexing="xy")
+ distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
+ results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20
+ results[seq]["depth_mask"] = results[seq]["validity_mask"].clone()
+ results[seq]["mask_fields"].add("depth_mask")
+ results[seq]["mask_fields"].add("validity_mask")
+
+ return super().preprocess(results)
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/aimotive.py b/unik3d/datasets/aimotive.py
new file mode 100644
index 0000000000000000000000000000000000000000..38557bc68a9f1aeeff0003eae46f12b46ec197ba
--- /dev/null
+++ b/unik3d/datasets/aimotive.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class aiMotive(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 100.0
+ depth_scale = 256.0
+ default_fps = 10
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["aiMotive.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/argoverse.py b/unik3d/datasets/argoverse.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4f4ea6612b3a29f537e0690ccb3b30b04d06d24
--- /dev/null
+++ b/unik3d/datasets/argoverse.py
@@ -0,0 +1,73 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class Argoverse(ImageDataset):
+ min_depth = 0.05
+ max_depth = 120.0
+ depth_scale = 256.0
+ test_split = "argo_val.txt"
+ train_split = "argo_train.txt"
+ intrisics_file = "argo_intrinsics.json"
+ hdf5_paths = ["argoverse11.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
diff --git a/unik3d/datasets/argoverse2.py b/unik3d/datasets/argoverse2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a49f49d18979cb39b1be7385b4ec4c09b47a244
--- /dev/null
+++ b/unik3d/datasets/argoverse2.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Argoverse2(SequenceDataset):
+ min_depth = 0.05
+ max_depth = 120.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences_clean.json"
+ hdf5_paths = [f"AV2_viz.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/arkit.py b/unik3d/datasets/arkit.py
new file mode 100644
index 0000000000000000000000000000000000000000..96225cd8b248cad6d1a5072cfc39adcf880c2297
--- /dev/null
+++ b/unik3d/datasets/arkit.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class ARKit(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "Training.txt"
+ train_split = "Training.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["ARKitS.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/ase.py b/unik3d/datasets/ase.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f57fd2119dd589ca62800d907fc4c235dc1ed1b
--- /dev/null
+++ b/unik3d/datasets/ase.py
@@ -0,0 +1,66 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class ASE(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 20.0
+ depth_scale = 1000.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"ASE.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def preprocess(self, results):
+ self.resizer.ctx = None
+ for i, seq in enumerate(results["sequence_fields"]):
+ # Create a mask where the distance from the center is less than H/2
+ H, W = results[seq]["image"].shape[-2:]
+ x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W)
+ y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H)
+ xv, yv = torch.meshgrid(x, y, indexing="xy")
+ distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
+ results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20
+ results[seq]["mask_fields"].add("validity_mask")
+
+ return super().preprocess(results)
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/base_dataset.py b/unik3d/datasets/base_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9474e767c569cc75b89da2d0510de808ed79208b
--- /dev/null
+++ b/unik3d/datasets/base_dataset.py
@@ -0,0 +1,344 @@
+import os
+from abc import abstractmethod
+from copy import deepcopy
+from math import ceil, log
+from typing import Any, Dict, Tuple
+
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+import unik3d.datasets.pipelines as pipelines
+from unik3d.utils import (eval_3d, eval_depth, identity, is_main_process,
+ recursive_index, sync_tensor_across_gpus)
+from unik3d.utils.constants import (IMAGENET_DATASET_MEAN,
+ IMAGENET_DATASET_STD, OPENAI_DATASET_MEAN,
+ OPENAI_DATASET_STD)
+
+
+class BaseDataset(Dataset):
+ min_depth = 0.01
+ max_depth = 1000.0
+
+ def __init__(
+ self,
+ image_shape: Tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: Dict[str, Any],
+ shape_constraints: Dict[str, Any],
+ resize_method: str,
+ mini: float,
+ num_copies: int = 1,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ assert normalize in [None, "imagenet", "openai"]
+
+ self.split_file = split_file
+ self.test_mode = test_mode
+ self.data_root = os.environ["DATAROOT"]
+ self.image_shape = image_shape
+ self.resize_method = resize_method
+ self.mini = mini
+ self.num_frames = 1
+ self.num_copies = num_copies
+ self.metrics_store = {}
+ self.metrics_count = {}
+
+ if normalize == "imagenet":
+ self.normalization_stats = {
+ "mean": torch.tensor(IMAGENET_DATASET_MEAN),
+ "std": torch.tensor(IMAGENET_DATASET_STD),
+ }
+ elif normalize == "openai":
+ self.normalization_stats = {
+ "mean": torch.tensor(OPENAI_DATASET_MEAN),
+ "std": torch.tensor(OPENAI_DATASET_STD),
+ }
+ else:
+ self.normalization_stats = {
+ "mean": torch.tensor([0.0, 0.0, 0.0]),
+ "std": torch.tensor([1.0, 1.0, 1.0]),
+ }
+
+ for k, v in augmentations_db.items():
+ setattr(self, k, v)
+ self.shape_constraints = shape_constraints
+ if not self.test_mode:
+ self._augmentation_space()
+
+ self.masker = pipelines.AnnotationMask(
+ min_value=0.0,
+ max_value=self.max_depth if test_mode else None,
+ custom_fn=identity,
+ )
+ self.filler = pipelines.RandomFiller(test_mode=test_mode)
+
+ shape_mult = self.shape_constraints["shape_mult"]
+ self.image_shape = [
+ ceil(self.image_shape[0] / shape_mult) * shape_mult,
+ ceil(self.image_shape[1] / shape_mult) * shape_mult,
+ ]
+ self.resizer = pipelines.ContextCrop(
+ image_shape=self.image_shape,
+ train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale),
+ test_min_ctx=self.test_context,
+ keep_original=test_mode,
+ shape_constraints=self.shape_constraints,
+ )
+
+ self.collecter = pipelines.Collect(
+ keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"]
+ )
+
+ def __len__(self):
+ return len(self.dataset)
+
+ def pack_batch(self, results):
+ results["paddings"] = [
+ results[x]["paddings"][0] for x in results["sequence_fields"]
+ ]
+ for fields_name in [
+ "image_fields",
+ "gt_fields",
+ "mask_fields",
+ "camera_fields",
+ ]:
+ fields = results.get(fields_name)
+ packed = {
+ field: torch.cat(
+ [results[seq][field] for seq in results["sequence_fields"]]
+ )
+ for field in fields
+ }
+ results.update(packed)
+ return results
+
+ def unpack_batch(self, results):
+ for fields_name in [
+ "image_fields",
+ "gt_fields",
+ "mask_fields",
+ "camera_fields",
+ ]:
+ fields = results.get(fields_name)
+ unpacked = {
+ field: {
+ seq: results[field][idx : idx + 1]
+ for idx, seq in enumerate(results["sequence_fields"])
+ }
+ for field in fields
+ }
+ results.update(unpacked)
+ return results
+
+ def _augmentation_space(self):
+ self.augmentations_dict = {
+ "Flip": pipelines.RandomFlip(prob=self.flip_p),
+ "Jitter": pipelines.RandomColorJitter(
+ (-self.random_jitter, self.random_jitter), prob=self.jitter_p
+ ),
+ "Gamma": pipelines.RandomGamma(
+ (-self.random_gamma, self.random_gamma), prob=self.gamma_p
+ ),
+ "Blur": pipelines.GaussianBlur(
+ kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p
+ ),
+ "Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p),
+ }
+
+ def augment(self, results):
+ for name, aug in self.augmentations_dict.items():
+ results = aug(results)
+ return results
+
+ def prepare_depth_eval(self, inputs, preds):
+ new_preds = {}
+ keyframe_idx = getattr(self, "keyframe_idx", None)
+ slice_idx = slice(
+ keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None
+ )
+ new_gts = inputs["depth"][slice_idx]
+ new_masks = inputs["depth_mask"][slice_idx].bool()
+ for key, val in preds.items():
+ if "depth" in key:
+ new_preds[key] = val[slice_idx]
+ return new_gts, new_preds, new_masks
+
+ def prepare_points_eval(self, inputs, preds):
+ new_preds = {}
+ new_gts = inputs["points"]
+ new_masks = inputs["depth_mask"].bool()
+ if "points_mask" in inputs:
+ new_masks = inputs["points_mask"].bool()
+ for key, val in preds.items():
+ if "points" in key:
+ new_preds[key] = val
+ return new_gts, new_preds, new_masks
+
+ def add_points(self, inputs):
+ inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct(
+ inputs["depth"]
+ )
+ return inputs
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def accumulate_metrics(
+ self,
+ inputs,
+ preds,
+ keyframe_idx=None,
+ metrics=["depth", "points", "flow_fwd", "pairwise"],
+ ):
+ if "depth" in inputs and "points" not in inputs:
+ inputs = self.add_points(inputs)
+
+ available_metrics = []
+ for metric in metrics:
+ metric_in_gt = any((metric in k for k in inputs.keys()))
+ metric_in_pred = any((metric in k for k in preds.keys()))
+ if metric_in_gt and metric_in_pred:
+ available_metrics.append(metric)
+
+ if keyframe_idx is not None:
+ inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1))
+ preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1))
+
+ if "depth" in available_metrics:
+ depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds)
+ self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks)
+
+ if "points" in available_metrics:
+ points_gt, points_pred, points_masks = self.prepare_points_eval(
+ inputs, preds
+ )
+ self.accumulate_metrics_3d(points_gt, points_pred, points_masks)
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def accumulate_metrics_depth(self, gts, preds, masks):
+ for eval_type, pred in preds.items():
+ log_name = eval_type.replace("depth", "").strip("-").strip("_")
+ if log_name not in self.metrics_store:
+ self.metrics_store[log_name] = {}
+ current_count = self.metrics_count.get(
+ log_name, torch.tensor([], device=gts.device)
+ )
+ new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
+ self.metrics_count[log_name] = torch.cat([current_count, new_count])
+ for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items():
+ current_metric = self.metrics_store[log_name].get(
+ k, torch.tensor([], device=gts.device)
+ )
+ self.metrics_store[log_name][k] = torch.cat([current_metric, v])
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def accumulate_metrics_3d(self, gts, preds, masks):
+ thresholds = torch.linspace(
+ log(self.min_depth),
+ log(self.max_depth / 20),
+ steps=100,
+ device=gts.device,
+ ).exp()
+ for eval_type, pred in preds.items():
+ log_name = eval_type.replace("points", "").strip("-").strip("_")
+ if log_name not in self.metrics_store:
+ self.metrics_store[log_name] = {}
+ current_count = self.metrics_count.get(
+ log_name, torch.tensor([], device=gts.device)
+ )
+ new_count = masks.view(gts.shape[0], -1).sum(dim=-1)
+ self.metrics_count[log_name] = torch.cat([current_count, new_count])
+ for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items():
+ current_metric = self.metrics_store[log_name].get(
+ k, torch.tensor([], device=gts.device)
+ )
+ self.metrics_store[log_name][k] = torch.cat([current_metric, v])
+
+ def get_evaluation(self, metrics=None):
+ metric_vals = {}
+ for eval_type in metrics if metrics is not None else self.metrics_store.keys():
+ assert self.metrics_store[eval_type]
+ cnts = sync_tensor_across_gpus(self.metrics_count[eval_type])
+ for name, val in self.metrics_store[eval_type].items():
+ # vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum()
+ vals_r = sync_tensor_across_gpus(val).mean()
+ metric_vals[f"{eval_type}_{name}".strip("_")] = np.round(
+ vals_r.cpu().item(), 5
+ )
+ self.metrics_store[eval_type] = {}
+ self.metrics_count = {}
+ return metric_vals
+
+ def replicate(self, results):
+ for i in range(1, self.num_copies):
+ results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()}
+ results["sequence_fields"].append((0, i))
+ return results
+
+ def log_load_dataset(self):
+ if is_main_process():
+ info = f"Loaded {self.__class__.__name__} with {len(self)} images."
+ print(info)
+
+ def pre_pipeline(self, results):
+ results["image_fields"] = results.get("image_fields", set())
+ results["gt_fields"] = results.get("gt_fields", set())
+ results["mask_fields"] = results.get("mask_fields", set())
+ results["sequence_fields"] = results.get("sequence_fields", set())
+ results["camera_fields"] = results.get("camera_fields", set())
+ results["dataset_name"] = (
+ [self.__class__.__name__] * self.num_frames * self.num_copies
+ )
+ results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies
+ results["si"] = [False] * self.num_frames * self.num_copies
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ results["valid_camera"] = [True] * self.num_frames * self.num_copies
+ results["valid_pose"] = [True] * self.num_frames * self.num_copies
+ return results
+
+ def eval_mask(self, valid_mask):
+ return valid_mask
+
+ def chunk(self, dataset, chunk_dim=1, pct=1.0):
+ subsampled_datasets = [
+ x
+ for i in range(0, len(dataset), int(1 / pct * chunk_dim))
+ for x in dataset[i : i + chunk_dim]
+ ]
+ return subsampled_datasets
+
+ @abstractmethod
+ def preprocess(self, results):
+ raise NotImplementedError
+
+ @abstractmethod
+ def postprocess(self, results):
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_mapper(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_intrinsics(self, idx, image_name):
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_extrinsics(self, idx, image_name):
+ raise NotImplementedError
+
+ @abstractmethod
+ def load_dataset(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_single_item(self, idx, sample=None, mapper=None):
+ raise NotImplementedError
+
+ @abstractmethod
+ def __getitem__(self, idx):
+ raise NotImplementedError
diff --git a/unik3d/datasets/bdd.py b/unik3d/datasets/bdd.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f8956ea4ea77cf0db86b1be744a3daaacd04616
--- /dev/null
+++ b/unik3d/datasets/bdd.py
@@ -0,0 +1,82 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class BDD(ImageDataset):
+ min_depth = 0.01
+ max_depth = 70.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train_clean.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["BDD.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(
+ intrinsics[os.path.join(*image_filename.split("/")[:2])]
+ ).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+ h5file.close()
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.1)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["si"] = [True] * self.num_copies
+ results["valid_camera"] = [False] * self.num_copies
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/bedlam.py b/unik3d/datasets/bedlam.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9e9cfdf46e5ae57bcc5b3becdf6043dbe1f7f8b
--- /dev/null
+++ b/unik3d/datasets/bedlam.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class BEDLAM(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 256.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "val.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["BEDLAM.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/behave.py b/unik3d/datasets/behave.py
new file mode 100644
index 0000000000000000000000000000000000000000..155afc72800367f09133505032dea8d94eb80083
--- /dev/null
+++ b/unik3d/datasets/behave.py
@@ -0,0 +1,52 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Behave(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ default_fps = 10
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["Behave.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["si"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/blendedmvg.py b/unik3d/datasets/blendedmvg.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3fa5f52b63785d0a4b02addb347506c51048ab9
--- /dev/null
+++ b/unik3d/datasets/blendedmvg.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class BlendedMVG(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 5000.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences_clean.json"
+ hdf5_paths = ["BlendedMVG_.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["si"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/cityscape.py b/unik3d/datasets/cityscape.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f462d6571fc24acc47925550f4331bb1053b651
--- /dev/null
+++ b/unik3d/datasets/cityscape.py
@@ -0,0 +1,78 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class Cityscape(ImageDataset):
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["cityscape.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/ddad.py b/unik3d/datasets/ddad.py
new file mode 100644
index 0000000000000000000000000000000000000000..322089c73c88c7ac2208c6b2b5f5b682effb7c95
--- /dev/null
+++ b/unik3d/datasets/ddad.py
@@ -0,0 +1,84 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class DDAD(ImageDataset):
+ min_depth = 0.05
+ max_depth = 120.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = [f"ddad/ddad_{i}.hdf5" for i in range(8)]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii").strip("\n")
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename, chunk_idx = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ "K": 2,
+ "chunk_idx": 3,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
diff --git a/unik3d/datasets/deep360.py b/unik3d/datasets/deep360.py
new file mode 100644
index 0000000000000000000000000000000000000000..20bf72aeacade6166201fac523137900d9ca1eba
--- /dev/null
+++ b/unik3d/datasets/deep360.py
@@ -0,0 +1,56 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Deep360(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 1000.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"Deep360.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+ self.resizer = Compose(
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/dense.py b/unik3d/datasets/dense.py
new file mode 100644
index 0000000000000000000000000000000000000000..69e31b74a2d48a43067f2112a9e5d339620d2ced
--- /dev/null
+++ b/unik3d/datasets/dense.py
@@ -0,0 +1,91 @@
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class DENSE(ImageDataset):
+ CAM_INTRINSIC = {
+ "ALL": torch.tensor(
+ [
+ [1177.8614, 0.0, 474.319027],
+ [0.0, 1177.8614, 224.275919],
+ [0.0, 0.0, 1.0],
+ ]
+ )
+ }
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 255.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ hdf5_paths = ["DENSE.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.intrisics = {}
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ sample = [image_filename, depth_filename]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_intrinsics(self, idx, image_name):
+ return self.CAM_INTRINSIC["ALL"].clone()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
diff --git a/unik3d/datasets/diml.py b/unik3d/datasets/diml.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a5b38411611655b7219bfacbe8c6fb528331084
--- /dev/null
+++ b/unik3d/datasets/diml.py
@@ -0,0 +1,79 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class DIML(ImageDataset):
+ min_depth = 0.01
+ max_depth = 100.0
+ depth_scale = 256.0
+ test_split = "test.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["DIML.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.intrisics = {}
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(
+ intrinsics[image_filename.split("/")[0]]
+ ).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+ h5file.close()
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/diode.py b/unik3d/datasets/diode.py
new file mode 100644
index 0000000000000000000000000000000000000000..9dc180eb096c52e50c237e8555cb952927f537c8
--- /dev/null
+++ b/unik3d/datasets/diode.py
@@ -0,0 +1,278 @@
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.sequence_dataset import SequenceDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class DiodeIndoor(ImageDataset):
+ CAM_INTRINSIC = {
+ "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]])
+ }
+ min_depth = 0.01
+ max_depth = 25.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ hdf5_paths = ["DiodeIndoor.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ # load annotations
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ sample = [
+ image_filename,
+ depth_filename,
+ ]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_intrinsics(self, *args, **kwargs):
+ return self.CAM_INTRINSIC["ALL"].clone()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
+
+
+class DiodeIndoor_F(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 25.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["DiodeIndoor-F.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, float],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=(
+ decode_fields if not test_mode else [*decode_fields, "points"]
+ ),
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
+
+
+class DiodeOutdoor(ImageDataset):
+ CAM_INTRINSIC = {
+ "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]])
+ }
+ min_depth = 0.1
+ max_depth = 80.0
+ log_mean = 0
+ log_std = 1
+ test_split = "diode_outdoor_val.txt"
+ train_split = "diode_outdoor_train.txt"
+ hdf5_paths = ["diode.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ depth_scale=256,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.depth_scale = depth_scale
+
+ self.masker = AnnotationMask(
+ min_value=self.min_depth,
+ max_value=self.max_depth if test_mode else None,
+ custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
+ )
+ # load annotations
+ self.load_dataset()
+
+ def load_dataset(self):
+ self.h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_path),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(self.h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1]
+ dataset = {"depth_filename": [], "image_filename": []}
+ for line in txt_string.split("\n"):
+ depth_filename = line.strip().split(" ")[1]
+ img_name = line.strip().split(" ")[0]
+ image_filename = img_name
+ dataset["depth_filename"].append(depth_filename)
+ dataset["image_filename"].append(image_filename)
+
+ self.dataset = pl.from_dict(dataset)
+
+ if not self.test_mode and self.mini:
+ self.dataset = self.dataset[::2]
+
+
+class Diode(ImageDataset):
+ CAM_INTRINSIC = {
+ "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]])
+ }
+ log_mean = 0
+ log_std = 1
+ min_depth = 0.6
+ max_depth = 80.0
+ test_split = "diode_val.txt"
+ train_split = "diode_train.txt"
+ hdf5_paths = ["diode.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ depth_scale=256,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.depth_scale = depth_scale
+
+ self.masker = AnnotationMask(
+ min_value=self.min_depth,
+ max_value=self.max_depth if test_mode else None,
+ custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
+ )
+ # load annotations
+ self.load_dataset()
+
+ def load_dataset(self):
+ self.h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_path),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(self.h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1]
+ dataset = {"depth_filename": [], "image_filename": []}
+ for line in txt_string.split("\n"):
+ depth_filename = line.strip().split(" ")[1]
+ image_filename = line.strip().split(" ")[0]
+ dataset["depth_filename"].append(depth_filename)
+ dataset["image_filename"].append(image_filename)
+
+ self.dataset = pl.from_dict(dataset)
+
+ if not self.test_mode and self.mini:
+ self.dataset = self.dataset[::2]
+
+ def get_intrinsics(self, *args, **kwargs):
+ return self.CAM_INTRINSIC["ALL"].clone()
diff --git a/unik3d/datasets/dl3dv.py b/unik3d/datasets/dl3dv.py
new file mode 100644
index 0000000000000000000000000000000000000000..fed077b27393389f38054175eed101d76922269e
--- /dev/null
+++ b/unik3d/datasets/dl3dv.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class DL3DV(SequenceDataset):
+ min_depth = 0.001
+ max_depth = 250.0
+ depth_scale = 512.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"DL3DVcv.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/driving_stereo.py b/unik3d/datasets/driving_stereo.py
new file mode 100644
index 0000000000000000000000000000000000000000..512c3bd00658609538790c3893a4a3d367ca4629
--- /dev/null
+++ b/unik3d/datasets/driving_stereo.py
@@ -0,0 +1,82 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class DrivingStereo(ImageDataset):
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ test_split = "drivingstereo_val.txt"
+ train_split = "drivingstereo_train.txt"
+ intrisics_file = "drivingstereo_intrinsics.json"
+ hdf5_paths = ["DrivingStereo.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=1.0)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
diff --git a/unik3d/datasets/dtu_rmvd.py b/unik3d/datasets/dtu_rmvd.py
new file mode 100644
index 0000000000000000000000000000000000000000..80ee8d54944b7029a5eb5fcd59f499df0910adfb
--- /dev/null
+++ b/unik3d/datasets/dtu_rmvd.py
@@ -0,0 +1,62 @@
+import json
+import os
+from typing import Any
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.pipelines import AnnotationMask, KittiCrop
+from unik3d.datasets.sequence_dataset import SequenceDataset
+from unik3d.datasets.utils import DatasetFromList
+from unik3d.utils import identity
+
+
+class DTURMVD(SequenceDataset):
+ min_depth = 0.05
+ max_depth = 3.0
+ depth_scale = 1000.0
+ default_fps = 6
+ test_split = "test.txt"
+ train_split = "test.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["dtu_rmvd.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/dummy.py b/unik3d/datasets/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e0157c7e2c60631d16603aa7b5616fa32d07cf2
--- /dev/null
+++ b/unik3d/datasets/dummy.py
@@ -0,0 +1,33 @@
+import numpy as np
+import torch
+from torch.utils.data import Dataset
+
+
+class Dummy(Dataset):
+ train_split = None
+ test_split = None
+
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.dataset = np.arange(1_000_000)
+
+ def get_single_item(self, idx):
+ # results = {}
+ # results["cam2w"] = torch.eye(4).unsqueeze(0)
+ # results["K"] = torch.eye(3).unsqueeze(0)
+ # results["image"] = torch.zeros(1, 3, 1024, 1024).to(torch.uint8)
+ # results["depth"] = torch.zeros(1, 1, 1024, 1024).to(torch.float32)
+ return {
+ "x": {(0, 0): torch.rand(1, 3, 1024, 1024, dtype=torch.float32)},
+ "img_metas": {"val": torch.rand(1, 1024, dtype=torch.float32)},
+ }
+
+ def __getitem__(self, idx):
+ if isinstance(idx, (list, tuple)):
+ results = [self.get_single_item(i) for i in idx]
+ else:
+ results = self.get_single_item(idx)
+ return results
+
+ def __len__(self):
+ return len(self.dataset)
diff --git a/unik3d/datasets/dynamic_replica.py b/unik3d/datasets/dynamic_replica.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb6211d476cd8012dfa6be2bccf57e3a37ba1e5f
--- /dev/null
+++ b/unik3d/datasets/dynamic_replica.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class DynReplica(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 20.0
+ default_fps = 30.0
+ depth_scale = 512.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences_clean.json"
+ hdf5_paths = ["DynReplica.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/eden.py b/unik3d/datasets/eden.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e12daeaa1a94cc8a3ed98aa5facb51268e5db28
--- /dev/null
+++ b/unik3d/datasets/eden.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class EDEN(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 100.0
+ depth_scale = 256.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"EDEN.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/eth3d.py b/unik3d/datasets/eth3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..20c1110e27124b803458677cbf96592a13df4170
--- /dev/null
+++ b/unik3d/datasets/eth3d.py
@@ -0,0 +1,164 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.sequence_dataset import SequenceDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class ETH3D(ImageDataset):
+ min_depth = 0.01
+ max_depth = 50.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["ETH3D.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
+
+
+class ETH3D_F(SequenceDataset):
+ min_depth = 0.05
+ max_depth = 60.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["ETH3D-F.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, float],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=(
+ decode_fields if not test_mode else [*decode_fields, "points"]
+ ),
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
+
+
+class ETH3DRMVD(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 50.0
+ depth_scale = 1000.0
+ default_fps = 6
+ test_split = "test.txt"
+ train_split = "test.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["eth3d_rmvd.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
diff --git a/unik3d/datasets/facedepth.py b/unik3d/datasets/facedepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..108afbcde8c332f4b1736646893fddb93e0c5020
--- /dev/null
+++ b/unik3d/datasets/facedepth.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class FaceDepth(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ default_fps = 10
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["FaceDepth.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/flsea.py b/unik3d/datasets/flsea.py
new file mode 100644
index 0000000000000000000000000000000000000000..394a5dd1ddfbe4b1154922032e35103ca6d9e62a
--- /dev/null
+++ b/unik3d/datasets/flsea.py
@@ -0,0 +1,100 @@
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class FLSea(ImageDataset):
+ CAM_INTRINSIC = {
+ "canyons": torch.tensor(
+ [
+ [1175.3913431656817, 0.0, 466.2595428966926],
+ [0.0, 1174.2805075232263, 271.2116633091501],
+ [0.0, 0.0, 1.0],
+ ]
+ ),
+ "red_sea": torch.tensor(
+ [
+ [1296.666758476217, 0.0, 501.50386149846],
+ [0.0, 1300.831316354508, 276.161712082695],
+ [0.0, 0.0, 1.0],
+ ]
+ ),
+ }
+ min_depth = 0.05
+ max_depth = 20.0
+ depth_scale = 1000.0
+ train_split = "train.txt"
+ hdf5_paths = ["FLSea.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=False,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ sample = [image_filename, depth_filename]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.33)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_intrinsics(self, idx, image_name):
+ return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/futurehouse.py b/unik3d/datasets/futurehouse.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0346177f183735676cfc7f941ed3db09d6a3cb8
--- /dev/null
+++ b/unik3d/datasets/futurehouse.py
@@ -0,0 +1,56 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class FutureHouse(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"FutureHouse.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+ self.resizer = Compose(
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/gibson.py b/unik3d/datasets/gibson.py
new file mode 100644
index 0000000000000000000000000000000000000000..da662fffee527d53adfe3bb3467e316186dfb8fd
--- /dev/null
+++ b/unik3d/datasets/gibson.py
@@ -0,0 +1,56 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Gibson(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"Gibson.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+ self.resizer = Compose(
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/hammer.py b/unik3d/datasets/hammer.py
new file mode 100644
index 0000000000000000000000000000000000000000..90b8a8594747fd7cbad04ef6f683f7709605db7c
--- /dev/null
+++ b/unik3d/datasets/hammer.py
@@ -0,0 +1,76 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class HAMMER(ImageDataset):
+ min_depth = 0.005
+ max_depth = 10.0
+ depth_scale = 1000.0
+ train_split = "test.txt"
+ test_split = "test.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["hammer.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/hm3d.py b/unik3d/datasets/hm3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f77897da89f7412de9efd6c78b31fdf9f133b531
--- /dev/null
+++ b/unik3d/datasets/hm3d.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class HM3D(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "val.txt"
+ train_split = "full.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"HM3D.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/hoi4d.py b/unik3d/datasets/hoi4d.py
new file mode 100644
index 0000000000000000000000000000000000000000..c71408d3b34dc91c080173222a625266027befa1
--- /dev/null
+++ b/unik3d/datasets/hoi4d.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class HOI4D(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ default_fps = 5
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["HOI4D.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/hypersim.py b/unik3d/datasets/hypersim.py
new file mode 100644
index 0000000000000000000000000000000000000000..542147980b543e3247b46f505617cd2a1a842341
--- /dev/null
+++ b/unik3d/datasets/hypersim.py
@@ -0,0 +1,97 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class HyperSim(ImageDataset):
+ min_depth = 0.01
+ max_depth = 50.0
+ depth_scale = 1000.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = [f"hypersim/hypersim_{i}.hdf5" for i in range(8)]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii").strip("\n")
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+
+ # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
+ # f.write(txt_string)
+ # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
+ # json.dump(intrinsics, f)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename, chunk_idx = line.strip().split(" ")
+ intrinsics_val = torch.tensor(
+ intrinsics[os.path.join(*image_filename.split("/")[:2])]
+ ).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.1)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ "K": 2,
+ "chunk_idx": 3,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["synthetic"] = [True] * self.num_copies
+ results["quality"] = [0] * self.num_copies
+ return results
diff --git a/unik3d/datasets/ibims.py b/unik3d/datasets/ibims.py
new file mode 100644
index 0000000000000000000000000000000000000000..11af2b1bbe67c196a991d016dc1d4f9876847b18
--- /dev/null
+++ b/unik3d/datasets/ibims.py
@@ -0,0 +1,125 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.sequence_dataset import SequenceDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class IBims(ImageDataset):
+ min_depth = 0.005
+ max_depth = 25.0
+ depth_scale = 1000.0
+ train_split = "ibims_val.txt"
+ test_split = "ibims_val.txt"
+ intrisics_file = "ibims_intrinsics.json"
+ hdf5_paths = ["ibims.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
+
+
+class IBims_F(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 25.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["IBims-F.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, float],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=(
+ decode_fields if not test_mode else [*decode_fields, "points"]
+ ),
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/image_dataset.py b/unik3d/datasets/image_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..78fca306d0f9e884eb0c3ade305a8386030ff9ef
--- /dev/null
+++ b/unik3d/datasets/image_dataset.py
@@ -0,0 +1,194 @@
+import io
+import os
+from time import time
+from typing import Any, Dict, List, Tuple
+
+import numpy as np
+import tables
+import torch
+import torchvision
+import torchvision.transforms.v2.functional as TF
+from PIL import Image
+
+from unik3d.datasets.base_dataset import BaseDataset
+from unik3d.utils import is_main_process
+from unik3d.utils.camera import BatchCamera, Pinhole
+
+"""
+Awful class for legacy reasons, we assume only pinhole cameras
+And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4)
+"""
+
+
+class ImageDataset(BaseDataset):
+ def __init__(
+ self,
+ image_shape: Tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: Dict[str, Any],
+ shape_constraints: Dict[str, Any],
+ resize_method: str,
+ mini: float,
+ benchmark: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ shape_constraints=shape_constraints,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.mapper = self.get_mapper()
+
+ def get_single_item(self, idx, sample=None, mapper=None):
+ sample = self.dataset[idx] if sample is None else sample
+ mapper = self.mapper if mapper is None else mapper
+
+ results = {
+ (0, 0): dict(
+ gt_fields=set(),
+ image_fields=set(),
+ mask_fields=set(),
+ camera_fields=set(),
+ )
+ }
+ results = self.pre_pipeline(results)
+ results["sequence_fields"] = [(0, 0)]
+
+ chunk_idx = (
+ int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0
+ )
+ h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx])
+ with tables.File(
+ h5_path,
+ mode="r",
+ libver="latest",
+ swmr=True,
+ ) as h5file_chunk:
+ for key_mapper, idx_mapper in mapper.items():
+ if "image" not in key_mapper and "depth" not in key_mapper:
+ continue
+ value = sample[idx_mapper]
+ results[(0, 0)][key_mapper] = value
+ name = key_mapper.replace("_filename", "")
+ value_root = "/" + value
+
+ if "image" in key_mapper:
+ results[(0, 0)]["filename"] = value
+ file = h5file_chunk.get_node(value_root).read()
+ image = (
+ torchvision.io.decode_image(torch.from_numpy(file))
+ .to(torch.uint8)
+ .squeeze()
+ )
+ results[(0, 0)]["image_fields"].add(name)
+ results[(0, 0)][f"image_ori_shape"] = image.shape[-2:]
+ results[(0, 0)][name] = image[None, ...]
+
+ # collect camera information for the given image
+ name = name.replace("image_", "")
+ results[(0, 0)]["camera_fields"].update({"camera", "cam2w"})
+ K = self.get_intrinsics(idx, value)
+ if K is None:
+ K = torch.eye(3)
+ K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1]
+ K[0, 2] = 0.5 * self.image_shape[1]
+ K[1, 2] = 0.5 * self.image_shape[0]
+
+ camera = Pinhole(K=K[None, ...].clone())
+ results[(0, 0)]["camera"] = BatchCamera.from_camera(camera)
+ results[(0, 0)]["cam2w"] = self.get_extrinsics(idx, value)[
+ None, ...
+ ]
+
+ elif "depth" in key_mapper:
+ # start = time()
+ file = h5file_chunk.get_node(value_root).read()
+ depth = Image.open(io.BytesIO(file))
+ depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32)
+ if depth.ndim == 3:
+ depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255
+
+ results[(0, 0)]["gt_fields"].add(name)
+ results[(0, 0)][f"depth_ori_shape"] = depth.shape
+
+ depth = (
+ depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale
+ )
+ results[(0, 0)][name] = depth
+
+ results = self.preprocess(results)
+ if not self.test_mode:
+ results = self.augment(results)
+ results = self.postprocess(results)
+ return results
+
+ def preprocess(self, results):
+ results = self.replicate(results)
+ for i, seq in enumerate(results["sequence_fields"]):
+ self.resizer.ctx = None
+ results[seq] = self.resizer(results[seq])
+ num_pts = torch.count_nonzero(results[seq]["depth"] > 0)
+ if num_pts < 50:
+ raise IndexError(f"Too few points in depth map ({num_pts})")
+
+ for key in results[seq].get("image_fields", ["image"]):
+ results[seq][key] = results[seq][key].to(torch.float32) / 255
+
+ # update fields common in sequence
+ for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]:
+ if key in results[(0, 0)]:
+ results[key] = results[(0, 0)][key]
+ results = self.pack_batch(results)
+ return results
+
+ def postprocess(self, results):
+ # normalize after because color aug requires [0,255]?
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.normalize(results[key], **self.normalization_stats)
+ results = self.filler(results)
+ results = self.unpack_batch(results)
+ results = self.masker(results)
+ results = self.collecter(results)
+ return results
+
+ def __getitem__(self, idx):
+ try:
+ if isinstance(idx, (list, tuple)):
+ results = [self.get_single_item(i) for i in idx]
+ else:
+ results = self.get_single_item(idx)
+ except Exception as e:
+ print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}")
+ idx = np.random.randint(0, len(self.dataset))
+ results = self[idx]
+ return results
+
+ def get_intrinsics(self, idx, image_name):
+ idx_sample = self.mapper.get("K", 1000)
+ sample = self.dataset[idx]
+ if idx_sample >= len(sample):
+ return None
+ return sample[idx_sample]
+
+ def get_extrinsics(self, idx, image_name):
+ idx_sample = self.mapper.get("cam2w", 1000)
+ sample = self.dataset[idx]
+ if idx_sample >= len(sample):
+ return torch.eye(4)
+ return sample[idx_sample]
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ "K": 2,
+ }
diff --git a/unik3d/datasets/ken_burns.py b/unik3d/datasets/ken_burns.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab443b4ef1298b472a63ecb88d7432f6e1ad95ba
--- /dev/null
+++ b/unik3d/datasets/ken_burns.py
@@ -0,0 +1,95 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class KenBurns(ImageDataset):
+ min_depth = 0.05
+ max_depth = 50.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = [f"3dkenburns/3DKenBurns_{i}.hdf5" for i in range(8)]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii").strip("\n")
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+
+ # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
+ # f.write(txt_string)
+ # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
+ # json.dump(intrinsics, f)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename, chunk_idx = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.25)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ "K": 2,
+ "chunk_idx": 3,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["synthetic"] = [True] * self.num_copies
+ results["quality"] = [0] * self.num_copies
+ return results
diff --git a/unik3d/datasets/kitti.py b/unik3d/datasets/kitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..77b7cef52109279edcc758c0f20a846d4fcbd01b
--- /dev/null
+++ b/unik3d/datasets/kitti.py
@@ -0,0 +1,317 @@
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.pipelines import AnnotationMask, Compose, KittiCrop
+from unik3d.datasets.sequence_dataset import SequenceDataset
+from unik3d.datasets.utils import DatasetFromList
+from unik3d.utils import identity
+
+
+class KITTI(ImageDataset):
+ CAM_INTRINSIC = {
+ "2011_09_26": torch.tensor(
+ [
+ [7.215377e02, 0.000000e00, 6.095593e02, 4.485728e01],
+ [0.000000e00, 7.215377e02, 1.728540e02, 2.163791e-01],
+ [0.000000e00, 0.000000e00, 1.000000e00, 2.745884e-03],
+ ]
+ ),
+ "2011_09_28": torch.tensor(
+ [
+ [7.070493e02, 0.000000e00, 6.040814e02, 4.575831e01],
+ [0.000000e00, 7.070493e02, 1.805066e02, -3.454157e-01],
+ [0.000000e00, 0.000000e00, 1.000000e00, 4.981016e-03],
+ ]
+ ),
+ "2011_09_29": torch.tensor(
+ [
+ [7.183351e02, 0.000000e00, 6.003891e02, 4.450382e01],
+ [0.000000e00, 7.183351e02, 1.815122e02, -5.951107e-01],
+ [0.000000e00, 0.000000e00, 1.000000e00, 2.616315e-03],
+ ]
+ ),
+ "2011_09_30": torch.tensor(
+ [
+ [7.070912e02, 0.000000e00, 6.018873e02, 4.688783e01],
+ [0.000000e00, 7.070912e02, 1.831104e02, 1.178601e-01],
+ [0.000000e00, 0.000000e00, 1.000000e00, 6.203223e-03],
+ ]
+ ),
+ "2011_10_03": torch.tensor(
+ [
+ [7.188560e02, 0.000000e00, 6.071928e02, 4.538225e01],
+ [0.000000e00, 7.188560e02, 1.852157e02, -1.130887e-01],
+ [0.000000e00, 0.000000e00, 1.000000e00, 3.779761e-03],
+ ]
+ ),
+ }
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ log_mean = 2.5462
+ log_std = 0.5871
+ test_split = "kitti_eigen_test.txt"
+ train_split = "kitti_eigen_train.txt"
+ test_split_benchmark = "kitti_test.txt"
+ hdf5_paths = ["kitti.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.masker = AnnotationMask(
+ min_value=0.0,
+ max_value=self.max_depth if test_mode else None,
+ custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
+ )
+ self.test_mode = test_mode
+ self.crop = crop
+ self.cropper_base = KittiCrop(crop_size=(352, 1216))
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename = line.strip().split(" ")[0]
+ depth_filename = line.strip().split(" ")[1]
+ if depth_filename == "None":
+ continue
+ sample = [
+ image_filename,
+ depth_filename,
+ ]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_intrinsics(self, idx, image_name):
+ return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone()
+
+ def preprocess(self, results):
+ results = self.replicate(results)
+ for i, seq in enumerate(results["sequence_fields"]):
+ self.resizer.ctx = None
+ results[seq] = self.cropper_base(results[seq])
+ results[seq] = self.resizer(results[seq])
+ num_pts = torch.count_nonzero(results[seq]["depth"] > 0)
+ if num_pts < 50:
+ raise IndexError(f"Too few points in depth map ({num_pts})")
+
+ for key in results[seq].get("image_fields", ["image"]):
+ results[seq][key] = results[seq][key].to(torch.float32) / 255
+
+ # update fields common in sequence
+ for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]:
+ if key in results[(0, 0)]:
+ results[key] = results[(0, 0)][key]
+ results = self.pack_batch(results)
+ return results
+
+ def eval_mask(self, valid_mask, info={}):
+ """Do grag_crop or eigen_crop for testing"""
+ mask_height, mask_width = valid_mask.shape[-2:]
+ eval_mask = torch.zeros_like(valid_mask)
+ if "garg" in self.crop:
+ eval_mask[
+ ...,
+ int(0.40810811 * mask_height) : int(0.99189189 * mask_height),
+ int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
+ ] = 1
+ elif "eigen" in self.crop:
+ eval_mask[
+ ...,
+ int(0.3324324 * mask_height) : int(0.91351351 * mask_height),
+ int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
+ ] = 1
+ return torch.logical_and(valid_mask, eval_mask)
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
+
+
+import json
+
+
+class KITTIBenchmark(ImageDataset):
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ test_split = "test_split.txt"
+ train_split = "val_split.txt"
+ intrinsics_file = "intrinsics.json"
+ hdf5_paths = ["kitti_benchmark.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=True,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+
+ self.masker = AnnotationMask(
+ min_value=self.min_depth,
+ max_value=self.max_depth if test_mode else None,
+ custom_fn=lambda x, *args, **kwargs: x,
+ )
+ self.collecter = Collect(keys=["image_fields", "mask_fields", "gt_fields"])
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_path),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(self.h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrinsics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics = torch.tensor(
+ intrinsics[os.path.join(*image_filename.split("/")[:2])]
+ ).squeeze()[:, :3]
+ sample = {
+ "image_filename": image_filename,
+ "depth_filename": depth_filename,
+ "K": intrinsics,
+ }
+ dataset.append(sample)
+
+ self.dataset = DatasetFromList(dataset)
+
+ self.log_load_dataset()
+
+
+class KITTIRMVD(SequenceDataset):
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ default_fps = 10
+ test_split = "test.txt"
+ train_split = "test.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["kitti_rmvd.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+ self.crop = crop
+ self.resizer = Compose([KittiCrop(crop_size=(352, 1216)), self.resizer])
+
+ def eval_mask(self, valid_mask, info={}):
+ """Do grag_crop or eigen_crop for testing"""
+ mask_height, mask_width = valid_mask.shape[-2:]
+ eval_mask = torch.zeros_like(valid_mask)
+ if "garg" in self.crop:
+ eval_mask[
+ ...,
+ int(0.40810811 * mask_height) : int(0.99189189 * mask_height),
+ int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
+ ] = 1
+ elif "eigen" in self.crop:
+ eval_mask[
+ ...,
+ int(0.3324324 * mask_height) : int(0.91351351 * mask_height),
+ int(0.03594771 * mask_width) : int(0.96405229 * mask_width),
+ ] = 1
+ else:
+ return valid_mask
+ return torch.logical_and(valid_mask, eval_mask)
diff --git a/unik3d/datasets/kitti360.py b/unik3d/datasets/kitti360.py
new file mode 100644
index 0000000000000000000000000000000000000000..403afdd2b28d4052181b7b65c986af847be3c99c
--- /dev/null
+++ b/unik3d/datasets/kitti360.py
@@ -0,0 +1,65 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class KITTI360(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 80.0
+ depth_scale = 256.0
+ train_split = "train.txt"
+ test_split = "val_split.txt"
+ sequences_file = "sequences_split.json"
+ hdf5_paths = [f"KITTI360.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=(
+ decode_fields if not test_mode else [*decode_fields, "points"]
+ ),
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def preprocess(self, results):
+ self.resizer.ctx = None
+ for i, seq in enumerate(results["sequence_fields"]):
+ # Create a mask where the distance from the center is less than H/2
+ H, W = results[seq]["image"].shape[-2:]
+ x = torch.linspace(-W / 2, W / 2, W)
+ y = torch.linspace(-H / 2, H / 2, H)
+ xv, yv = torch.meshgrid(x, y, indexing="xy")
+ distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
+ results[seq]["validity_mask"] = distance_from_center < (H / 2)
+ return super().preprocess(results)
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/lyft.py b/unik3d/datasets/lyft.py
new file mode 100644
index 0000000000000000000000000000000000000000..bafeef68bc07a8222202c315e25ae24e8acf8597
--- /dev/null
+++ b/unik3d/datasets/lyft.py
@@ -0,0 +1,84 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class Lyft(ImageDataset):
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ test_split = "test.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["Lyft2.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
+ # f.write(txt_string)
+ # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
+ # json.dump(intrinsics, f)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [
+ image_filename,
+ depth_filename,
+ intrinsics_val,
+ ]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False]
+ return results
diff --git a/unik3d/datasets/mapillary.py b/unik3d/datasets/mapillary.py
new file mode 100644
index 0000000000000000000000000000000000000000..d58054add6dcd1cd5243de3bbd9eb496862d3488
--- /dev/null
+++ b/unik3d/datasets/mapillary.py
@@ -0,0 +1,84 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class Mapillary(ImageDataset):
+ min_depth = 0.01
+ max_depth = 70.0
+ depth_scale = 256.0
+ test_split = "mapillary_val.txt"
+ train_split = "mapillary_train_clean.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["Mapillary.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.05)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["si"] = [True] * self.num_copies
+ results["valid_camera"] = [False] * self.num_copies
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/matrix_city.py b/unik3d/datasets/matrix_city.py
new file mode 100644
index 0000000000000000000000000000000000000000..32a7f9f93451def80a1561d18a93777bc357f445
--- /dev/null
+++ b/unik3d/datasets/matrix_city.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class MatrixCity(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 200.0
+ depth_scale = 1000.0
+ test_split = "test.txt"
+ train_split = "train_full.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"MatrixCity.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/matterport3d.py b/unik3d/datasets/matterport3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c1602725974ae7959987ee1127024696dc0a3f2
--- /dev/null
+++ b/unik3d/datasets/matterport3d.py
@@ -0,0 +1,56 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Matterport3D(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"Matterport3D.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["cam2w", "camera_params"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+ self.resizer = Compose(
+ [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer]
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/megadepth.py b/unik3d/datasets/megadepth.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d65d8d1604595b14e516b1258ad49d23fb33acd
--- /dev/null
+++ b/unik3d/datasets/megadepth.py
@@ -0,0 +1,83 @@
+import os
+
+import h5py
+import numpy as np
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class MegaDepth(ImageDataset):
+ min_depth = 0.01
+ max_depth = 1000.0
+ depth_scale = 50.0
+ test_split = "test.txt"
+ train_split = "train.txt"
+ hdf5_paths = ["MegaDepth.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
+ # f.write(txt_string)
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ sample = [
+ image_filename,
+ depth_filename,
+ ]
+ dataset.append(sample)
+ h5file.close()
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+ else:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.5)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["ssi"] = [True]
+ results["valid_camera"] = [False]
+ results["dense"] = [False]
+ return results
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ }
diff --git a/unik3d/datasets/megadepth_s.py b/unik3d/datasets/megadepth_s.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad6189460f90aa527574c4e1554131e0c17b1fd1
--- /dev/null
+++ b/unik3d/datasets/megadepth_s.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class MegaDepthS(SequenceDataset):
+ min_depth = 0.001
+ max_depth = 10000.0
+ depth_scale = 512.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences_filter_clean.json"
+ hdf5_paths = ["MegaDepthS.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["intrinsics", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/midair.py b/unik3d/datasets/midair.py
new file mode 100644
index 0000000000000000000000000000000000000000..184a78305fd89e9ffc8dd6682ccc9ab5a645d0fc
--- /dev/null
+++ b/unik3d/datasets/midair.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class MidAir(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 1000.0
+ depth_scale = 1000.0
+ default_fps = 6
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["MidAir.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/mip.py b/unik3d/datasets/mip.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba18a000e60f5ffc694088de38f69bdf30dc2fe0
--- /dev/null
+++ b/unik3d/datasets/mip.py
@@ -0,0 +1,52 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class MIP(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 100.0
+ depth_scale = 1000.0
+ default_fps = 10
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["MIP.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/ms2.py b/unik3d/datasets/ms2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd9f13dd1fb3ff96a55cf4992fe35127fc8a2626
--- /dev/null
+++ b/unik3d/datasets/ms2.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class MS2(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 100.0
+ depth_scale = 256.0
+ default_fps = 5
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["MS2.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/mvimgnet.py b/unik3d/datasets/mvimgnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..53976cb31f8f0abd006d910478716fdbe83d8bf2
--- /dev/null
+++ b/unik3d/datasets/mvimgnet.py
@@ -0,0 +1,137 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+INVALID_SEQUENCES = [
+ "1/000121f2-0",
+ "15/1600ae56-0",
+ "26/000000f3-0",
+ "33/1d00e677-0",
+ "43/22008925-0",
+ "49/000147db-0",
+ "51/23002a43-0",
+ "51/23000916-0",
+ "108/000133ae-0",
+ "129/000037f2-0",
+ "141/17012545-0",
+ "141/1700f3de-0",
+ "152/1b00e061-0",
+ "154/1d00decb-0",
+ "154/1d017c1c-0",
+ "154/1d0019a5-0",
+ "154/1d00334d-0",
+ "154/1d012ed6-0",
+ "154/1d016b8a-0",
+ "154/1d016cc1-0",
+ "154/1d008d5f-0",
+ "159/000157f9-0",
+ "159/00000b96-0",
+ "159/000075c0-0",
+ "159/0000445c-0",
+ "159/000056a0-0",
+ "159/00010c68-0",
+ "159/0000573b-0",
+ "159/00002698-0",
+ "159/00008fca-0",
+ "159/00009ef8-0",
+ "159/00015f05-0",
+ "159/0000c6df-0",
+ "159/0000ee59-0",
+ "163/290159d2-0",
+ "163/29016c7c-0",
+ "163/2900239c-0",
+ "163/29002f7b-0",
+ "163/29014b05-0",
+ "163/29000196-0",
+ "163/2901750f-0",
+ "164/1b0145cf-0",
+ "164/1b00eb1d-0",
+ "164/1b00c28b-0",
+ "164/1b0110d0-0",
+ "164/1b00dd20-0",
+ "165/2600e15a-0",
+ "165/26008444-0",
+ "165/260145c5-0",
+ "165/26003a0c-0",
+ "165/260106ba-0",
+ "165/26001548-0",
+ "167/2a0092b0-0",
+ "167/2a014dbe-0",
+ "167/2a003ce6-0",
+ "169/1800c645-0",
+ "171/2500014d-0",
+ "176/1d0021c2-0",
+ "176/1d014abf-0",
+ "176/1d00e714-0",
+ "176/1d0159cb-0",
+ "176/1e016629-0",
+ "178/000102b8-0",
+ "191/23008fdb-0",
+ "191/2300187f-0",
+ "191/2300ae68-0",
+ "191/230076dd-0",
+ "191/24007d7e-0",
+ "192/000107b5-0",
+ "195/1f012359-0",
+ "195/1f00f751-0",
+ "195/1f011331-0",
+ "195/1e00d999-0",
+ "196/1c01304e-0",
+ "198/1a00e02f-0",
+ "198/050084ac-0",
+ "198/1a0075fa-0",
+ "199/1e001742-0",
+ "199/1e00116a-0",
+ "199/1e011d00-0",
+ "199/1e018040-0",
+ "199/1e001107-0",
+]
+
+
+class MVImgNet(SequenceDataset):
+ min_depth = 0.005
+ max_depth = 10.0
+ # weird scale issue, should be 1000, but avg depth is ~10meters...
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["MVImgNet.hdf5"]
+ invalid_sequences = INVALID_SEQUENCES
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["intrinsics", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/mvsynth.py b/unik3d/datasets/mvsynth.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d34f6c0c57ebe32baab09d80459f932cd9aff43
--- /dev/null
+++ b/unik3d/datasets/mvsynth.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class MVSynth(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 1000.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"MVSynth.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/nerds360.py b/unik3d/datasets/nerds360.py
new file mode 100644
index 0000000000000000000000000000000000000000..c669a0ab46a3d7b09cdfefa5c688d7aecadb7f32
--- /dev/null
+++ b/unik3d/datasets/nerds360.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class NeRDS360(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 1000.0
+ depth_scale = 1000.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["NeRDS360.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/niantic_mapfree.py b/unik3d/datasets/niantic_mapfree.py
new file mode 100644
index 0000000000000000000000000000000000000000..768cb3f0c4f11dd821bef00dd51dbeffb0d9b330
--- /dev/null
+++ b/unik3d/datasets/niantic_mapfree.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class NianticMapFree(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 250.0
+ depth_scale = 512.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"NianticMapFree.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/nuscenes.py b/unik3d/datasets/nuscenes.py
new file mode 100644
index 0000000000000000000000000000000000000000..1992efba64c81927f0452f363f3018180d4c0873
--- /dev/null
+++ b/unik3d/datasets/nuscenes.py
@@ -0,0 +1,89 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class Nuscenes(ImageDataset):
+ min_depth = 0.05
+ max_depth = 80.0
+ depth_scale = 256.0
+ test_split = "val.txt"
+ train_split = "train.txt"
+ intrisics_file = "intrinsics.json"
+ # hdf5_paths = ["Nuscenes2.hdf5"]
+ hdf5_paths = [f"nuscenes/nuscenes_{i}.hdf5" for i in range(8)]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii").strip("\n")
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename, chunk_idx = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=6, pct=self.mini)
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=6, pct=0.1)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ "K": 2,
+ "chunk_idx": 3,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_copies
+ results["quality"] = [1] * self.num_copies
+ return results
diff --git a/unik3d/datasets/nyuv2.py b/unik3d/datasets/nyuv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c270c83af17f4527adacc3fe0122022501c4b14b
--- /dev/null
+++ b/unik3d/datasets/nyuv2.py
@@ -0,0 +1,112 @@
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.pipelines import AnnotationMask
+from unik3d.datasets.utils import DatasetFromList
+from unik3d.utils import identity
+
+
+class NYUv2Depth(ImageDataset):
+ CAM_INTRINSIC = {
+ "ALL": torch.tensor(
+ [
+ [5.1885790117450188e02, 0, 3.2558244941119034e02],
+ [0, 5.1946961112127485e02, 2.5373616633400465e02],
+ [0, 0, 1],
+ ]
+ )
+ }
+ min_depth = 0.005
+ max_depth = 10.0
+ depth_scale = 1000.0
+ log_mean = 0.9140
+ log_std = 0.4825
+ test_split = "nyu_test.txt"
+ train_split = "nyu_train.txt"
+ hdf5_paths = ["nyuv2.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.masker = AnnotationMask(
+ min_value=0.0,
+ max_value=self.max_depth if test_mode else None,
+ custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename, _ = line.strip().split(" ")
+ sample = [
+ image_filename,
+ depth_filename,
+ ]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ return results
+
+ def get_intrinsics(self, idx, image_name):
+ return self.CAM_INTRINSIC["ALL"].clone()
+
+ def eval_mask(self, valid_mask, info={}):
+ border_mask = torch.zeros_like(valid_mask)
+ border_mask[..., 45:-9, 41:-39] = 1
+ return torch.logical_and(valid_mask, border_mask)
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/pipelines/__init__.py b/unik3d/datasets/pipelines/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..07aa134db64b5134bf32e797608360ec70e450df
--- /dev/null
+++ b/unik3d/datasets/pipelines/__init__.py
@@ -0,0 +1,46 @@
+from .formating import AnnotationMask, Collect
+from .transforms import (Compose, ContextCrop, Crop, Dilation, DownsamplerGT,
+ DummyCrop, GaussianBlur, KittiCrop, MotionBlur,
+ PanoCrop, PanoRoll, RandomAutoContrast,
+ RandomBrightness, RandomColor, RandomColorJitter,
+ RandomContrast, RandomCut, RandomEqualize,
+ RandomFiller, RandomFlip, RandomGamma,
+ RandomGrayscale, RandomInvert, RandomMasking,
+ RandomPosterize, RandomSaturation, RandomSharpness,
+ RandomShear, RandomSolarize, RandomTranslate, Resize,
+ Rotate)
+
+__all__ = [
+ "Resize",
+ "Rotate",
+ "RandomFlip",
+ "RandomBrightness",
+ "Crop",
+ "Dilation",
+ "RandomColor",
+ "RandomContrast",
+ "RandomEqualize",
+ "RandomSaturation",
+ "Collect",
+ "AnnotationMask",
+ "RandomSolarize",
+ "RandomPosterize",
+ "RandomSharpness",
+ "RandomShear",
+ "RandomTranslate",
+ "RandomAutoContrast",
+ "RandomInvert",
+ "KittiCrop",
+ "DownsamplerGT",
+ "RandomCut",
+ "RandomGamma",
+ "RandomMasking",
+ "RandomColorJitter",
+ "GaussianBlur",
+ "RandomGrayscale",
+ "ContextCrop",
+ "RandomFiller",
+ "MotionBlur",
+ "Compose",
+ "DummyCrop",
+]
diff --git a/unik3d/datasets/pipelines/formating.py b/unik3d/datasets/pipelines/formating.py
new file mode 100644
index 0000000000000000000000000000000000000000..f746d1727e5152bd1047f5972a5e175d8ed03b2d
--- /dev/null
+++ b/unik3d/datasets/pipelines/formating.py
@@ -0,0 +1,113 @@
+from collections.abc import Sequence
+
+import numpy as np
+import torch
+
+
+class Collect(object):
+ def __init__(
+ self,
+ keys,
+ meta_keys=(
+ "filename",
+ "keyframe_idx",
+ "sequence_name",
+ "image_filename",
+ "depth_filename",
+ "image_ori_shape",
+ "camera",
+ "original_camera",
+ "sfm",
+ "image_shape",
+ "resized_shape",
+ "scale_factor",
+ "rotation",
+ "resize_factor",
+ "flip",
+ "flip_direction",
+ "dataset_name",
+ "paddings",
+ "max_value",
+ "log_mean",
+ "log_std",
+ "image_rescale",
+ "focal_rescale",
+ "depth_rescale",
+ ),
+ ):
+ self.keys = keys
+ self.meta_keys = meta_keys
+
+ def __call__(self, results):
+ data_keys = [key for field in self.keys for key in results.get(field, [])]
+ data = {
+ key: {
+ sequence_key: results[key][sequence_key]
+ for sequence_key in results["sequence_fields"]
+ }
+ for key in data_keys
+ }
+ data["img_metas"] = {
+ key: value for key, value in results.items() if key not in data_keys
+ }
+ return data
+
+ def __repr__(self):
+ return (
+ self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})"
+ )
+
+
+class AnnotationMask(object):
+ def __init__(self, min_value, max_value, custom_fn=lambda x: x):
+ self.min_value = min_value
+ self.max_value = max_value
+ self.custom_fn = custom_fn
+
+ def __call__(self, results):
+ for key in results.get("gt_fields", []):
+ if key + "_mask" in results["mask_fields"]:
+ if "flow" in key:
+ for sequence_idx in results.get("sequence_fields", []):
+ boundaries = (results[key][sequence_idx] >= -1) & (
+ results[key][sequence_idx] <= 1
+ )
+ boundaries = boundaries[:, :1] & boundaries[:, 1:]
+ results[key + "_mask"][sequence_idx] = (
+ results[key + "_mask"][sequence_idx].bool() & boundaries
+ )
+ continue
+ for sequence_idx in results.get("sequence_fields", []):
+ # take care of xyz or flow, dim=1 is the channel dim
+ if results[key][sequence_idx].shape[1] == 1:
+ mask = results[key][sequence_idx] > self.min_value
+ else:
+ mask = (
+ results[key][sequence_idx].norm(dim=1, keepdim=True)
+ > self.min_value
+ )
+ if self.max_value is not None:
+ if results[key][sequence_idx].shape[1] == 1:
+ mask = mask & (results[key][sequence_idx] < self.max_value)
+ else:
+ mask = mask & (
+ results[key][sequence_idx].norm(dim=1, keepdim=True)
+ < self.max_value
+ )
+ mask = self.custom_fn(mask, info=results)
+ if key + "_mask" not in results:
+ results[key + "_mask"] = {}
+ if sequence_idx not in results[key + "_mask"]:
+ results[key + "_mask"][sequence_idx] = mask.bool()
+ else:
+ results[key + "_mask"][sequence_idx] = (
+ results[key + "_mask"][sequence_idx].bool() & mask.bool()
+ )
+ results["mask_fields"].add(key + "_mask")
+ return results
+
+ def __repr__(self):
+ return (
+ self.__class__.__name__
+ + f"(min_value={self.min_value}, max_value={ self.max_value})"
+ )
diff --git a/unik3d/datasets/pipelines/transforms.py b/unik3d/datasets/pipelines/transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f035f1b1ce224494c4019ea6bd97b5621f45cc
--- /dev/null
+++ b/unik3d/datasets/pipelines/transforms.py
@@ -0,0 +1,2176 @@
+import os
+import random
+from copy import deepcopy
+from math import ceil, exp, log, log2, log10, tanh
+from typing import Dict, List, Tuple
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.v2.functional as TF
+
+from unik3d.utils.geometric import downsample
+
+
+def euler_to_rotation_matrix(angles):
+ """
+ Convert Euler angles to a 3x3 rotation matrix.
+
+ Args:
+ angles (torch.Tensor): Euler angles [roll, pitch, yaw].
+
+ Returns:
+ torch.Tensor: 3x3 rotation matrix.
+ """
+ phi, theta, psi = angles
+ cos_phi, sin_phi = torch.cos(phi), torch.sin(phi)
+ cos_theta, sin_theta = torch.cos(theta), torch.sin(theta)
+ cos_psi, sin_psi = torch.cos(psi), torch.sin(psi)
+
+ # Rotation matrices
+ Rx = torch.tensor([[1, 0, 0], [0, cos_phi, -sin_phi], [0, sin_phi, cos_phi]])
+
+ Ry = torch.tensor(
+ [[cos_theta, 0, sin_theta], [0, 1, 0], [-sin_theta, 0, cos_theta]]
+ )
+
+ Rz = torch.tensor([[cos_psi, -sin_psi, 0], [sin_psi, cos_psi, 0], [0, 0, 1]])
+
+ return Rz @ Ry @ Rx
+
+
+def compute_grid(H, W):
+ meshgrid = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy")
+ id_coords = torch.stack(meshgrid, axis=0).to(torch.float32)
+ id_coords = id_coords.reshape(2, -1)
+ id_coords = torch.cat(
+ [id_coords, torch.ones(1, id_coords.shape[-1])], dim=0
+ ) # 3 HW
+ id_coords = id_coords.unsqueeze(0)
+ return id_coords
+
+
+def lexsort(keys):
+ sorted_indices = torch.arange(keys[0].size(0))
+ for key in reversed(keys):
+ _, sorted_indices = key[sorted_indices].sort()
+ return sorted_indices
+
+
+def masked_bilinear_interpolation(input, mask, target_size):
+ B, C, H, W = input.shape
+ target_H, target_W = target_size
+ mask = mask.float()
+
+ # Generate a grid of coordinates in the target space
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0, H - 1, target_H), torch.linspace(0, W - 1, target_W)
+ )
+ grid_y = grid_y.to(input.device)
+ grid_x = grid_x.to(input.device)
+
+ # Calculate the floor and ceil of the grid coordinates to get the bounding box
+ x0 = torch.floor(grid_x).long().clamp(0, W - 1)
+ x1 = (x0 + 1).clamp(0, W - 1)
+ y0 = torch.floor(grid_y).long().clamp(0, H - 1)
+ y1 = (y0 + 1).clamp(0, H - 1)
+
+ # Gather depth values at the four corners
+ Ia = input[..., y0, x0]
+ Ib = input[..., y1, x0]
+ Ic = input[..., y0, x1]
+ Id = input[..., y1, x1]
+
+ # Gather corresponding mask values
+ ma = mask[..., y0, x0]
+ mb = mask[..., y1, x0]
+ mc = mask[..., y0, x1]
+ md = mask[..., y1, x1]
+
+ # Calculate the areas (weights) for bilinear interpolation
+ wa = (x1.float() - grid_x) * (y1.float() - grid_y)
+ wb = (x1.float() - grid_x) * (grid_y - y0.float())
+ wc = (grid_x - x0.float()) * (y1.float() - grid_y)
+ wd = (grid_x - x0.float()) * (grid_y - y0.float())
+
+ wa = wa.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
+ wb = wb.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
+ wc = wc.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
+ wd = wd.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1)
+
+ # Only consider valid points for interpolation
+ weights_sum = (wa * ma) + (wb * mb) + (wc * mc) + (wd * md)
+ weights_sum = torch.clamp(weights_sum, min=1e-5)
+
+ # Perform the interpolation
+ interpolated_depth = (
+ wa * Ia * ma + wb * Ib * mb + wc * Ic * mc + wd * Id * md
+ ) / weights_sum
+
+ return interpolated_depth, (ma + mb + mc + md) > 0
+
+
+def masked_nearest_interpolation(input, mask, target_size):
+ B, C, H, W = input.shape
+ target_H, target_W = target_size
+ mask = mask.float()
+
+ # Generate a grid of coordinates in the target space
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0, H - 1, target_H),
+ torch.linspace(0, W - 1, target_W),
+ indexing="ij",
+ )
+ grid_y = grid_y.to(input.device)
+ grid_x = grid_x.to(input.device)
+
+ # Calculate the floor and ceil of the grid coordinates to get the bounding box
+ x0 = torch.floor(grid_x).long().clamp(0, W - 1)
+ x1 = (x0 + 1).clamp(0, W - 1)
+ y0 = torch.floor(grid_y).long().clamp(0, H - 1)
+ y1 = (y0 + 1).clamp(0, H - 1)
+
+ # Gather depth values at the four corners
+ Ia = input[..., y0, x0]
+ Ib = input[..., y1, x0]
+ Ic = input[..., y0, x1]
+ Id = input[..., y1, x1]
+
+ # Gather corresponding mask values
+ ma = mask[..., y0, x0]
+ mb = mask[..., y1, x0]
+ mc = mask[..., y0, x1]
+ md = mask[..., y1, x1]
+
+ # Calculate distances to each neighbor
+ # The distances are calculated from the center (grid_x, grid_y) to each corner
+ dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left
+ dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left
+ dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right
+ dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right
+
+ # Stack the neighbors, their masks, and distances
+ stacked_values = torch.stack(
+ [Ia, Ib, Ic, Id], dim=-1
+ ) # Shape: (B, C, target_H, target_W, 4)
+ stacked_masks = torch.stack(
+ [ma, mb, mc, md], dim=-1
+ ) # Shape: (B, 1, target_H, target_W, 4)
+ stacked_distances = torch.stack(
+ [dist_a, dist_b, dist_c, dist_d], dim=-1
+ ) # Shape: (target_H, target_W, 4)
+ stacked_distances = (
+ stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1)
+ ) # Shape: (B, 1, target_H, target_W, 4)
+
+ # Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen)
+ stacked_distances[stacked_masks == 0] = float("inf")
+
+ # Find the index of the nearest valid neighbor (the one with the smallest distance)
+ nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[
+ ..., :1
+ ] # Shape: (B, 1, target_H, target_W, 1)
+
+ # Select the corresponding depth value using the nearest valid neighbor index
+ interpolated_depth = torch.gather(
+ stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1)
+ ).squeeze(-1)
+
+ # Set depth to zero where no valid neighbors were found
+ interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip(
+ min=0.0, max=1.0
+ )
+
+ return interpolated_depth
+
+
+def masked_nxn_interpolation(input, mask, target_size, N=2):
+ B, C, H, W = input.shape
+ target_H, target_W = target_size
+
+ # Generate a grid of coordinates in the target space
+ grid_y, grid_x = torch.meshgrid(
+ torch.linspace(0, H - 1, target_H),
+ torch.linspace(0, W - 1, target_W),
+ indexing="ij",
+ )
+ grid_y = grid_y.to(input.device)
+ grid_x = grid_x.to(input.device)
+
+ # Calculate the top-left corner of the NxN grid
+ half_N = (N - 1) // 2
+ y0 = torch.floor(grid_y - half_N).long().clamp(0, H - 1)
+ x0 = torch.floor(grid_x - half_N).long().clamp(0, W - 1)
+
+ # Prepare to gather NxN neighborhoods
+ input_patches = []
+ mask_patches = []
+ weights = []
+
+ for i in range(N):
+ for j in range(N):
+ yi = (y0 + i).clamp(0, H - 1)
+ xi = (x0 + j).clamp(0, W - 1)
+
+ # Gather depth and mask values
+ input_patches.append(input[..., yi, xi])
+ mask_patches.append(mask[..., yi, xi])
+
+ # Compute bilinear weights
+ weight_y = 1 - torch.abs(grid_y - yi.float()) / N
+ weight_x = 1 - torch.abs(grid_x - xi.float()) / N
+ weight = (
+ (weight_y * weight_x)
+ .reshape(1, 1, target_H, target_W)
+ .repeat(B, C, 1, 1)
+ )
+ weights.append(weight)
+
+ input_patches = torch.stack(input_patches)
+ mask_patches = torch.stack(mask_patches)
+ weights = torch.stack(weights)
+
+ # Calculate weighted sum and normalize by the sum of weights
+ weighted_sum = (input_patches * mask_patches * weights).sum(dim=0)
+ weights_sum = (mask_patches * weights).sum(dim=0)
+ interpolated_tensor = weighted_sum / torch.clamp(weights_sum, min=1e-8)
+
+ if N != 2:
+ interpolated_tensor_2x2, mask_sum_2x2 = masked_bilinear_interpolation(
+ input, mask, target_size
+ )
+ interpolated_tensor = torch.where(
+ mask_sum_2x2, interpolated_tensor_2x2, interpolated_tensor
+ )
+
+ return interpolated_tensor
+
+
+class PanoCrop:
+ def __init__(self, crop_v=0.15):
+ self.crop_v = crop_v
+
+ def _crop_data(self, results, crop_size):
+ offset_w, offset_h = crop_size
+ left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1]
+ H, W = results["image"].shape[-2:]
+ for key in results.get("image_fields", ["image"]):
+ img = results[key][..., top : H - bottom, left : W - right]
+ results[key] = img
+ results["image_shape"] = tuple(img.shape)
+
+ for key in results.get("gt_fields", []):
+ results[key] = results[key][..., top : H - bottom, left : W - right]
+
+ for key in results.get("mask_fields", []):
+ results[key] = results[key][..., top : H - bottom, left : W - right]
+
+ results["camera"] = results["camera"].crop(left, top, right, bottom)
+ return results
+
+ def __call__(self, results):
+ H, W = results["image"].shape[-2:]
+ crop_w = (0, 0)
+ crop_h = (int(H * self.crop_v), int(H * self.crop_v))
+ results = self._crop_data(results, (crop_w, crop_h))
+ return results
+
+
+class PanoRoll:
+ def __init__(self, test_mode, roll=[-0.5, 0.5]):
+ self.roll = roll
+ self.test_mode = test_mode
+
+ def __call__(self, results):
+ if self.test_mode:
+ return results
+ W = results["image"].shape[-1]
+ roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1]))
+ for key in results.get("image_fields", ["image"]):
+ img = results[key]
+ img = torch.roll(img, roll, dims=-1)
+ results[key] = img
+ for key in results.get("gt_fields", []):
+ results[key] = torch.roll(results[key], roll, dims=-1)
+ for key in results.get("mask_fields", []):
+ results[key] = torch.roll(results[key], roll, dims=-1)
+ return results
+
+
+class RandomFlip:
+ def __init__(self, direction="horizontal", prob=0.5, consistent=False, **kwargs):
+ self.flip_ratio = prob
+ valid_directions = ["horizontal", "vertical", "diagonal"]
+ if isinstance(direction, str):
+ assert direction in valid_directions
+ elif isinstance(direction, list):
+ assert set(direction).issubset(set(valid_directions))
+ else:
+ raise ValueError("direction must be either str or list of str")
+ self.direction = direction
+ self.consistent = consistent
+
+ def __call__(self, results):
+ if "flip" not in results:
+ # None means non-flip
+ if isinstance(self.direction, list):
+ direction_list = self.direction + [None]
+ else:
+ direction_list = [self.direction, None]
+
+ if isinstance(self.flip_ratio, list):
+ non_flip_ratio = 1 - sum(self.flip_ratio)
+ flip_ratio_list = self.flip_ratio + [non_flip_ratio]
+ else:
+ non_flip_ratio = 1 - self.flip_ratio
+ # exclude non-flip
+ single_ratio = self.flip_ratio / (len(direction_list) - 1)
+ flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [
+ non_flip_ratio
+ ]
+
+ cur_dir = np.random.choice(direction_list, p=flip_ratio_list)
+
+ results["flip"] = cur_dir is not None
+
+ if "flip_direction" not in results:
+ results["flip_direction"] = cur_dir
+
+ if results["flip"]:
+ # flip image
+ if results["flip_direction"] != "vertical":
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.hflip(results[key])
+ for key in results.get("mask_fields", []):
+ results[key] = TF.hflip(results[key])
+ for key in results.get("gt_fields", []):
+ results[key] = TF.hflip(results[key])
+ if "flow" in key: # flip u direction
+ results[key][:, 0] = -results[key][:, 0]
+
+ H, W = results["image"].shape[-2:]
+ results["camera"] = results["camera"].flip(
+ H=H, W=W, direction="horizontal"
+ )
+ flip_transform = torch.tensor(
+ [[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]],
+ dtype=torch.float32,
+ ).unsqueeze(0)
+ repeats = (results["cam2w"].shape[0],) + (1,) * (
+ results["cam2w"].ndim - 1
+ )
+ results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"]
+
+ if results["flip_direction"] != "horizontal":
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.vflip(results[key])
+ for key in results.get("mask_fields", []):
+ results[key] = TF.vflip(results[key])
+ for key in results.get("gt_fields", []):
+ results[key] = TF.vflip(results[key])
+ results["K"][..., 1, 2] = (
+ results["image"].shape[-2] - results["K"][..., 1, 2]
+ )
+ results["flip"] = [results["flip"]] * len(results["image"])
+ return results
+
+
+class Crop:
+ def __init__(
+ self,
+ crop_size,
+ crop_type="absolute",
+ crop_offset=(0, 0),
+ ):
+ if crop_type not in [
+ "relative_range",
+ "relative",
+ "absolute",
+ "absolute_range",
+ ]:
+ raise ValueError(f"Invalid crop_type {crop_type}.")
+ if crop_type in ["absolute", "absolute_range"]:
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int)
+ else:
+ assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1
+ self.crop_size = crop_size
+ self.crop_type = crop_type
+ self.offset_h, self.offset_w = (
+ crop_offset[: len(crop_offset) // 2],
+ crop_offset[len(crop_offset) // 2 :],
+ )
+
+ def _get_crop_size(self, image_shape):
+ h, w = image_shape
+ if self.crop_type == "absolute":
+ return (min(self.crop_size[0], h), min(self.crop_size[1], w))
+ elif self.crop_type == "absolute_range":
+ assert self.crop_size[0] <= self.crop_size[1]
+ crop_h = np.random.randint(
+ min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1
+ )
+ crop_w = np.random.randint(
+ min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1
+ )
+ return crop_h, crop_w
+ elif self.crop_type == "relative":
+ crop_h, crop_w = self.crop_size
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+ elif self.crop_type == "relative_range":
+ crop_size = np.asarray(self.crop_size, dtype=np.float32)
+ crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size)
+ return int(h * crop_h + 0.5), int(w * crop_w + 0.5)
+
+ def _crop_data(self, results, crop_size):
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ for key in results.get("image_fields", ["image"]):
+ img = results[key]
+ img = TF.crop(
+ img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
+ )
+ results[key] = img
+ results["image_shape"] = tuple(img.shape)
+
+ for key in results.get("gt_fields", []):
+ gt = results[key]
+ results[key] = TF.crop(
+ gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
+ )
+
+ # crop semantic seg
+ for key in results.get("mask_fields", []):
+ mask = results[key]
+ results[key] = TF.crop(
+ mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1]
+ )
+
+ results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0]
+ results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0]
+ return results
+
+ def __call__(self, results):
+ image_shape = results["image"].shape[-2:]
+ crop_size = self._get_crop_size(image_shape)
+ results = self._crop_data(results, crop_size)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(crop_size={self.crop_size}, "
+ repr_str += f"crop_type={self.crop_type}, "
+ return repr_str
+
+
+class KittiCrop:
+ def __init__(self, crop_size):
+ self.crop_size = crop_size
+
+ def _crop_data(self, results, crop_size):
+ """Function to randomly crop images, bounding boxes, masks, semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+ crop_size (tuple): Expected absolute size after cropping, (h, w).
+ allow_negative_crop (bool): Whether to allow a crop that does not
+ contain any bbox area. Default to False.
+
+ Returns:
+ dict: Randomly cropped results, 'image_shape' key in result dict is
+ updated according to crop size.
+ """
+ assert crop_size[0] > 0 and crop_size[1] > 0
+ for key in results.get("image_fields", ["image"]):
+ img = results[key]
+ h, w = img.shape[-2:]
+ offset_h, offset_w = int(h - self.crop_size[0]), int(
+ (w - self.crop_size[1]) / 2
+ )
+
+ # crop the image
+ img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1])
+ results[key] = img
+ results["image_shape"] = tuple(img.shape)
+
+ for key in results.get("gt_fields", []):
+ gt = results[key]
+ results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1])
+
+ # crop semantic seg
+ for key in results.get("mask_fields", []):
+ mask = results[key]
+ results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1])
+
+ results["camera"] = results["camera"].crop(offset_w, offset_h)
+ return results
+
+ def __call__(self, results):
+ """Call function to randomly crop images, bounding boxes, masks,
+ semantic segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Randomly cropped results, 'image_shape' key in result dict is
+ updated according to crop size.
+ """
+ results = self._crop_data(results, self.crop_size)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(crop_size={self.crop_size}, "
+ return repr_str
+
+
+class RandomMasking:
+ def __init__(
+ self,
+ mask_ratio,
+ mask_patch=16,
+ prob=0.5,
+ warmup_steps=50000,
+ sampling="random",
+ curriculum=False,
+ ):
+ self.mask_patch = mask_patch
+ self.prob = prob
+ self.mask_ratio = mask_ratio
+ self.warmup_steps = max(1, warmup_steps)
+ self.hard_bound = 1
+ self.idx = 0
+ self.curriculum = curriculum
+ self.sampling = sampling
+ self.low_bound = 0.0
+ self.up_bound = 0.0
+
+ def __call__(self, results):
+ B, _, H, W = results["image"].shape
+ device = results["image"].device
+ down_size = H // self.mask_patch, W // self.mask_patch
+ if np.random.random() > self.prob: # fill with dummy
+ return self._nop(results, down_size, device)
+
+ validity_mask = results["validity_mask"].float().reshape(B, -1, H, W)
+ validity_mask = F.interpolate(validity_mask, size=down_size).bool()
+ validity_mask = validity_mask.reshape(B, 1, *down_size)
+ is_random = self.is_warmup or results.get("guidance") is None
+
+ if not is_random:
+ guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear")
+ results["guidance"] = -F.max_pool2d(
+ -guidance, kernel_size=self.mask_patch, stride=self.mask_patch
+ )
+
+ if is_random and self.sampling == "inverse":
+ sampling = self.inverse_sampling
+ elif is_random and self.sampling == "random":
+ sampling = self.random_sampling
+ else:
+ sampling = self.guided_sampling
+ mask_ratio = np.random.uniform(self.low_bound, self.up_bound)
+ for key in results.get("image_fields", ["image"]):
+ mask = sampling(results, mask_ratio, down_size, validity_mask, device)
+ results[key + "_mask"] = mask
+ return results
+
+ def _nop(self, results, down_size, device):
+ B = results["image"].shape[0]
+ for key in results.get("image_fields", ["image"]):
+ mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device)
+ results[key + "_mask"] = mask_blocks
+ return results
+
+ def random_sampling(self, results, mask_ratio, down_size, validity_mask, device):
+ B = results["image"].shape[0]
+ prob_blocks = torch.rand(size=(B, 1, *down_size), device=device)
+ mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask)
+ return mask_blocks
+
+ def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device):
+ # from PIL import Image
+ # from unik3d.utils import colorize
+ def area_sample(depth, fx, fy):
+ dtype = depth.dtype
+ B = depth.shape[0]
+ H, W = down_size
+ depth = downsample(depth, depth.shape[-2] // H)
+ depth[depth > 200] = 50 # set sky as if depth 50 meters
+ pixel_area3d = depth / torch.sqrt(fx * fy)
+
+ # Set invalid as -1 (no div problem) -> then clip to 0.0
+ pixel_area3d[depth == 0.0] = -1
+ prob_density = (1 / pixel_area3d).clamp(min=0.0).square()
+ prob_density = prob_density / prob_density.sum(
+ dim=(-1, -2), keepdim=True
+ ).clamp(min=1e-5)
+ # Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png")
+
+ # Sample locations based on prob_density
+ prob_density_flat = prob_density.view(B, -1)
+
+ # Get the avgerage valid locations, of those we mask self.mask_ratio
+ valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1)
+
+ masks = []
+ for i in range(B):
+ num_samples = int(valid_locations[i] * mask_ratio)
+ mask = torch.zeros_like(prob_density_flat[i])
+ # Sample indices
+ if num_samples > 0:
+ sampled_indices_flat = torch.multinomial(
+ prob_density_flat[i], num_samples, replacement=False
+ )
+ mask.scatter_(0, sampled_indices_flat, 1)
+ masks.append(mask)
+ return torch.stack(masks).bool().view(B, 1, H, W)
+
+ def random_sample(validity_mask):
+ prob_blocks = torch.rand(
+ size=(validity_mask.shape[0], 1, *down_size), device=device
+ )
+ mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask)
+ return mask
+
+ fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch
+ fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch
+
+ valid = ~results["ssi"] & ~results["si"] & results["valid_camera"]
+ mask_blocks = torch.zeros_like(validity_mask)
+ if valid.any():
+ out = area_sample(results["depth"][valid], fx[valid], fy[valid])
+ mask_blocks[valid] = out
+ if (~valid).any():
+ mask_blocks[~valid] = random_sample(validity_mask[~valid])
+
+ # mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy()
+ # Image.fromarray(mask_blocks_[0]).save("mask1.png")
+ # Image.fromarray(mask_blocks_[-1]).save("mask2.png")
+ # dd = results["depth"]
+ # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png")
+ # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png")
+ # dd = downsample(dd, dd.shape[-2] // down_size[0])
+ # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png")
+ # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png")
+ # raise ValueError
+
+ return mask_blocks
+
+ def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device):
+ # get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask
+ B = results["image"].shape[0]
+ guidance = results["guidance"]
+ mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device)
+ for b in range(B):
+ low_bound = torch.quantile(
+ guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio)
+ )
+ up_bound = torch.quantile(
+ guidance[b][validity_mask[b]], min(1.0, self.hard_bound)
+ )
+ mask_blocks[b] = torch.logical_and(
+ guidance[b] < up_bound, guidance[b] > low_bound
+ )
+ mask_blocks = torch.logical_and(mask_blocks, validity_mask)
+ return mask_blocks
+
+ def step(self):
+ self.idx += 1
+ # schedule hard from 1.0 to self.mask_ratio
+ if self.curriculum:
+ step = max(0, self.idx / self.warmup_steps / 2 - 0.5)
+ self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step)
+ self.up_bound = self.mask_ratio * tanh(step)
+ self.low_bound = 0.1 * tanh(step)
+
+ @property
+ def is_warmup(self):
+ return self.idx < self.warmup_steps
+
+
+class Resize:
+ def __init__(self, image_scale=None, image_shape=None, keep_original=False):
+ assert (image_scale is None) ^ (image_shape is None)
+ if isinstance(image_scale, (float, int)):
+ image_scale = (image_scale, image_scale)
+ if isinstance(image_shape, (float, int)):
+ image_shape = (int(image_shape), int(image_shape))
+ self.image_scale = image_scale
+ self.image_shape = image_shape
+ self.keep_original = keep_original
+
+ def _resize_img(self, results):
+ for key in results.get("image_fields", ["image"]):
+ img = TF.resize(
+ results[key],
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.BILINEAR,
+ antialias=True,
+ )
+ results[key] = img
+
+ def _resize_masks(self, results):
+ for key in results.get("mask_fields", []):
+ mask = TF.resize(
+ results[key],
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ antialias=True,
+ )
+ results[key] = mask
+
+ def _resize_gt(self, results):
+ for key in results.get("gt_fields", []):
+ gt = TF.resize(
+ results[key],
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ antialias=True,
+ )
+ results[key] = gt
+
+ def __call__(self, results):
+ h, w = results["image"].shape[-2:]
+ results["K_original"] = results["K"].clone()
+ if self.image_scale:
+ image_shape = (
+ int(h * self.image_scale[0] + 0.5),
+ int(w * self.image_scale[1] + 0.5),
+ )
+ image_scale = self.image_scale
+ elif self.image_shape:
+ image_scale = (self.image_shape[0] / h, self.image_shape[1] / w)
+ image_shape = self.image_shape
+ else:
+ raise ValueError(
+ f"In {self.__class__.__name__}: image_scale of image_shape must be set"
+ )
+
+ results["resized_shape"] = tuple(image_shape)
+ results["resize_factor"] = tuple(image_scale)
+ results["K"][..., 0, 2] = (results["K"][..., 0, 2] - 0.5) * image_scale[1] + 0.5
+ results["K"][..., 1, 2] = (results["K"][..., 1, 2] - 0.5) * image_scale[0] + 0.5
+ results["K"][..., 0, 0] = results["K"][..., 0, 0] * image_scale[1]
+ results["K"][..., 1, 1] = results["K"][..., 1, 1] * image_scale[0]
+
+ self._resize_img(results)
+ if not self.keep_original:
+ self._resize_masks(results)
+ self._resize_gt(results)
+
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ return repr_str
+
+
+class Rotate:
+ def __init__(
+ self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5
+ ):
+ if isinstance(img_fill_val, (float, int)):
+ img_fill_val = tuple([float(img_fill_val)] * 3)
+ elif isinstance(img_fill_val, tuple):
+ assert len(img_fill_val) == 3, (
+ "image_fill_val as tuple must "
+ f"have 3 elements. got {len(img_fill_val)}."
+ )
+ img_fill_val = tuple([float(val) for val in img_fill_val])
+ else:
+ raise ValueError("image_fill_val must be float or tuple with 3 elements.")
+ assert np.all(
+ [0 <= val <= 255 for val in img_fill_val]
+ ), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}."
+ assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}."
+ self.center = center
+ self.img_fill_val = img_fill_val
+ self.prob = prob
+ self.random = not isinstance(angle, (float, int))
+ self.angle = angle
+
+ def _rotate(self, results, angle, center=None, fill_val=0.0):
+ for key in results.get("image_fields", ["image"]):
+ img = results[key]
+ img_rotated = TF.rotate(
+ img,
+ angle,
+ center=center,
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=self.img_fill_val,
+ )
+ results[key] = img_rotated.to(img.dtype)
+ results["image_shape"] = results[key].shape
+
+ for key in results.get("mask_fields", []):
+ results[key] = TF.rotate(
+ results[key],
+ angle,
+ center=center,
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=fill_val,
+ )
+
+ for key in results.get("gt_fields", []):
+ results[key] = TF.rotate(
+ results[key],
+ angle,
+ center=center,
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=fill_val,
+ )
+
+ def __call__(self, results):
+ """Call function to rotate images, bounding boxes, masks and semantic
+ segmentation maps.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Rotated results.
+ """
+ if np.random.random() > self.prob:
+ return results
+
+ angle = (
+ (self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0]
+ if self.random
+ else np.random.choice([-1, 1], size=1) * self.angle
+ )
+ self._rotate(results, angle, None, fill_val=0.0)
+ results["rotation"] = angle
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(angle={self.angle}, "
+ repr_str += f"center={self.center}, "
+ repr_str += f"image_fill_val={self.img_fill_val}, "
+ repr_str += f"prob={self.prob}, "
+ return repr_str
+
+
+class RandomColor:
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _adjust_color_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else self.level
+ )
+ self._adjust_color_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomSaturation:
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _adjust_saturation_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get("image_fields", ["image"]):
+ # NOTE defaultly the image should be BGR format
+ results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else 2**self.level
+ )
+ self._adjust_saturation_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomSharpness:
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _adjust_sharpeness_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get("image_fields", ["image"]):
+ # NOTE defaultly the image should be BGR format
+ results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else 2**self.level
+ )
+ self._adjust_sharpeness_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomSolarize:
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _adjust_solarize_img(self, results, factor=255.0):
+ """Apply Color transformation to image."""
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.solarize(results[key], factor) # .to(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else self.level
+ )
+ self._adjust_solarize_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomPosterize:
+ """Apply Color transformation to image. The bboxes, masks, and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Color transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _posterize_img(self, results, factor=1.0):
+ """Apply Color transformation to image."""
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Color transformation.
+
+ Args:
+ results (dict): Result dict from loading pipeline.
+
+ Returns:
+ dict: Colored results.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else self.level
+ )
+ self._posterize_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomEqualize:
+ """Apply Equalize transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ prob (float): The probability for performing Equalize transformation.
+ """
+
+ def __init__(self, prob=0.5):
+ assert 0 <= prob <= 1.0, "The probability should be in range [0,1]."
+ self.prob = prob
+
+ def _imequalize(self, results):
+ """Equalizes the histogram of one image."""
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.equalize(results[key]) # .to(img.dtype)
+
+ def __call__(self, results):
+ """Call function for Equalize transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.random() > self.prob:
+ return results
+ self._imequalize(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(prob={self.prob})"
+
+
+class RandomBrightness:
+ """Apply Brightness transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Brightness transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _adjust_brightness_img(self, results, factor=1.0):
+ """Adjust the brightness of image."""
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype)
+
+ def __call__(self, results, level=None):
+ """Call function for Brightness transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else 2**self.level
+ )
+ self._adjust_brightness_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomContrast:
+ """Apply Contrast transformation to image. The bboxes, masks and
+ segmentations are not modified.
+
+ Args:
+ level (int | float): Should be in range [0,_MAX_LEVEL].
+ prob (float): The probability for performing Contrast transformation.
+ """
+
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def _adjust_contrast_img(self, results, factor=1.0):
+ """Adjust the image contrast."""
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype)
+
+ def __call__(self, results, level=None):
+ """Call function for Contrast transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (
+ 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else 2**self.level
+ )
+ self._adjust_contrast_img(results, factor)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class RandomGamma:
+ def __init__(self, level, prob=0.5):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+
+ def __call__(self, results, level=None):
+ """Call function for Contrast transformation.
+
+ Args:
+ results (dict): Results dict from loading pipeline.
+
+ Returns:
+ dict: Results after the transformation.
+ """
+ if np.random.random() > self.prob:
+ return results
+ factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0]
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ results[key] = TF.adjust_gamma(results[key], 1 + factor)
+ return results
+
+
+class RandomInvert:
+ def __init__(self, prob=0.5):
+ self.prob = prob
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ results[key] = TF.invert(results[key]) # .to(img.dtype)
+ return results
+
+
+class RandomAutoContrast:
+ def __init__(self, prob=0.5):
+ self.prob = prob
+
+ def _autocontrast_img(self, results):
+ for key in results.get("image_fields", ["image"]):
+ img = results[key]
+ results[key] = TF.autocontrast(img) # .to(img.dtype)
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+ self._autocontrast_img(results)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"prob={self.prob})"
+ return repr_str
+
+
+class Dilation:
+ def __init__(self, origin, kernel, border_value=-1.0, iterations=1) -> None:
+ self.structured_element = torch.ones(size=kernel)
+ self.origin = origin
+ self.border_value = border_value
+ self.iterations = iterations
+
+ def dilate(self, image):
+ image_pad = F.pad(
+ image,
+ [
+ self.origin[0],
+ self.structured_element.shape[0] - self.origin[0] - 1,
+ self.origin[1],
+ self.structured_element.shape[1] - self.origin[1] - 1,
+ ],
+ mode="constant",
+ value=self.border_value,
+ )
+ if image_pad.ndim < 4:
+ image_pad = image_pad.unsqueeze(0)
+ # Unfold the image to be able to perform operation on neighborhoods
+ image_unfold = F.unfold(image_pad, kernel_size=self.structured_element.shape)
+ # Flatten the structural element since its two dimensions have been flatten when unfolding
+ # structured_element_flatten = torch.flatten(self.structured_element).unsqueeze(0).unsqueeze(-1)
+ # Perform the greyscale operation; sum would be replaced by rest if you want erosion
+ # sums = image_unfold + structured_element_flatten
+ # Take maximum over the neighborhood
+ # since we use depth, we need to take the cloest point (perspectivity)
+ # thus the min. But min is for "unknown" (0), so put it to a large number
+ # than take min
+
+ mask = image_unfold < 1e-3 # if == 0, some pixels are not involved, why?
+
+ # Replace the zero elements with a large value, so they don't affect the minimum operation
+ image_unfold = image_unfold.masked_fill(mask, 1000.0)
+
+ # Calculate the minimum along the neighborhood axis
+ dilate_image = torch.min(image_unfold, dim=1).values
+
+ # Fill the masked values with 0 to propagate zero if all pixels are zero
+ dilate_image[mask.all(dim=1)] = 0
+ return torch.reshape(dilate_image, image.shape)
+
+ def __call__(self, results):
+ for key in results.get("gt_fields", []):
+ gt = results[key]
+ for _ in range(self.iterations):
+ gt[gt < 1e-4] = self.dilate(gt)[gt < 1e-4]
+ results[key] = gt
+
+ return results
+
+
+class RandomShear(object):
+ def __init__(
+ self,
+ level,
+ prob=0.5,
+ direction="horizontal",
+ ):
+ self.random = not isinstance(level, (float, int))
+ self.level = level
+ self.prob = prob
+ self.direction = direction
+
+ def _shear_img(self, results, magnitude):
+ for key in results.get("image_fields", ["image"]):
+ img_sheared = TF.affine(
+ results[key],
+ angle=0.0,
+ translate=[0.0, 0.0],
+ scale=1.0,
+ shear=magnitude,
+ interpolation=TF.InterpolationMode.BILINEAR,
+ fill=0.0,
+ )
+ results[key] = img_sheared
+
+ def _shear_masks(self, results, magnitude):
+ for key in results.get("mask_fields", []):
+ mask_sheared = TF.affine(
+ results[key],
+ angle=0.0,
+ translate=[0.0, 0.0],
+ scale=1.0,
+ shear=magnitude,
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=0.0,
+ )
+ results[key] = mask_sheared
+
+ def _shear_gt(
+ self,
+ results,
+ magnitude,
+ ):
+ for key in results.get("gt_fields", []):
+ mask_sheared = TF.affine(
+ results[key],
+ angle=0.0,
+ translate=[0.0, 0.0],
+ scale=1.0,
+ shear=magnitude,
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=0.0,
+ )
+ results[key] = mask_sheared
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+ magnitude = (
+ ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0])
+ if self.random
+ else np.random.choice([-1, 1], size=1) * self.level
+ )
+ if self.direction == "horizontal":
+ magnitude = [magnitude, 0.0]
+ else:
+ magnitude = [0.0, magnitude]
+ self._shear_img(results, magnitude)
+ self._shear_masks(results, magnitude)
+ self._shear_gt(results, magnitude)
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(level={self.level}, "
+ repr_str += f"img_fill_val={self.img_fill_val}, "
+ repr_str += f"seg_ignore_label={self.seg_ignore_label}, "
+ repr_str += f"prob={self.prob}, "
+ repr_str += f"direction={self.direction}, "
+ repr_str += f"max_shear_magnitude={self.max_shear_magnitude}, "
+ repr_str += f"random_negative_prob={self.random_negative_prob}, "
+ repr_str += f"interpolation={self.interpolation})"
+ return repr_str
+
+
+class RandomTranslate(object):
+ def __init__(
+ self,
+ range,
+ prob=0.5,
+ direction="horizontal",
+ ):
+ self.range = range
+ self.prob = prob
+ self.direction = direction
+
+ def _translate_img(self, results, magnitude):
+ for key in results.get("image_fields", ["image"]):
+ img_sheared = TF.affine(
+ results[key],
+ angle=0.0,
+ translate=magnitude,
+ scale=1.0,
+ shear=[0.0, 0.0],
+ interpolation=TF.InterpolationMode.BILINEAR,
+ fill=(123.68, 116.28, 103.53),
+ )
+ results[key] = img_sheared
+
+ def _translate_mask(self, results, magnitude):
+ for key in results.get("mask_fields", []):
+ mask_sheared = TF.affine(
+ results[key],
+ angle=0.0,
+ translate=magnitude,
+ scale=1.0,
+ shear=[0.0, 0.0],
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=0.0,
+ )
+ results[key] = mask_sheared
+
+ def _translate_gt(
+ self,
+ results,
+ magnitude,
+ ):
+ for key in results.get("gt_fields", []):
+ mask_sheared = TF.affine(
+ results[key],
+ angle=0.0,
+ translate=magnitude,
+ scale=1.0,
+ shear=[0.0, 0.0],
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ fill=0.0,
+ )
+ results[key] = mask_sheared
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+ magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0]
+ if self.direction == "horizontal":
+ magnitude = [magnitude * results["image"].shape[1], 0]
+ else:
+ magnitude = [0, magnitude * results["image"].shape[0]]
+ self._translate_img(results, magnitude)
+ self._translate_mask(results, magnitude)
+ self._translate_gt(results, magnitude)
+ results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0]
+ results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1]
+ return results
+
+ def __repr__(self):
+ repr_str = self.__class__.__name__
+ repr_str += f"(range={self.range}, "
+ repr_str += f"prob={self.prob}, "
+ repr_str += f"direction={self.direction}, "
+ return repr_str
+
+
+class RandomCut(object):
+ def __init__(self, prob=0.5, direction="all"):
+ self.direction = direction
+ self.prob = prob
+
+ def _cut_img(self, results, coord, dim):
+ for key in results.get("image_fields", ["image"]):
+ img_sheared = torch.roll(
+ results[key], int(coord * results[key].shape[dim]), dims=dim
+ )
+ results[key] = img_sheared
+
+ def _cut_mask(self, results, coord, dim):
+ for key in results.get("mask_fields", []):
+ mask_sheared = torch.roll(
+ results[key], int(coord * results[key].shape[dim]), dims=dim
+ )
+ results[key] = mask_sheared
+
+ def _cut_gt(self, results, coord, dim):
+ for key in results.get("gt_fields", []):
+ gt_sheared = torch.roll(
+ results[key], int(coord * results[key].shape[dim]), dims=dim
+ )
+ results[key] = gt_sheared
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+ coord = 0.8 * random.random() + 0.1
+ if self.direction == "horizontal":
+ dim = -1
+ elif self.direction == "vertical":
+ dim = -2
+ else:
+ dim = -1 if random.random() < 0.5 else -2
+
+ self._cut_img(results, coord, dim)
+ self._cut_mask(results, coord, dim)
+ self._cut_gt(results, coord, dim)
+ return results
+
+
+class DownsamplerGT(object):
+ def __init__(self, downsample_factor: int, min_depth: float = 0.01):
+ assert downsample_factor == round(
+ downsample_factor, 0
+ ), f"Downsample factor needs to be an integer, got {downsample_factor}"
+ self.downsample_factor = downsample_factor
+ self.min_depth = min_depth
+
+ def _downsample_gt(self, results):
+ for key in deepcopy(results.get("gt_fields", [])):
+ gt = results[key]
+ N, H, W = gt.shape
+ gt = gt.view(
+ N,
+ H // self.downsample_factor,
+ self.downsample_factor,
+ W // self.downsample_factor,
+ self.downsample_factor,
+ 1,
+ )
+ gt = gt.permute(0, 1, 3, 5, 2, 4)
+ gt = gt.view(-1, self.downsample_factor * self.downsample_factor)
+ gt_tmp = torch.where(gt == 0.0, 1e5 * torch.ones_like(gt), gt)
+ gt = torch.min(gt_tmp, dim=-1).values
+ gt = gt.view(N, H // self.downsample_factor, W // self.downsample_factor)
+ gt = torch.where(gt > 1000, torch.zeros_like(gt), gt)
+ results[f"{key}_downsample"] = gt
+ results["gt_fields"].append(f"{key}_downsample")
+ results["downsampled"] = True
+ return results
+
+ def __call__(self, results):
+ results = self._downsample_gt(results)
+ return results
+
+
+class RandomColorJitter:
+ def __init__(self, level, prob=0.9):
+ self.level = level
+ self.prob = prob
+ self.list_transform = [
+ self._adjust_brightness_img,
+ # self._adjust_sharpness_img,
+ self._adjust_contrast_img,
+ self._adjust_saturation_img,
+ self._adjust_color_img,
+ ]
+
+ def _adjust_contrast_img(self, results, factor=1.0):
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ img = results[key]
+ results[key] = TF.adjust_contrast(img, factor)
+
+ def _adjust_sharpness_img(self, results, factor=1.0):
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ img = results[key]
+ results[key] = TF.adjust_sharpness(img, factor)
+
+ def _adjust_brightness_img(self, results, factor=1.0):
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ img = results[key]
+ results[key] = TF.adjust_brightness(img, factor)
+
+ def _adjust_saturation_img(self, results, factor=1.0):
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ img = results[key]
+ results[key] = TF.adjust_saturation(img, factor / 2.0)
+
+ def _adjust_color_img(self, results, factor=1.0):
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ img = results[key]
+ results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0)
+
+ def __call__(self, results):
+ random.shuffle(self.list_transform)
+ for op in self.list_transform:
+ if np.random.random() < self.prob:
+ factor = 1.0 + (
+ (self.level[1] - self.level[0]) * np.random.random() + self.level[0]
+ )
+ op(results, factor)
+ return results
+
+
+class RandomGrayscale:
+ def __init__(self, prob=0.1, num_output_channels=3):
+ super().__init__()
+ self.prob = prob
+ self.num_output_channels = num_output_channels
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ results[key] = TF.rgb_to_grayscale(
+ results[key], num_output_channels=self.num_output_channels
+ )
+ return results
+
+
+class ContextCrop(Resize):
+ def __init__(
+ self,
+ image_shape,
+ keep_original=False,
+ test_min_ctx=1.0,
+ train_ctx_range=[0.5, 1.5],
+ shape_constraints={},
+ ):
+ super().__init__(image_shape=image_shape, keep_original=keep_original)
+ self.test_min_ctx = test_min_ctx
+ self.train_ctx_range = train_ctx_range
+
+ self.shape_mult = shape_constraints["shape_mult"]
+ self.sample = shape_constraints["sample"]
+ self.ratio_bounds = shape_constraints["ratio_bounds"]
+ pixels_min = shape_constraints["pixels_min"] / (
+ self.shape_mult * self.shape_mult
+ )
+ pixels_max = shape_constraints["pixels_max"] / (
+ self.shape_mult * self.shape_mult
+ )
+ self.pixels_bounds = (pixels_min, pixels_max)
+ self.keepGT = int(os.environ.get("keepGT", 0))
+ self.ctx = None
+
+ def _transform_img(self, results, shapes):
+ for key in results.get("image_fields", ["image"]):
+ img = self.crop(results[key], **shapes)
+ img = TF.resize(
+ img,
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.BICUBIC,
+ antialias=True,
+ )
+ results[key] = img
+
+ def _transform_masks(self, results, shapes):
+ for key in results.get("mask_fields", []):
+ mask = self.crop(results[key].float(), **shapes).byte()
+ if "flow" in key: # take pad/crop into flow resize
+ mask = TF.resize(
+ mask,
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ antialias=False,
+ )
+ else:
+ mask = masked_nearest_interpolation(
+ mask, mask > 0, results["resized_shape"]
+ )
+ results[key] = mask
+
+ def _transform_gt(self, results, shapes):
+ for key in results.get("gt_fields", []):
+ gt = self.crop(results[key], **shapes)
+ if not self.keepGT:
+ if "flow" in key: # take pad/crop into flow resize
+ gt = self._rescale_flow(gt, results)
+ gt = TF.resize(
+ gt,
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.NEAREST_EXACT,
+ antialias=False,
+ )
+ else:
+ gt = masked_nearest_interpolation(
+ gt, gt > 0, results["resized_shape"]
+ )
+
+ results[key] = gt
+
+ def _rescale_flow(self, gt, results):
+ h_new, w_new = gt.shape[-2:]
+ h_old, w_old = results["image_ori_shape"]
+ gt[:, 0] = gt[:, 0] * (w_old - 1) / (w_new - 1)
+ gt[:, 1] = gt[:, 1] * (h_old - 1) / (h_new - 1)
+ return gt
+
+ @staticmethod
+ def crop(img, height, width, top, left) -> torch.Tensor:
+ h, w = img.shape[-2:]
+ right = left + width
+ bottom = top + height
+ padding_ltrb = [
+ max(-left + min(0, right), 0),
+ max(-top + min(0, bottom), 0),
+ max(right - max(w, left), 0),
+ max(bottom - max(h, top), 0),
+ ]
+ image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right]
+ return TF.pad(image_cropped, padding_ltrb)
+
+ def test_closest_shape(self, image_shape):
+ h, w = image_shape
+ input_ratio = w / h
+ if self.sample:
+ input_pixels = int(ceil(h / self.shape_mult * w / self.shape_mult))
+ pixels = max(
+ min(input_pixels, self.pixels_bounds[1]), self.pixels_bounds[0]
+ )
+ ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1])
+ h = round((pixels / ratio) ** 0.5)
+ w = h * ratio
+ self.image_shape[0] = int(h) * self.shape_mult
+ self.image_shape[1] = int(w) * self.shape_mult
+
+ def _get_crop_shapes(self, image_shape, ctx=None):
+ h, w = image_shape
+ input_ratio = w / h
+ if self.keep_original:
+ self.test_closest_shape(image_shape)
+ ctx = 1.0
+ elif ctx is None:
+ ctx = float(
+ torch.empty(1)
+ .uniform_(self.train_ctx_range[0], self.train_ctx_range[1])
+ .item()
+ )
+ output_ratio = self.image_shape[1] / self.image_shape[0]
+
+ if output_ratio <= input_ratio: # out like 4:3 in like kitti
+ if (
+ ctx >= 1
+ ): # fully in -> use just max_length with sqrt(ctx), here max is width
+ new_w = w * ctx**0.5
+ # sporge un po in una sola dim
+ # we know that in_width will stick out before in_height, partial overshoot (sporge)
+ # new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx
+ elif output_ratio / input_ratio * ctx > 1:
+ new_w = w * ctx
+ else: # fully contained -> use area
+ new_w = w * (ctx * output_ratio / input_ratio) ** 0.5
+ new_h = new_w / output_ratio
+ else:
+ if ctx >= 1:
+ new_h = h * ctx**0.5
+ elif input_ratio / output_ratio * ctx > 1:
+ new_h = h * ctx
+ else:
+ new_h = h * (ctx * input_ratio / output_ratio) ** 0.5
+ new_w = new_h * output_ratio
+ return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx
+
+ # def sample_view(self, results):
+ # original_K = results["K"]
+ # original_image = results["image"]
+ # original_depth = results["depth"]
+ # original_validity_mask = results["validity_mask"].float()
+ # # sample angles and translation
+ # # sample translation:
+ # # 10 max of z
+
+ # x = np.random.normal(0, 0.05 / 2) * original_depth.max()
+ # y = np.random.normal(0, 0.05)
+ # z = np.random.normal(0, 0.05) * original_depth.max()
+
+ # fov = 2 * np.arctan(original_image.shape[-2] / 2 / results["K"][0, 0, 0])
+ # phi = np.random.normal(0, fov / 10)
+ # theta = np.random.normal(0, fov / 10)
+ # psi = np.random.normal(0, np.pi / 60)
+ # translation = torch.tensor([x, y, z]).unsqueeze(0)
+ # angles = torch.tensor([phi, theta, psi])
+ # angles = euler_to_rotation_matrix(angles)
+ # translation = translation @ angles # translation before rotation
+
+ # cam2w = torch.eye(4).unsqueeze(0)
+ # cam2w[..., :3, :3] = angles
+ # cam2w[..., :3, 3] = translation
+ # cam2cam = torch.inverse(cam2w)
+ # image_warped, depth_warped = forward_warping(original_image, original_depth, original_K, original_K, cam2cam=cam2cam)
+ # depth_warped[depth_warped > 0] = depth_warped[depth_warped > 0] - z
+ # validity_mask_warped = image_warped.sum(dim=1, keepdim=True) > 0.0
+
+ # results["K"] = results["K"].repeat(2, 1, 1)
+ # results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), cam2w])
+ # results["image"] = torch.cat([original_image, image_warped])
+ # results["depth"] = torch.cat([original_depth, depth_warped])
+ # results["validity_mask"] = torch.cat([original_validity_mask, validity_mask_warped], dim=0)
+
+ # # results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), torch.eye(4).unsqueeze(0)])
+ # # results["image"] = torch.cat([original_image, original_image])
+ # # results["depth"] = torch.cat([original_depth, original_depth])
+ # # results["validity_mask"] = torch.cat([original_validity_mask, original_validity_mask], dim=0)
+ # return results
+
+ def __call__(self, results):
+ h, w = results["image"].shape[-2:]
+ results["image_ori_shape"] = (h, w)
+ results["camera_fields"].add("camera_original")
+ results["camera_original"] = results["camera"].clone()
+
+ results.get("mask_fields", set()).add("validity_mask")
+ if "validity_mask" not in results:
+ results["validity_mask"] = torch.ones(
+ (results["image"].shape[0], 1, h, w),
+ dtype=torch.uint8,
+ device=results["image"].device,
+ )
+
+ n_iter = 1 if self.keep_original or not self.sample else 100
+
+ min_valid_area = 0.5
+ max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list
+ ctx = None
+ for ii in range(n_iter):
+
+ (height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx)
+ margin_h = h - height
+ margin_w = w - width
+
+ # keep it centered in y direction
+ top = margin_h // 2
+ left = margin_w // 2
+ if not self.keep_original:
+ left = left + np.random.randint(
+ -self.shape_mult // 2, self.shape_mult // 2 + 1
+ )
+ top = top + np.random.randint(
+ -self.shape_mult // 2, self.shape_mult // 2 + 1
+ )
+
+ right = left + width
+ bottom = top + height
+ x_zoom = self.image_shape[0] / height
+ paddings = [
+ max(-left + min(0, right), 0),
+ max(bottom - max(h, top), 0),
+ max(right - max(w, left), 0),
+ max(-top + min(0, bottom), 0),
+ ]
+
+ valid_area = (
+ h
+ * w
+ / (h + paddings[1] + paddings[3])
+ / (w + paddings[0] + paddings[2])
+ )
+ new_hfov, new_vfov = results["camera_original"].get_new_fov(
+ new_shape=(height, width), original_shape=(h, w)
+ )[0]
+ # if valid_area >= min_valid_area or getattr(self, "ctx", None) is not None:
+ # break
+ if (
+ valid_area >= min_valid_area
+ and new_hfov < max_hfov
+ and new_vfov < max_vfov
+ ):
+ break
+ ctx = (
+ ctx * 0.96
+ ) # if not enough valid area, try again with less ctx (more zoom)
+
+ # save ctx for next iteration of sequences?
+ self.ctx = ctx
+
+ results["resized_shape"] = self.image_shape
+ results["paddings"] = paddings # left ,top ,right, bottom
+ results["image_rescale"] = x_zoom
+ results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom
+ results["camera"] = results["camera"].crop(
+ left, top, right=w - right, bottom=h - bottom
+ )
+ results["camera"] = results["camera"].resize(x_zoom)
+
+ # print("XAM", results["camera"].params.squeeze(), results["camera"][0].params.squeeze(), results["camera_original"].params.squeeze(), results["camera_original"][0].params.squeeze())
+
+ shapes = dict(height=height, width=width, top=top, left=left)
+ self._transform_img(results, shapes)
+ if not self.keep_original:
+ self._transform_gt(results, shapes)
+ self._transform_masks(results, shapes)
+ else:
+ # only validity_mask (rgb's masks follows rgb transform) #FIXME
+ mask = results["validity_mask"].float()
+ mask = self.crop(mask, **shapes).byte()
+ mask = TF.resize(
+ mask,
+ results["resized_shape"],
+ interpolation=TF.InterpolationMode.NEAREST,
+ )
+ results["validity_mask"] = mask
+
+ # # print(ii, ctx, results["camera"].hfov[0] * 180 / np.pi, original_hfov * 180 / np.pi, results["camera"].vfov[0] * 180 / np.pi, original_vfov * 180 / np.pi, valid_area)
+ # from PIL import Image
+ # from unik3d.utils.visualization import colorize
+ # img1 = results["image"][0].permute(1,2,0).clip(0, 255.0).cpu().numpy()
+ # # img2 = results["image"][1].permute(1,2,0).clip(0, 255.0).cpu().numpy()
+ # Image.fromarray(img1.astype(np.uint8)).save("test_col1.png")
+ # # Image.fromarray(img2.astype(np.uint8)).save("test_col2.png")
+ # Image.fromarray(colorize(results["depth"][0].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep1.png")
+ # # Image.fromarray(colorize(results["depth"][1].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep2.png")
+ # raise ValueError
+
+ # keep original images before photo-augment
+ results["image_original"] = results["image"].clone()
+ results["image_fields"].add(
+ *[
+ field.replace("image", "image_original")
+ for field in results["image_fields"]
+ ]
+ )
+
+ # repeat for batch resized shape and paddings
+ results["paddings"] = [results["paddings"]] * results["image"].shape[0]
+ results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[
+ 0
+ ]
+ return results
+
+
+class RandomFiller:
+ def __init__(self, test_mode, *args, **kwargs):
+ super().__init__()
+ self.test_mode = test_mode
+
+ def _transform(self, results):
+ def fill_noise(size, device):
+ return torch.normal(0, 2.0, size=size, device=device)
+
+ def fill_black(size, device):
+ return -4 * torch.ones(size, device=device, dtype=torch.float32)
+
+ def fill_white(size, device):
+ return 4 * torch.ones(size, device=device, dtype=torch.float32)
+
+ def fill_zero(size, device):
+ return torch.zeros(size, device=device, dtype=torch.float32)
+
+ B, C = results["image"].shape[:2]
+ mismatch = B // results["validity_mask"].shape[0]
+ if mismatch:
+ results["validity_mask"] = results["validity_mask"].repeat(
+ mismatch, 1, 1, 1
+ )
+ validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool()
+ filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero])
+ if self.test_mode:
+ filler_fn = fill_zero
+ for key in results.get("image_fields", ["image"]):
+ results[key][~validity_mask] = filler_fn(
+ size=results[key][~validity_mask].shape, device=results[key].device
+ )
+
+ def __call__(self, results):
+ # generate mask for filler
+ if "validity_mask" not in results:
+ paddings = results.get("padding_size", [0] * 4)
+ height, width = results["image"].shape[-2:]
+ results.get("mask_fields", set()).add("validity_mask")
+ results["validity_mask"] = torch.zeros_like(results["image"][:, :1])
+ results["validity_mask"][
+ ...,
+ paddings[1] : height - paddings[3],
+ paddings[0] : width - paddings[2],
+ ] = 1.0
+ self._transform(results)
+ return results
+
+
+class GaussianBlur:
+ def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.sigma = sigma
+ self.prob = prob
+ self.padding = kernel_size // 2
+
+ def apply(self, x, kernel):
+ # Pad the input tensor
+ x = F.pad(
+ x, (self.padding, self.padding, self.padding, self.padding), mode="reflect"
+ )
+ # Apply the convolution with the Gaussian kernel
+ return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1))
+
+ def _create_kernel(self, sigma):
+ # Create a 1D Gaussian kernel
+ kernel_1d = torch.exp(
+ -torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2)
+ )
+ kernel_1d = kernel_1d / kernel_1d.sum()
+
+ # Expand the kernel to 2D and match size of the input
+ kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1)
+ kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand(
+ 3, 1, -1, -1
+ )
+ return kernel_2d
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+ sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0]
+ kernel = self._create_kernel(sigma)
+ for key in results.get("image_fields", ["image"]):
+ if "original" not in key:
+ results[key] = self.apply(results[key], kernel)
+ return results
+
+
+class MotionBlur:
+ def __init__(self, kernel_size=(9, 9), angles=(-180, 180), prob=0.1):
+ super().__init__()
+ self.kernel_size = kernel_size
+ self.angles = angles
+ self.prob = prob
+ self.padding = kernel_size // 2
+
+ def _create_kernel(self, angle):
+ # Generate a 2D grid of coordinates
+ grid = torch.meshgrid(
+ torch.arange(self.kernel_size), torch.arange(self.kernel_size)
+ )
+ grid = torch.stack(grid).float() # Shape: (2, kernel_size, kernel_size)
+
+ # Calculate relative coordinates from the center
+ center = (self.kernel_size - 1) / 2.0
+ x_offset = grid[1] - center
+ y_offset = grid[0] - center
+
+ # Compute motion blur kernel
+ cos_theta = torch.cos(angle * torch.pi / 180.0)
+ sin_theta = torch.sin(angle * torch.pi / 180.0)
+ kernel = (1.0 / self.kernel_size) * (
+ 1.0 - torch.abs(x_offset * cos_theta + y_offset * sin_theta)
+ )
+
+ # Expand kernel dimensions to match input image channels
+ kernel = kernel.unsqueeze(0).unsqueeze(0).expand(3, 1, -1, -1)
+ return kernel
+
+ def apply(self, image, kernel):
+ x = F.pad(
+ x, (self.padding, self.padding, self.padding, self.padding), mode="reflect"
+ )
+ # Apply convolution with the motion blur kernel
+ blurred_image = F.conv2d(image, kernel, stride=1, padding=0, groups=x.size(1))
+ return blurred_image
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+
+ angle = np.random.uniform(self.angles[0], self.angles[1])
+ kernel = self._create_kernel(angle)
+ for key in results.get("image_fields", ["image"]):
+ if "original" in key:
+ continue
+ results[key] = self.apply(results[key], kernel)
+
+ return results
+
+
+class JPEGCompression:
+ def __init__(self, level=(10, 70), prob=0.1):
+ super().__init__()
+ self.level = level
+ self.prob = prob
+
+ def __call__(self, results):
+ if np.random.random() > self.prob:
+ return results
+
+ level = np.random.uniform(self.level[0], self.level[1])
+ for key in results.get("image_fields", ["image"]):
+ if "original" in key:
+ continue
+ results[key] = TF.jpeg(results[key], level)
+
+ return results
+
+
+class Compose:
+ def __init__(self, transforms):
+ self.transforms = deepcopy(transforms)
+
+ def __call__(self, results):
+ for t in self.transforms:
+ results = t(results)
+ return results
+
+ def __setattr__(self, name: str, value) -> None:
+ super().__setattr__(name, value)
+ for t in self.transforms:
+ setattr(t, name, value)
+
+ def __repr__(self):
+ format_string = self.__class__.__name__ + "("
+ for t in self.transforms:
+ format_string += f"\n {t}"
+ format_string += "\n)"
+ return format_string
+
+
+class DummyCrop(Resize):
+ def __init__(
+ self,
+ *args,
+ **kwargs,
+ ):
+ # dummy image shape, not really used
+ super().__init__(image_shape=(512, 512))
+
+ def __call__(self, results):
+ h, w = results["image"].shape[-2:]
+ results["image_ori_shape"] = (h, w)
+ results["camera_fields"].add("camera_original")
+ results["camera_original"] = results["camera"].clone()
+ results.get("mask_fields", set()).add("validity_mask")
+ if "validity_mask" not in results:
+ results["validity_mask"] = torch.ones(
+ (results["image"].shape[0], 1, h, w),
+ dtype=torch.uint8,
+ device=results["image"].device,
+ )
+
+ self.ctx = 1.0
+
+ results["resized_shape"] = self.image_shape
+ results["paddings"] = [0, 0, 0, 0]
+ results["image_rescale"] = 1.0
+ results["scale_factor"] = results.get("scale_factor", 1.0) * 1.0
+ results["camera"] = results["camera"].crop(0, 0, right=w, bottom=h)
+ results["camera"] = results["camera"].resize(1)
+
+ # keep original images before photo-augment
+ results["image_original"] = results["image"].clone()
+ results["image_fields"].add(
+ *[
+ field.replace("image", "image_original")
+ for field in results["image_fields"]
+ ]
+ )
+
+ # repeat for batch resized shape and paddings
+ results["paddings"] = [results["paddings"]] * results["image"].shape[0]
+ results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[
+ 0
+ ]
+ return results
diff --git a/unik3d/datasets/point_odyssey.py b/unik3d/datasets/point_odyssey.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3cc3de5419518fa1a397ca757abd2175bc33248
--- /dev/null
+++ b/unik3d/datasets/point_odyssey.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class PointOdyssey(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 250.0
+ depth_scale = 1000.0
+ test_split = "test.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences_clean.json"
+ hdf5_paths = [f"PointOdyssey.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/proteus.py b/unik3d/datasets/proteus.py
new file mode 100644
index 0000000000000000000000000000000000000000..726fb3e9a5f764fc2c1f605252f923e98eae941a
--- /dev/null
+++ b/unik3d/datasets/proteus.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Proteus(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ default_fps = 5
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["Proteus.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/samplers.py b/unik3d/datasets/samplers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0a33a032c0fc346f256fcd81959a431727cb51c
--- /dev/null
+++ b/unik3d/datasets/samplers.py
@@ -0,0 +1,242 @@
+import itertools
+import warnings
+from operator import itemgetter
+from typing import Any, Optional
+
+import numpy as np
+import torch
+from torch.utils.data import Sampler
+
+from unik3d.utils import get_dist_info
+
+
+def _get_numpy_dtype(size: int) -> Any:
+ return np.int32 if size <= 2**31 else np.int64
+
+
+def _get_torch_dtype(size: int) -> Any:
+ return torch.int32 if size <= 2**31 else torch.int64
+
+
+def _generate_randperm_indices(*, size: int, generator: torch.Generator):
+ """Generate the indices of a random permutation."""
+ dtype = _get_torch_dtype(size)
+ # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921
+ perm = torch.arange(size, dtype=dtype)
+ for i in range(size):
+ j = torch.randint(i, size, size=(1,), generator=generator).item()
+
+ # Always swap even if no-op
+ value = perm[j].item()
+ perm[j] = perm[i].item()
+ perm[i] = value
+ yield value
+
+
+# The following function is somewhat equivalent to _new_shuffle_tensor_slice below,
+# but avoids a full in-place random permutation generation.
+def _shuffle_tensor_slice(
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
+) -> np.ndarray:
+ stop = len(tensor)
+ count = stop // step
+ drop_count = stop - step * count
+ if drop_count:
+ warnings.warn(f"# of dropped samples: {drop_count}")
+
+ dtype = _get_numpy_dtype(stop)
+ result = np.empty(count, dtype=dtype)
+
+ for i in range(count):
+ j = (
+ torch.randint(0, i + 1, size=(1,), generator=generator).item()
+ if i > 0
+ else 0
+ )
+
+ result[i] = result[j]
+ result[j] = tensor[start + i * step].item()
+
+ return result
+
+
+def _new_shuffle_tensor_slice(
+ *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator
+) -> np.ndarray:
+ stop = len(tensor)
+ count = stop // step
+ dtype = torch.int64 # Needed for using randperm result as indices
+ count = stop // step
+ drop_count = stop - step * count
+ if drop_count:
+ warnings.warn(f"# of dropped samples: {drop_count}")
+ indices = torch.randperm(count, dtype=dtype, generator=generator)
+ return tensor[start::step][indices].numpy()
+
+
+def _make_seed(seed: int, start: int, iter_count: int) -> int:
+ # NOTE: Tried a few variants (including iter_count << 32), this one worked best.
+ return seed + start + (iter_count << 24)
+
+
+class ShardedInfiniteSampler(Sampler):
+ def __init__(
+ self,
+ *,
+ sample_count: int,
+ shuffle: bool = False,
+ seed: int = 0,
+ start: Optional[int] = None,
+ step: Optional[int] = None,
+ advance: int = 0,
+ use_new_shuffle_tensor_slice: bool = False,
+ ):
+ self._sample_count = sample_count
+ self._seed = seed
+ self._shuffle = shuffle
+ rank, world_size = get_dist_info()
+ self._start = rank if start is None else start
+ self._step = world_size if step is None else step
+ self._advance = advance
+ self._iter_count = 0
+ self._shuffle_tensor_slice_fn = (
+ _new_shuffle_tensor_slice
+ if use_new_shuffle_tensor_slice
+ else _shuffle_tensor_slice
+ )
+
+ def __iter__(self):
+ iter_count = self._advance // self._sample_count
+ if iter_count > 0:
+ self._advance -= iter_count * self._sample_count
+ self._iter_count += iter_count
+
+ if self._shuffle:
+ iterator = self._shuffled_iterator()
+ else:
+ iterator = self._iterator()
+
+ yield from itertools.islice(iterator, self._advance, None)
+
+ def _iterator(self):
+ assert not self._shuffle
+
+ while True:
+ iterable = range(self._sample_count)
+ yield from itertools.islice(iterable, self._start, None, self._step)
+
+ def _shuffled_iterator(self):
+ assert self._shuffle
+
+ # Instantiate a generator here (rather than in the ctor) to be keep the class
+ # picklable (requirement of mp.spawn)
+ generator = torch.Generator()
+
+ # Always shuffle everything first
+ generator.manual_seed(self._seed)
+ dtype = _get_torch_dtype(self._sample_count)
+ perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator)
+
+ while True:
+ # Re-seed on each iteration to allow skipping whole permutations
+ seed = _make_seed(self._seed, self._start, self._iter_count)
+ generator.manual_seed(seed)
+
+ iterable = self._shuffle_tensor_slice_fn(
+ tensor=perm, start=self._start, step=self._step, generator=generator
+ )
+ yield from iterable
+ self._iter_count += 1
+
+
+class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler):
+ """A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler"""
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ if not self.drop_last and len(self.dataset) % self.num_replicas != 0:
+ # some ranks may have less samples, that's fine
+ if self.rank >= len(self.dataset) % self.num_replicas:
+ self.num_samples -= 1
+ self.total_size = len(self.dataset)
+
+
+class DatasetFromSampler(torch.utils.data.Dataset):
+ """Dataset to create indexes from `Sampler`.
+
+ Args:
+ sampler: PyTorch sampler
+ """
+
+ def __init__(self, sampler: Sampler):
+ """Initialisation for DatasetFromSampler."""
+ self.sampler = sampler
+ self.sampler_list = None
+
+ def __getitem__(self, index: int):
+ """Gets element of the dataset.
+
+ Args:
+ index: index of the element in the dataset
+
+ Returns:
+ Single element by index
+ """
+ if self.sampler_list is None:
+ self.sampler_list = list(self.sampler)
+ return self.sampler_list[index]
+
+ def __len__(self) -> int:
+ """
+ Returns:
+ int: length of the dataset
+ """
+ return len(self.sampler)
+
+
+class DistributedSamplerWrapper(torch.utils.data.DistributedSampler):
+ """
+ Wrapper over `Sampler` for distributed training
+ Allows you to use any sampler in distributed mode.
+
+ It is especially useful in conjunction with
+ `torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSamplerWrapper instance as a DataLoader
+ sampler, and load a subset of subsampled data of the original dataset
+ that is exclusive to it.
+
+ .. note::
+ Sampler is assumed to be of constant size.
+ """
+
+ def __init__(
+ self,
+ sampler,
+ num_replicas: Optional[int] = None,
+ rank: Optional[int] = None,
+ shuffle: bool = True,
+ ):
+ """
+
+ Args:
+ sampler: Sampler used for subsampling
+ num_replicas (int, optional): Number of processes participating in
+ distributed training
+ rank (int, optional): Rank of the current process
+ within ``num_replicas``
+ shuffle (bool, optional): If true (default),
+ sampler will shuffle the indices
+ """
+ super(DistributedSamplerWrapper, self).__init__(
+ DatasetFromSampler(sampler),
+ num_replicas=num_replicas,
+ rank=rank,
+ shuffle=shuffle,
+ )
+ self.sampler = sampler
+
+ def __iter__(self):
+ self.dataset = DatasetFromSampler(self.sampler)
+ indexes_of_indexes = super().__iter__()
+ subsampler_indexes = self.dataset
+ return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes))
diff --git a/unik3d/datasets/scannet.py b/unik3d/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..178c601df320a4b7b9ca9136ee0b00f650d4812a
--- /dev/null
+++ b/unik3d/datasets/scannet.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class ScanNet(SequenceDataset):
+ min_depth = 0.005
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "test.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["ScanNetS.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/scannetpp.py b/unik3d/datasets/scannetpp.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebd1f5b77699b486cec9e61137341b63cf03278c
--- /dev/null
+++ b/unik3d/datasets/scannetpp.py
@@ -0,0 +1,98 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class ScanNetpp(SequenceDataset):
+ min_depth = 0.001
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "val_iphone.txt"
+ train_split = "train_iphone.txt"
+ sequences_file = "sequences_iphone_clean.json"
+ hdf5_paths = [f"ScanNetpp_viz.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
+
+
+class ScanNetpp_F(SequenceDataset):
+ min_depth = 0.001
+ max_depth = 10.0
+ depth_scale = 1000.0
+ train_split = "train.txt"
+ test_split = "val_split.txt"
+
+ sequences_file = "sequences_split.json"
+ hdf5_paths = [f"ScanNetpp_F.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=(
+ decode_fields if not test_mode else [*decode_fields, "points"]
+ ),
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/sequence_dataset.py b/unik3d/datasets/sequence_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c0f383f43a88e3c08e36e3cc9827a20cdeaffcc
--- /dev/null
+++ b/unik3d/datasets/sequence_dataset.py
@@ -0,0 +1,301 @@
+import json
+import os
+from functools import partial
+from typing import Any, Dict, Tuple
+
+import h5py
+import numpy as np
+import tables
+import torch
+import torchvision.transforms.v2.functional as TF
+
+from unik3d.datasets.base_dataset import BaseDataset
+from unik3d.datasets.utils import DatasetFromList
+from unik3d.datasets.utils_decode import (decode_camera, decode_depth,
+ decode_flow, decode_K, decode_mask,
+ decode_numpy, decode_rgb,
+ decode_tensor)
+from unik3d.utils.distributed import is_main_process
+
+
+class SequenceDataset(BaseDataset):
+ DECODE_FNS = {
+ "image": partial(decode_rgb, name="image"),
+ "points": partial(decode_numpy, name="points"),
+ "K": partial(decode_K, name="camera"),
+ "camera_params": partial(decode_camera, name="camera"),
+ "cam2w": partial(decode_tensor, name="cam2w"),
+ "depth": partial(decode_depth, name="depth"),
+ "flow_fwd": partial(decode_flow, name="flow_fwd"),
+ "flow_bwd": partial(decode_flow, name="flow_bwd"),
+ "flow_fwd_mask": partial(decode_mask, name="flow_fwd_mask"),
+ "flow_bwd_mask": partial(decode_mask, name="flow_bwd_mask"),
+ }
+ default_fps = 5
+
+ def __init__(
+ self,
+ image_shape: Tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: Dict[str, Any],
+ shape_constraints: Dict[str, Any],
+ resize_method: str,
+ mini: float,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ shape_constraints=shape_constraints,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.num_frames = num_frames
+ self.original_num_frames = num_frames
+ self.decode_fields = decode_fields
+ self.inplace_fields = inplace_fields
+ self.fps = self.default_fps
+ self.fps_range = kwargs.get("fps_range", None)
+ if self.fps_range is not None:
+ self.fps_range[1] = min(self.default_fps, self.fps_range[1])
+
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii").strip()
+ sequences = np.array(h5file[self.sequences_file]).tostring().decode("ascii")
+ sequences = json.loads(sequences)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ if len(line.strip().split(" ")) == 1:
+ print(line)
+ continue
+ sequence_name, num_samples = line.strip().split(" ")
+ dataset.append(
+ {
+ "sequence_name": sequence_name,
+ "num_samples": int(num_samples),
+ "chunk_idx": 0,
+ }
+ )
+
+ # filter dataset based on attr "invalid_sequences"
+ invalid_sequences = getattr(self, "invalid_sequences", [])
+ dataset = [
+ sample
+ for sample in dataset
+ if sample["sequence_name"] not in invalid_sequences
+ ]
+
+ self.dataset = DatasetFromList(dataset)
+ self.sequences = DatasetFromList(
+ [sequences[sample["sequence_name"]] for sample in dataset]
+ )
+ self.log_load_dataset()
+
+ def get_random_idxs(self, num_samples_sequence):
+ if self.num_frames == 1:
+ return [np.random.randint(0, num_samples_sequence)], 0
+
+ # Check if we can satisfy the required number of frames
+ if self.num_frames > num_samples_sequence:
+ raise ValueError(
+ "Cannot sample more frames than available in the sequence."
+ )
+
+ # Restrict FPS range to be within default FPS
+ min_fps, max_fps = self.fps_range
+ max_fps = min(max_fps, self.default_fps)
+ if min_fps > self.default_fps:
+ sampled_fps = self.default_fps
+ else:
+ # Compute minimal viable FPS
+ min_required_fps = (
+ self.num_frames / num_samples_sequence
+ ) * self.default_fps
+ min_fps = max(min_fps, min_required_fps)
+
+ # Sample an FPS from the viable range
+ sampled_fps = np.random.uniform(min_fps, max_fps)
+
+ # Compute the stride based on the sampled FPS
+ stride = self.default_fps / sampled_fps
+ max_start_index = num_samples_sequence - int(stride * (self.num_frames - 1))
+
+ # Ensure a valid starting position
+ if max_start_index <= 0:
+ raise ValueError(
+ "No valid start position allows sampling num_frames with the chosen FPS."
+ )
+
+ start_index = np.random.randint(0, max_start_index + 1)
+
+ # Compute indices based on the sampled FPS
+ indices = [int(start_index + i * stride) for i in range(self.num_frames)]
+
+ return indices, np.random.randint(0, len(indices))
+
+ def get_test_idxs(self, num_samples_sequence, keyframe_idx):
+ if self.num_frames == 1:
+ return [
+ keyframe_idx if keyframe_idx is not None else num_samples_sequence // 2
+ ], 0
+
+ if self.num_frames == -1:
+ cap_idxs = min(32, num_samples_sequence) # CAP 32 images
+ idxs = list(
+ range(max(0, num_samples_sequence - cap_idxs), num_samples_sequence, 1)
+ )
+ return idxs, keyframe_idx
+
+ # pick closest keyframe_idx st they are around it or capped by the 0 and max num_samples_sequence
+ keyframe_idx = (
+ keyframe_idx if keyframe_idx is not None else num_samples_sequence - 1
+ )
+ excess_tail = 0 - min(0, keyframe_idx - self.num_frames // 2)
+ excess_head = (
+ max(num_samples_sequence, keyframe_idx + (self.num_frames - 1) // 2)
+ - num_samples_sequence
+ )
+ start = keyframe_idx - self.num_frames // 2 + excess_tail - excess_head
+ end = keyframe_idx + (self.num_frames - 1) // 2 + excess_head - excess_tail
+ idxs = list(range(start, 1 + end))
+
+ return idxs, idxs.index(keyframe_idx)
+
+ def get_single_sequence(self, idx):
+ self.num_frames = self.original_num_frames
+ # sequence_name = self.dataset[idx]["sequence_name"]
+ sample = self.sequences[idx]
+ chunk_idx = int(sample.get("chunk_idx", 0))
+ h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx])
+
+ num_samples_sequence = len(sample["image"])
+ if self.num_frames > 0 and num_samples_sequence < self.num_frames:
+ raise IndexError(f"Sequence {idx} has less than {self.num_frames} frames")
+ keyframe_idx = None
+
+ if not self.test_mode:
+ idxs, keyframe_idx = self.get_random_idxs(num_samples_sequence)
+ else:
+ idxs, keyframe_idx = self.get_test_idxs(
+ num_samples_sequence, sample.get("keyframe_idx", None)
+ )
+
+ self.num_frames = len(idxs)
+ results = {}
+ results = self.pre_pipeline(results)
+ results["sequence_fields"] = [(i, 0) for i in range(self.num_frames)]
+ results["keyframe_idx"] = keyframe_idx
+ with tables.File(
+ h5_path,
+ mode="r",
+ libver="latest",
+ swmr=True,
+ ) as h5file_chunk:
+
+ for i, j in enumerate(idxs):
+ results[(i, 0)] = {
+ k: v.copy() for k, v in results.items() if "fields" in k
+ }
+ for inplace_field in self.inplace_fields:
+ inplace_field_ = inplace_field.replace("intrinsics", "K")
+ inplace_field_ = inplace_field_.replace("extrinsics", "cam2w")
+ if inplace_field_ == "cam2w":
+ # take care of missing cam2w -> assume 1 frame and assume identity
+ if inplace_field_ not in sample:
+ sample[inplace_field] = [None] * num_samples_sequence
+ if sample[inplace_field][j] is None:
+ sample[inplace_field][j] = np.eye(4, dtype=np.float32)
+ results = self.DECODE_FNS[inplace_field_](
+ results, sample[inplace_field][j], idx=i, sample=sample, j=j
+ )
+
+ for i, j in enumerate(idxs):
+ for decode_field in self.decode_fields:
+ results = self.DECODE_FNS[decode_field](
+ results,
+ h5file_chunk,
+ sample[decode_field][j],
+ idx=i,
+ depth_scale=self.depth_scale,
+ )
+
+ results["filename"] = sample["image"][j]
+
+ results = self.preprocess(results)
+ if not self.test_mode:
+ results = self.augment(results)
+ results = self.postprocess(results)
+ return results
+
+ def preprocess(self, results):
+ results = self.replicate(results)
+ for i, seq in enumerate(results["sequence_fields"]):
+ results[seq] = self.resizer(results[seq])
+ self.resizer.ctx = None if self.num_copies > 1 else self.resizer.ctx
+ num_pts = torch.count_nonzero(results[seq]["depth"] > 0)
+ if num_pts < 50:
+ raise IndexError(f"Too few points in depth map ({num_pts})")
+
+ for key in results[seq].get("image_fields", ["image"]):
+ results[seq][key] = results[seq][key].to(torch.float32) / 255
+
+ # update fields common in sequence
+ for key in [
+ "image_fields",
+ "gt_fields",
+ "mask_fields",
+ "camera_fields",
+ ]:
+ if key in results[(0, 0)]:
+ results[key] = results[(0, 0)][key]
+
+ results = self.pack_batch(results)
+ return results
+
+ def postprocess(self, results):
+ # # normalize after because color aug requires [0,255]?
+ for key in results.get("image_fields", ["image"]):
+ results[key] = TF.normalize(results[key], **self.normalization_stats)
+ results = self.filler(results)
+ results = self.unpack_batch(results)
+ results = self.masker(results)
+ results = self.collecter(results)
+ return results
+
+ def __getitem__(self, idx):
+ try:
+ if isinstance(idx, (list, tuple)):
+ results = [self.get_single_sequence(i) for i in idx]
+ else:
+ results = self.get_single_sequence(idx)
+ except Exception as e:
+ print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}")
+ idx = np.random.randint(0, len(self.dataset))
+ results = self[idx]
+ return results
+
+ def log_load_dataset(self):
+ if is_main_process():
+ info = f"Loaded {self.__class__.__name__} with {sum([len(x['image']) for x in self.sequences])} images in {len(self)} sequences."
+ print(info)
diff --git a/unik3d/datasets/sintel.py b/unik3d/datasets/sintel.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b2d0eb34986c81609f42b1e14fa7dd2bd29d7df
--- /dev/null
+++ b/unik3d/datasets/sintel.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Sintel(SequenceDataset):
+ min_depth = 0.001
+ max_depth = 1000.0
+ depth_scale = 1000.0
+ test_split = "training.txt"
+ train_split = "training.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["Sintel.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/sunrgbd.py b/unik3d/datasets/sunrgbd.py
new file mode 100644
index 0000000000000000000000000000000000000000..43cc98a4bb25511501010195dae195a0396a6d91
--- /dev/null
+++ b/unik3d/datasets/sunrgbd.py
@@ -0,0 +1,73 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class SUNRGBD(ImageDataset):
+ min_depth = 0.005
+ max_depth = 8.0
+ depth_scale = 1000.0
+ test_split = "alltest.txt"
+ train_split = "alltrain.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["SUNRGB.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
diff --git a/unik3d/datasets/synscapes.py b/unik3d/datasets/synscapes.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff078984285ca1172622d516becbee8248bf0e5
--- /dev/null
+++ b/unik3d/datasets/synscapes.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Synscapes(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 1000.0
+ depth_scale = 256.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"Synscapes.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/tartanair.py b/unik3d/datasets/tartanair.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcd24d4a200f4800ebb51f8e9f742215bc541a46
--- /dev/null
+++ b/unik3d/datasets/tartanair.py
@@ -0,0 +1,51 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class TartanAir(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 512.0
+ depth_scale = 1000.0
+ default_fps = 15
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["TartanAir.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/taskonomy.py b/unik3d/datasets/taskonomy.py
new file mode 100644
index 0000000000000000000000000000000000000000..69880eacce0289496d9522b233f061feba84f0b3
--- /dev/null
+++ b/unik3d/datasets/taskonomy.py
@@ -0,0 +1,91 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class Taskonomy(ImageDataset):
+ min_depth = 0.005
+ max_depth = 15.0
+ depth_scale = 512.0
+ test_split = "val.txt"
+ train_split = "train_clean.txt"
+ intrisics_file = "intrinsics.json"
+ hdf5_paths = ["Taskonomy.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f:
+ # f.write(txt_string)
+ # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f:
+ # json.dump(intrinsics, f)
+
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename, chunk_idx = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val, chunk_idx]
+ dataset.append(sample)
+ h5file.close()
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ if self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=0.01)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def get_mapper(self):
+ return {
+ "image_filename": 0,
+ "depth_filename": 1,
+ "K": 2,
+ }
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/tat_rmvd.py b/unik3d/datasets/tat_rmvd.py
new file mode 100644
index 0000000000000000000000000000000000000000..08072cb2d4873afebd7f6efc42b40f6edadbf6d4
--- /dev/null
+++ b/unik3d/datasets/tat_rmvd.py
@@ -0,0 +1,63 @@
+import json
+import os
+from copy import deepcopy
+from typing import Any
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.pipelines import AnnotationMask, KittiCrop
+from unik3d.datasets.sequence_dataset import SequenceDataset
+from unik3d.datasets.utils import DatasetFromList
+from unik3d.utils import identity
+
+
+class TATRMVD(SequenceDataset):
+ min_depth = 0.001
+ max_depth = 50.0
+ depth_scale = 1000.0
+ default_fps = 6
+ test_split = "test.txt"
+ train_split = "test.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["tanks_and_temples_rmvd.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["si"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [2] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/theo.py b/unik3d/datasets/theo.py
new file mode 100644
index 0000000000000000000000000000000000000000..9be1b2f3f245a98e582a4359398acb102a9d5d51
--- /dev/null
+++ b/unik3d/datasets/theo.py
@@ -0,0 +1,66 @@
+from typing import Any
+
+import torch
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Theo(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ default_fps = 5
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["THEO.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["camera_params", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def preprocess(self, results):
+ self.resizer.ctx = None
+ for i, seq in enumerate(results["sequence_fields"]):
+ # Create a mask where the distance from the center is less than H/2
+ H, W = results[seq]["image"].shape[-2:]
+ x = torch.linspace(-(W - 1) / 2, (W - 1) / 2, W)
+ y = torch.linspace(-(H - 1) / 2, (H - 1) / 2, H)
+ xv, yv = torch.meshgrid(x, y, indexing="xy")
+ distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W)
+ results[seq]["validity_mask"] = distance_from_center < (H - 1) / 2
+
+ return super().preprocess(results)
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/unrealstereo4k.py b/unik3d/datasets/unrealstereo4k.py
new file mode 100644
index 0000000000000000000000000000000000000000..64520dff8ae473c2027340663cd759eed4acb75a
--- /dev/null
+++ b/unik3d/datasets/unrealstereo4k.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class UnrealStereo4K(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 200.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"UnrealStereo4K.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/urbansyn.py b/unik3d/datasets/urbansyn.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3f72c77cf7fc1ca1981cd7cb09aa97f7701cb2a
--- /dev/null
+++ b/unik3d/datasets/urbansyn.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class UrbanSyn(SequenceDataset):
+ min_depth = 0.1
+ max_depth = 1000.0
+ depth_scale = 256.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"UrbanSyn.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/utils.py b/unik3d/datasets/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2994594d53f48665c72096d870be56e3796fbdb8
--- /dev/null
+++ b/unik3d/datasets/utils.py
@@ -0,0 +1,229 @@
+import copy
+import multiprocessing as mp
+import pickle
+import random
+from collections import defaultdict
+from typing import Any, Dict, List
+
+import numpy as np
+import torch
+import torch.utils.data
+
+from unik3d.utils.distributed import (all_gather, get_local_rank,
+ get_local_size, get_rank, get_world_size)
+
+
+class ConcatDataset(torch.utils.data.ConcatDataset):
+ def __init__(self, datasets, shape_constraints: dict[str, list[int]] = {}):
+ super().__init__(datasets)
+
+ self.sample = shape_constraints["sample"]
+ self.shape_mult = shape_constraints["shape_mult"]
+ self.ratio_bounds = shape_constraints["ratio_bounds"]
+ self.pixels_max = float(shape_constraints["pixels_max"])
+ self.pixels_min = float(shape_constraints["pixels_min"])
+
+ self.height_min = shape_constraints["height_min"]
+ self.width_min = shape_constraints["width_min"]
+
+ def sample_shape(self):
+ if not self.sample:
+ return
+ # 1: sample image ratio
+ ratio = np.random.uniform(*self.ratio_bounds)
+ pixels_min = self.pixels_min // (self.shape_mult * self.shape_mult)
+ pixels_max = self.pixels_max // (self.shape_mult * self.shape_mult)
+ # 2: sample image height or width, if ratio > 1 or < 1
+ if ratio > 1:
+ height_min = max(self.height_min, np.sqrt(pixels_min / ratio))
+ height = np.random.uniform(height_min, np.sqrt(pixels_max / ratio))
+ width = height * ratio
+ else:
+ width_min = max(self.width_min, np.sqrt(pixels_min * ratio))
+ width = np.random.uniform(width_min, np.sqrt(pixels_max * ratio))
+ height = width / ratio
+ # 3: get final shape based on the shape_mult
+ shape = [int(height) * self.shape_mult, int(width) * self.shape_mult]
+ for dataset in self.datasets:
+ setattr(dataset, "image_shape", shape)
+ setattr(dataset.resizer, "image_shape", shape)
+
+ def __getitem__(self, idxs):
+ self.sample_shape()
+ samples = [super(ConcatDataset, self).__getitem__(idx) for idx in idxs]
+ return samples
+
+
+def _paddings(image_shape, network_shape):
+ cur_h, cur_w = image_shape
+ h, w = network_shape
+ pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2
+ pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2
+ return pad_left, pad_right, pad_top, pad_bottom
+
+
+def collate_fn(in_data: List[List[Dict[str, Any]]], is_batched: bool = True):
+ out_data = defaultdict(list)
+ img_metas = []
+ in_data = in_data[0] if is_batched else in_data
+
+ # get max_shape and paddings
+ shapes = [tensor.shape[-2:] for x in in_data for tensor in x["depth"].values()]
+ max_shape_tuple = tuple(max(elements) for elements in zip(*shapes))
+ paddings = [
+ [
+ _paddings(tensor.shape[-2:], max_shape_tuple)
+ for tensor in x["depth"].values()
+ ]
+ for x in in_data
+ ]
+
+ for x in in_data: # here iter over batches
+ padding = paddings.pop(0)
+ for k, v in x.items():
+ if "img_metas" not in k:
+ values = list(v.values())
+ v = torch.cat(values)
+ out_data[k].append(v)
+ else:
+ v["depth_paddings"] = padding
+ img_metas.append(v)
+
+ # calculate all valid_samples in batch as depth_mask.sum().long() and then append to out_data as "valid_samples"
+ num_valid_samples = [
+ x.flatten(1).sum(dim=1).long() for i, x in enumerate(out_data["depth_mask"])
+ ]
+ out_data["num_valid_depth"] = [
+ torch.stack([x, sum(num_valid_samples) * torch.ones_like(x)], dim=-1)
+ for x in num_valid_samples
+ ]
+ num_valid_samples = [
+ x.flatten(1).sum(dim=1).long() for i, x in enumerate(out_data["validity_mask"])
+ ]
+ out_data["num_valid_pix"] = [
+ torch.stack([x, sum(num_valid_samples) * torch.ones_like(x)], dim=-1)
+ for x in num_valid_samples
+ ]
+ output_dict = {
+ "data": {k: torch.stack(v) for k, v in out_data.items()},
+ "img_metas": img_metas,
+ }
+ if "camera" in output_dict["data"]:
+ output_dict["data"]["camera"] = output_dict["data"]["camera"].reshape(
+ *output_dict["data"]["image"].shape[:2]
+ )
+ return output_dict
+
+
+def local_scatter(array: list[Any]):
+ if get_world_size() == 1:
+ return array[0]
+ if get_local_rank() == 0:
+ assert len(array) == get_local_size()
+ all_gather(array)
+ else:
+ all_data = all_gather(None)
+ array = all_data[get_rank() - get_local_rank()]
+ return array[get_local_rank()]
+
+
+class DatasetFromList(torch.utils.data.Dataset): # type: ignore
+ """Wrap a list to a torch Dataset.
+
+ We serialize and wrap big python objects in a torch.Dataset due to a
+ memory leak when dealing with large python objects using multiple workers.
+ See: https://github.com/pytorch/pytorch/issues/13246
+ """
+
+ def __init__(self, lst: List[Any], deepcopy: bool = False, serialize: bool = True):
+ self._copy = deepcopy
+ self._serialize = serialize
+
+ def _serialize(data: Any):
+ buffer = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL)
+ return torch.frombuffer(buffer, dtype=torch.uint8)
+
+ if self._serialize:
+ # load only on 0th rank
+ if get_local_rank() == 0:
+ _lst = [_serialize(x) for x in lst]
+ self._addr = torch.cumsum(
+ torch.tensor([len(x) for x in _lst], dtype=torch.int64), dim=0
+ )
+ self._lst = torch.concatenate(_lst)
+ # Move data to shared memory, obtain a handle to send to each local worker.
+ handles = [None] + [
+ bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst)))
+ for _ in range(get_local_size() - 1)
+ ]
+ else:
+ handles = None
+
+ # Each worker receives the handle from local leader (rank 0)
+ # then materialize the tensor from shared memory
+ handle = local_scatter(handles)
+ if get_local_rank() > 0:
+ self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle)
+
+ else:
+ self._lst = lst
+
+ def __len__(self) -> int:
+ if self._serialize:
+ return len(self._addr)
+ return len(self._lst)
+
+ def __getitem__(self, idx: int) -> Any:
+ if self._serialize:
+ start_addr = 0 if idx == 0 else self._addr[idx - 1]
+ end_addr = self._addr[idx]
+ bytes_ = memoryview(self._lst[start_addr:end_addr].numpy())
+ return pickle.loads(bytes_)
+ if self._copy:
+ return copy.deepcopy(self._lst[idx])
+
+ return self._lst[idx]
+
+
+def get_weights(
+ train_datasets: dict[str, torch.utils.data.Dataset], sampling: dict[str, float]
+) -> torch.Tensor:
+ from .image_dataset import ImageDataset
+ from .sequence_dataset import SequenceDataset
+
+ weights = []
+ num_samples = 0
+ info_weights = {}
+ for dataset_name, dataset in train_datasets.items():
+ assert (
+ dataset_name in sampling
+ ), f"Dataset {dataset_name} not found in {sampling.keys()}"
+
+ if isinstance(dataset, ImageDataset):
+ # sum of all samples has weight as in sampling s.t. sampling dataset in general
+ # is as in sampling inside is uniform
+ weight = sampling[dataset_name] / len(dataset)
+ weights.append(torch.full((len(dataset),), weight).double())
+ num_samples += len(dataset)
+
+ elif isinstance(dataset, SequenceDataset):
+ # local weight is num_samples, but global must be as in sampling
+ # hence is num_samples / (sum num_samples / sampling[dataset_name])
+ # s.t. sampling anything from the dataset is
+ # sum(num_samples / (sum num_samples / sampling[dataset_name]))
+ # -> sampling[dataset_name]
+ numerator = [int(data["num_samples"]) for data in dataset.dataset]
+ weights.append(
+ sampling[dataset_name]
+ * torch.tensor(numerator).double()
+ / sum(numerator)
+ )
+ num_samples += sum(numerator)
+
+ else:
+ weight = sampling[dataset_name] / len(dataset)
+ weights.append(torch.full((len(dataset),), weight).double())
+
+ info_weights[dataset_name] = weights[-1][-1]
+
+ return torch.cat(weights), num_samples
diff --git a/unik3d/datasets/utils_decode.py b/unik3d/datasets/utils_decode.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a628a708c9cfa2b1ac18614546dcdeb23f3678a
--- /dev/null
+++ b/unik3d/datasets/utils_decode.py
@@ -0,0 +1,122 @@
+import io
+
+import cv2
+import numpy as np
+import torch
+import torchvision
+import torchvision.transforms.v2.functional as TF
+from PIL import Image
+
+from unik3d.utils.camera import (EUCM, MEI, OPENCV, BatchCamera, Fisheye624,
+ Pinhole, Spherical)
+
+
+def decode_depth(results, h5file, value, idx, depth_scale, name="depth", **kwargs):
+ file = h5file.get_node("/" + value).read()
+ decoded_data = Image.open(io.BytesIO(file))
+ decoded_data = TF.pil_to_tensor(decoded_data).squeeze()
+
+ if decoded_data.ndim == 3: # 24 channel loading
+ decoded_channels = [
+ (decoded_data[0] & 0xFF).to(torch.int32),
+ (decoded_data[1] & 0xFF).to(torch.int32),
+ (decoded_data[2] & 0xFF).to(torch.int32),
+ ]
+ # Reshape and extract the original depth map
+ decoded_data = (
+ decoded_channels[0]
+ | (decoded_channels[1] << 8)
+ | (decoded_channels[2] << 16)
+ )
+
+ decoded_data = decoded_data.to(torch.float32)
+ results.get("gt_fields", set()).add(name)
+ results[(idx, 0)].get("gt_fields", set()).add(name)
+ results[f"{name}_ori_shape"] = decoded_data.shape
+ results[(idx, 0)][name] = (
+ decoded_data.view(1, 1, *decoded_data.shape).contiguous() / depth_scale
+ )
+ return results
+
+
+def decode_numpy(results, h5file, value, idx, name="points", **kwargs):
+ file = h5file.get_node("/" + value).read()
+ decoded_data = np.load(io.BytesIO(file), allow_pickle=False)
+ decoded_data = torch.from_numpy(decoded_data).to(torch.float32)
+ if decoded_data.ndim > 2:
+ decoded_data = decoded_data.permute(2, 0, 1)
+ results.get("gt_fields", set()).add(name)
+ results[(idx, 0)].get("gt_fields", set()).add(name)
+ results[(idx, 0)][name] = decoded_data.unsqueeze(0)
+ return results
+
+
+def decode_tensor(results, value, idx, name, **kwargs):
+ results.get("camera_fields", set()).add(name)
+ results[(idx, 0)].get("camera_fields", set()).add(name)
+ results[(idx, 0)][name] = torch.tensor(value).unsqueeze(0)
+ return results
+
+
+def decode_camera(results, value, idx, name, sample, j, **kwargs):
+ results.get("camera_fields", set()).add(name)
+ results[(idx, 0)].get("camera_fields", set()).add(name)
+ camera = eval(sample["camera_model"][j])(params=torch.tensor(value).unsqueeze(0))
+ results[(idx, 0)][name] = BatchCamera.from_camera(camera)
+ return results
+
+
+def decode_K(results, value, idx, name, **kwargs):
+ results.get("camera_fields", set()).add(name)
+ results[(idx, 0)].get("camera_fields", set()).add(name)
+ camera = Pinhole(K=torch.tensor(value).unsqueeze(0))
+ results[(idx, 0)][name] = BatchCamera.from_camera(camera)
+ return results
+
+
+def decode_mask(results, h5file, value, idx, name, **kwargs):
+ file = h5file.get_node("/" + value).read()
+ mask = torchvision.io.decode_image(torch.from_numpy(file)).bool().squeeze()
+ results.get("mask_fields", set()).add(name)
+ results[(idx, 0)].get("mask_fields", set()).add(name)
+ results[f"{name}_ori_shape"] = mask.shape[-2:]
+ results[(idx, 0)][name] = mask.view(1, 1, *mask.shape).contiguous()
+ return results
+
+
+def decode_rgb(results, h5file, value, idx, name="image", **kwargs):
+ file = h5file.get_node("/" + value).read()
+ image = (
+ torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze()
+ )
+ results.get("image_fields", set()).add(name)
+ results[(idx, 0)].get("image_fields", set()).add(name)
+ results[f"{name}_ori_shape"] = image.shape[-2:]
+ if image.ndim == 2:
+ image = image.unsqueeze(0).repeat(3, 1, 1)
+ results[(idx, 0)][name] = image.unsqueeze(0)
+ return results
+
+
+def decode_flow(results, h5file, value, idx, name, **kwargs):
+ file = h5file.get_node("/" + value).read()
+ image = (
+ torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze()
+ )
+ decoded_channels = [
+ (image[0] & 0xFF).to(torch.int16),
+ (image[1] & 0xFF).to(torch.int16),
+ (image[2] & 0xFF).to(torch.int16),
+ ]
+
+ # Reshape and extract the original 2-channel flow map
+ flow = torch.zeros((2, image.shape[1], image.shape[2]), dtype=torch.int16)
+ flow[0] = (decoded_channels[0] | decoded_channels[1] << 8) & 0xFFF
+ flow[1] = (decoded_channels[1] >> 4 | decoded_channels[2] << 4) & 0xFFF
+
+ results.get("gt_fields", set()).add(name)
+ results[(idx, 0)].get("gt_fields", set()).add(name)
+ results[f"{name}_ori_shape"] = flow.shape[-2:]
+ flow = flow.unsqueeze(0).contiguous().float()
+ results[(idx, 0)][name] = (0.5 + flow) / 4095.0 * 2 - 1
+ return results
diff --git a/unik3d/datasets/vkitti.py b/unik3d/datasets/vkitti.py
new file mode 100644
index 0000000000000000000000000000000000000000..5864b068808d0049837cab16b158684065a824c2
--- /dev/null
+++ b/unik3d/datasets/vkitti.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class VKITTI(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 255.0
+ depth_scale = 256.0
+ test_split = "training.txt"
+ train_split = "training.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["VKITTI2.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["synthetic"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [0] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/void.py b/unik3d/datasets/void.py
new file mode 100644
index 0000000000000000000000000000000000000000..f77f7450979fa2aa9ef577e235ad3a854662eaec
--- /dev/null
+++ b/unik3d/datasets/void.py
@@ -0,0 +1,79 @@
+import json
+import os
+
+import h5py
+import numpy as np
+import torch
+
+from unik3d.datasets.image_dataset import ImageDataset
+from unik3d.datasets.utils import DatasetFromList
+
+
+class VOID(ImageDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 256.0
+ test_split = "void_val.txt"
+ train_split = "void_train.txt"
+ intrisics_file = "void_intrinsics.json"
+ hdf5_paths = ["void.hdf5"]
+
+ def __init__(
+ self,
+ image_shape,
+ split_file,
+ test_mode,
+ crop=None,
+ benchmark=False,
+ augmentations_db={},
+ normalize=True,
+ resize_method="hard",
+ mini=1.0,
+ **kwargs,
+ ):
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ **kwargs,
+ )
+ self.test_mode = test_mode
+
+ self.crop = crop
+ self.load_dataset()
+
+ def load_dataset(self):
+ h5file = h5py.File(
+ os.path.join(self.data_root, self.hdf5_paths[0]),
+ "r",
+ libver="latest",
+ swmr=True,
+ )
+ txt_file = np.array(h5file[self.split_file])
+ txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1
+ intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii")
+ intrinsics = json.loads(intrinsics)
+ h5file.close()
+ dataset = []
+ for line in txt_string.split("\n"):
+ image_filename, depth_filename = line.strip().split(" ")
+ intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3]
+ sample = [image_filename, depth_filename, intrinsics_val]
+ dataset.append(sample)
+
+ if not self.test_mode:
+ dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini)
+
+ self.dataset = DatasetFromList(dataset)
+ self.log_load_dataset()
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_copies
+ results["quality"] = [2] * self.num_copies
+ return results
diff --git a/unik3d/datasets/waymo.py b/unik3d/datasets/waymo.py
new file mode 100644
index 0000000000000000000000000000000000000000..63e34ed7c741df4942f4b6a8ac6f2ac3e4fe36a6
--- /dev/null
+++ b/unik3d/datasets/waymo.py
@@ -0,0 +1,50 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class Waymo(SequenceDataset):
+ min_depth = 0.05
+ max_depth = 70.0
+ depth_scale = 256.0
+ test_split = "validation.txt"
+ train_split = "training.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = [f"Waymo_viz.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [False] * self.num_frames * self.num_copies
+ results["synthetic"] = [False] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/datasets/wildrgbd.py b/unik3d/datasets/wildrgbd.py
new file mode 100644
index 0000000000000000000000000000000000000000..db863daa395a3b33c0c8617db467ac6b43e0b166
--- /dev/null
+++ b/unik3d/datasets/wildrgbd.py
@@ -0,0 +1,49 @@
+from typing import Any
+
+from unik3d.datasets.sequence_dataset import SequenceDataset
+
+
+class WildRGBD(SequenceDataset):
+ min_depth = 0.01
+ max_depth = 10.0
+ depth_scale = 1000.0
+ test_split = "train.txt"
+ train_split = "train.txt"
+ sequences_file = "sequences.json"
+ hdf5_paths = ["WildRGBD.hdf5"]
+
+ def __init__(
+ self,
+ image_shape: tuple[int, int],
+ split_file: str,
+ test_mode: bool,
+ normalize: bool,
+ augmentations_db: dict[str, Any],
+ resize_method: str,
+ mini: float = 1.0,
+ num_frames: int = 1,
+ benchmark: bool = False,
+ decode_fields: list[str] = ["image", "depth"],
+ inplace_fields: list[str] = ["K", "cam2w"],
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ image_shape=image_shape,
+ split_file=split_file,
+ test_mode=test_mode,
+ benchmark=benchmark,
+ normalize=normalize,
+ augmentations_db=augmentations_db,
+ resize_method=resize_method,
+ mini=mini,
+ num_frames=num_frames,
+ decode_fields=decode_fields,
+ inplace_fields=inplace_fields,
+ **kwargs,
+ )
+
+ def pre_pipeline(self, results):
+ results = super().pre_pipeline(results)
+ results["dense"] = [True] * self.num_frames * self.num_copies
+ results["quality"] = [1] * self.num_frames * self.num_copies
+ return results
diff --git a/unik3d/layers/__init__.py b/unik3d/layers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..96d3c6ddce2dda1eb482e876766c65ad8216bbf2
--- /dev/null
+++ b/unik3d/layers/__init__.py
@@ -0,0 +1,20 @@
+from .activation import GEGLU, SwiGLU
+from .attention import AttentionBlock, AttentionDecoderBlock, AttentionLayer
+from .grad_choker import GradChoker
+from .mlp import MLP
+from .positional_encoding import PositionEmbeddingSine
+from .upsample import ResUpsample, ResUpsampleBil, ResUpsampleSH
+
+__all__ = [
+ "SwiGLU",
+ "GEGLU",
+ "AttentionBlock",
+ "AttentionLayer",
+ "PositionEmbeddingSine",
+ "MLP",
+ "AttentionDecoderBlock",
+ "ResUpsample",
+ "ResUpsampleSH",
+ "ResUpsampleBil",
+ "GradChoker",
+]
diff --git a/unik3d/layers/activation.py b/unik3d/layers/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5787a340013ba59e2956b6b829f724d9cfb7fcc
--- /dev/null
+++ b/unik3d/layers/activation.py
@@ -0,0 +1,15 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SwiGLU(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, gates = x.chunk(2, dim=-1)
+ return x * F.silu(gates)
+
+
+class GEGLU(nn.Module):
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x, gates = x.chunk(2, dim=-1)
+ return x * F.gelu(gates)
diff --git a/unik3d/layers/attention.py b/unik3d/layers/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e4cec63e6202c9d515c3308ef550f6223776966
--- /dev/null
+++ b/unik3d/layers/attention.py
@@ -0,0 +1,378 @@
+from functools import partial
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+
+from .layer_scale import LayerScale
+from .mlp import MLP
+
+
+class SimpleAttention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ context_dim: int | None = None,
+ ):
+ super().__init__()
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.hidden_dim = dim
+ context_dim = context_dim or dim
+
+ self.kv = nn.Linear(context_dim, dim * 2, bias=False)
+ self.q = nn.Linear(dim, dim, bias=False)
+ self.norm_attnx = nn.LayerNorm(dim)
+ self.norm_attnctx = nn.LayerNorm(context_dim)
+ self.cosine = cosine
+ self.out = nn.Linear(dim, dim, bias=False)
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ context = x if context is None else context
+ x = self.norm_attnx(x)
+ context = self.norm_attnctx(context)
+ k, v = rearrange(
+ self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
+ ).unbind(dim=-1)
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+ if rope is not None:
+ q = rope(q)
+ k = rope(k)
+ else:
+ if pos_embed is not None:
+ pos_embed = rearrange(
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
+ )
+ q = q + pos_embed
+ if pos_embed_context is not None:
+ pos_embed_context = rearrange(
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
+ )
+ k = k + pos_embed_context
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+ )
+ x = rearrange(x, "b h n d -> b n (h d)")
+ x = self.out(x)
+ return x
+
+
+class AttentionBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ detach_query: bool = False,
+ residual_ls: bool = False,
+ ):
+ super().__init__()
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.hidden_dim = dim
+ context_dim = dim if context_dim is None else context_dim
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
+ self.kv = nn.Linear(context_dim, dim * 2, bias=False)
+ self.q = nn.Linear(dim, dim, bias=False)
+ self.norm_attnx = nn.LayerNorm(dim)
+ self.norm_attnctx = nn.LayerNorm(context_dim)
+ self.cosine = cosine
+ self.out = nn.Linear(dim, dim, bias=False)
+ self.ls1_1 = (
+ LayerScale(dim, layer_scale)
+ if layer_scale > 0.0 and not residual_ls
+ else nn.Identity()
+ )
+ self.ls1_2 = (
+ LayerScale(dim, layer_scale)
+ if layer_scale > 0.0 and residual_ls
+ else nn.Identity()
+ )
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+ self.detach_query = detach_query
+
+ def attn(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ if self.detach_query:
+ x = x.detach()
+ x = self.norm_attnx(x)
+ context = self.norm_attnctx(context)
+ k, v = rearrange(
+ self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
+ ).unbind(dim=-1)
+ q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+ if rope is not None:
+ q = rope(q.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
+ k = rope(k.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3)
+ else:
+ if pos_embed is not None:
+ pos_embed = rearrange(
+ pos_embed, "b n (h d) -> b h n d", h=self.num_heads
+ )
+ q = q + pos_embed
+ if pos_embed_context is not None:
+ pos_embed_context = rearrange(
+ pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads
+ )
+ k = k + pos_embed_context
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
+
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+ )
+ x = rearrange(x, "b h n d -> b n (h d)")
+ x = self.out(x)
+ return x
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ attn_bias: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ context = x if context is None else context
+ x = self.ls1_1(
+ self.attn(
+ x,
+ rope=rope,
+ rope_pos=rope_pos,
+ attn_bias=attn_bias,
+ context=context,
+ pos_embed=pos_embed,
+ pos_embed_context=pos_embed_context,
+ )
+ ) + self.ls1_2(x)
+ x = self.ls2(self.mlp(x)) + x
+ return x
+
+
+class AttentionLayer(nn.Module):
+ def __init__(
+ self,
+ num_blocks: int,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ detach_query: bool = False,
+ residual_ls: bool = False,
+ ):
+ super().__init__()
+ self.layers = nn.ModuleList(
+ [
+ AttentionBlock(
+ dim=dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ cosine=cosine,
+ gated=gated,
+ layer_scale=layer_scale,
+ context_dim=context_dim,
+ detach_query=detach_query,
+ residual_ls=residual_ls,
+ )
+ for _ in range(num_blocks)
+ ]
+ )
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ attn_bias: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ for layer in self.layers:
+ x = layer(
+ x,
+ context=context,
+ pos_embed=pos_embed,
+ pos_embed_context=pos_embed_context,
+ attn_bias=attn_bias,
+ rope=rope,
+ rope_pos=rope_pos,
+ )
+ return x
+
+
+class AttentionDecoderBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 4,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ cosine: bool = False,
+ gated: bool = False,
+ layer_scale: float = 1.0,
+ context_dim: int | None = None,
+ single_head_ca: bool = True,
+ ):
+ super().__init__()
+ self.dropout = dropout
+ self.num_heads = num_heads
+ self.hidden_dim = dim
+ self.single_head_ca = single_head_ca
+ context_dim = context_dim or dim
+ self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated)
+ self.kv_ca = nn.Linear(context_dim, dim * 2, bias=False)
+ self.q_ca = nn.Linear(dim, dim, bias=False)
+ self.kv_sa = nn.Linear(dim, dim * 2, bias=False)
+ self.q_sa = nn.Linear(dim, dim, bias=False)
+ self.norm_x_sa = nn.LayerNorm(dim)
+ self.norm_x_ca = nn.LayerNorm(dim)
+ self.norm_ctx_ca = nn.LayerNorm(context_dim)
+ self.cosine = cosine
+ self.out_ca = nn.Linear(dim, dim, bias=False)
+ self.out_sa = nn.Linear(dim, dim, bias=False)
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+ self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+ self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+
+ def cross_attn(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ num_heads = 1 if self.single_head_ca else self.num_heads
+ x = self.norm_x_ca(x)
+ context = self.norm_ctx_ca(context)
+ k, v = rearrange(
+ self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2
+ ).unbind(dim=-1)
+ q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads)
+
+ if rope is not None:
+ q = rope(q)
+ k = rope(k)
+ else:
+ if pos_embed is not None:
+ pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads)
+ q = q + pos_embed
+ if pos_embed_context is not None:
+ pos_embed_context = rearrange(
+ pos_embed_context, "b n (h d) -> b h n d", h=num_heads
+ )
+ k = k + pos_embed_context
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+ )
+ x = rearrange(x, "b h n d -> b n (h d)")
+ x = self.out_ca(x)
+ return x
+
+ def self_attn(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ x = self.norm_x_sa(x)
+ k, v = rearrange(
+ self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2
+ ).unbind(dim=-1)
+ q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads)
+
+ if rope is not None:
+ q = rope(q)
+ k = rope(k)
+ elif pos_embed is not None:
+ pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads)
+ q = q + pos_embed
+
+ if self.cosine:
+ q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim
+ x = F.scaled_dot_product_attention(
+ q, k, v, dropout_p=self.dropout, attn_mask=attn_bias
+ )
+ x = rearrange(x, "b h n d -> b n (h d)")
+ x = self.out_sa(x)
+ return x
+
+ def forward(
+ self,
+ x: torch.Tensor,
+ attn_bias: torch.Tensor | None = None,
+ context: torch.Tensor | None = None,
+ pos_embed: torch.Tensor | None = None,
+ pos_embed_context: torch.Tensor | None = None,
+ rope: nn.Module | None = None,
+ rope_pos: torch.Tensor | None = None,
+ ) -> torch.Tensor:
+ context = x if context is None else context
+ x = (
+ self.ls1(
+ self.cross_attn(
+ x,
+ rope=rope,
+ attn_bias=attn_bias,
+ context=context,
+ pos_embed=pos_embed,
+ pos_embed_context=pos_embed_context,
+ )
+ )
+ + x
+ )
+ x = (
+ self.ls2(
+ self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed)
+ )
+ + x
+ )
+ x = self.ls3(self.mlp(x)) + x
+ return x
diff --git a/unik3d/layers/convnext.py b/unik3d/layers/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..d186885fb8499b98f2fe09e5dd5150a2f9f14b04
--- /dev/null
+++ b/unik3d/layers/convnext.py
@@ -0,0 +1,80 @@
+import torch
+import torch.nn as nn
+
+
+class CvnxtBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ kernel_size=7,
+ layer_scale=1.0,
+ expansion=4,
+ dilation=1,
+ padding_mode: str = "zeros",
+ ):
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ groups=dim,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(dim)
+ self.pwconv1 = nn.Linear(dim, expansion * dim)
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(expansion * dim, dim)
+ self.gamma = (
+ nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1))
+ if layer_scale > 0.0
+ else 1.0
+ )
+ self.skip_add = nn.quantized.FloatFunctional()
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ return self.skip_add.add(self.gamma * x.permute(0, 3, 1, 2), input)
+
+
+class SimpleCvnxtBlock(nn.Module):
+ def __init__(
+ self,
+ dim,
+ output_dim=None,
+ kernel_size=7,
+ expansion=4,
+ dilation=1,
+ padding_mode: str = "zeros",
+ ):
+ super().__init__()
+ output_dim = output_dim if output_dim is not None else dim
+ self.dwconv = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ groups=dim,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ ) # depthwise conv
+ self.norm = nn.LayerNorm(dim)
+ self.pwconv1 = nn.Linear(dim, expansion * dim)
+ self.act = nn.GELU()
+ self.pwconv2 = nn.Linear(expansion * dim, output_dim)
+
+ def forward(self, x):
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.pwconv2(x)
+ return x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
diff --git a/unik3d/layers/drop_path.py b/unik3d/layers/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..781ff566500c923b1f199542b0c7dfb862a077ca
--- /dev/null
+++ b/unik3d/layers/drop_path.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+
+
+def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/unik3d/layers/grad_choker.py b/unik3d/layers/grad_choker.py
new file mode 100644
index 0000000000000000000000000000000000000000..12b70174f1922a3d7b962c0cba560d892ac80b60
--- /dev/null
+++ b/unik3d/layers/grad_choker.py
@@ -0,0 +1,25 @@
+import torch
+import torch.nn as nn
+from torch.autograd import Function
+
+
+class ChockerFunction(Function):
+ @staticmethod
+ def forward(ctx, x, alpha):
+ ctx.alpha = alpha
+ return x
+
+ @staticmethod
+ def backward(ctx, grad_output):
+ grad_input = grad_output * ctx.alpha
+ return grad_input, None
+
+
+class GradChoker(nn.Module):
+ def __init__(self, alpha):
+ super().__init__()
+ self.alpha = alpha
+
+ def forward(self, x):
+ alpha = torch.tensor(self.alpha, requires_grad=False, device=x.device)
+ return ChockerFunction.apply(x, alpha)
diff --git a/unik3d/layers/layer_scale.py b/unik3d/layers/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b6662490d7296725f103d1abf8790cac84d0f8
--- /dev/null
+++ b/unik3d/layers/layer_scale.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn as nn
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: float | torch.Tensor = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/unik3d/layers/misc.py b/unik3d/layers/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5d32030e066eaa8107e8ec060658318c675e8a3
--- /dev/null
+++ b/unik3d/layers/misc.py
@@ -0,0 +1,30 @@
+import torch
+import torch.nn as nn
+
+from .layer_scale import LayerScale
+
+
+class Addition(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ layer_scale: float | torch.Tensor = 1e-5,
+ ) -> None:
+ super().__init__()
+ self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return x + self.ls1(y)
+
+
+class Concat(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ layer_scale: float | torch.Tensor = 1e-5,
+ ) -> None:
+ super().__init__()
+ self.project = nn.Linear(2 * dim, dim)
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ return self.project(torch.cat([x, y], dim=-1))
diff --git a/unik3d/layers/mlp.py b/unik3d/layers/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f1b9d2e6d020c039a677f566f4eb925aab696d4
--- /dev/null
+++ b/unik3d/layers/mlp.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+from unik3d.utils.misc import default
+
+from .activation import SwiGLU
+
+
+class MLP(nn.Module):
+ def __init__(
+ self,
+ input_dim: int,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ gated: bool = False,
+ output_dim: int | None = None,
+ ):
+ super().__init__()
+ if gated:
+ expansion = int(expansion * 2 / 3)
+ hidden_dim = int(input_dim * expansion)
+ output_dim = default(output_dim, input_dim)
+ self.norm = nn.LayerNorm(input_dim)
+ self.proj1 = nn.Linear(input_dim, hidden_dim)
+ self.proj2 = nn.Linear(hidden_dim, output_dim)
+ self.act = nn.GELU() if not gated else SwiGLU()
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ x = self.norm(x)
+ x = self.proj1(x)
+ x = self.act(x)
+ x = self.proj2(x)
+ x = self.dropout(x)
+ return x
diff --git a/unik3d/layers/positional_encoding.py b/unik3d/layers/positional_encoding.py
new file mode 100644
index 0000000000000000000000000000000000000000..a76e93499eab94c4dd6fcb43c4baab1493f7c2cb
--- /dev/null
+++ b/unik3d/layers/positional_encoding.py
@@ -0,0 +1,303 @@
+from math import log, pi
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+class PositionEmbeddingSine(nn.Module):
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * pi
+ self.scale = scale
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if mask is None:
+ mask = torch.zeros(
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
+ )
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
+ )
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self, _repr_indent=4):
+ head = "Positional encoding " + self.__class__.__name__
+ body = [
+ "num_pos_feats: {}".format(self.num_pos_feats),
+ "temperature: {}".format(self.temperature),
+ "normalize: {}".format(self.normalize),
+ "scale: {}".format(self.scale),
+ ]
+ # _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
+
+
+class LearnedSinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ self.weights = nn.Parameter(torch.randn(half_dim))
+
+ def forward(self, x):
+ x = rearrange(x, "b -> b 1")
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
+ fouriered = torch.cat((x, fouriered), dim=-1)
+ return fouriered
+
+
+def generate_fourier_features(x, max_freq=64, num_bands=16):
+ x = x.unsqueeze(-1)
+ device, dtype, orig_x = x.device, x.dtype, x
+
+ scales = torch.linspace(
+ -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype
+ )
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
+
+ x = x * scales * pi
+ x = torch.cat([x.sin(), x.cos()], dim=-1)
+ x = torch.cat((x, orig_x), dim=-1)
+ return x.flatten(-2)
+
+
+def broadcat(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
+ ), "invalid dimensions for broadcastable concatentation"
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+ return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+class VisionRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ pt_seq_len,
+ ft_seq_len=None,
+ custom_freqs=None,
+ freqs_for="lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ ):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ )
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
+
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
+
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
+
+ self.register_buffer("freqs_cos", freqs.cos())
+ self.register_buffer("freqs_sin", freqs.sin())
+
+ print("======== shape of rope freq", self.freqs_cos.shape, "========")
+
+ def forward(self, t, start_index=0):
+ rot_dim = self.freqs_cos.shape[-1]
+ end_index = start_index + rot_dim
+ assert (
+ rot_dim <= t.shape[-1]
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+ t_left, t, t_right = (
+ t[..., :start_index],
+ t[..., start_index:end_index],
+ t[..., end_index:],
+ )
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
+ return torch.cat((t_left, t, t_right), dim=-1)
+
+
+class VisionRotaryEmbeddingFast(nn.Module):
+ def __init__(
+ self,
+ dim,
+ pt_seq_len,
+ ft_seq_len=None,
+ custom_freqs=None,
+ freqs_for="lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ ):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ )
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
+
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
+
+ self.register_buffer("freqs_cos", freqs_cos)
+ self.register_buffer("freqs_sin", freqs_sin)
+
+ def forward(self, t):
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
+
+
+class RotaryPositionalEmbeddings(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ max_seq_len: int = 30,
+ base: int = 10_000,
+ ) -> None:
+ super().__init__()
+ self.dim = dim
+ self.base = base
+ self.max_seq_len = max_seq_len
+ self._rope_init()
+
+ # We need to explicitly define reset_parameters for FSDP initialization, see
+ # https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885
+ def reset_parameters(self):
+ self._rope_init()
+
+ def _rope_init(self):
+ theta = 1.0 / (
+ self.base
+ ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim)
+ )
+ self.register_buffer("theta", theta, persistent=False)
+ self.build_rope_cache(self.max_seq_len)
+
+ def build_rope_cache(self, max_seq_len: int = 4096) -> None:
+ # Create position indexes `[0, 1, ..., max_seq_len - 1]`
+ seq_idx = torch.arange(
+ max_seq_len, dtype=self.theta.dtype, device=self.theta.device
+ )
+
+ # Outer product of theta and position index; output tensor has
+ # a shape of [max_seq_len, dim // 2]
+ idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float()
+
+ cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1)
+ self.register_buffer("cache", cache, persistent=False)
+
+ def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor:
+ """
+ Args:
+ x (Tensor): input tensor with shape
+ [bsz, seq_len, num_heads, head_dim]
+ input_pos (Optional[Tensor]): contains the position of the current toke
+
+ Returns:
+ Tensor: output tensor with RoPE applied
+
+ Notation used for tensor shapes:
+ - b: batch size
+ - s: sequence length
+ - n_h: num heads
+ - h_d: head dim
+ """
+ rope_cache = self.cache[input_pos]
+
+ # reshape input; the last dimension is used for computing the output.
+ # Cast to float to match the reference implementation
+ # tensor has shape [b, s, n_h, n_d // 2, 2]
+ xshaped = x.reshape(*x.shape[:-1], -1, 2)
+
+ # reshape the cache for broadcasting
+ # tensor has shape [b, s, 1, n_d // 2, 2]
+ rope_cache = rope_cache.unsqueeze(2)
+
+ # tensor has shape [b, s, n_h, n_d // 2, 2]
+ x_out = torch.stack(
+ [
+ xshaped[..., 0] * rope_cache[..., 0]
+ - xshaped[..., 1] * rope_cache[..., 1],
+ xshaped[..., 1] * rope_cache[..., 0]
+ + xshaped[..., 0] * rope_cache[..., 1],
+ ],
+ -1,
+ )
+
+ # tensor has shape [b, s, n_h, n_d]
+ return x_out.flatten(3)
diff --git a/unik3d/layers/upsample.py b/unik3d/layers/upsample.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a4b9141ed1e0ce747e7e43769c94b709591f115
--- /dev/null
+++ b/unik3d/layers/upsample.py
@@ -0,0 +1,169 @@
+import torch
+import torch.nn as nn
+from einops import rearrange
+
+from unik3d.utils.constants import VERBOSE
+from unik3d.utils.misc import profile_method
+
+
+class ResidualConvUnit(nn.Module):
+ def __init__(
+ self,
+ dim,
+ kernel_size: int = 3,
+ padding_mode: str = "zeros",
+ dilation: int = 1,
+ layer_scale: float = 1.0,
+ use_norm: bool = False,
+ ):
+ super().__init__()
+ self.conv1 = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ )
+ self.conv2 = nn.Conv2d(
+ dim,
+ dim,
+ kernel_size=kernel_size,
+ padding=dilation * (kernel_size - 1) // 2,
+ dilation=dilation,
+ padding_mode=padding_mode,
+ )
+ self.activation = nn.LeakyReLU()
+ self.skip_add = nn.quantized.FloatFunctional()
+ self.gamma = (
+ nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1))
+ if layer_scale > 0.0
+ else 1.0
+ )
+ self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
+ self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity()
+
+ def forward(self, x):
+ out = self.activation(x)
+ out = self.conv1(out)
+ out = self.norm1(out)
+ out = self.activation(out)
+ out = self.conv2(out)
+ out = self.norm2(out)
+ return self.skip_add.add(self.gamma * out, x)
+
+
+class ResUpsampleBil(nn.Module):
+ def __init__(
+ self,
+ hidden_dim,
+ output_dim: int = None,
+ num_layers: int = 2,
+ kernel_size: int = 3,
+ layer_scale: float = 1.0,
+ padding_mode: str = "zeros",
+ use_norm: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ output_dim = output_dim if output_dim is not None else hidden_dim // 2
+ self.convs = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.convs.append(
+ ResidualConvUnit(
+ hidden_dim,
+ kernel_size=kernel_size,
+ layer_scale=layer_scale,
+ padding_mode=padding_mode,
+ use_norm=use_norm,
+ )
+ )
+ self.up = nn.Sequential(
+ nn.Conv2d(
+ hidden_dim,
+ output_dim,
+ kernel_size=1,
+ padding=0,
+ padding_mode=padding_mode,
+ ),
+ nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False),
+ )
+
+ @profile_method(verbose=VERBOSE)
+ def forward(self, x: torch.Tensor):
+ for conv in self.convs:
+ x = conv(x)
+ x = self.up(x)
+ return x
+
+
+class ResUpsample(nn.Module):
+ def __init__(
+ self,
+ hidden_dim,
+ num_layers: int = 2,
+ kernel_size: int = 3,
+ layer_scale: float = 1.0,
+ padding_mode: str = "zeros",
+ **kwargs,
+ ):
+ super().__init__()
+ self.convs = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.convs.append(
+ ResidualConvUnit(
+ hidden_dim,
+ kernel_size=kernel_size,
+ layer_scale=layer_scale,
+ padding_mode=padding_mode,
+ )
+ )
+ self.up = nn.ConvTranspose2d(
+ hidden_dim, hidden_dim // 2, kernel_size=2, stride=2, padding=0
+ )
+
+ @profile_method(verbose=VERBOSE)
+ def forward(self, x: torch.Tensor):
+ for conv in self.convs:
+ x = conv(x)
+ x = self.up(x)
+ return x
+
+
+class ResUpsampleSH(nn.Module):
+ def __init__(
+ self,
+ hidden_dim,
+ num_layers: int = 2,
+ kernel_size: int = 3,
+ layer_scale: float = 1.0,
+ padding_mode: str = "zeros",
+ **kwargs,
+ ):
+ super().__init__()
+ self.convs = nn.ModuleList([])
+ for _ in range(num_layers):
+ self.convs.append(
+ ResidualConvUnit(
+ hidden_dim,
+ kernel_size=kernel_size,
+ layer_scale=layer_scale,
+ padding_mode=padding_mode,
+ )
+ )
+ self.up = nn.Sequential(
+ nn.PixelShuffle(2),
+ nn.Conv2d(
+ hidden_dim // 4,
+ hidden_dim // 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode=padding_mode,
+ ),
+ )
+
+ def forward(self, x: torch.Tensor):
+ for conv in self.convs:
+ x = conv(x)
+ x = self.up(x)
+ return x
diff --git a/unik3d/models/__init__.py b/unik3d/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..62d77529507b5eb1ae2dc281cd8c152e84212125
--- /dev/null
+++ b/unik3d/models/__init__.py
@@ -0,0 +1,3 @@
+from unik3d.models.unik3d import UniK3D
+
+__all__ = ["UniK3D"]
diff --git a/unik3d/models/backbones/__init__.py b/unik3d/models/backbones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..af81441399818e6a32fd91b23aae7945b8fd4e1c
--- /dev/null
+++ b/unik3d/models/backbones/__init__.py
@@ -0,0 +1,13 @@
+from .convnext import ConvNeXt
+from .convnext2 import ConvNeXtV2
+from .dinov2 import _make_dinov2_model
+from .swinv2 import SwinTransformerV2
+
+# from .svd import StableVideoDiffusion
+
+__all__ = [
+ "SwinTransformerV2",
+ "ConvNeXtV2",
+ "_make_dinov2_model",
+ "ConvNeXt",
+]
diff --git a/unik3d/models/backbones/convnext.py b/unik3d/models/backbones/convnext.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f9d820f3230e349e6263cc0be4ce1139b2df3f
--- /dev/null
+++ b/unik3d/models/backbones/convnext.py
@@ -0,0 +1,577 @@
+from collections import OrderedDict
+from functools import partial
+from typing import Callable, Optional, Sequence, Tuple, Union
+
+import torch
+import torch.nn as nn
+from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp,
+ LayerNorm, LayerNorm2d, Mlp, create_conv2d,
+ get_act_layer, make_divisible, to_ntuple,
+ trunc_normal_)
+from torch.utils.checkpoint import checkpoint
+
+
+def get_num_layer_for_convnext(var_name):
+ """
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
+ consecutive blocks, including possible neighboring downsample layers;
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
+ """
+ if var_name.startswith("downsample_layers"):
+ stage_id = int(var_name.split(".")[1])
+ if stage_id == 0:
+ layer_id = 0
+ elif stage_id == 1 or stage_id == 2:
+ layer_id = stage_id + 1
+ elif stage_id == 3:
+ layer_id = 12
+
+ elif var_name.startswith("stages"):
+ stage_id = int(var_name.split(".")[1])
+ block_id = int(var_name.split(".")[3])
+ if stage_id == 0 or stage_id == 1:
+ layer_id = stage_id + 1
+ elif stage_id == 2:
+ layer_id = 3 + block_id // 3
+ elif stage_id == 3:
+ layer_id = 12
+
+ elif var_name.startswith("stem"):
+ return 0
+ else:
+ layer_id = 12
+ return layer_id + 1
+
+
+def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None):
+ parameter_group_names = {}
+ parameter_group_vars = {}
+ skip = set()
+ if skip_list is not None:
+ skip = skip_list
+ if hasattr(model, "no_weight_decay"):
+ skip.update(model.no_weight_decay())
+ num_layers = 12
+ layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if len(param.shape) == 1 or name.endswith(".bias") or name in skip:
+ group_name = "no_decay"
+ this_wd = 0.0
+ else:
+ group_name = "decay"
+ this_wd = wd
+
+ layer_id = get_num_layer_for_convnext(name)
+ group_name = "layer_%d_%s" % (layer_id, group_name)
+
+ if group_name not in parameter_group_names:
+ scale = layer_scale[layer_id]
+ cur_lr = lr * scale
+
+ parameter_group_names[group_name] = {
+ "weight_decay": this_wd,
+ "weight_decay_init": this_wd,
+ "weight_decay_base": this_wd,
+ "params": [],
+ "lr_init": cur_lr,
+ "lr_base": lr,
+ "lr": cur_lr,
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_wd,
+ "weight_decay_init": this_wd,
+ "weight_decay_base": this_wd,
+ "params": [],
+ "lr_init": cur_lr,
+ "lr_base": lr,
+ "lr": cur_lr,
+ }
+ if this_wd == 0.0:
+ parameter_group_names[group_name]["weight_decay_final"] = 0.0
+ parameter_group_vars[group_name]["weight_decay_final"] = 0.0
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+
+ return list(parameter_group_vars.values()), [
+ v["lr"] for k, v in parameter_group_vars.items()
+ ]
+
+
+class Downsample(nn.Module):
+ def __init__(self, in_chs, out_chs, stride=1, dilation=1):
+ super().__init__()
+ avg_stride = stride if dilation == 1 else 1
+ if stride > 1 or dilation > 1:
+ avg_pool_fn = (
+ AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ )
+ self.pool = avg_pool_fn(
+ 2, avg_stride, ceil_mode=True, count_include_pad=False
+ )
+ else:
+ self.pool = nn.Identity()
+
+ if in_chs != out_chs:
+ self.conv = create_conv2d(in_chs, out_chs, 1, stride=1)
+ else:
+ self.conv = nn.Identity()
+
+ def forward(self, x):
+ x = self.pool(x)
+ x = self.conv(x)
+ return x
+
+
+class ConvNeXtBlock(nn.Module):
+ """ConvNeXt Block
+ There are two equivalent implementations:
+ (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W)
+ (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back
+
+ Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate
+ choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear
+ is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW.
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ out_chs: Optional[int] = None,
+ kernel_size: int = 7,
+ stride: int = 1,
+ dilation: Union[int, Tuple[int, int]] = (1, 1),
+ mlp_ratio: float = 4,
+ conv_mlp: bool = False,
+ conv_bias: bool = True,
+ use_grn: bool = False,
+ ls_init_value: Optional[float] = 1e-6,
+ act_layer: Union[str, Callable] = "gelu",
+ norm_layer: Optional[Callable] = None,
+ drop_path: float = 0.0,
+ ):
+ """
+
+ Args:
+ in_chs: Block input channels.
+ out_chs: Block output channels (same as in_chs if None).
+ kernel_size: Depthwise convolution kernel size.
+ stride: Stride of depthwise convolution.
+ dilation: Tuple specifying input and output dilation of block.
+ mlp_ratio: MLP expansion ratio.
+ conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True.
+ conv_bias: Apply bias for all convolution (linear) layers.
+ use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2)
+ ls_init_value: Layer-scale init values, layer-scale applied if not None.
+ act_layer: Activation layer.
+ norm_layer: Normalization layer (defaults to LN if not specified).
+ drop_path: Stochastic depth probability.
+ """
+ super().__init__()
+ out_chs = out_chs or in_chs
+ dilation = to_ntuple(2)(dilation)
+ act_layer = get_act_layer(act_layer)
+ if not norm_layer:
+ norm_layer = LayerNorm2d if conv_mlp else LayerNorm
+ mlp_layer = partial(
+ GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp
+ )
+ self.use_conv_mlp = conv_mlp
+ self.conv_dw = create_conv2d(
+ in_chs,
+ out_chs,
+ kernel_size=kernel_size,
+ stride=stride,
+ dilation=dilation[0],
+ depthwise=True,
+ bias=conv_bias,
+ )
+ self.norm = norm_layer(out_chs)
+ self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer)
+ self.gamma = (
+ nn.Parameter(ls_init_value * torch.ones(out_chs))
+ if ls_init_value is not None
+ else None
+ )
+ if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]:
+ self.shortcut = Downsample(
+ in_chs, out_chs, stride=stride, dilation=dilation[0]
+ )
+ else:
+ self.shortcut = nn.Identity()
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ def forward(self, x):
+ shortcut = x
+ x = self.conv_dw(x.contiguous())
+ if self.use_conv_mlp:
+ x = self.norm(x)
+ x = self.mlp(x)
+ else:
+ x = x.permute(0, 2, 3, 1).contiguous()
+ x = self.norm(x)
+ x = self.mlp(x)
+ x = x.permute(0, 3, 1, 2).contiguous()
+ if self.gamma is not None:
+ x = x.mul(self.gamma.reshape(1, -1, 1, 1))
+
+ x = self.drop_path(x) + self.shortcut(shortcut)
+ return x.contiguous()
+
+
+class ConvNeXtStage(nn.Module):
+ def __init__(
+ self,
+ in_chs,
+ out_chs,
+ kernel_size=7,
+ stride=2,
+ depth=2,
+ dilation=(1, 1),
+ drop_path_rates=None,
+ ls_init_value=1.0,
+ conv_mlp=False,
+ conv_bias=True,
+ use_grn=False,
+ act_layer="gelu",
+ norm_layer=None,
+ norm_layer_cl=None,
+ ):
+ super().__init__()
+ self.grad_checkpointing = False
+
+ if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]:
+ ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1
+ pad = (
+ "same" if dilation[1] > 1 else 0
+ ) # same padding needed if dilation used
+ self.downsample = nn.Sequential(
+ norm_layer(in_chs),
+ create_conv2d(
+ in_chs,
+ out_chs,
+ kernel_size=ds_ks,
+ stride=stride,
+ dilation=dilation[0],
+ padding=pad,
+ bias=conv_bias,
+ ),
+ )
+ in_chs = out_chs
+ else:
+ self.downsample = nn.Identity()
+
+ drop_path_rates = drop_path_rates or [0.0] * depth
+ stage_blocks = []
+ for i in range(depth):
+ stage_blocks.append(
+ ConvNeXtBlock(
+ in_chs=in_chs,
+ out_chs=out_chs,
+ kernel_size=kernel_size,
+ dilation=dilation[1],
+ drop_path=drop_path_rates[i],
+ ls_init_value=ls_init_value,
+ conv_mlp=conv_mlp,
+ conv_bias=conv_bias,
+ use_grn=use_grn,
+ act_layer=act_layer,
+ norm_layer=norm_layer if conv_mlp else norm_layer_cl,
+ )
+ )
+ in_chs = out_chs
+ self.blocks = nn.ModuleList(stage_blocks)
+
+ def forward(self, x):
+ xs = []
+ x = self.downsample(x)
+ for block in self.blocks:
+ if self.grad_checkpointing:
+ x = checkpoint(block, x)
+ else:
+ x = block(x)
+ xs.append(x)
+ return xs
+
+
+class ConvNeXt(nn.Module):
+ def __init__(
+ self,
+ in_chans: int = 3,
+ output_stride: int = 32,
+ depths: Tuple[int, ...] = (3, 3, 9, 3),
+ dims: Tuple[int, ...] = (96, 192, 384, 768),
+ kernel_sizes: Union[int, Tuple[int, ...]] = 7,
+ ls_init_value: Optional[float] = 1e-6,
+ stem_type: str = "patch",
+ patch_size: int = 4,
+ conv_mlp: bool = False,
+ conv_bias: bool = True,
+ use_grn: bool = False,
+ act_layer: Union[str, Callable] = "gelu",
+ norm_layer: Optional[Union[str, Callable]] = None,
+ norm_eps: Optional[float] = None,
+ drop_path_rate: float = 0.0,
+ output_idx=[],
+ use_checkpoint=False,
+ ):
+ """
+ Args:
+ in_chans: Number of input image channels.
+ num_classes: Number of classes for classification head.
+ global_pool: Global pooling type.
+ output_stride: Output stride of network, one of (8, 16, 32).
+ depths: Number of blocks at each stage.
+ dims: Feature dimension at each stage.
+ kernel_sizes: Depthwise convolution kernel-sizes for each stage.
+ ls_init_value: Init value for Layer Scale, disabled if None.
+ stem_type: Type of stem.
+ patch_size: Stem patch size for patch stem.
+ head_init_scale: Init scaling value for classifier weights and biases.
+ head_norm_first: Apply normalization before global pool + head.
+ head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False.
+ conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last.
+ conv_bias: Use bias layers w/ all convolutions.
+ use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP.
+ act_layer: Activation layer type.
+ norm_layer: Normalization layer type.
+ drop_rate: Head pre-classifier dropout rate.
+ drop_path_rate: Stochastic depth drop rate.
+ """
+ super().__init__()
+ self.num_layers = len(depths)
+ self.depths = output_idx
+ self.embed_dims = [
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
+ ]
+ self.embed_dim = dims[0]
+
+ assert output_stride in (8, 16, 32)
+ kernel_sizes = to_ntuple(4)(kernel_sizes)
+ if norm_layer is None:
+ norm_layer = LayerNorm2d
+ norm_layer_cl = norm_layer if conv_mlp else LayerNorm
+ if norm_eps is not None:
+ norm_layer = partial(norm_layer, eps=norm_eps)
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
+ else:
+ assert (
+ conv_mlp
+ ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input"
+ norm_layer_cl = norm_layer
+ if norm_eps is not None:
+ norm_layer_cl = partial(norm_layer_cl, eps=norm_eps)
+
+ self.feature_info = []
+
+ assert stem_type in ("patch", "overlap", "overlap_tiered")
+ if stem_type == "patch":
+ # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4
+ self.stem = nn.Sequential(
+ nn.Conv2d(
+ in_chans,
+ dims[0],
+ kernel_size=patch_size,
+ stride=patch_size,
+ bias=conv_bias,
+ ),
+ norm_layer(dims[0]),
+ )
+ stem_stride = patch_size
+ else:
+ mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0]
+ self.stem = nn.Sequential(
+ nn.Conv2d(
+ in_chans,
+ mid_chs,
+ kernel_size=3,
+ stride=2,
+ padding=1,
+ bias=conv_bias,
+ ),
+ nn.Conv2d(
+ mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias
+ ),
+ norm_layer(dims[0]),
+ )
+ stem_stride = 4
+
+ self.stages = nn.Sequential()
+ dp_rates = [
+ x.tolist()
+ for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths)
+ ]
+ stages = []
+ prev_chs = dims[0]
+ curr_stride = stem_stride
+ dilation = 1
+ # 4 feature resolution stages, each consisting of multiple residual blocks
+ for i in range(4):
+ stride = 2 if curr_stride == 2 or i > 0 else 1
+ if curr_stride >= output_stride and stride > 1:
+ dilation *= stride
+ stride = 1
+ curr_stride *= stride
+ first_dilation = 1 if dilation in (1, 2) else 2
+ out_chs = dims[i]
+ stages.append(
+ ConvNeXtStage(
+ prev_chs,
+ out_chs,
+ kernel_size=kernel_sizes[i],
+ stride=stride,
+ dilation=(first_dilation, dilation),
+ depth=depths[i],
+ drop_path_rates=dp_rates[i],
+ ls_init_value=ls_init_value,
+ conv_mlp=conv_mlp,
+ conv_bias=conv_bias,
+ use_grn=use_grn,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ norm_layer_cl=norm_layer_cl,
+ )
+ )
+ prev_chs = out_chs
+ # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2
+ self.feature_info += [
+ dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}")
+ ]
+ self.stages = nn.ModuleList(stages)
+ self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1))
+ self.num_features = prev_chs
+ self.apply(self._init_weights)
+ self.set_grad_checkpointing(use_checkpoint)
+
+ def _init_weights(self, module):
+ if isinstance(module, nn.Conv2d):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+ elif isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ nn.init.zeros_(module.bias)
+
+ def forward(self, x, masks=None):
+ outs = []
+ x = self.stem(x)
+ if masks is not None:
+ masks = torch.nn.functional.interpolate(
+ masks.float(), size=x.shape[-2:], mode="nearest"
+ )
+ x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous()
+ for stage in self.stages:
+ xs = stage(x)
+ outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs])
+ x = xs[-1]
+ return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
+
+ @torch.jit.ignore
+ def group_matcher(self, coarse=False):
+ return dict(
+ stem=r"^stem",
+ blocks=(
+ r"^stages\.(\d+)"
+ if coarse
+ else [
+ (r"^stages\.(\d+)\.downsample", (0,)), # blocks
+ (r"^stages\.(\d+)\.blocks\.(\d+)", None),
+ (r"^norm_pre", (99999,)),
+ ]
+ ),
+ )
+
+ @torch.jit.ignore
+ def set_grad_checkpointing(self, enable=True):
+ for s in self.stages:
+ s.grad_checkpointing = enable
+
+ def freeze(self) -> None:
+ for module in self.modules():
+ module.eval()
+ for parameters in self.parameters():
+ parameters.requires_grad = False
+
+ def get_params(self, lr, wd, ld, *args, **kwargs):
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
+ return encoder_p, encoder_lr
+
+ def no_weight_decay(self):
+ return {"mask_token"}
+
+ @classmethod
+ def build(cls, config):
+ obj = globals()[config["model"]["encoder"]["name"]](config)
+ return obj
+
+
+def checkpoint_filter_fn(state_dict, model):
+ """Remap FB checkpoints -> timm"""
+ if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict:
+ return state_dict # non-FB checkpoint
+ if "model" in state_dict:
+ state_dict = state_dict["model"]
+
+ out_dict = {}
+ if "visual.trunk.stem.0.weight" in state_dict:
+ out_dict = {
+ k.replace("visual.trunk.", ""): v
+ for k, v in state_dict.items()
+ if k.startswith("visual.trunk.")
+ }
+ if "visual.head.proj.weight" in state_dict:
+ out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"]
+ out_dict["head.fc.bias"] = torch.zeros(
+ state_dict["visual.head.proj.weight"].shape[0]
+ )
+ elif "visual.head.mlp.fc1.weight" in state_dict:
+ out_dict["head.pre_logits.fc.weight"] = state_dict[
+ "visual.head.mlp.fc1.weight"
+ ]
+ out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"]
+ out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"]
+ out_dict["head.fc.bias"] = torch.zeros(
+ state_dict["visual.head.mlp.fc2.weight"].shape[0]
+ )
+ return out_dict
+
+ import re
+
+ for k, v in state_dict.items():
+ k = k.replace("downsample_layers.0.", "stem.")
+ k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k)
+ k = re.sub(
+ r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k
+ )
+ k = k.replace("dwconv", "conv_dw")
+ k = k.replace("pwconv", "mlp.fc")
+ if "grn" in k:
+ k = k.replace("grn.beta", "mlp.grn.bias")
+ k = k.replace("grn.gamma", "mlp.grn.weight")
+ v = v.reshape(v.shape[-1])
+ k = k.replace("head.", "head.fc.")
+ if k.startswith("norm."):
+ k = k.replace("norm", "head.norm")
+ if v.ndim == 2 and "head" not in k:
+ model_shape = model.state_dict()[k].shape
+ v = v.reshape(model_shape)
+ out_dict[k] = v
+
+ return out_dict
+
+
+HF_URL = {
+ "convnext_xxlarge_pt": (
+ "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup",
+ "open_clip_pytorch_model.bin",
+ ),
+ "convnext_large_pt": (
+ "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup",
+ "open_clip_pytorch_model.bin",
+ ),
+ "convnext_large": (
+ "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384",
+ "pytorch_model.bin",
+ ),
+}
diff --git a/unik3d/models/backbones/convnext2.py b/unik3d/models/backbones/convnext2.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d5bbf12bf586552ef35e2e82d59e47b9c9cc42
--- /dev/null
+++ b/unik3d/models/backbones/convnext2.py
@@ -0,0 +1,288 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from timm.models.layers import DropPath, trunc_normal_
+
+
+def get_num_layer_for_convnext_single(var_name, depths):
+ """
+ Each layer is assigned distinctive layer ids
+ """
+ if var_name.startswith("downsample_layers"):
+ stage_id = int(var_name.split(".")[1])
+ layer_id = sum(depths[:stage_id]) + 1
+ return layer_id
+
+ elif var_name.startswith("stages"):
+ stage_id = int(var_name.split(".")[1])
+ block_id = int(var_name.split(".")[2])
+ layer_id = sum(depths[:stage_id]) + block_id + 1
+ return layer_id
+
+ else:
+ return sum(depths) + 1
+
+
+def get_num_layer_for_convnext(var_name):
+ """
+ Divide [3, 3, 27, 3] layers into 12 groups; each group is three
+ consecutive blocks, including possible neighboring downsample layers;
+ adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py
+ """
+ num_max_layer = 12
+ if var_name.startswith("downsample_layers"):
+ stage_id = int(var_name.split(".")[1])
+ if stage_id == 0:
+ layer_id = 0
+ elif stage_id == 1 or stage_id == 2:
+ layer_id = stage_id + 1
+ elif stage_id == 3:
+ layer_id = 12
+ return layer_id
+
+ elif var_name.startswith("stages"):
+ stage_id = int(var_name.split(".")[1])
+ block_id = int(var_name.split(".")[2])
+ if stage_id == 0 or stage_id == 1:
+ layer_id = stage_id + 1
+ elif stage_id == 2:
+ layer_id = 3 + block_id // 3
+ elif stage_id == 3:
+ layer_id = 12
+ return layer_id
+ else:
+ return num_max_layer + 1
+
+
+def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
+ parameter_group_names = {}
+ parameter_group_vars = {}
+ skip = {}
+ if skip_list is not None:
+ skip = skip_list
+ elif hasattr(model, "no_weight_decay"):
+ skip = model.no_weight_decay()
+ num_layers = 12 # sum(model.depths)
+ layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2))
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if (
+ len(param.shape) == 1
+ or name.endswith(".bias")
+ or name in skip
+ or name.endswith(".gamma")
+ or name.endswith(".beta")
+ ):
+ group_name = "no_decay"
+ this_weight_decay = 0.0
+ else:
+ group_name = "decay"
+ this_weight_decay = wd
+
+ # layer_id = get_num_layer_for_convnext_single(name, model.depths)
+ layer_id = get_num_layer_for_convnext(name)
+ group_name = "layer_%d_%s" % (layer_id, group_name)
+
+ if group_name not in parameter_group_names:
+ scale = layer_scale[layer_id]
+ cur_lr = lr * scale
+
+ parameter_group_names[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale,
+ "lr": cur_lr,
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_weight_decay,
+ "params": [],
+ "lr_scale": scale,
+ "lr": cur_lr,
+ }
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+ # if is_main_process():
+ # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2))
+ return list(parameter_group_vars.values()), [
+ v["lr"] for k, v in parameter_group_vars.items()
+ ]
+
+
+class LayerNorm(nn.Module):
+ """LayerNorm that supports two data formats: channels_last (default) or channels_first.
+ The ordering of the dimensions in the inputs. channels_last corresponds to inputs with
+ shape (batch_size, height, width, channels) while channels_first corresponds to inputs
+ with shape (batch_size, channels, height, width).
+ """
+
+ def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"):
+ super().__init__()
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
+ self.eps = eps
+ self.data_format = data_format
+ if self.data_format not in ["channels_last", "channels_first"]:
+ raise NotImplementedError
+ self.normalized_shape = (normalized_shape,)
+
+ def forward(self, x):
+ if self.data_format == "channels_last":
+ return F.layer_norm(
+ x, self.normalized_shape, self.weight, self.bias, self.eps
+ )
+ elif self.data_format == "channels_first":
+ u = x.mean(1, keepdim=True)
+ s = (x - u).pow(2).mean(1, keepdim=True)
+ x = (x - u) / torch.sqrt(s + self.eps)
+ x = self.weight[:, None, None] * x + self.bias[:, None, None]
+ return x
+
+
+class GRN(nn.Module):
+ """GRN (Global Response Normalization) layer"""
+
+ def __init__(self, dim):
+ super().__init__()
+ self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim))
+ self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim))
+
+ def forward(self, x):
+ Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
+ Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6)
+ return self.gamma * (x * Nx) + self.beta + x
+
+
+class Block(nn.Module):
+ """ConvNeXtV2 Block.
+
+ Args:
+ dim (int): Number of input channels.
+ drop_path (float): Stochastic depth rate. Default: 0.0
+ """
+
+ def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False):
+ super().__init__()
+ self.dwconv = nn.Conv2d(
+ dim, dim, kernel_size=7, padding=3, groups=dim
+ ) # depthwise conv
+ self.norm = LayerNorm(dim, eps=1e-6)
+ self.pwconv1 = nn.Linear(
+ dim, mult * dim
+ ) # pointwise/1x1 convs, implemented with linear layers
+ self.act = nn.GELU()
+ self.grn = GRN(mult * dim)
+ self.pwconv2 = nn.Linear(mult * dim, dim)
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.use_checkpoint = use_checkpoint
+
+ def forward(self, x):
+ input = x
+ x = self.dwconv(x)
+ x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C)
+ x = self.norm(x)
+ x = self.pwconv1(x)
+ x = self.act(x)
+ x = self.grn(x)
+ x = self.pwconv2(x)
+ x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W)
+
+ x = input + self.drop_path(x)
+ return x
+
+
+class ConvNeXtV2(nn.Module):
+ """ConvNeXt V2
+
+ Args:
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3]
+ dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768]
+ drop_path_rate (float): Stochastic depth rate. Default: 0.
+ head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1.
+ """
+
+ def __init__(
+ self,
+ in_chans=3,
+ depths=[3, 3, 9, 3],
+ dims=96,
+ drop_path_rate=0.0,
+ output_idx=[],
+ use_checkpoint=False,
+ ):
+ super().__init__()
+ self.num_layers = len(depths)
+ self.depths = output_idx
+ self.embed_dims = [
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
+ ]
+ self.embed_dim = dims[0]
+
+ self.downsample_layers = (
+ nn.ModuleList()
+ ) # stem and 3 intermediate downsampling conv layers
+ stem = nn.Sequential(
+ nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4),
+ LayerNorm(dims[0], eps=1e-6, data_format="channels_first"),
+ )
+ self.downsample_layers.append(stem)
+ for i in range(3):
+ downsample_layer = nn.Sequential(
+ LayerNorm(dims[i], eps=1e-6, data_format="channels_first"),
+ nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2),
+ )
+ self.downsample_layers.append(downsample_layer)
+
+ self.stages = (
+ nn.ModuleList()
+ ) # 4 feature resolution stages, each consisting of multiple residual blocks
+ self.out_norms = nn.ModuleList()
+ dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
+ cur = 0
+ for i in range(4):
+ stage = nn.ModuleList(
+ [
+ Block(
+ dim=dims[i],
+ drop_path=dp_rates[cur + j],
+ use_checkpoint=use_checkpoint,
+ )
+ for j in range(depths[i])
+ ]
+ )
+ self.stages.append(stage)
+ cur += depths[i]
+
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, (nn.Conv2d, nn.Linear)):
+ trunc_normal_(m.weight, std=0.02)
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ outs = []
+ for i in range(4):
+ x = self.downsample_layers[i](x)
+ for stage in self.stages[i]:
+ x = stage(x)
+ outs.append(x.permute(0, 2, 3, 1))
+ cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs]
+ return outs, cls_tokens
+
+ def get_params(self, lr, wd, ld, *args, **kwargs):
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
+ return encoder_p, encoder_lr
+
+ def freeze(self) -> None:
+ for module in self.modules():
+ module.eval()
+ for parameters in self.parameters():
+ parameters.requires_grad = False
+
+ @classmethod
+ def build(cls, config):
+ obj = globals()[config["model"]["encoder"]["name"]](config)
+ return obj
diff --git a/unik3d/models/backbones/dinov2.py b/unik3d/models/backbones/dinov2.py
new file mode 100644
index 0000000000000000000000000000000000000000..163fbaac32c723251bb2614f765e30337d143d11
--- /dev/null
+++ b/unik3d/models/backbones/dinov2.py
@@ -0,0 +1,521 @@
+import contextlib
+import logging
+import math
+from functools import partial
+from typing import Callable, Sequence
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+
+from unik3d.models.metadinov2 import (Block, MemEffAttention, Mlp, PatchEmbed,
+ SwiGLUFFNFused)
+
+
+def named_apply(
+ fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
+) -> nn.Module:
+ if not depth_first and include_root:
+ fn(module=module, name=name)
+ for child_name, child_module in module.named_children():
+ child_name = ".".join((name, child_name)) if name else child_name
+ named_apply(
+ fn=fn,
+ module=child_module,
+ name=child_name,
+ depth_first=depth_first,
+ include_root=True,
+ )
+ if depth_first and include_root:
+ fn(module=module, name=name)
+ return module
+
+
+def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()):
+ parameter_group_names = {}
+ parameter_group_vars = {}
+ skip = {}
+ if skip_list is not None:
+ skip = skip_list
+ elif hasattr(model, "no_weight_decay"):
+ skip = model.no_weight_decay()
+
+ num_layers = model.n_blocks
+ layer_scale = list(ld ** (num_layers - i) for i in range(num_layers))
+
+ for name, param in model.named_parameters():
+ if not param.requires_grad:
+ continue
+
+ if len(param.shape) == 1: # norm
+ group_name = "no_decay"
+ this_wd = 0.0
+ # layer scale, bias beta?
+ elif (
+ name in skip
+ or name.endswith(".gamma")
+ or name.endswith(".beta")
+ or name.endswith(".bias")
+ ):
+ group_name = "no_decay"
+ this_wd = 0.0
+ elif "cls_token" in name or "pos_embed" in name or "mask_token" in name:
+ group_name = "no_decay"
+ this_wd = 0.0
+ else:
+ group_name = "decay"
+ this_wd = wd
+
+ if name.startswith("blocks"):
+ layer_id = int(name.split(".")[1])
+ elif name.startswith("patch_embed"):
+ layer_id = 0
+ else:
+ layer_id = 0
+
+ group_name = f"layer_{layer_id}_{group_name}"
+
+ if group_name not in parameter_group_names:
+ scale = layer_scale[layer_id]
+ cur_lr = lr * scale
+
+ parameter_group_names[group_name] = {
+ "weight_decay": this_wd,
+ "params": [],
+ "lr_init": cur_lr,
+ "lr_base": lr,
+ "lr": cur_lr,
+ }
+ parameter_group_vars[group_name] = {
+ "weight_decay": this_wd,
+ "params": [],
+ "lr_init": cur_lr,
+ "lr_base": lr,
+ "lr": cur_lr,
+ }
+ parameter_group_vars[group_name]["params"].append(param)
+ parameter_group_names[group_name]["params"].append(name)
+
+ # for group_name in parameter_group_names.keys():
+ # for k, v in zip(parameter_group_names[group_name]["params"], parameter_group_vars[group_name]["params"]):
+ # print(group_name,k)
+ return list(parameter_group_vars.values()), [
+ v["lr"] for k, v in parameter_group_vars.items()
+ ]
+
+
+class BlockChunk(nn.ModuleList):
+ def forward(self, x):
+ for b in self:
+ x = b(x)
+ return x
+
+
+class DummyModule(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ for i in range(100):
+ setattr(self, f"layer{i}", nn.Linear(2048, 2048))
+
+ def forward(self, x):
+ return self.layer(x)
+
+
+class DinoVisionTransformer(nn.Module):
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=16,
+ in_chans=3,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ ffn_bias=True,
+ proj_bias=True,
+ drop_path_rate=0.0,
+ drop_path_uniform=False,
+ init_values=None, # for layerscale: None or 0 => no layerscale
+ embed_layer=PatchEmbed,
+ act_layer=nn.GELU,
+ block_fn=Block,
+ ffn_layer="mlp",
+ block_chunks=1,
+ output_idx=[5, 12, 18, 24],
+ checkpoint: bool = False,
+ num_register_tokens=0,
+ interpolate_antialias=False,
+ interpolate_offset=0.0,
+ use_norm=True,
+ frozen_stages=0,
+ freeze_norm=True,
+ ):
+ """
+ Args:
+ img_size (int, tuple): input image size
+ patch_size (int, tuple): patch size
+ in_chans (int): number of input channels
+ embed_dim (int): embedding dimension
+ depth (int): depth of transformer
+ num_heads (int): number of attention heads
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
+ qkv_bias (bool): enable bias for qkv if True
+ proj_bias (bool): enable bias for proj in attn if True
+ ffn_bias (bool): enable bias for ffn if True
+ drop_path_rate (float): stochastic depth rate
+ drop_path_uniform (bool): apply uniform drop rate across blocks
+ weight_init (str): weight init scheme
+ init_values (float): layer-scale init values
+ embed_layer (nn.Module): patch embedding layer
+ act_layer (nn.Module): MLP activation layer
+ block_fn (nn.Module): transformer block class
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
+ """
+ super().__init__()
+ # norm_layer = partial(nn.LayerNorm, eps=1e-6)
+
+ self.num_features = embed_dim # num_features for consistency with other models
+ self.frozen_stages = frozen_stages
+ self.embed_dims = [embed_dim] * output_idx[-1]
+ self.embed_dim = embed_dim
+ self.num_tokens = 1
+ self.freeze_norm = freeze_norm
+ self.n_blocks = depth
+ self.num_heads = num_heads
+ self.patch_size = patch_size
+ self.depths = output_idx
+ self.checkpoint = checkpoint
+ self.num_register_tokens = num_register_tokens
+ self.interpolate_antialias = interpolate_antialias
+ self.interpolate_offset = interpolate_offset
+
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ )
+ num_patches = self.patch_embed.num_patches
+
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
+ self.pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches + self.num_tokens, embed_dim)
+ )
+ assert num_register_tokens >= 0
+ self.register_tokens = nn.Parameter(
+ torch.zeros(1, max(1, num_register_tokens), embed_dim)
+ )
+
+ if drop_path_uniform is True:
+ dpr = [drop_path_rate] * depth
+ else:
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, depth)
+ ] # stochastic depth decay rule
+
+ if ffn_layer == "mlp":
+ ffn_layer = Mlp
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
+ ffn_layer = SwiGLUFFNFused
+ elif ffn_layer == "identity":
+
+ def f(*args, **kwargs):
+ return nn.Identity()
+
+ ffn_layer = f
+ else:
+ raise NotImplementedError
+
+ blocks_list = [
+ # nn.Identity()
+ block_fn(
+ dim=embed_dim,
+ num_heads=num_heads,
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ ffn_bias=ffn_bias,
+ drop_path=dpr[i],
+ norm_layer=nn.LayerNorm,
+ act_layer=act_layer,
+ ffn_layer=ffn_layer,
+ init_values=init_values,
+ )
+ for i in range(depth)
+ ]
+ self.chunked_blocks = False
+ self.blocks = nn.ModuleList(blocks_list)
+
+ self.norm = nn.LayerNorm(embed_dim)
+ self.use_norm = use_norm
+ self.head = nn.Identity()
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
+ self.init_weights()
+
+ def init_weights(self):
+ trunc_normal_(self.pos_embed, std=0.02)
+ nn.init.normal_(self.cls_token, std=1e-6)
+ if self.num_register_tokens:
+ nn.init.normal_(self.register_tokens, std=1e-6)
+ named_apply(init_weights_vit_timm, self)
+
+ def interpolate_pos_encoding(self, x, w, h):
+ previous_dtype = x.dtype
+ npatch = x.shape[1] - 1
+ N = self.pos_embed.shape[1] - 1
+ if npatch == N and w == h:
+ return self.pos_embed
+ pos_embed = self.pos_embed.float()
+ class_pos_embed = pos_embed[:, 0]
+ patch_pos_embed = pos_embed[:, 1:]
+ dim = x.shape[-1]
+ w0 = w // self.patch_size
+ h0 = h // self.patch_size
+
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
+ assert N == M * M
+ kwargs = {}
+ if self.interpolate_offset:
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
+ sx = float(w0 + self.interpolate_offset) / M
+ sy = float(h0 + self.interpolate_offset) / M
+ kwargs["scale_factor"] = (sx, sy)
+ else:
+ # Simply specify an output size instead of a scale factor
+ kwargs["size"] = (w0, h0)
+
+ patch_pos_embed = nn.functional.interpolate(
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
+ mode="bicubic",
+ antialias=self.interpolate_antialias,
+ **kwargs,
+ )
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
+
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
+ previous_dtype
+ )
+
+ def prepare_tokens_with_masks(self, x, masks=None):
+ B, nc, w, h = x.shape
+ with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext():
+ x = self.patch_embed(x)
+ if masks is not None:
+ masks = masks.bool().view(B, -1, 1)
+ x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x)
+
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
+ x = x + self.interpolate_pos_encoding(x, w, h)
+
+ if self.num_register_tokens:
+ x = torch.cat(
+ (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]),
+ dim=1,
+ )
+
+ return x
+
+ def forward_features(self, x, masks=None):
+ shapes = [val // self.patch_size for val in x.shape[-2:]]
+ batch_size = x.shape[0]
+ outputs = []
+
+ x = self.prepare_tokens_with_masks(x, masks)
+
+ for i, blk in enumerate(self.blocks):
+ with (
+ torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext()
+ ):
+ x = blk(x)
+ outputs.append(x)
+
+ if self.use_norm:
+ with (
+ torch.no_grad()
+ if self.frozen_stages >= len(self.blocks)
+ else contextlib.nullcontext()
+ ):
+ outputs = [self.norm(out) for out in outputs]
+ class_tokens = [out[:, :1] for out in outputs]
+ outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs]
+ outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs]
+
+ return (outputs, class_tokens)
+
+ def get_params(self, lr, wd, ld, *args, **kwargs):
+ encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld)
+ return encoder_p, encoder_lr
+
+ def freeze(self) -> None:
+ for module in self.modules():
+ module.eval()
+ for parameters in self.parameters():
+ parameters.requires_grad = False
+
+ def train(self, mode=True):
+ super().train(mode)
+
+ if self.freeze_norm:
+ for module in self.modules():
+ if isinstance(module, nn.LayerNorm):
+ for param in module.parameters():
+ param.requires_grad = False
+ module.eval()
+
+ if self.frozen_stages > -1:
+ for p in self.patch_embed.parameters():
+ p.requires_grad = False
+
+ for i, blk in enumerate(self.blocks):
+ if i < self.frozen_stages:
+ blk.eval()
+ for p in blk.parameters():
+ p.requires_grad = False
+
+ for p in self.norm.parameters():
+ p.requires_grad = self.frozen_stages <= len(self.blocks)
+
+ self.cls_token.requires_grad = self.frozen_stages < 1
+ self.pos_embed.requires_grad = self.frozen_stages < 1
+ self.mask_token.requires_grad = False
+ self.register_tokens.requires_grad = False
+
+ def forward(self, *args, is_training=False, **kwargs):
+ ret = self.forward_features(*args, **kwargs)
+ return ret
+
+
+def init_weights_vit_timm(module: nn.Module, name: str = ""):
+ """ViT weight initialization, original timm impl (for reproducibility)"""
+ if isinstance(module, nn.Linear):
+ trunc_normal_(module.weight, std=0.02)
+ if module.bias is not None:
+ nn.init.zeros_(module.bias)
+
+
+def vit_small(patch_size=16, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=384,
+ depth=12,
+ num_heads=6,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=768,
+ depth=12,
+ num_heads=12,
+ mlp_ratio=4,
+ num_register_tokens=num_register_tokens,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1024,
+ depth=24,
+ num_heads=16,
+ mlp_ratio=4,
+ num_register_tokens=num_register_tokens,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+def vit_giant2(patch_size=16, **kwargs):
+ """
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
+ """
+ model = DinoVisionTransformer(
+ patch_size=patch_size,
+ embed_dim=1536,
+ depth=40,
+ num_heads=24,
+ mlp_ratio=4,
+ block_fn=partial(Block, attn_class=MemEffAttention),
+ **kwargs,
+ )
+ return model
+
+
+_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
+
+
+def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str:
+ compact_arch_name = arch_name.replace("_", "")[:4]
+ return f"dinov2_{compact_arch_name}{patch_size}"
+
+
+def _make_dinov2_model(
+ *,
+ arch_name: str = "vit_large",
+ img_size: int = 518,
+ patch_size: int = 14,
+ init_values: float = 1.0,
+ ffn_layer: str = "mlp",
+ block_chunks: int = 0,
+ pretrained: str = "",
+ output_idx: Sequence[int] = [],
+ num_register_tokens: int = 0,
+ drop_path_rate: float = 0.0,
+ use_norm: bool = False,
+ interpolate_offset: float = 0.0,
+ frozen_stages: int = 0,
+ freeze_norm: bool = True,
+ **kwargs,
+):
+ model_name = _make_dinov2_model_name(arch_name, patch_size)
+
+ vit_kwargs = dict(
+ img_size=img_size,
+ patch_size=patch_size,
+ init_values=init_values,
+ ffn_layer=ffn_layer,
+ block_chunks=block_chunks,
+ output_idx=output_idx,
+ drop_path_rate=drop_path_rate,
+ num_register_tokens=num_register_tokens,
+ use_norm=use_norm,
+ interpolate_offset=interpolate_offset,
+ frozen_stages=frozen_stages,
+ freeze_norm=freeze_norm,
+ )
+ vit_kwargs.update(**kwargs)
+ model = eval(arch_name)(**vit_kwargs)
+
+ if pretrained == "":
+ url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}"
+ if num_register_tokens > 0:
+ url += "_reg4"
+ url += "_pretrain.pth"
+ state_dict = torch.hub.load_state_dict_from_url(
+ url, map_location="cpu", progress=False
+ )
+ info = model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ elif pretrained is not None:
+ state_dict = torch.load(pretrained, map_location="cpu", weights_only=False)
+ info = model.load_state_dict(state_dict, strict=False)
+ del state_dict
+ else:
+ info = {}
+
+ print(f"DINOv2 loaded from {pretrained} with info:", info)
+
+ return model
diff --git a/unik3d/models/backbones/swinv2.py b/unik3d/models/backbones/swinv2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f352aa3b784214637c9ced580a2a3e44063008b
--- /dev/null
+++ b/unik3d/models/backbones/swinv2.py
@@ -0,0 +1,942 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.utils.checkpoint as checkpoint
+from timm.models.layers import DropPath, to_2tuple, trunc_normal_
+
+from unik3d.utils.misc import get_params, load_checkpoint_swin
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features,
+ hidden_features=None,
+ out_features=None,
+ act_layer=nn.GELU,
+ drop=0.0,
+ ):
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features)
+ self.drop = nn.Dropout(drop)
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
+
+
+def window_partition(x, window_size):
+ """
+ Args:
+ x: (B, H, W, C)
+ window_size (int): window size
+
+ Returns:
+ windows: (num_windows*B, window_size, window_size, C)
+ """
+ B, H, W, C = x.shape
+ x = x.view(
+ B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C
+ )
+ windows = (
+ x.permute(0, 1, 3, 2, 4, 5)
+ .contiguous()
+ .view(-1, window_size[0], window_size[1], C)
+ )
+ return windows
+
+
+def window_reverse(windows, window_size, H, W):
+ """
+ Args:
+ windows: (num_windows*B, window_size, window_size, C)
+ window_size (int): Window size
+ H (int): Height of image
+ W (int): Width of image
+
+ Returns:
+ x: (B, H, W, C)
+ """
+ B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1]))
+ x = windows.view(
+ B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1
+ )
+ x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1)
+ return x
+
+
+class WindowAttention(nn.Module):
+ r"""Window based multi-head self attention (W-MSA) module with relative position bias.
+ It supports both of shifted and non-shifted window.
+
+ Args:
+ dim (int): Number of input channels.
+ window_size (tuple[int]): The height and width of the window.
+ num_heads (int): Number of attention heads.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0
+ proj_drop (float, optional): Dropout ratio of output. Default: 0.0
+ pretrained_window_size (tuple[int]): The height and width of the window in pre-training.
+ """
+
+ def __init__(
+ self,
+ dim,
+ window_size,
+ num_heads,
+ qkv_bias=True,
+ attn_drop=0.0,
+ proj_drop=0.0,
+ pretrained_window_size=[0, 0],
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.window_size = window_size # Wh, Ww
+ self.pretrained_window_size = pretrained_window_size
+ self.num_heads = num_heads
+
+ self.logit_scale = nn.Parameter(
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
+ )
+
+ # mlp to generate continuous relative position bias
+ self.rpe_mlp = nn.Sequential(
+ nn.Linear(2, 512, bias=True),
+ nn.ReLU(inplace=True),
+ nn.Linear(512, num_heads, bias=False),
+ )
+
+ # get relative_coords_table
+ relative_coords_h = torch.arange(
+ -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32
+ )
+ relative_coords_w = torch.arange(
+ -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32
+ )
+ relative_coords_table = (
+ torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w]))
+ .permute(1, 2, 0)
+ .contiguous()
+ .unsqueeze(0)
+ ) # 1, 2*Wh-1, 2*Ww-1, 2
+ if pretrained_window_size[0] > 0:
+ relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1
+ relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1
+ else:
+ relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1
+ relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1
+ relative_coords_table *= 8 # normalize to -8, 8
+ relative_coords_table = (
+ torch.sign(relative_coords_table)
+ * torch.log2(torch.abs(relative_coords_table) + 1.0)
+ / np.log2(8)
+ )
+
+ self.register_buffer("relative_coords_table", relative_coords_table)
+
+ # get pair-wise relative position index for each token inside the window
+ coords_h = torch.arange(self.window_size[0])
+ coords_w = torch.arange(self.window_size[1])
+ coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww
+ coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww
+ relative_coords = (
+ coords_flatten[:, :, None] - coords_flatten[:, None, :]
+ ) # 2, Wh*Ww, Wh*Ww
+ relative_coords = relative_coords.permute(
+ 1, 2, 0
+ ).contiguous() # Wh*Ww, Wh*Ww, 2
+ relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0
+ relative_coords[:, :, 1] += self.window_size[1] - 1
+ relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1
+ relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww
+ self.register_buffer("relative_position_index", relative_position_index)
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=False)
+ if qkv_bias:
+ self.q_bias = nn.Parameter(torch.zeros(dim))
+ self.v_bias = nn.Parameter(torch.zeros(dim))
+ else:
+ self.q_bias = None
+ self.v_bias = None
+ self.attn_drop = nn.Dropout(attn_drop)
+ self.proj = nn.Linear(dim, dim)
+ self.proj_drop = nn.Dropout(proj_drop)
+ self.softmax = nn.Softmax(dim=-1)
+
+ def forward(self, x, mask=None):
+ """
+ Args:
+ x: input features with shape of (num_windows*B, N, C)
+ mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None
+ """
+ B_, N, C = x.shape
+ qkv_bias = None
+ if self.q_bias is not None:
+ qkv_bias = torch.cat(
+ (
+ self.q_bias,
+ torch.zeros_like(self.v_bias, requires_grad=False),
+ self.v_bias,
+ )
+ )
+ qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias)
+ qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4)
+ q, k, v = (
+ qkv[0],
+ qkv[1],
+ qkv[2],
+ ) # make torchscript happy (cannot use tensor as tuple)
+
+ # cosine attention
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
+ logit_scale = torch.clamp(
+ self.logit_scale,
+ max=torch.log(torch.tensor(1.0 / 0.01, device=self.logit_scale.device)),
+ ).exp()
+ attn = attn * logit_scale
+
+ relative_position_bias_table = self.rpe_mlp(self.relative_coords_table).view(
+ -1, self.num_heads
+ )
+ relative_position_bias = relative_position_bias_table[
+ self.relative_position_index.view(-1)
+ ].view(
+ self.window_size[0] * self.window_size[1],
+ self.window_size[0] * self.window_size[1],
+ -1,
+ ) # Wh*Ww,Wh*Ww,nH
+ relative_position_bias = relative_position_bias.permute(
+ 2, 0, 1
+ ).contiguous() # nH, Wh*Ww, Wh*Ww
+ relative_position_bias = 16 * torch.sigmoid(relative_position_bias)
+ attn = attn + relative_position_bias.unsqueeze(0)
+
+ if mask is not None:
+ nW = mask.shape[0]
+ attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze(
+ 1
+ ).unsqueeze(0)
+ attn = attn.view(-1, self.num_heads, N, N)
+
+ attn = self.softmax(attn)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B_, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+ def extra_repr(self) -> str:
+ return (
+ f"dim={self.dim}, window_size={self.window_size}, "
+ f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}"
+ )
+
+ def flops(self, N):
+ # calculate flops for 1 window with token length of N
+ flops = 0
+ # qkv = self.qkv(x)
+ flops += N * self.dim * 3 * self.dim
+ # attn = (q @ k.transpose(-2, -1))
+ flops += self.num_heads * N * (self.dim // self.num_heads) * N
+ # x = (attn @ v)
+ flops += self.num_heads * N * N * (self.dim // self.num_heads)
+ # x = self.proj(x)
+ flops += N * self.dim * self.dim
+ return flops
+
+
+class SwinTransformerBlock(nn.Module):
+ r"""Swin Transformer Block.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resulotion.
+ num_heads (int): Number of attention heads.
+ window_size (int): Window size.
+ shift_size (int): Shift size for SW-MSA.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float, optional): Stochastic depth rate. Default: 0.0
+ act_layer (nn.Module, optional): Activation layer. Default: nn.GELU
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ pretrained_window_size (int): Window size in pre-training.
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ num_heads,
+ window_size=7,
+ shift_size=0,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ act_layer=nn.GELU,
+ norm_layer=nn.LayerNorm,
+ pretrained_window_size=0,
+ ):
+ super().__init__()
+ self.dim = dim
+ self.num_heads = num_heads
+ self.window_size = window_size
+ self.shift_size = shift_size
+ self.mlp_ratio = mlp_ratio
+
+ if input_resolution[0] <= self.window_size[0]:
+ self.shift_size[0] = 0
+ self.window_size[0] = input_resolution[0]
+ if input_resolution[1] <= self.window_size[1]:
+ self.shift_size[1] = 0
+ self.window_size[1] = input_resolution[1]
+
+ assert (
+ 0 <= self.shift_size[1] < self.window_size[1]
+ ), "shift_size must in 0-window_size"
+ assert (
+ 0 <= self.shift_size[0] < self.window_size[0]
+ ), "shift_size must in 0-window_size"
+ self.norm1 = norm_layer(dim)
+ self.attn = WindowAttention(
+ dim,
+ window_size=self.window_size,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ pretrained_window_size=pretrained_window_size,
+ )
+
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = Mlp(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ )
+
+ # if self.shift_size > 0:
+ # # calculate attention mask for SW-MSA
+ # H, W = self.input_resolution
+ # img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1
+ # h_slices = (slice(0, -self.window_size),
+ # slice(-self.window_size, -self.shift_size),
+ # slice(-self.shift_size, None))
+ # w_slices = (slice(0, -self.window_size),
+ # slice(-self.window_size, -self.shift_size),
+ # slice(-self.shift_size, None))
+ # cnt = 0
+ # for h in h_slices:
+ # for w in w_slices:
+ # img_mask[:, h, w, :] = cnt
+ # cnt += 1
+
+ # mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1
+ # mask_windows = mask_windows.view(-1, self.window_size * self.window_size)
+ # attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0))
+ # else:
+ # attn_mask = None
+
+ # self.register_buffer("attn_mask", attn_mask)
+
+ def forward(self, x, mask_matrix):
+ H, W = self.H, self.W
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+
+ shortcut = x
+ x = x.view(B, H, W, C)
+
+ # pad feature maps to multiples of window size
+ pad_l = pad_t = 0
+ pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1]
+ pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0]
+ if pad_r > 0 or pad_b > 0:
+ x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b))
+ _, Hp, Wp, _ = x.shape
+
+ # cyclic shift
+ if self.shift_size[0] > 0 or self.shift_size[1] > 0:
+ shifted_x = torch.roll(
+ x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2)
+ )
+ attn_mask = mask_matrix
+ else:
+ shifted_x = x
+ attn_mask = None
+
+ # partition windows
+ x_windows = window_partition(
+ shifted_x, self.window_size
+ ) # nW*B, window_size, window_size, C
+ x_windows = x_windows.view(
+ -1, self.window_size[0] * self.window_size[1], C
+ ) # nW*B, window_size*window_size, C
+
+ # W-MSA/SW-MSA
+ attn_windows = self.attn(
+ x_windows, mask=attn_mask
+ ) # nW*B, window_size*window_size, C
+
+ # merge windows
+ attn_windows = attn_windows.view(
+ -1, self.window_size[0], self.window_size[1], C
+ )
+ shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C
+
+ # reverse cyclic shift
+ if self.shift_size[0] > 0 or self.shift_size[1] > 0:
+ x = torch.roll(
+ shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2)
+ )
+ else:
+ x = shifted_x
+
+ if pad_r > 0 or pad_b > 0:
+ x = x[:, :H, :W, :].contiguous()
+
+ x = x.view(B, H * W, C)
+
+ x = shortcut + self.drop_path(self.norm1(x))
+
+ # FFN
+ x = x + self.drop_path(self.norm2(self.mlp(x)))
+
+ return x
+
+ def extra_repr(self) -> str:
+ return (
+ f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, "
+ f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}"
+ )
+
+ def flops(self):
+ flops = 0
+ H, W = self.input_resolution
+ # norm1
+ flops += self.dim * H * W
+ # W-MSA/SW-MSA
+ nW = H * W / self.window_size / self.window_size
+ flops += nW * self.attn.flops(self.window_size * self.window_size)
+ # mlp
+ flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio
+ # norm2
+ flops += self.dim * H * W
+ return flops
+
+
+class PatchMerging(nn.Module):
+ r"""Patch Merging Layer.
+
+ Args:
+ input_resolution (tuple[int]): Resolution of input feature.
+ dim (int): Number of input channels.
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ """
+
+ def __init__(self, dim, norm_layer=nn.LayerNorm):
+ super().__init__()
+ self.dim = dim
+ self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False)
+ self.norm = norm_layer(2 * dim)
+
+ def forward(self, x, H, W):
+ """
+ x: B, H*W, C
+ """
+ B, L, C = x.shape
+ assert L == H * W, "input feature has wrong size"
+ assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even."
+
+ x = x.view(B, H, W, C)
+
+ x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C
+ x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C
+ x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C
+ x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C
+ x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C
+ x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C
+
+ x = self.reduction(x)
+ x = self.norm(x)
+
+ return x
+
+ def extra_repr(self) -> str:
+ return f"input_resolution={self.input_resolution}, dim={self.dim}"
+
+ def flops(self):
+ H, W = self.input_resolution
+ flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim
+ flops += H * W * self.dim // 2
+ return flops
+
+
+class BasicLayer(nn.Module):
+ """A basic Swin Transformer layer for one stage.
+
+ Args:
+ dim (int): Number of input channels.
+ input_resolution (tuple[int]): Input resolution.
+ depth (int): Number of blocks.
+ num_heads (int): Number of attention heads.
+ window_size (int): Local window size.
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim.
+ qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True
+ drop (float, optional): Dropout rate. Default: 0.0
+ attn_drop (float, optional): Attention dropout rate. Default: 0.0
+ drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0
+ norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm
+ downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False.
+ pretrained_window_size (int): Local window size in pre-training.
+ """
+
+ def __init__(
+ self,
+ dim,
+ input_resolution,
+ depth,
+ num_heads,
+ window_size,
+ use_shift=True,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop=0.0,
+ attn_drop=0.0,
+ drop_path=0.0,
+ norm_layer=nn.LayerNorm,
+ downsample=None,
+ use_checkpoint=False,
+ pretrained_window_size=0,
+ ):
+
+ super().__init__()
+ self.dim = dim
+ self.depth = depth
+ self.use_checkpoint = use_checkpoint
+ self.window_size = list(to_2tuple(window_size))
+ pretrained_window_size = list(to_2tuple(pretrained_window_size))
+ self.shift_size = (
+ [x // 2 for x in window_size]
+ if isinstance(window_size, (tuple, list))
+ else window_size // 2
+ )
+ self.shift_size = list(to_2tuple(self.shift_size))
+
+ # build blocks
+ self.blocks = nn.ModuleList(
+ [
+ SwinTransformerBlock(
+ dim=dim,
+ input_resolution=input_resolution,
+ num_heads=num_heads,
+ window_size=self.window_size,
+ shift_size=self.shift_size if (i % 2 and use_shift) else [0, 0],
+ mlp_ratio=mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop,
+ attn_drop=attn_drop,
+ drop_path=(
+ drop_path[i] if isinstance(drop_path, list) else drop_path
+ ),
+ norm_layer=norm_layer,
+ pretrained_window_size=pretrained_window_size,
+ )
+ for i in range(depth)
+ ]
+ )
+
+ # patch merging layer
+ if downsample is not None:
+ self.downsample = downsample(dim=dim, norm_layer=norm_layer)
+ else:
+ self.downsample = None
+
+ def forward(self, x, H, W):
+ # calculate attention mask for SW-MSA
+ Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0]
+ Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1]
+ img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1
+ h_slices = (
+ slice(0, -self.window_size[0]),
+ slice(-self.window_size[0], -self.shift_size[0]),
+ slice(-self.shift_size[0], None),
+ )
+ w_slices = (
+ slice(0, -self.window_size[1]),
+ slice(-self.window_size[1], -self.shift_size[1]),
+ slice(-self.shift_size[1], None),
+ )
+ cnt = 0
+ for h in h_slices:
+ for w in w_slices:
+ img_mask[:, h, w, :] = cnt
+ cnt += 1
+
+ mask_windows = window_partition(
+ img_mask, self.window_size
+ ) # nW, window_size, window_size, 1
+ mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1])
+ attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
+ attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(
+ attn_mask == 0, float(0.0)
+ )
+
+ x_outs, cls_tokens = [], []
+ for blk in self.blocks:
+ blk.H, blk.W = H, W
+ if self.use_checkpoint:
+ x = checkpoint.checkpoint(blk, x, attn_mask)
+ else:
+ x = blk(x, attn_mask)
+ x_outs.append(x)
+ if self.downsample is not None:
+ x_down = self.downsample(x, H, W)
+ Wh, Ww = (H + 1) // 2, (W + 1) // 2
+ return x_outs, H, W, x_down, Wh, Ww
+ else:
+ return x_outs, H, W, x, H, W
+
+ def extra_repr(self) -> str:
+ return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}"
+
+ def flops(self):
+ flops = 0
+ for blk in self.blocks:
+ flops += blk.flops()
+ if self.downsample is not None:
+ flops += self.downsample.flops()
+ return flops
+
+ def _init_respostnorm(self):
+ for blk in self.blocks:
+ nn.init.constant_(blk.norm1.bias, 0)
+ nn.init.constant_(blk.norm1.weight, 0)
+ nn.init.constant_(blk.norm2.bias, 0)
+ nn.init.constant_(blk.norm2.weight, 0)
+
+
+class PatchEmbed(nn.Module):
+ r"""Image to Patch Embedding
+
+ Args:
+ img_size (int): Image size. Default: 224.
+ patch_size (int): Patch token size. Default: 4.
+ in_chans (int): Number of input image channels. Default: 3.
+ embed_dim (int): Number of linear projection output channels. Default: 96.
+ norm_layer (nn.Module, optional): Normalization layer. Default: None
+ """
+
+ def __init__(
+ self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None
+ ):
+ super().__init__()
+ img_size = to_2tuple(img_size)
+ patch_size = to_2tuple(patch_size)
+ patches_resolution = [
+ img_size[0] // patch_size[0],
+ img_size[1] // patch_size[1],
+ ]
+ self.img_size = img_size
+ self.patch_size = patch_size
+ self.patches_resolution = patches_resolution
+ self.num_patches = patches_resolution[0] * patches_resolution[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_size, stride=patch_size
+ )
+ if norm_layer is not None:
+ self.norm = norm_layer(embed_dim)
+ else:
+ self.norm = None
+
+ def forward(self, x):
+ """Forward function."""
+ # padding
+ _, _, H, W = x.size()
+ if W % self.patch_size[1] != 0:
+ x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1]))
+ if H % self.patch_size[0] != 0:
+ x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0]))
+
+ x = self.proj(x) # B C Wh Ww
+ if self.norm is not None:
+ Wh, Ww = x.size(2), x.size(3)
+ x = self.norm(x.flatten(2).transpose(1, 2))
+ x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww)
+
+ return x
+
+ def flops(self):
+ Ho, Wo = self.patches_resolution
+ flops = (
+ Ho
+ * Wo
+ * self.embed_dim
+ * self.in_chans
+ * (self.patch_size[0] * self.patch_size[1])
+ )
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
+
+
+class SwinTransformerV2(nn.Module):
+ r"""Swin Transformer
+ A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` -
+ https://arxiv.org/pdf/2103.14030
+
+ Args:
+ img_size (int | tuple(int)): Input image size. Default 224
+ patch_size (int | tuple(int)): Patch size. Default: 4
+ in_chans (int): Number of input image channels. Default: 3
+ num_classes (int): Number of classes for classification head. Default: 1000
+ embed_dim (int): Patch embedding dimension. Default: 96
+ depths (tuple(int)): Depth of each Swin Transformer layer.
+ num_heads (tuple(int)): Number of attention heads in different layers.
+ window_size (int): Window size. Default: 7
+ mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4
+ qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True
+ drop_rate (float): Dropout rate. Default: 0
+ attn_drop_rate (float): Attention dropout rate. Default: 0
+ drop_path_rate (float): Stochastic depth rate. Default: 0.1
+ norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm.
+ ape (bool): If True, add absolute position embedding to the patch embedding. Default: False
+ patch_norm (bool): If True, add normalization after patch embedding. Default: True
+ use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False
+ pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer.
+ """
+
+ def __init__(
+ self,
+ img_size=224,
+ patch_size=4,
+ in_chans=3,
+ embed_dim=96,
+ depths=[2, 2, 6, 2],
+ num_heads=[3, 6, 12, 24],
+ window_size=7,
+ mlp_ratio=4.0,
+ qkv_bias=True,
+ drop_rate=0.0,
+ attn_drop_rate=0.0,
+ drop_path_rate=0.1,
+ norm_layer=nn.LayerNorm,
+ ape=False,
+ patch_norm=True,
+ use_checkpoint=False,
+ use_shift=True,
+ pretrained_window_sizes=[0, 0, 0, 0],
+ pretrained=None,
+ frozen_stages=-1,
+ output_idx=[2, 4, 22, 24],
+ **kwargs,
+ ):
+ super().__init__()
+ self.num_layers = len(depths)
+ self.depths = output_idx
+ self.embed_dim = embed_dim
+ dims = [embed_dim * 2**i for i in range(len(depths))]
+ self.embed_dims = [
+ int(dim) for i, dim in enumerate(dims) for _ in range(depths[i])
+ ]
+
+ self.ape = ape
+ self.patch_norm = patch_norm
+ self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)]
+ self.mlp_ratio = mlp_ratio
+ self.frozen_stages = frozen_stages
+ if isinstance(window_size, int):
+ window_size = [window_size] * self.num_layers
+ if isinstance(use_shift, bool):
+ use_shift = [use_shift] * self.num_layers
+
+ # self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim))
+ # trunc_normal_(self.mask_token, mean=0., std=.02)
+ # split image into non-overlapping patches
+ self.patch_embed = PatchEmbed(
+ img_size=img_size,
+ patch_size=patch_size,
+ in_chans=in_chans,
+ embed_dim=embed_dim,
+ norm_layer=norm_layer if self.patch_norm else None,
+ )
+ num_patches = self.patch_embed.num_patches
+ patches_resolution = self.patch_embed.patches_resolution
+ self.patches_resolution = patches_resolution
+
+ # absolute position embedding
+ if self.ape:
+ self.absolute_pos_embed = nn.Parameter(
+ torch.zeros(1, num_patches, embed_dim)
+ )
+ trunc_normal_(self.absolute_pos_embed, std=0.02)
+
+ self.pos_drop = nn.Dropout(p=drop_rate)
+
+ # stochastic depth
+ dpr = [
+ x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))
+ ] # stochastic depth decay rule
+
+ # build layers
+ self.layers = nn.ModuleList()
+ for i_layer in range(self.num_layers):
+ layer = BasicLayer(
+ dim=int(embed_dim * 2**i_layer),
+ input_resolution=[
+ img_size[0] // (2 ** (2 + i_layer)),
+ img_size[1] // (2 ** (2 + i_layer)),
+ ],
+ depth=depths[i_layer],
+ num_heads=num_heads[i_layer],
+ window_size=window_size[i_layer],
+ use_shift=use_shift[i_layer],
+ mlp_ratio=self.mlp_ratio,
+ qkv_bias=qkv_bias,
+ drop=drop_rate,
+ attn_drop=attn_drop_rate,
+ drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])],
+ norm_layer=norm_layer,
+ downsample=PatchMerging if (i_layer < self.num_layers - 1) else None,
+ use_checkpoint=use_checkpoint,
+ pretrained_window_size=pretrained_window_sizes[i_layer],
+ )
+ self.layers.append(layer)
+
+ self.apply(self._init_weights)
+ for bly in self.layers:
+ bly._init_respostnorm()
+
+ if pretrained is not None:
+ pretrained_state = torch.load(pretrained, map_location="cpu")["model"]
+ pretrained_state_filtered = load_checkpoint_swin(self, pretrained_state)
+ msg = self.load_state_dict(pretrained_state_filtered, strict=False)
+
+ self._freeze_stages()
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ nn.init.constant_(m.bias, 0)
+ nn.init.constant_(m.weight, 1.0)
+
+ @torch.jit.ignore
+ def no_weight_decay(self):
+ return {"absolute_pos_embed"}
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {"rpe_mlp", "logit_scale", "relative_position_bias_table", "mask_token"}
+
+ def forward(self, x, mask=None):
+ """Forward function."""
+ # Add requires_grad_() to all input to support freezing with gradient checkpointing!
+ x = self.patch_embed(x.requires_grad_())
+ B, Wh, Ww = x.size(0), x.size(2), x.size(3)
+ if self.ape:
+ # interpolate the position embedding to the corresponding size
+ absolute_pos_embed = F.interpolate(
+ self.absolute_pos_embed,
+ size=(Wh, Ww),
+ mode="bicubic",
+ align_corners=True,
+ )
+ x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C
+ else:
+ x = x.flatten(2).transpose(1, 2)
+ x = self.pos_drop(x)
+
+ # B, L, _ = x.shape
+ # if mask is not None:
+ # mask_tokens = self.mask_token.expand(B, L, -1)
+ # mask = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens)
+ # else:
+ # mask = torch.zeros_like(x)
+ # mask_tokens = torch.zeros_like(self.mask_token).expand(B, L, -1)
+ # x = x * (1. - mask) + mask_tokens * mask
+
+ outs, cls_tokens = [], []
+ for i in range(self.num_layers):
+ layer = self.layers[i]
+ x_outs, H, W, x, Wh, Ww = layer(x.requires_grad_(), Wh, Ww)
+ out = [
+ x_out.view(-1, H, W, self.num_features[i]).contiguous()
+ for x_out in x_outs
+ ]
+ outs.extend(out)
+ cls_token_ = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in out]
+ cls_tokens.extend(cls_token_)
+ return outs, cls_tokens
+
+ def train(self, mode=True):
+ super().train(mode)
+ self._freeze_stages()
+
+ def freeze(self) -> None:
+ for module in self.modules():
+ module.eval()
+ for parameters in self.parameters():
+ parameters.requires_grad = False
+
+ def _freeze_stages(self):
+ if self.frozen_stages >= 0:
+ self.patch_embed.eval()
+ for param in self.patch_embed.parameters():
+ param.requires_grad = False
+ if self.ape:
+ self.absolute_pos_embed.requires_grad = False
+ self.pos_drop.eval()
+
+ for i in range(1, self.frozen_stages + 1):
+ m = self.layers[i - 1]
+ m.eval()
+ for param in m.parameters():
+ param.requires_grad = False
+
+ def flops(self):
+ flops = 0
+ flops += self.patch_embed.flops()
+ for i, layer in enumerate(self.layers):
+ flops += layer.flops()
+ flops += (
+ self.num_features
+ * self.patches_resolution[0]
+ * self.patches_resolution[1]
+ // (2**self.num_layers)
+ )
+ return flops
+
+ def get_params(self, lr, wd, *args, **kwargs):
+ encoder_p, encoder_lr = get_params(self, lr, wd)
+ return encoder_p, encoder_lr
+
+ @classmethod
+ def build(cls, config):
+ obj = globals()[config["name"]](config)
+ return obj
diff --git a/unik3d/models/camera_augmenter.py b/unik3d/models/camera_augmenter.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd088daaac1ea7810da3aa388e481b4be02e1e06
--- /dev/null
+++ b/unik3d/models/camera_augmenter.py
@@ -0,0 +1,145 @@
+import numpy as np
+import torch
+from einops import rearrange
+
+from unik3d.utils.camera import CameraSampler
+from unik3d.utils.coordinate import coords_grid
+from unik3d.utils.geometric import iou
+
+try:
+ from splatting import splatting_function
+except Exception as e:
+ splatting_function = None
+ print(
+ f"Splatting not available, please install it from github.com/hperrot/splatting"
+ )
+
+
+def fill(self, rgb, mask):
+ def fill_noise(size, device):
+ return torch.normal(0, 1.0, size=size, device=device)
+
+ def fill_black(size, device):
+ return -2 * torch.ones(size, device=device, dtype=torch.float32)
+
+ def fill_white(size, device):
+ return 2 * torch.ones(size, device=device, dtype=torch.float32)
+
+ def fill_zero(size, device):
+ return torch.zeros(size, device=device, dtype=torch.float32)
+
+ B, C = rgb.shape[:2]
+ validity_mask = mask.repeat(1, C, 1, 1).bool()
+ for i in range(B):
+ filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero])
+ rgb[i][~validity_mask[i]] = filler_fn(
+ size=rgb[i][~validity_mask[i]].shape, device=rgb.device
+ )
+ return rgb
+
+
+@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float32)
+def augment_camera(self, inputs, camera_sampler):
+ rgb = inputs["image"]
+ gt = inputs["depth"].clone()
+ guidance = inputs[
+ "depth_guidance"
+ ] # from GT if dense/synthetic or from a model's metric output
+ validity_mask = inputs["validity_mask"].bool()
+ dtype, device = gt.dtype, gt.device
+ B, C, H, W = rgb.shape
+ augmentable_indices = inputs["valid_camera"] & (
+ inputs["depth_mask"].reshape(B, -1).float().mean(dim=1) > 0.0
+ )
+
+ augment_indices = torch.rand(B, 1, 1, device=device, dtype=dtype) > 0.9
+ augment_indices[~augmentable_indices] = False
+ id_coords = coords_grid(B, H, W, device=device)
+ # get rescaled depth
+ augment_indices = augment_indices.reshape(-1)
+ for i, is_augment in enumerate(augment_indices):
+ if not is_augment:
+ continue
+
+ pinhole_camera = inputs["camera"][i]
+ fov = max(pinhole_camera.hfov[0], pinhole_camera.vfov[0]) * 180 / np.pi
+ ratio = min(70.0 / fov, 1.0) # decrease effect for larger fov
+ if fov < 40.0: # skips ~5%
+ augment_indices[i] = False
+ continue
+
+ rgb_i = rgb[i : i + 1]
+ id_coords_i = id_coords[i : i + 1]
+
+ validity_mask_i = validity_mask[i : i + 1]
+ depth = guidance[i : i + 1]
+
+ if (depth < 0.0).any():
+ augment_indices[i] = False
+ continue
+
+ depth = depth.sqrt() # why sqrt??
+ depth[~validity_mask_i] = depth.max() * 2.0
+
+ fx, fy, cx, cy = pinhole_camera.params[:, :4].unbind(dim=-1)
+ new_camera = camera_sampler(fx, fy, cx, cy, mult=1.0, ratio=ratio, H=H)
+ unprojected = pinhole_camera.reconstruct(depth)
+ projected = new_camera.project(unprojected)
+ projection_mask = new_camera.projection_mask
+ overlap_mask = (
+ new_camera.overlap_mask
+ if new_camera.overlap_mask is not None
+ else torch.ones_like(projection_mask)
+ )
+ mask = validity_mask_i & overlap_mask
+
+ # if it is actually going out, we need to remember the regions
+ # remember when the tengetial distortion was keeping the validaty_mask border after re-warpingi
+ # need a better way to define overlap class, in case of vortex style if will mask wrong parts...
+ # also is_collapse does not take into consideration when we have vortex effect,
+ # how can we avoid vortex in the first place????
+ is_collapse = (projected[0, 1, 0, :] >= 0.0).all()
+ if is_collapse:
+ projected[~mask.repeat(1, 2, 1, 1)] = id_coords_i[~mask.repeat(1, 2, 1, 1)]
+ flow = projected - id_coords_i
+ depth[~mask] = depth.max() * 2.0
+
+ if flow.norm(dim=1).median() / max(H, W) > 0.1: # extreme cases
+ augment_indices[i] = False
+ continue
+
+ # warp via soft splat
+ depth_image = torch.cat([rgb_i, guidance[i : i + 1], mask], dim=1)
+ depth_image = splatting_function(
+ "softmax", depth_image, flow, -torch.log(1 + depth.clip(0.01))
+ )
+ rgb_warp = depth_image[:, :3]
+ validity_mask_i = depth_image[:, -1:] > 0.0
+
+ expanding = validity_mask_i.sum() > validity_mask[i : i + 1].sum()
+ threshold = 0.7 if expanding else 0.25
+ _iou = iou(validity_mask_i, validity_mask[i : i + 1])
+ if _iou < threshold: # too strong augmentation, lose most of the image
+ augment_indices[i] = False
+ continue
+
+ # where it goes out
+ mask_unwarpable = projection_mask & overlap_mask
+ inputs["depth_mask"][i] = inputs["depth_mask"][i] & mask_unwarpable.squeeze(0)
+
+ # compute new rays, and use the for supervision
+ rays = new_camera.get_rays(shapes=(1, H, W))
+ rays = rearrange(rays, "b c h w -> b (h w) c")
+ inputs["rays"][i] = torch.where(
+ rays.isnan().any(dim=-1, keepdim=True), 0.0, rays
+ )[0]
+
+ # update image, camera and validity_mask
+ inputs["camera"][i] = new_camera
+ inputs["image"][i] = self.fill(rgb_warp, validity_mask_i)[0]
+ inputs["validity_mask"][i] = inputs["validity_mask"][i] & mask_unwarpable[0]
+
+ # needed to reverse the augmentation for loss-computation (i.e. un-warp the prediction)
+ inputs["grid_sample"][i] = projected[0]
+
+ return inputs
diff --git a/unik3d/models/decoder.py b/unik3d/models/decoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d5f8ac3e77e987914945f61e60e57b6a2028243f
--- /dev/null
+++ b/unik3d/models/decoder.py
@@ -0,0 +1,562 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from math import tanh
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange
+from timm.models.layers import trunc_normal_
+
+from unik3d.layers import (MLP, AttentionBlock, AttentionLayer, GradChoker,
+ PositionEmbeddingSine, ResUpsampleBil)
+from unik3d.utils.coordinate import coords_grid
+from unik3d.utils.geometric import flat_interpolate
+from unik3d.utils.misc import get_params
+from unik3d.utils.positional_embedding import generate_fourier_features
+from unik3d.utils.sht import rsh_cart_3
+
+
+def orthonormal_init(num_tokens, dims):
+ pe = torch.randn(num_tokens, dims)
+
+ # Apply Gram-Schmidt process to make the matrix orthonormal
+ # Awful loop..
+ for i in range(num_tokens):
+ for j in range(i):
+ pe[i] -= torch.dot(pe[i], pe[j]) * pe[j]
+ pe[i] = F.normalize(pe[i], p=2, dim=0)
+
+ return pe
+
+
+class ListAdapter(nn.Module):
+ def __init__(self, input_dims: list[int], hidden_dim: int):
+ super().__init__()
+ self.input_adapters = nn.ModuleList([])
+ self.num_chunks = len(input_dims)
+ self.checkpoint = True
+ for input_dim in input_dims:
+ self.input_adapters.append(nn.Linear(input_dim, hidden_dim))
+
+ def forward(self, xs: torch.Tensor) -> torch.Tensor:
+ outs = [self.input_adapters[i](x) for i, x in enumerate(xs)]
+ return outs
+
+
+class AngularModule(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int = 8,
+ expansion: int = 4,
+ dropout: float = 0.0,
+ layer_scale: float = 1.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.pin_params = 3
+ self.deg1_params = 3
+ self.deg2_params = 5
+ self.deg3_params = 7
+ self.num_params = (
+ self.pin_params + self.deg1_params + self.deg2_params + self.deg3_params
+ )
+
+ self.aggregate1 = AttentionBlock(
+ hidden_dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ layer_scale=layer_scale,
+ )
+ self.aggregate2 = AttentionBlock(
+ hidden_dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ layer_scale=layer_scale,
+ )
+ self.latents_pos = nn.Parameter(
+ torch.randn(1, self.num_params, hidden_dim), requires_grad=True
+ )
+ self.in_features = nn.Identity()
+
+ self.project_pin = nn.Linear(
+ hidden_dim, self.pin_params * hidden_dim, bias=False
+ )
+ self.project_deg1 = nn.Linear(
+ hidden_dim, self.deg1_params * hidden_dim, bias=False
+ )
+ self.project_deg2 = nn.Linear(
+ hidden_dim, self.deg2_params * hidden_dim, bias=False
+ )
+ self.project_deg3 = nn.Linear(
+ hidden_dim, self.deg3_params * hidden_dim, bias=False
+ )
+
+ self.out_pinhole = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=1)
+ self.out_deg1 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3)
+ self.out_deg2 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3)
+ self.out_deg3 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3)
+
+ def fill_intrinsics(self, x):
+ hfov, cx, cy = x.unbind(dim=-1)
+ hfov = torch.sigmoid(hfov - 1.1) # 1.1 magic number s.t hfov = pi/2 for x=0
+ ratio = self.shapes[0] / self.shapes[1]
+ vfov = hfov * ratio
+ cx = torch.sigmoid(cx)
+ cy = torch.sigmoid(cy)
+ correction_tensor = torch.tensor(
+ [2 * torch.pi, 2 * torch.pi, self.shapes[1], self.shapes[0]],
+ device=x.device,
+ dtype=x.dtype,
+ )
+
+ intrinsics = torch.stack([hfov, vfov, cx, cy], dim=1)
+ intrinsics = correction_tensor.unsqueeze(0) * intrinsics
+ return intrinsics
+
+ def forward(self, cls_tokens) -> torch.Tensor:
+ latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1)
+
+ pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens = cls_tokens.chunk(4, dim=1)
+ pin_tokens = rearrange(
+ self.project_pin(pin_tokens), "b n (h c) -> b (n h) c", h=self.pin_params
+ )
+ deg1_tokens = rearrange(
+ self.project_deg1(deg1_tokens), "b n (h c) -> b (n h) c", h=self.deg1_params
+ )
+ deg2_tokens = rearrange(
+ self.project_deg2(deg2_tokens), "b n (h c) -> b (n h) c", h=self.deg2_params
+ )
+ deg3_tokens = rearrange(
+ self.project_deg3(deg3_tokens), "b n (h c) -> b (n h) c", h=self.deg3_params
+ )
+ tokens = torch.cat([pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens], dim=1)
+
+ tokens = self.aggregate1(tokens, pos_embed=latents_pos)
+ tokens = self.aggregate2(tokens, pos_embed=latents_pos)
+
+ tokens_pinhole, tokens_deg1, tokens_deg2, tokens_deg3 = torch.split(
+ tokens,
+ [self.pin_params, self.deg1_params, self.deg2_params, self.deg3_params],
+ dim=1,
+ )
+ x = self.out_pinhole(tokens_pinhole).squeeze(-1)
+ d1 = self.out_deg1(tokens_deg1)
+ d2 = self.out_deg2(tokens_deg2)
+ d3 = self.out_deg3(tokens_deg3)
+
+ camera_intrinsics = self.fill_intrinsics(x)
+ return camera_intrinsics, torch.cat([d1, d2, d3], dim=1)
+
+ def set_shapes(self, shapes: tuple[int, int]):
+ self.shapes = shapes
+
+
+class RadialModule(nn.Module):
+ def __init__(
+ self,
+ hidden_dim: int,
+ num_heads: int = 8,
+ expansion: int = 4,
+ depths: int | list[int] = 4,
+ camera_dim: int = 256,
+ dropout: float = 0.0,
+ kernel_size: int = 7,
+ layer_scale: float = 1.0,
+ out_dim: int = 1,
+ num_prompt_blocks: int = 1,
+ use_norm: bool = False,
+ **kwargs,
+ ) -> None:
+ super().__init__()
+ self.camera_dim = camera_dim
+ self.out_dim = out_dim
+ self.hidden_dim = hidden_dim
+
+ self.ups = nn.ModuleList([])
+ self.depth_mlp = nn.ModuleList([])
+ self.process_features = nn.ModuleList([])
+ self.project_features = nn.ModuleList([])
+ self.out = nn.ModuleList([])
+ self.prompt_camera = nn.ModuleList([])
+ mult = 2
+ self.to_latents = nn.Linear(hidden_dim, hidden_dim)
+
+ for _ in range(4):
+ self.prompt_camera.append(
+ AttentionLayer(
+ num_blocks=num_prompt_blocks,
+ dim=hidden_dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ layer_scale=-1.0,
+ context_dim=hidden_dim,
+ )
+ )
+
+ for i, depth in enumerate(depths):
+ current_dim = min(hidden_dim, mult * hidden_dim // int(2**i))
+ next_dim = mult * hidden_dim // int(2 ** (i + 1))
+ output_dim = max(next_dim, out_dim)
+ self.process_features.append(
+ nn.ConvTranspose2d(
+ hidden_dim,
+ current_dim,
+ kernel_size=max(1, 2 * i),
+ stride=max(1, 2 * i),
+ padding=0,
+ )
+ )
+ self.ups.append(
+ ResUpsampleBil(
+ current_dim,
+ output_dim=output_dim,
+ expansion=expansion,
+ layer_scale=layer_scale,
+ kernel_size=kernel_size,
+ num_layers=depth,
+ use_norm=use_norm,
+ )
+ )
+ depth_mlp = (
+ nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim))
+ if i == len(depths) - 1
+ else nn.Identity()
+ )
+ self.depth_mlp.append(depth_mlp)
+
+ self.confidence_mlp = nn.Sequential(
+ nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)
+ )
+
+ self.to_depth_lr = nn.Conv2d(
+ output_dim,
+ output_dim // 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="reflect",
+ )
+ self.to_confidence_lr = nn.Conv2d(
+ output_dim,
+ output_dim // 2,
+ kernel_size=3,
+ padding=1,
+ padding_mode="reflect",
+ )
+ self.to_depth_hr = nn.Sequential(
+ nn.Conv2d(
+ output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect"
+ ),
+ nn.LeakyReLU(),
+ nn.Conv2d(32, 1, kernel_size=1),
+ )
+ self.to_confidence_hr = nn.Sequential(
+ nn.Conv2d(
+ output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect"
+ ),
+ nn.LeakyReLU(),
+ nn.Conv2d(32, 1, kernel_size=1),
+ )
+
+ def set_original_shapes(self, shapes: tuple[int, int]):
+ self.original_shapes = shapes
+
+ def set_shapes(self, shapes: tuple[int, int]):
+ self.shapes = shapes
+
+ def embed_rays(self, rays):
+ rays_embedding = flat_interpolate(
+ rays, old=self.original_shapes, new=self.shapes, antialias=True
+ )
+ rays_embedding = rays_embedding / torch.norm(
+ rays_embedding, dim=-1, keepdim=True
+ ).clip(min=1e-4)
+ x, y, z = rays_embedding[..., 0], rays_embedding[..., 1], rays_embedding[..., 2]
+ polar = torch.acos(z)
+ x_clipped = x.abs().clip(min=1e-3) * (2 * (x >= 0).int() - 1)
+ azimuth = torch.atan2(y, x_clipped)
+ rays_embedding = torch.stack([polar, azimuth], dim=-1)
+ rays_embedding = generate_fourier_features(
+ rays_embedding,
+ dim=self.hidden_dim,
+ max_freq=max(self.shapes) // 2,
+ use_log=True,
+ cat_orig=False,
+ )
+ return rays_embedding
+
+ def condition(self, feat, rays_embeddings):
+ conditioned_features = [
+ prompter(rearrange(feature, "b h w c -> b (h w) c"), rays_embeddings)
+ for prompter, feature in zip(self.prompt_camera, feat)
+ ]
+ return conditioned_features
+
+ def process(self, features_list, rays_embeddings):
+ conditioned_features = self.condition(features_list, rays_embeddings)
+ init_latents = self.to_latents(conditioned_features[0])
+ init_latents = rearrange(
+ init_latents, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1]
+ ).contiguous()
+ conditioned_features = [
+ rearrange(
+ x, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1]
+ ).contiguous()
+ for x in conditioned_features
+ ]
+ latents = init_latents
+
+ out_features = []
+ for i, up in enumerate(self.ups):
+ latents = latents + self.process_features[i](conditioned_features[i + 1])
+ latents = up(latents)
+ out_features.append(latents)
+
+ return out_features, init_latents
+
+ def depth_proj(self, out_features):
+ depths = []
+ h_out, w_out = out_features[-1].shape[-2:]
+ # aggregate output and project to depth
+ for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features)):
+ out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
+ if i < len(self.depth_mlp) - 1:
+ continue
+ depths.append(out_depth_features)
+ out_depth_features = F.interpolate(
+ out_depth_features, size=(h_out, w_out), mode="bilinear", align_corners=True
+ )
+ logdepth = self.to_depth_lr(out_depth_features)
+ logdepth = F.interpolate(
+ logdepth, size=self.original_shapes, mode="bilinear", align_corners=True
+ )
+ logdepth = self.to_depth_hr(logdepth)
+ return logdepth
+
+ def confidence_proj(self, out_features):
+ highres_features = out_features[-1].permute(0, 2, 3, 1)
+ confidence = self.confidence_mlp(highres_features).permute(0, 3, 1, 2)
+ confidence = self.to_confidence_lr(confidence)
+ confidence = F.interpolate(
+ confidence, size=self.original_shapes, mode="bilinear", align_corners=True
+ )
+ confidence = self.to_confidence_hr(confidence)
+ return confidence
+
+ def decode(self, out_features):
+ logdepth = self.depth_proj(out_features)
+ confidence = self.confidence_proj(out_features)
+ return logdepth, confidence
+
+ def forward(
+ self,
+ features: list[torch.Tensor],
+ rays_hr: torch.Tensor,
+ pos_embed,
+ level_embed,
+ ) -> torch.Tensor:
+ rays_embeddings = self.embed_rays(rays_hr)
+ features, lowres_features = self.process(features, rays_embeddings)
+ logdepth, logconf = self.decode(features)
+ return logdepth, logconf, lowres_features
+
+
+class Decoder(nn.Module):
+ def __init__(
+ self,
+ config,
+ ):
+ super().__init__()
+ self.build(config)
+ self.apply(self._init_weights)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.Conv2d):
+ trunc_normal_(m.weight, std=0.02)
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ elif isinstance(m, nn.LayerNorm):
+ if m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+ if m.weight is not None:
+ nn.init.constant_(m.weight, 1.0)
+
+ def run_camera(self, cls_tokens, original_shapes, rays_gt):
+ H, W = original_shapes
+
+ # camera layer
+ intrinsics, sh_coeffs = self.angular_module(cls_tokens=cls_tokens)
+ B, N = intrinsics.shape
+ device = intrinsics.device
+ dtype = intrinsics.dtype
+
+ id_coords = coords_grid(B, H, W, device=sh_coeffs.device)
+
+ # This is fov based
+ longitude = (
+ (id_coords[:, 0] - intrinsics[:, 2].view(-1, 1, 1))
+ / W
+ * intrinsics[:, 0].view(-1, 1, 1)
+ )
+ latitude = (
+ (id_coords[:, 1] - intrinsics[:, 3].view(-1, 1, 1))
+ / H
+ * intrinsics[:, 1].view(-1, 1, 1)
+ )
+ x = torch.cos(latitude) * torch.sin(longitude)
+ z = torch.cos(latitude) * torch.cos(longitude)
+ y = -torch.sin(latitude)
+ unit_sphere = torch.stack([x, y, z], dim=-1)
+ unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=-1, keepdim=True).clip(
+ min=1e-5
+ )
+
+ harmonics = rsh_cart_3(unit_sphere)[..., 1:] # remove constant-value harmonic
+ rays_pred = torch.einsum("bhwc,bcd->bhwd", harmonics, sh_coeffs)
+ rays_pred = rays_pred / torch.norm(rays_pred, dim=-1, keepdim=True).clip(
+ min=1e-5
+ )
+ rays_pred = rays_pred.permute(0, 3, 1, 2)
+
+ ### LEGACY CODE for training
+ # if self.training:
+ # prob = 1 - tanh(self.steps / self.num_steps)
+ # where_use_gt_rays = torch.rand(B, 1, 1, device=device, dtype=dtype) < prob
+ # where_use_gt_rays = where_use_gt_rays.int()
+ # rays = rays_gt * where_use_gt_rays + rays_pred * (1 - where_use_gt_rays)
+
+ # should clean also nans
+ if self.training:
+ rays = rays_pred
+ else:
+ rays = rays_gt if rays_gt is not None else rays_pred
+ rays = rearrange(rays, "b c h w -> b (h w) c")
+
+ return intrinsics, rays
+
+ def forward(self, inputs, image_metas) -> torch.Tensor:
+ B, C, H, W = inputs["image"].shape
+ device = inputs["image"].device
+
+ rays_gt = inputs.get("rays", None)
+
+ # get features in b n d format
+ common_shape = inputs["features"][0].shape[1:3]
+
+ # input shapes repeat shapes for each level, times the amount of the layers:
+ features = self.input_adapter(inputs["features"])
+
+ # positional embeddings, spatial and level
+ level_embed = self.level_embeds.repeat(
+ B, common_shape[0] * common_shape[1], 1, 1
+ )
+ level_embed = rearrange(level_embed, "b n l d -> b (n l) d")
+ dummy_tensor = torch.zeros(
+ B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False
+ )
+ pos_embed = self.pos_embed(dummy_tensor)
+ pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(1, 4, 1)
+
+ # get cls tokens projections
+ camera_tokens = inputs["tokens"]
+ camera_tokens = [self.choker(x.contiguous()) for x in camera_tokens]
+ camera_tokens = self.camera_token_adapter(camera_tokens)
+ self.angular_module.set_shapes((H, W))
+
+ intrinsics, rays = self.run_camera(
+ torch.cat(camera_tokens, dim=1),
+ original_shapes=(H, W),
+ rays_gt=rays_gt,
+ )
+
+ # run bulk of the model
+ self.radial_module.set_shapes(common_shape)
+ self.radial_module.set_original_shapes((H, W))
+ logradius, logconfidence, lowres_features = self.radial_module(
+ features=features,
+ rays_hr=rays,
+ pos_embed=pos_embed,
+ level_embed=level_embed,
+ )
+ radius = torch.exp(logradius.clip(min=-8.0, max=8.0) + 2.0)
+ confidence = torch.exp(logconfidence.clip(min=-8.0, max=10.0))
+
+ outputs = {
+ "distance": radius,
+ "lowres_features": lowres_features,
+ "confidence": confidence,
+ "K": intrinsics,
+ "rays": rays,
+ }
+
+ return outputs
+
+ @torch.jit.ignore
+ def no_weight_decay_keywords(self):
+ return {"latents_pos", "level_embeds"}
+
+ def get_params(self, lr, wd):
+ angles_p, _ = get_params(self.angular_module, lr, wd)
+ radius_p, _ = get_params(self.radial_module, lr, wd)
+ tokens_p, _ = get_params(self.camera_token_adapter, lr, wd)
+ input_p, _ = get_params(self.input_adapter, lr, wd)
+ return [*tokens_p, *angles_p, *input_p, *radius_p]
+
+ def build(self, config):
+ input_dims = config["model"]["pixel_encoder"]["embed_dims"]
+ hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"]
+ expansion = config["model"]["expansion"]
+ num_heads = config["model"]["num_heads"]
+ dropout = config["model"]["pixel_decoder"]["dropout"]
+ layer_scale = config["model"]["layer_scale"]
+ depth = config["model"]["pixel_decoder"]["depths"]
+ depths_encoder = config["model"]["pixel_encoder"]["depths"]
+ out_dim = config["model"]["pixel_decoder"]["out_dim"]
+ kernel_size = config["model"]["pixel_decoder"]["kernel_size"]
+ self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder))
+ input_dims = [input_dims[d - 1] for d in depths_encoder]
+ self.steps = 0
+ self.num_steps = config["model"].get("num_steps", 100000)
+
+ camera_dims = input_dims
+ self.choker = GradChoker(config["model"]["pixel_decoder"]["detach"])
+ self.input_adapter = ListAdapter(input_dims, hidden_dim)
+ self.camera_token_adapter = ListAdapter(camera_dims, hidden_dim)
+ self.angular_module = AngularModule(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ dropout=dropout,
+ layer_scale=layer_scale,
+ )
+ self.radial_module = RadialModule(
+ hidden_dim=hidden_dim,
+ num_heads=num_heads,
+ expansion=expansion,
+ depths=depth,
+ dropout=dropout,
+ camera_dim=96,
+ layer_scale=layer_scale,
+ out_dim=out_dim,
+ kernel_size=kernel_size,
+ num_prompt_blocks=config["model"]["pixel_decoder"]["num_prompt_blocks"],
+ use_norm=False,
+ )
+ self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True)
+ self.level_embeds = nn.Parameter(
+ orthonormal_init(len(input_dims), hidden_dim).reshape(
+ 1, 1, len(input_dims), hidden_dim
+ ),
+ requires_grad=False,
+ )
diff --git a/unik3d/models/encoder.py b/unik3d/models/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3e8c4551983ec6566797645bed92927669cce60
--- /dev/null
+++ b/unik3d/models/encoder.py
@@ -0,0 +1,231 @@
+from functools import partial
+
+import torch
+import torch.nn as nn
+from timm.models.vision_transformer import _cfg
+
+from unik3d.models.backbones import (ConvNeXt, ConvNeXtV2, SwinTransformerV2,
+ _make_dinov2_model)
+
+
+def swin2_tiny(
+ config,
+ pretrained=None,
+ *args,
+ **kwargs,
+):
+ model = SwinTransformerV2(
+ img_size=config["image_shape"],
+ patch_size=4,
+ window_size=config.get("window_size", 16),
+ embed_dim=96,
+ num_heads=[3, 6, 12, 24],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[2, 2, 6, 2],
+ drop_path_rate=0.2,
+ pretrained=pretrained,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ output_idx=config.get("output_idx", [2, 4, 10, 12]),
+ use_shift=config.get("use_shift", True),
+ use_checkpoint=config.get("use_checkpoint", False),
+ frozen_stages=-1,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def swin2_base(
+ config,
+ pretrained=None,
+ *args,
+ **kwargs,
+):
+ model = SwinTransformerV2(
+ img_size=config["image_shape"],
+ patch_size=4,
+ window_size=config.get("window_size", 12),
+ embed_dim=128,
+ num_heads=[4, 8, 16, 32],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[2, 2, 18, 2],
+ drop_path_rate=0.3,
+ pretrained=pretrained,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ use_shift=config.get("use_shift", True),
+ use_checkpoint=config["use_checkpoint"],
+ frozen_stages=-1,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def swin2_large(
+ config,
+ pretrained=None,
+ *args,
+ **kwargs,
+):
+ model = SwinTransformerV2(
+ img_size=config["image_shape"],
+ patch_size=4,
+ window_size=config.get("window_size", 12),
+ embed_dim=192,
+ num_heads=[6, 12, 24, 48],
+ mlp_ratios=[4, 4, 4, 4],
+ qkv_bias=True,
+ norm_layer=partial(nn.LayerNorm, eps=1e-6),
+ depths=[2, 2, 18, 2],
+ drop_path_rate=0.3,
+ pretrained=pretrained,
+ pretrained_window_sizes=[12, 12, 12, 6],
+ use_shift=config.get("use_shift", True),
+ use_checkpoint=config["use_checkpoint"],
+ frozen_stages=-1,
+ )
+ model.default_cfg = _cfg()
+ return model
+
+
+def convnextv2_base(config, **kwargs):
+ model = ConvNeXtV2(
+ depths=[3, 3, 27, 3],
+ dims=[128, 256, 512, 1024],
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
+ use_checkpoint=config["use_checkpoint"],
+ **kwargs,
+ )
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt"
+ state_dict = torch.hub.load_state_dict_from_url(
+ url, map_location="cpu", progress=False
+ )["model"]
+ info = model.load_state_dict(state_dict, strict=False)
+ print(info)
+ return model
+
+
+def convnextv2_large(config, **kwargs):
+ model = ConvNeXtV2(
+ depths=[3, 3, 27, 3],
+ dims=[192, 384, 768, 1536],
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
+ use_checkpoint=config["use_checkpoint"],
+ **kwargs,
+ )
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt"
+ state_dict = torch.hub.load_state_dict_from_url(
+ url, map_location="cpu", progress=False
+ )["model"]
+ info = model.load_state_dict(state_dict, strict=False)
+ print(info)
+ return model
+
+
+def convnextv2_large_mae(config, **kwargs):
+ model = ConvNeXtV2(
+ depths=[3, 3, 27, 3],
+ dims=[192, 384, 768, 1536],
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
+ use_checkpoint=config["use_checkpoint"],
+ **kwargs,
+ )
+ url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt"
+ state_dict = torch.hub.load_state_dict_from_url(
+ url, map_location="cpu", progress=False
+ )["model"]
+ info = model.load_state_dict(state_dict, strict=False)
+ print(info)
+ return model
+
+
+def convnext_large(config, **kwargs):
+ model = ConvNeXt(
+ depths=[3, 3, 27, 3],
+ dims=[192, 384, 768, 1536],
+ output_idx=config.get("output_idx", [3, 6, 33, 36]),
+ use_checkpoint=config.get("use_checkpoint", False),
+ drop_path_rate=config.get("drop_path", 0.0),
+ **kwargs,
+ )
+ from huggingface_hub import hf_hub_download
+ from huggingface_hub.utils import disable_progress_bars
+
+ from unik3d.models.backbones.convnext import HF_URL, checkpoint_filter_fn
+
+ disable_progress_bars()
+ repo_id, filename = HF_URL["convnext_large"]
+ state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename))
+ state_dict = checkpoint_filter_fn(state_dict, model)
+ info = model.load_state_dict(state_dict, strict=False)
+ print(info)
+ return model
+
+
+def dinov2_vits14(config, pretrained: bool = True, **kwargs):
+ vit = _make_dinov2_model(
+ arch_name="vit_small",
+ pretrained=config["pretrained"],
+ output_idx=config.get("output_idx", [3, 6, 9, 12]),
+ checkpoint=config.get("use_checkpoint", False),
+ drop_path_rate=config.get("drop_path", 0.0),
+ num_register_tokens=config.get("num_register_tokens", 0),
+ use_norm=config.get("use_norm", False),
+ interpolate_offset=config.get("interpolate_offset", 0.0),
+ frozen_stages=config.get("frozen_stages", 0),
+ freeze_norm=config.get("freeze_norm", False),
+ **kwargs,
+ )
+ return vit
+
+
+def dinov2_vitb14(config, pretrained: bool = True, **kwargs):
+ vit = _make_dinov2_model(
+ arch_name="vit_base",
+ pretrained=config["pretrained"],
+ output_idx=config.get("output_idx", [3, 6, 9, 12]),
+ checkpoint=config.get("use_checkpoint", False),
+ drop_path_rate=config.get("drop_path", 0.0),
+ num_register_tokens=config.get("num_register_tokens", 0),
+ use_norm=config.get("use_norm", False),
+ interpolate_offset=config.get("interpolate_offset", 0.0),
+ frozen_stages=config.get("frozen_stages", 0),
+ freeze_norm=config.get("freeze_norm", False),
+ **kwargs,
+ )
+ return vit
+
+
+def dinov2_vitl14(config, pretrained: str = "", **kwargs):
+ vit = _make_dinov2_model(
+ arch_name="vit_large",
+ pretrained=config["pretrained"],
+ output_idx=config.get("output_idx", [5, 12, 18, 24]),
+ checkpoint=config.get("use_checkpoint", False),
+ drop_path_rate=config.get("drop_path", 0.0),
+ num_register_tokens=config.get("num_register_tokens", 0),
+ use_norm=config.get("use_norm", False),
+ interpolate_offset=config.get("interpolate_offset", 0.0),
+ frozen_stages=config.get("frozen_stages", 0),
+ freeze_norm=config.get("freeze_norm", False),
+ **kwargs,
+ )
+ return vit
+
+
+def dinov2_vitg14(config, pretrained: str = "", **kwargs):
+ vit = _make_dinov2_model(
+ arch_name="vit_giant2",
+ ffn_layer="swiglufused",
+ pretrained=config["pretrained"],
+ output_idx=config.get("output_idx", [10, 20, 30, 40]),
+ checkpoint=config.get("use_checkpoint", False),
+ drop_path_rate=config.get("drop_path", 0.0),
+ num_register_tokens=config.get("num_register_tokens", 0),
+ use_norm=config.get("use_norm", False),
+ interpolate_offset=config.get("interpolate_offset", 0.0),
+ **kwargs,
+ )
+ return vit
diff --git a/unik3d/models/export.py b/unik3d/models/export.py
new file mode 100644
index 0000000000000000000000000000000000000000..65a8704816f30850307342f10442b532726c949a
--- /dev/null
+++ b/unik3d/models/export.py
@@ -0,0 +1,165 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import argparse
+import json
+import os
+from math import ceil
+
+import huggingface_hub
+import torch.nn.functional as F
+import torch.onnx
+
+from unik3d.models.unik3d import UniK3D
+
+
+class UniK3DONNX(UniK3D):
+ def __init__(
+ self,
+ config,
+ eps: float = 1e-6,
+ **kwargs,
+ ):
+ super().__init__(config, eps)
+
+ def forward(self, rgbs):
+ B, _, H, W = rgbs.shape
+ features, tokens = self.pixel_encoder(rgbs)
+
+ inputs = {}
+ inputs["image"] = rgbs
+ inputs["features"] = [
+ self.stacking_fn(features[i:j]).contiguous()
+ for i, j in self.slices_encoder_range
+ ]
+ inputs["tokens"] = [
+ self.stacking_fn(tokens[i:j]).contiguous()
+ for i, j in self.slices_encoder_range
+ ]
+ outputs = self.pixel_decoder(inputs, [])
+ outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
+ pts_3d = outputs["rays"] * outputs["radius"]
+
+ return pts_3d, outputs["confidence"]
+
+
+class UniK3DONNXcam(UniK3D):
+ def __init__(
+ self,
+ config,
+ eps: float = 1e-6,
+ **kwargs,
+ ):
+ super().__init__(config, eps)
+
+ def forward(self, rgbs, rays):
+ B, _, H, W = rgbs.shape
+ features, tokens = self.pixel_encoder(rgbs)
+
+ inputs = {}
+ inputs["image"] = rgbs
+ inputs["rays"] = rays
+ inputs["features"] = [
+ self.stacking_fn(features[i:j]).contiguous()
+ for i, j in self.slices_encoder_range
+ ]
+ inputs["tokens"] = [
+ self.stacking_fn(tokens[i:j]).contiguous()
+ for i, j in self.slices_encoder_range
+ ]
+ outputs = self.pixel_decoder(inputs, [])
+ outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W)
+ pts_3d = outputs["rays"] * outputs["radius"]
+
+ return pts_3d, outputs["confidence"]
+
+
+def export(model, path, shape=(462, 630), with_camera=False):
+ model.eval()
+ image = torch.rand(1, 3, *shape)
+ dynamic_axes_in = {"rgbs": {0: "batch"}}
+ inputs = [image]
+ if with_camera:
+ rays = torch.rand(1, 3, *shape)
+ inputs.append(rays)
+ dynamic_axes_in["rays"] = {0: "batch"}
+
+ dynamic_axes_out = {
+ "pts_3d": {0: "batch"},
+ "confidence": {0: "batch"},
+ }
+ torch.onnx.export(
+ model,
+ tuple(inputs),
+ path,
+ input_names=list(dynamic_axes_in.keys()),
+ output_names=list(dynamic_axes_out.keys()),
+ opset_version=14,
+ dynamic_axes={**dynamic_axes_in, **dynamic_axes_out},
+ )
+ print(f"Model exported to {path}")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX")
+ parser.add_argument(
+ "--backbone",
+ type=str,
+ default="vitl",
+ choices=["vits", "vitb", "vitl"],
+ help="Backbone model",
+ )
+ parser.add_argument(
+ "--shape",
+ type=int,
+ nargs=2,
+ default=(462, 630),
+ help="Input shape. No dyamic shape supported!",
+ )
+ parser.add_argument(
+ "--output-path", type=str, default="unik3d.onnx", help="Output ONNX file"
+ )
+ parser.add_argument(
+ "--with-camera",
+ action="store_true",
+ help="Export model that expects GT camera as unprojected rays at inference",
+ )
+ args = parser.parse_args()
+
+ backbone = args.backbone
+ shape = args.shape
+ output_path = args.output_path
+ with_camera = args.with_camera
+
+ # force shape to be multiple of 14
+ shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape]
+ if list(shape) != list(shape_rounded):
+ print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}")
+ shape = shape_rounded
+
+ # assumes command is from root of repo
+ with open(os.path.join("configs", f"config_{backbone}.json")) as f:
+ config = json.load(f)
+
+ # tell DINO not to use efficient attention: not exportable
+ config["training"]["export"] = True
+
+ model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config)
+ path = huggingface_hub.hf_hub_download(
+ repo_id=f"lpiccinelli/unik3d-{backbone}",
+ filename=f"pytorch_model.bin",
+ repo_type="model",
+ )
+ info = model.load_state_dict(torch.load(path), strict=False)
+ print(f"UUniK3D_{backbone} is loaded with:")
+ print(f"\t missing keys: {info.missing_keys}")
+ print(f"\t additional keys: {info.unexpected_keys}")
+
+ export(
+ model=model,
+ path=os.path.join(os.environ.get("TMPDIR", "."), output_path),
+ shape=shape,
+ with_camera=with_camera,
+ )
diff --git a/unik3d/models/metadinov2/__init__.py b/unik3d/models/metadinov2/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0534f6bab5034b20875641adba68f8aa1f857ac
--- /dev/null
+++ b/unik3d/models/metadinov2/__init__.py
@@ -0,0 +1,12 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from .attention import Attention, MemEffAttention
+from .block import Block, NestedTensorBlock
+from .dino_head import DINOHead
+from .mlp import Mlp
+from .patch_embed import PatchEmbed
+from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
diff --git a/unik3d/models/metadinov2/attention.py b/unik3d/models/metadinov2/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea22c9f4e8e1f6d78e1bbd0f4899d56e69424c4e
--- /dev/null
+++ b/unik3d/models/metadinov2/attention.py
@@ -0,0 +1,84 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
+
+import logging
+
+import torch.nn as nn
+from torch import Tensor
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha, memory_efficient_attention, unbind
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Attention(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int = 8,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ attn_drop: float = 0.0,
+ proj_drop: float = 0.0,
+ ) -> None:
+ super().__init__()
+ self.num_heads = num_heads
+ head_dim = dim // num_heads
+ self.scale = head_dim**-0.5
+
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
+ self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity()
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
+ self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ B, N, C = x.shape
+ qkv = (
+ self.qkv(x)
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
+ .permute(2, 0, 3, 1, 4)
+ )
+
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
+ attn = q @ k.transpose(-2, -1)
+
+ attn = attn.softmax(dim=-1)
+ attn = self.attn_drop(attn)
+
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
+
+
+class MemEffAttention(Attention):
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
+ if not XFORMERS_AVAILABLE or x.device.type == "cpu":
+ assert attn_bias is None, "xFormers is required for nested tensors usage"
+ return super().forward(x)
+
+ B, N, C = x.shape
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
+
+ q, k, v = unbind(qkv, 2)
+
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
+ x = x.reshape([B, N, C])
+
+ x = self.proj(x)
+ x = self.proj_drop(x)
+ return x
diff --git a/unik3d/models/metadinov2/block.py b/unik3d/models/metadinov2/block.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e6cfc5232905b3e0ddfe1c126bd8c5798f974f4
--- /dev/null
+++ b/unik3d/models/metadinov2/block.py
@@ -0,0 +1,281 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+import logging
+from typing import Any, Callable, Dict, List, Tuple
+
+import torch
+import torch.nn as nn
+
+from .attention import Attention, MemEffAttention
+from .drop_path import DropPath
+from .layer_scale import LayerScale
+from .mlp import Mlp
+
+logger = logging.getLogger("dinov2")
+
+
+try:
+ from xformers.ops import fmha, index_select_cat, scaled_index_add
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ logger.warning("xFormers not available")
+ XFORMERS_AVAILABLE = False
+
+
+class Block(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_heads: int,
+ mlp_ratio: float = 4.0,
+ qkv_bias: bool = False,
+ proj_bias: bool = True,
+ ffn_bias: bool = True,
+ drop: float = 0.0,
+ attn_drop: float = 0.0,
+ init_values=None,
+ drop_path: float = 0.0,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
+ attn_class: Callable[..., nn.Module] = Attention,
+ ffn_layer: Callable[..., nn.Module] = Mlp,
+ ) -> None:
+ super().__init__()
+ self.norm1 = norm_layer(dim)
+ self.attn = attn_class(
+ dim,
+ num_heads=num_heads,
+ qkv_bias=qkv_bias,
+ proj_bias=proj_bias,
+ attn_drop=attn_drop,
+ proj_drop=drop,
+ )
+ self.ls1 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.norm2 = norm_layer(dim)
+ mlp_hidden_dim = int(dim * mlp_ratio)
+ self.mlp = ffn_layer(
+ in_features=dim,
+ hidden_features=mlp_hidden_dim,
+ act_layer=act_layer,
+ drop=drop,
+ bias=ffn_bias,
+ )
+ self.ls2 = (
+ LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
+ )
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
+
+ self.sample_drop_ratio = drop_path
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ def attn_residual_func(x: torch.Tensor) -> torch.Tensor:
+ return self.ls1(self.attn(self.norm1(x)))
+
+ def ffn_residual_func(x: torch.Tensor) -> torch.Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ if self.training and self.sample_drop_ratio > 0.1:
+ # the overhead is compensated only for a drop path rate larger than 0.1
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ x = drop_add_residual_stochastic_depth(
+ x,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ )
+ elif self.training and self.sample_drop_ratio > 0.0:
+ x = x + self.drop_path1(attn_residual_func(x))
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
+ else:
+ x = x + attn_residual_func(x)
+ x = x + ffn_residual_func(x)
+ return x
+
+
+def drop_add_residual_stochastic_depth(
+ x: torch.Tensor,
+ residual_func, #: Callable[[torch.Tensor], torch.Tensor],
+ sample_drop_ratio: float = 0.0,
+) -> torch.Tensor:
+ # 1) extract subset using permutation
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ x_subset = x[brange]
+
+ # 2) apply residual_func to get residual
+ residual = residual_func(x_subset)
+
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+
+ residual_scale_factor = b / sample_subset_size
+
+ # 3) add the residual
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ return x_plus_residual.view_as(x)
+
+
+def get_branges_scales(x, sample_drop_ratio=0.0):
+ b, n, d = x.shape
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
+ residual_scale_factor = b / sample_subset_size
+ return brange, residual_scale_factor
+
+
+def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
+ if scaling_vector is None:
+ x_flat = x.flatten(1)
+ residual = residual.flatten(1)
+ x_plus_residual = torch.index_add(
+ x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
+ )
+ else:
+ x_plus_residual = scaled_index_add(
+ x,
+ brange,
+ residual.to(dtype=x.dtype),
+ scaling=scaling_vector,
+ alpha=residual_scale_factor,
+ )
+ return x_plus_residual
+
+
+attn_bias_cache: Dict[Tuple, Any] = {}
+
+
+def get_attn_bias_and_cat(x_list, branges=None):
+ """
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
+ """
+ batch_sizes = (
+ [b.shape[0] for b in branges]
+ if branges is not None
+ else [x.shape[0] for x in x_list]
+ )
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
+ if all_shapes not in attn_bias_cache.keys():
+ seqlens = []
+ for b, x in zip(batch_sizes, x_list):
+ for _ in range(b):
+ seqlens.append(x.shape[1])
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
+ attn_bias._batch_sizes = batch_sizes
+ attn_bias_cache[all_shapes] = attn_bias
+
+ if branges is not None:
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
+ 1, -1, x_list[0].shape[-1]
+ )
+ else:
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
+
+ return attn_bias_cache[all_shapes], cat_tensors
+
+
+def drop_add_residual_stochastic_depth_list(
+ x_list: List[torch.Tensor],
+ residual_func, #: Callable[[torch.Tensor, Any], torch.Tensor],
+ sample_drop_ratio: float = 0.0,
+ scaling_vector=None,
+) -> torch.Tensor:
+ # 1) generate random set of indices for dropping samples in the batch
+ branges_scales = [
+ get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
+ ]
+ branges = [s[0] for s in branges_scales]
+ residual_scale_factors = [s[1] for s in branges_scales]
+
+ # 2) get attention bias and index+concat the tensors
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
+
+ # 3) apply residual_func to get residual, and split the result
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
+
+ outputs = []
+ for x, brange, residual, residual_scale_factor in zip(
+ x_list, branges, residual_list, residual_scale_factors
+ ):
+ outputs.append(
+ add_residual(
+ x, brange, residual, residual_scale_factor, scaling_vector
+ ).view_as(x)
+ )
+ return outputs
+
+
+class NestedTensorBlock(Block):
+ def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
+ """
+ x_list contains a list of tensors to nest together and run
+ """
+ assert isinstance(self.attn, MemEffAttention)
+
+ if self.training and self.sample_drop_ratio > 0.0:
+
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
+
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.mlp(self.norm2(x))
+
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=attn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(
+ self.ls1.gamma if isinstance(self.ls1, LayerScale) else None
+ ),
+ )
+ x_list = drop_add_residual_stochastic_depth_list(
+ x_list,
+ residual_func=ffn_residual_func,
+ sample_drop_ratio=self.sample_drop_ratio,
+ scaling_vector=(
+ self.ls2.gamma if isinstance(self.ls1, LayerScale) else None
+ ),
+ )
+ return x_list
+ else:
+
+ def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
+
+ def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor:
+ return self.ls2(self.mlp(self.norm2(x)))
+
+ attn_bias, x = get_attn_bias_and_cat(x_list)
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
+ x = x + ffn_residual_func(x)
+ return attn_bias.split(x)
+
+ def forward(self, x_or_x_list):
+ if isinstance(x_or_x_list, torch.Tensor):
+ return super().forward(x_or_x_list)
+ elif isinstance(x_or_x_list, list):
+ assert (
+ XFORMERS_AVAILABLE
+ ), "Please install xFormers for nested tensors usage"
+ return self.forward_nested(x_or_x_list)
+ else:
+ raise AssertionError
diff --git a/unik3d/models/metadinov2/dino_head.py b/unik3d/models/metadinov2/dino_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b
--- /dev/null
+++ b/unik3d/models/metadinov2/dino_head.py
@@ -0,0 +1,68 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import torch.nn as nn
+from torch.nn.init import trunc_normal_
+from torch.nn.utils import weight_norm
+
+
+class DINOHead(nn.Module):
+ def __init__(
+ self,
+ in_dim,
+ out_dim,
+ use_bn=False,
+ nlayers=3,
+ hidden_dim=2048,
+ bottleneck_dim=256,
+ mlp_bias=True,
+ ):
+ super().__init__()
+ nlayers = max(nlayers, 1)
+ self.mlp = _build_mlp(
+ nlayers,
+ in_dim,
+ bottleneck_dim,
+ hidden_dim=hidden_dim,
+ use_bn=use_bn,
+ bias=mlp_bias,
+ )
+ self.apply(self._init_weights)
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
+ self.last_layer.weight_g.data.fill_(1)
+
+ def _init_weights(self, m):
+ if isinstance(m, nn.Linear):
+ trunc_normal_(m.weight, std=0.02)
+ if isinstance(m, nn.Linear) and m.bias is not None:
+ nn.init.constant_(m.bias, 0)
+
+ def forward(self, x):
+ x = self.mlp(x)
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
+ x = self.last_layer(x)
+ return x
+
+
+def _build_mlp(
+ nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
+):
+ if nlayers == 1:
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
+ else:
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ for _ in range(nlayers - 2):
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
+ if use_bn:
+ layers.append(nn.BatchNorm1d(hidden_dim))
+ layers.append(nn.GELU())
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
+ return nn.Sequential(*layers)
diff --git a/unik3d/models/metadinov2/drop_path.py b/unik3d/models/metadinov2/drop_path.py
new file mode 100644
index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93
--- /dev/null
+++ b/unik3d/models/metadinov2/drop_path.py
@@ -0,0 +1,37 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
+
+
+import torch.nn as nn
+
+
+def drop_path(x, drop_prob: float = 0.0, training: bool = False):
+ if drop_prob == 0.0 or not training:
+ return x
+ keep_prob = 1 - drop_prob
+ shape = (x.shape[0],) + (1,) * (
+ x.ndim - 1
+ ) # work with diff dim tensors, not just 2D ConvNets
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
+ if keep_prob > 0.0:
+ random_tensor.div_(keep_prob)
+ output = x * random_tensor
+ return output
+
+
+class DropPath(nn.Module):
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
+
+ def __init__(self, drop_prob=None):
+ super(DropPath, self).__init__()
+ self.drop_prob = drop_prob
+
+ def forward(self, x):
+ return drop_path(x, self.drop_prob, self.training)
diff --git a/unik3d/models/metadinov2/layer_scale.py b/unik3d/models/metadinov2/layer_scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..40d18b5427183534d5516652b076f9883a609fc6
--- /dev/null
+++ b/unik3d/models/metadinov2/layer_scale.py
@@ -0,0 +1,28 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
+
+from typing import Union
+
+import torch
+import torch.nn as nn
+from torch import Tensor
+
+
+class LayerScale(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ init_values: Union[float, Tensor] = 1e-5,
+ inplace: bool = False,
+ ) -> None:
+ super().__init__()
+ self.inplace = inplace
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
+
+ def forward(self, x: Tensor) -> Tensor:
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
diff --git a/unik3d/models/metadinov2/mlp.py b/unik3d/models/metadinov2/mlp.py
new file mode 100644
index 0000000000000000000000000000000000000000..af598999855d897948142cc986fce82abc9e3b53
--- /dev/null
+++ b/unik3d/models/metadinov2/mlp.py
@@ -0,0 +1,41 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
+
+
+from typing import Callable, Optional
+
+from torch import Tensor, nn
+
+
+class Mlp(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = nn.GELU,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
+ self.act = act_layer()
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
+ self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ x = self.fc1(x)
+ x = self.act(x)
+ x = self.drop(x)
+ x = self.fc2(x)
+ x = self.drop(x)
+ return x
diff --git a/unik3d/models/metadinov2/patch_embed.py b/unik3d/models/metadinov2/patch_embed.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5a56c02609e67922eb8f859588ef274e5298b55
--- /dev/null
+++ b/unik3d/models/metadinov2/patch_embed.py
@@ -0,0 +1,101 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+# References:
+# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
+# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
+
+from typing import Callable, Optional, Tuple, Union
+
+import torch.nn as nn
+from torch import Tensor
+
+
+def make_2tuple(x):
+ if isinstance(x, tuple):
+ assert len(x) == 2
+ return x
+
+ assert isinstance(x, int)
+ return (x, x)
+
+
+class PatchEmbed(nn.Module):
+ """
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
+
+ Args:
+ img_size: Image size.
+ patch_size: Patch token size.
+ in_chans: Number of input image channels.
+ embed_dim: Number of linear projection output channels.
+ norm_layer: Normalization layer.
+ """
+
+ def __init__(
+ self,
+ img_size: Union[int, Tuple[int, int]] = 224,
+ patch_size: Union[int, Tuple[int, int]] = 16,
+ in_chans: int = 3,
+ embed_dim: int = 768,
+ norm_layer: Optional[Callable] = None,
+ flatten_embedding: bool = True,
+ ) -> None:
+ super().__init__()
+
+ image_HW = make_2tuple(img_size)
+ patch_HW = make_2tuple(patch_size)
+ patch_grid_size = (
+ image_HW[0] // patch_HW[0],
+ image_HW[1] // patch_HW[1],
+ )
+
+ self.img_size = image_HW
+ self.patch_size = patch_HW
+ self.patches_resolution = patch_grid_size
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
+
+ self.in_chans = in_chans
+ self.embed_dim = embed_dim
+
+ self.flatten_embedding = flatten_embedding
+
+ self.proj = nn.Conv2d(
+ in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
+ )
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
+
+ def forward(self, x: Tensor) -> Tensor:
+ _, _, H, W = x.shape
+ patch_H, patch_W = self.patch_size
+
+ assert (
+ H % patch_H == 0
+ ), f"Input image height {H} is not a multiple of patch height {patch_H}"
+ assert (
+ W % patch_W == 0
+ ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
+
+ x = self.proj(x) # B C H W
+ H, W = x.size(2), x.size(3)
+ x = x.flatten(2).transpose(1, 2) # B HW C
+ x = self.norm(x)
+ if not self.flatten_embedding:
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
+ return x
+
+ def flops(self) -> float:
+ Ho, Wo = self.patches_resolution
+ flops = (
+ Ho
+ * Wo
+ * self.embed_dim
+ * self.in_chans
+ * (self.patch_size[0] * self.patch_size[1])
+ )
+ if self.norm is not None:
+ flops += Ho * Wo * self.embed_dim
+ return flops
diff --git a/unik3d/models/metadinov2/swiglu_ffn.py b/unik3d/models/metadinov2/swiglu_ffn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e82999e9b09b41cd6aba9edbc4c05d51ab663a1e
--- /dev/null
+++ b/unik3d/models/metadinov2/swiglu_ffn.py
@@ -0,0 +1,63 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+from typing import Callable, Optional
+
+import torch.nn.functional as F
+from torch import Tensor, nn
+
+
+class SwiGLUFFN(nn.Module):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ super().__init__()
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
+
+ def forward(self, x: Tensor) -> Tensor:
+ x12 = self.w12(x)
+ x1, x2 = x12.chunk(2, dim=-1)
+ hidden = F.silu(x1) * x2
+ return self.w3(hidden)
+
+
+try:
+ from xformers.ops import SwiGLU
+
+ XFORMERS_AVAILABLE = True
+except ImportError:
+ SwiGLU = SwiGLUFFN
+ XFORMERS_AVAILABLE = False
+
+
+class SwiGLUFFNFused(SwiGLU):
+ def __init__(
+ self,
+ in_features: int,
+ hidden_features: Optional[int] = None,
+ out_features: Optional[int] = None,
+ act_layer: Callable[..., nn.Module] = None,
+ drop: float = 0.0,
+ bias: bool = True,
+ ) -> None:
+ out_features = out_features or in_features
+ hidden_features = hidden_features or in_features
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
+ super().__init__(
+ in_features=in_features,
+ hidden_features=hidden_features,
+ out_features=out_features,
+ bias=bias,
+ )
diff --git a/unik3d/models/unik3d.py b/unik3d/models/unik3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..6366d137598514a76fe5bb8df57c8b2782e461f9
--- /dev/null
+++ b/unik3d/models/unik3d.py
@@ -0,0 +1,475 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+import importlib
+import warnings
+from copy import deepcopy
+from math import ceil
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torchvision.transforms.v2.functional as TF
+from einops import rearrange
+from huggingface_hub import PyTorchModelHubMixin
+
+from unik3d.models.decoder import Decoder
+from unik3d.utils.camera import BatchCamera, Camera
+from unik3d.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD
+from unik3d.utils.distributed import is_main_process
+from unik3d.utils.misc import get_params, last_stack, match_gt
+
+DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
+
+
+def orthonormal_init(num_tokens, dims):
+ pe = torch.randn(num_tokens, dims)
+ # use Gram-Schmidt process to make the matrix orthonormal
+ for i in range(num_tokens):
+ for j in range(i):
+ pe[i] -= torch.dot(pe[i], pe[j]) * pe[j]
+ pe[i] = F.normalize(pe[i], p=2, dim=0)
+ return pe
+
+
+def get_paddings(original_shape, aspect_ratio_range):
+ # Original dimensions
+ H_ori, W_ori = original_shape
+ orig_aspect_ratio = W_ori / H_ori
+
+ # Determine the closest aspect ratio within the range
+ min_ratio, max_ratio = aspect_ratio_range
+ target_aspect_ratio = min(max_ratio, max(min_ratio, orig_aspect_ratio))
+
+ if orig_aspect_ratio > target_aspect_ratio: # Too wide
+ W_new = W_ori
+ H_new = int(W_ori / target_aspect_ratio)
+ pad_top = (H_new - H_ori) // 2
+ pad_bottom = H_new - H_ori - pad_top
+ pad_left, pad_right = 0, 0
+ else: # Too tall
+ H_new = H_ori
+ W_new = int(H_ori * target_aspect_ratio)
+ pad_left = (W_new - W_ori) // 2
+ pad_right = W_new - W_ori - pad_left
+ pad_top, pad_bottom = 0, 0
+
+ return (pad_left, pad_right, pad_top, pad_bottom), (H_new, W_new)
+
+
+def get_resize_factor(original_shape, pixels_range, shape_multiplier=14):
+ # Original dimensions
+ H_ori, W_ori = original_shape
+ n_pixels_ori = W_ori * H_ori
+
+ # Determine the closest number of pixels within the range
+ min_pixels, max_pixels = pixels_range
+ target_pixels = min(max_pixels, max(min_pixels, n_pixels_ori))
+
+ # Calculate the resize factor
+ resize_factor = (target_pixels / n_pixels_ori) ** 0.5
+ new_width = int(W_ori * resize_factor)
+ new_height = int(H_ori * resize_factor)
+ new_height = ceil(new_height / shape_multiplier) * shape_multiplier
+ new_width = ceil(new_width / shape_multiplier) * shape_multiplier
+
+ return resize_factor, (new_height, new_width)
+
+
+def _postprocess(tensor, shapes, paddings, interpolation_mode="bilinear"):
+
+ # interpolate to original size
+ tensor = F.interpolate(
+ tensor, size=shapes, mode=interpolation_mode, align_corners=False
+ )
+
+ # remove paddings
+ pad1_l, pad1_r, pad1_t, pad1_b = paddings
+ tensor = tensor[..., pad1_t : shapes[0] - pad1_b, pad1_l : shapes[1] - pad1_r]
+ return tensor
+
+
+class UniK3D(
+ nn.Module,
+ PyTorchModelHubMixin,
+ library_name="UniK3D",
+ repo_url="https://github.com/lpiccinelli-eth/UniK3D",
+ tags=["monocular-metric-3D-estimation"],
+):
+ def __init__(
+ self,
+ config,
+ eps: float = 1e-6,
+ **kwargs,
+ ):
+ super().__init__()
+ self.eps = eps
+ self.build(config)
+ self.build_losses(config)
+
+ def pack_sequence(
+ self,
+ inputs: dict[str, torch.Tensor],
+ ):
+ for key, value in inputs.items():
+ if isinstance(value, torch.Tensor):
+ inputs[key] = value.reshape(-1, *value.shape[2:])
+ elif isinstance(value, BatchCamera):
+ inputs[key] = value.reshape(-1)
+ return inputs
+
+ def unpack_sequence(self, inputs: dict[str, torch.Tensor], B: int, T: int):
+ for key, value in inputs.items():
+ if isinstance(value, torch.Tensor):
+ inputs[key] = value.reshape(B, T, *value.shape[1:])
+ elif isinstance(value, BatchCamera):
+ inputs[key] = value.reshape(B, T)
+ return inputs
+
+ def forward_train(self, inputs, image_metas):
+ losses = {"opt": {}, "stat": {}}
+ B, T = inputs["image"].shape[:2]
+ image_metas[0]["B"], image_metas[0]["T"] = B, T
+ inputs = self.pack_sequence(inputs) # move from B, T, ... -> B*T, ...
+
+ inputs, outputs = self.encode_decode(inputs, image_metas)
+ validity_mask = inputs["validity_mask"]
+
+ # be careful on possible NaNs in reconstruced 3D (unprojection out-of-bound)
+ pts_gt = inputs["camera"].reconstruct(inputs["depth"]) * validity_mask.float()
+ pts_gt = torch.where(pts_gt.isnan().any(dim=1, keepdim=True), 0.0, pts_gt)
+ mask_pts_gt_nan = ~pts_gt.isnan().any(dim=1, keepdim=True)
+ mask = (
+ inputs["depth_mask"].bool() & validity_mask.bool() & mask_pts_gt_nan.bool()
+ )
+
+ # compute loss!
+ inputs["radius"] = torch.norm(pts_gt, dim=1, keepdim=True)
+ inputs["points"] = pts_gt
+ inputs["depth_mask"] = mask
+ losses = self.compute_losses(outputs, inputs, image_metas)
+
+ outputs = self.unpack_sequence(outputs, B, T)
+ return (
+ outputs,
+ losses,
+ )
+
+ def forward_test(self, inputs, image_metas):
+ B, T = inputs["image"].shape[:2]
+ image_metas[0]["B"], image_metas[0]["T"] = B, T
+ # move from B, T, ... -> B*T, ...
+ inputs = self.pack_sequence(inputs)
+ inputs, outputs = self.encode_decode(inputs, image_metas)
+
+ # you can add a dummy tensor with the actual output shape
+ depth_gt = inputs["depth"]
+
+ outs = {}
+ outs["points"] = match_gt(
+ outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None
+ )
+ outs["confidence"] = match_gt(
+ outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None
+ )
+ outs["distance"] = outs["points"].norm(dim=1, keepdim=True)
+ outs["depth"] = outs["points"][:, -1:]
+ outs["rays"] = outs["points"] / torch.norm(
+ outs["points"], dim=1, keepdim=True
+ ).clip(min=1e-5)
+
+ outs = self.unpack_sequence(outs, B, T)
+ return outs
+
+ def forward(self, inputs, image_metas):
+ if self.training:
+ return self.forward_train(inputs, image_metas)
+ else:
+ return self.forward_test(inputs, image_metas)
+
+ def encode_decode(self, inputs, image_metas=[]):
+ B, _, H, W = inputs["image"].shape
+
+ # shortcut eval should avoid errors
+ if len(image_metas) and "paddings" in image_metas[0]:
+ # lrtb
+ inputs["paddings"] = torch.tensor(
+ [image_meta["paddings"] for image_meta in image_metas],
+ device=self.device,
+ )[..., [0, 2, 1, 3]]
+ inputs["depth_paddings"] = torch.tensor(
+ [image_meta["depth_paddings"] for image_meta in image_metas],
+ device=self.device,
+ )
+ # at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop)
+ if self.training:
+ inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"]
+ else:
+ inputs["paddings"] = inputs["paddings"].squeeze(0)
+ inputs["depth_paddings"] = inputs["depth_paddings"].squeeze(0)
+
+ if inputs.get("camera", None) is not None:
+ inputs["rays"] = inputs["camera"].get_rays(shapes=(B, H, W))
+
+ features, tokens = self.pixel_encoder(inputs["image"])
+ inputs["features"] = [
+ self.stacking_fn(features[i:j]).contiguous()
+ for i, j in self.slices_encoder_range
+ ]
+ inputs["tokens"] = [
+ self.stacking_fn(tokens[i:j]).contiguous()
+ for i, j in self.slices_encoder_range
+ ]
+
+ outputs = self.pixel_decoder(inputs, image_metas)
+ outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W)
+ pts_3d = outputs["rays"] * outputs["distance"]
+ outputs.update({"points": pts_3d, "depth": pts_3d[:, -1:]})
+
+ return inputs, outputs
+
+ def compute_losses(self, outputs, inputs, image_metas):
+ B, _, H, W = inputs["image"].shape
+ losses = {"opt": {}, "stat": {}}
+ losses_to_be_computed = list(self.losses.keys())
+
+ # depth loss
+ si = torch.tensor(
+ [x.get("si", False) for x in image_metas], device=self.device
+ ).reshape(B)
+ loss = self.losses["depth"]
+ depth_losses = loss(
+ outputs["depth"],
+ target=inputs["depth"],
+ mask=inputs["depth_mask"].clone(),
+ si=si,
+ )
+ losses["opt"][loss.name] = loss.weight * depth_losses.mean()
+ losses_to_be_computed.remove("depth")
+
+ loss = self.losses["camera"]
+ camera_losses = loss(
+ outputs["rays"], target=inputs["rays"], mask=inputs["validity_mask"].bool()
+ )
+ losses["opt"][loss.name] = loss.weight * camera_losses.mean()
+ losses_to_be_computed.remove("camera")
+
+ # remaining losses, we expect no more losses to be computed
+ loss = self.losses["confidence"]
+ conf_losses = loss(
+ outputs["confidence"],
+ target_gt=inputs["depth"],
+ target_pred=outputs["depth"],
+ mask=inputs["depth_mask"].clone(),
+ )
+ losses["opt"][loss.name + "_conf"] = loss.weight * conf_losses.mean()
+ losses_to_be_computed.remove("confidence")
+
+ assert (
+ not losses_to_be_computed
+ ), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method"
+
+ return losses
+
+ @torch.no_grad()
+ @torch.autocast(device_type=DEVICE, enabled=True, dtype=torch.float16)
+ def infer(
+ self,
+ rgb: torch.Tensor,
+ camera: torch.Tensor | Camera | None = None,
+ rays=None,
+ normalize=True,
+ ):
+ ratio_bounds = self.shape_constraints["ratio_bounds"]
+ pixels_bounds = [
+ self.shape_constraints["pixels_min"],
+ self.shape_constraints["pixels_max"],
+ ]
+ if hasattr(self, "resolution_level"):
+ assert (
+ self.resolution_level >= 0 and self.resolution_level < 10
+ ), "resolution_level should be in [0, 10)"
+ pixels_range = pixels_bounds[1] - pixels_bounds[0]
+ interval = pixels_range / 10
+ new_lowbound = self.resolution_level * interval + pixels_bounds[0]
+ new_upbound = (self.resolution_level + 1) * interval + pixels_bounds[0]
+ pixels_bounds = (new_lowbound, new_upbound)
+ else:
+ warnings.warn("!! self.resolution_level not set, using default bounds !!")
+
+ # houskeeping on cpu/cuda and batchify
+ if rgb.ndim == 3:
+ rgb = rgb.unsqueeze(0)
+ if camera is not None:
+ camera = BatchCamera.from_camera(camera)
+ camera = camera.to(self.device)
+
+ B, _, H, W = rgb.shape
+ rgb = rgb.to(self.device)
+
+ # preprocess
+ paddings, (padded_H, padded_W) = get_paddings((H, W), ratio_bounds)
+ (pad_left, pad_right, pad_top, pad_bottom) = paddings
+ resize_factor, (new_H, new_W) = get_resize_factor(
+ (padded_H, padded_W), pixels_bounds
+ )
+ # -> rgb preprocess (input std-ized and resized)
+ if normalize:
+ rgb = TF.normalize(
+ rgb.float() / 255.0,
+ mean=IMAGENET_DATASET_MEAN,
+ std=IMAGENET_DATASET_STD,
+ )
+ rgb = F.pad(rgb, (pad_left, pad_right, pad_top, pad_bottom), value=0.0)
+ rgb = F.interpolate(
+ rgb, size=(new_H, new_W), mode="bilinear", align_corners=False
+ )
+ # -> camera preprocess
+ if camera is not None:
+ camera = camera.crop(
+ left=-pad_left, top=-pad_top, right=-pad_right, bottom=-pad_bottom
+ )
+ camera = camera.resize(resize_factor)
+
+ # prepare inputs
+ inputs = {"image": rgb}
+ if camera is not None:
+ inputs["camera"] = camera
+ rays = camera.get_rays(shapes=(B, new_H, new_W), noisy=False).reshape(
+ B, 3, new_H, new_W
+ )
+ inputs["rays"] = rays
+
+ if rays is not None:
+ rays = rays.to(self.device)
+ if rays.ndim == 3:
+ rays = rays.unsqueeze(0)
+ rays = F.pad(
+ rays,
+ (
+ max(0, pad_left),
+ max(0, pad_right),
+ max(0, pad_top),
+ max(0, pad_bottom),
+ ),
+ value=0.0,
+ )
+ rays = F.interpolate(
+ rays, size=(new_H, new_W), mode="bilinear", align_corners=False
+ )
+ inputs["rays"] = rays
+
+ # run model
+ _, model_outputs = self.encode_decode(inputs, image_metas={})
+
+ # collect outputs
+ out = {}
+ out["confidence"] = _postprocess(
+ model_outputs["confidence"],
+ (padded_H, padded_W),
+ paddings=paddings,
+ interpolation_mode=self.interpolation_mode,
+ )
+ points = _postprocess(
+ model_outputs["points"],
+ (padded_H, padded_W),
+ paddings=paddings,
+ interpolation_mode=self.interpolation_mode,
+ )
+ rays = _postprocess(
+ model_outputs["rays"],
+ (padded_H, padded_W),
+ paddings=paddings,
+ interpolation_mode=self.interpolation_mode,
+ )
+
+ out["distance"] = points.norm(dim=1, keepdim=True)
+ out["depth"] = points[:, -1:]
+ out["points"] = points
+ out["rays"] = rays / torch.norm(rays, dim=1, keepdim=True).clip(min=1e-5)
+ out["lowres_features"] = model_outputs["lowres_features"]
+ return out
+
+ def load_pretrained(self, model_file):
+ dict_model = torch.load(model_file, map_location="cpu", weights_only=False)
+ if "model" in dict_model:
+ dict_model = dict_model["model"]
+ info = self.load_state_dict(dict_model, strict=False)
+ if is_main_process():
+ print(
+ f"Loaded from {model_file} for {self.__class__.__name__} results in:",
+ info,
+ )
+
+ def build(self, config):
+ mod = importlib.import_module("unik3d.models.encoder")
+ pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"])
+ pixel_encoder_config = {
+ **config["training"],
+ **config["model"]["pixel_encoder"],
+ **config["data"],
+ }
+ pixel_encoder = pixel_encoder_factory(pixel_encoder_config)
+ pixel_encoder_embed_dims = (
+ pixel_encoder.embed_dims
+ if hasattr(pixel_encoder, "embed_dims")
+ else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)]
+ )
+ config["model"]["pixel_encoder"]["embed_dim"] = getattr(
+ pixel_encoder, "embed_dim"
+ )
+ config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims
+ config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths
+ config["model"]["pixel_encoder"]["cls_token_embed_dims"] = getattr(
+ pixel_encoder, "cls_token_embed_dims", pixel_encoder_embed_dims
+ )
+
+ pixel_decoder = Decoder(config)
+
+ self.pixel_encoder = pixel_encoder
+ self.pixel_decoder = pixel_decoder
+ self.slices_encoder_range = list(
+ zip([0, *self.pixel_encoder.depths[:-1]], self.pixel_encoder.depths)
+ )
+ self.stacking_fn = last_stack
+ self.shape_constraints = config["data"]["shape_constraints"]
+ self.interpolation_mode = "bilinear"
+
+ def build_losses(self, config):
+ self.losses = {}
+ for loss_name, loss_config in config["training"]["losses"].items():
+ mod = importlib.import_module("unik3d.ops.losses")
+ loss_factory = getattr(mod, loss_config["name"])
+ self.losses[loss_name] = loss_factory.build(loss_config)
+
+ def get_params(self, config):
+ if hasattr(self.pixel_encoder, "get_params"):
+ encoder_p, _ = self.pixel_encoder.get_params(
+ config["model"]["pixel_encoder"]["lr"],
+ config["training"]["wd"],
+ config["training"]["ld"],
+ )
+ else:
+ encoder_p, _ = get_params(
+ self.pixel_encoder,
+ config["model"]["pixel_encoder"]["lr"],
+ config["training"]["wd"],
+ )
+ decoder_p = self.pixel_decoder.get_params(
+ config["training"]["lr"], config["training"]["wd"]
+ )
+ return [*encoder_p, *decoder_p]
+
+ def step(self):
+ self.pixel_decoder.steps += 1
+
+ def parameters_grad(self):
+ for p in self.parameters():
+ if p.requires_grad:
+ yield p
+
+ @property
+ def device(self):
+ return next(self.parameters()).device
diff --git a/unik3d/ops/__init__.py b/unik3d/ops/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb598ed6fbf6ffae6052dce0f11645adb0d3fda3
--- /dev/null
+++ b/unik3d/ops/__init__.py
@@ -0,0 +1,18 @@
+from .losses import (Confidence, Dummy, LocalNormal, LocalSSI, PolarRegression,
+ Regression, RobustLoss, Scale, SILog, SpatialGradient)
+from .scheduler import CosineScheduler, PlainCosineScheduler
+
+__all__ = [
+ "Dummy",
+ "SpatialGradient",
+ "LocalSSI",
+ "Regression",
+ "LocalNormal",
+ "RobustLoss",
+ "SILog",
+ "CosineScheduler",
+ "PlainCosineScheduler",
+ "PolarRegression",
+ "Scale",
+ "Confidence",
+]
diff --git a/unik3d/ops/knn/__init__.py b/unik3d/ops/knn/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba469c8fcb582177fd183702a468d6e72d262a7f
--- /dev/null
+++ b/unik3d/ops/knn/__init__.py
@@ -0,0 +1,6 @@
+from .functions.knn import knn_gather, knn_points
+
+__all__ = [
+ "knn_points",
+ "knn_gather",
+]
diff --git a/unik3d/ops/knn/compile.sh b/unik3d/ops/knn/compile.sh
new file mode 100755
index 0000000000000000000000000000000000000000..ccf4338cd44a1291d75c260fa180a24790ed98b4
--- /dev/null
+++ b/unik3d/ops/knn/compile.sh
@@ -0,0 +1,5 @@
+#!/usr/bin/env bash
+
+export TORCH_CUDA_ARCH_LIST="8.0 8.6+PTX"
+# export FORCE_CUDA=1 #if you do not actually have cuda, workaround
+python setup.py build install
\ No newline at end of file
diff --git a/unik3d/ops/knn/functions/__init__.py b/unik3d/ops/knn/functions/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a54211c5e208129ba33507b0011733cb4a31491
--- /dev/null
+++ b/unik3d/ops/knn/functions/__init__.py
@@ -0,0 +1,6 @@
+from .knn import knn_gather, knn_points
+
+__all__ = [
+ "knn_points",
+ "knn_gather",
+]
diff --git a/unik3d/ops/knn/functions/knn.py b/unik3d/ops/knn/functions/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..42c36344d42952f123290cb0bec10e39a6f2da3b
--- /dev/null
+++ b/unik3d/ops/knn/functions/knn.py
@@ -0,0 +1,249 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-unsafe
+
+from collections import namedtuple
+from typing import Union
+
+import torch
+from KNN import knn_points_backward, knn_points_idx
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+_KNN = namedtuple("KNN", "dists idx knn")
+
+
+class _knn_points(Function):
+ """
+ Torch autograd Function wrapper for KNN C++/CUDA implementations.
+ """
+
+ @staticmethod
+ # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
+ def forward(
+ ctx,
+ p1,
+ p2,
+ lengths1,
+ lengths2,
+ K,
+ version,
+ norm: int = 2,
+ return_sorted: bool = True,
+ ):
+ """
+ K-Nearest neighbors on point clouds.
+
+ Args:
+ p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
+ containing up to P1 points of dimension D.
+ p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
+ containing up to P2 points of dimension D.
+ lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
+ length of each pointcloud in p1. Or None to indicate that every cloud has
+ length P1.
+ lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
+ length of each pointcloud in p2. Or None to indicate that every cloud has
+ length P2.
+ K: Integer giving the number of nearest neighbors to return.
+ version: Which KNN implementation to use in the backend. If version=-1,
+ the correct implementation is selected based on the shapes of the inputs.
+ norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
+ return_sorted: (bool) whether to return the nearest neighbors sorted in
+ ascending order of distance.
+
+ Returns:
+ p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
+ the nearest neighbors. This is padded with zeros both where a cloud in p2
+ has fewer than K points and where a cloud in p1 has fewer than P1 points.
+
+ p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
+ K nearest neighbors from points in p1 to points in p2.
+ Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
+ neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
+ in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
+ """
+ if not ((norm == 1) or (norm == 2)):
+ raise ValueError("Support for 1 or 2 norm.")
+
+ idx, dists = knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
+
+ # sort KNN in ascending order if K > 1
+ if K > 1 and return_sorted:
+ if lengths2.min() < K:
+ P1 = p1.shape[1]
+ mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
+ # mask has shape [N, K], true where dists irrelevant
+ mask = mask[:, None].expand(-1, P1, -1)
+ # mask has shape [N, P1, K], true where dists irrelevant
+ dists[mask] = float("inf")
+ dists, sort_idx = dists.sort(dim=2)
+ dists[mask] = 0
+ else:
+ dists, sort_idx = dists.sort(dim=2)
+ idx = idx.gather(2, sort_idx)
+
+ ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
+ ctx.mark_non_differentiable(idx)
+ ctx.norm = norm
+ return dists, idx
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_dists, grad_idx):
+ p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
+ norm = ctx.norm
+ # TODO(gkioxari) Change cast to floats once we add support for doubles.
+ if not (grad_dists.dtype == torch.float32):
+ grad_dists = grad_dists.float()
+ if not (p1.dtype == torch.float32):
+ p1 = p1.float()
+ if not (p2.dtype == torch.float32):
+ p2 = p2.float()
+ grad_p1, grad_p2 = knn_points_backward(
+ p1, p2, lengths1, lengths2, idx, norm, grad_dists
+ )
+ return grad_p1, grad_p2, None, None, None, None, None, None
+
+
+def knn_points(
+ p1: torch.Tensor,
+ p2: torch.Tensor,
+ lengths1: Union[torch.Tensor, None] = None,
+ lengths2: Union[torch.Tensor, None] = None,
+ norm: int = 2,
+ K: int = 1,
+ version: int = -1,
+ return_nn: bool = False,
+ return_sorted: bool = True,
+) -> _KNN:
+ """
+ K-Nearest neighbors on point clouds.
+
+ Args:
+ p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
+ containing up to P1 points of dimension D.
+ p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
+ containing up to P2 points of dimension D.
+ lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
+ length of each pointcloud in p1. Or None to indicate that every cloud has
+ length P1.
+ lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
+ length of each pointcloud in p2. Or None to indicate that every cloud has
+ length P2.
+ norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
+ K: Integer giving the number of nearest neighbors to return.
+ version: Which KNN implementation to use in the backend. If version=-1,
+ the correct implementation is selected based on the shapes of the inputs.
+ return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
+ return_sorted: (bool) whether to return the nearest neighbors sorted in
+ ascending order of distance.
+
+ Returns:
+ dists: Tensor of shape (N, P1, K) giving the squared distances to
+ the nearest neighbors. This is padded with zeros both where a cloud in p2
+ has fewer than K points and where a cloud in p1 has fewer than P1 points.
+
+ idx: LongTensor of shape (N, P1, K) giving the indices of the
+ K nearest neighbors from points in p1 to points in p2.
+ Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
+ neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
+ in p2 has fewer than K points and where a cloud in p1 has fewer than P1
+ points.
+
+ nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
+ each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
+ for `p1[n, i]`. Returned if `return_nn` is True.
+ The nearest neighbors are collected using `knn_gather`
+
+ .. code-block::
+
+ p2_nn = knn_gather(p2, p1_idx, lengths2)
+
+ which is a helper function that allows indexing any tensor of shape (N, P2, U) with
+ the indices `p1_idx` returned by `knn_points`. The output is a tensor
+ of shape (N, P1, K, U).
+
+ """
+ if p1.shape[0] != p2.shape[0]:
+ raise ValueError("pts1 and pts2 must have the same batch dimension.")
+ if p1.shape[2] != p2.shape[2]:
+ raise ValueError("pts1 and pts2 must have the same point dimension.")
+
+ p1 = p1.contiguous()
+ p2 = p2.contiguous()
+
+ P1 = p1.shape[1]
+ P2 = p2.shape[1]
+
+ if lengths1 is None:
+ lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
+ if lengths2 is None:
+ lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
+
+ p1_dists, p1_idx = _knn_points.apply(
+ p1, p2, lengths1, lengths2, K, version, norm, return_sorted
+ )
+
+ p2_nn = None
+ if return_nn:
+ p2_nn = knn_gather(p2, p1_idx, lengths2)
+
+ return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
+
+
+def knn_gather(
+ x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
+):
+ """
+ A helper function for knn that allows indexing a tensor x with the indices `idx`
+ returned by `knn_points`.
+
+ For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
+ where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
+ then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
+ It can also be applied for any tensor x of shape (N, M, U) where U != D.
+
+ Args:
+ x: Tensor of shape (N, M, U) containing U-dimensional features to
+ be gathered.
+ idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
+ lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
+ length of each example in the batch in x. Or None to indicate that every
+ example has length M.
+ Returns:
+ x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
+ with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
+ If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
+ """
+ N, M, U = x.shape
+ _N, L, K = idx.shape
+
+ if N != _N:
+ raise ValueError("x and idx must have same batch dimension.")
+
+ if lengths is None:
+ lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
+
+ idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
+ # idx_expanded has shape [N, L, K, U]
+
+ x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
+ # p2_nn has shape [N, L, K, U]
+
+ needs_mask = lengths.min() < K
+ if needs_mask:
+ # mask has shape [N, K], true where idx is irrelevant because
+ # there is less number of points in p2 than K
+ mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
+
+ # expand mask to shape [N, L, K, U]
+ mask = mask[:, None].expand(-1, L, -1)
+ mask = mask[:, :, :, None].expand(-1, -1, -1, U)
+ x_out[mask] = 0.0
+
+ return x_out
diff --git a/unik3d/ops/knn/setup.py b/unik3d/ops/knn/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e7f7dd9953b7da22310cdb7d227acccf2749f36
--- /dev/null
+++ b/unik3d/ops/knn/setup.py
@@ -0,0 +1,61 @@
+import glob
+import os
+
+import torch
+from setuptools import find_packages, setup
+from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension
+
+requirements = ["torch", "torchvision"]
+
+
+def get_extensions():
+ this_dir = os.path.dirname(os.path.abspath(__file__))
+ extensions_dir = os.path.join(this_dir, "src")
+
+ main_file = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cpu = glob.glob(os.path.join(extensions_dir, "*.cpp"))
+ source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu"))
+
+ sources = main_file + source_cpu
+ extension = CppExtension
+ extra_compile_args = {"cxx": ["-O3"]}
+ define_macros = []
+
+ if torch.cuda.is_available() and CUDA_HOME is not None:
+ extension = CUDAExtension
+ sources += source_cuda
+ define_macros += [("WITH_CUDA", None)]
+ extra_compile_args["nvcc"] = [
+ "-O3",
+ ]
+ else:
+ raise NotImplementedError("Cuda is not available")
+
+ sources = list(set([os.path.join(extensions_dir, s) for s in sources]))
+ include_dirs = [extensions_dir]
+ ext_modules = [
+ extension(
+ "KNN",
+ sources,
+ include_dirs=include_dirs,
+ define_macros=define_macros,
+ extra_compile_args=extra_compile_args,
+ )
+ ]
+
+ return ext_modules
+
+
+setup(
+ name="KNN",
+ version="0.1",
+ author="Luigi Piccinelli",
+ ext_modules=get_extensions(),
+ packages=find_packages(
+ exclude=(
+ "configs",
+ "tests",
+ )
+ ),
+ cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension},
+)
diff --git a/unik3d/ops/knn/src/knn.cu b/unik3d/ops/knn/src/knn.cu
new file mode 100644
index 0000000000000000000000000000000000000000..ba0732dca6ee00732286d5e3a15191aaaf7c5409
--- /dev/null
+++ b/unik3d/ops/knn/src/knn.cu
@@ -0,0 +1,587 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+#include
+#include
+#include
+
+#include "utils/dispatch.cuh"
+#include "utils/mink.cuh"
+
+// A chunk of work is blocksize-many points of P1.
+// The number of potential chunks to do is N*(1+(P1-1)/blocksize)
+// call (1+(P1-1)/blocksize) chunks_per_cloud
+// These chunks are divided among the gridSize-many blocks.
+// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc .
+// In chunk i, we work on cloud i/chunks_per_cloud on points starting from
+// blocksize*(i%chunks_per_cloud).
+
+template
+__global__ void KNearestNeighborKernelV0(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t D,
+ const size_t K,
+ const size_t norm) {
+ // Store both dists and indices for knn in global memory.
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ int offset = n * P1 * K + p1 * K;
+ int64_t length2 = lengths2[n];
+ MinK mink(dists + offset, idxs + offset, K);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ // Find the distance between points1[n, p1] and points[n, p2]
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ scalar_t coord1 = points1[n * P1 * D + p1 * D + d];
+ scalar_t coord2 = points2[n * P2 * D + p2 * D + d];
+ scalar_t diff = coord1 - coord2;
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ }
+}
+
+template
+__global__ void KNearestNeighborKernelV1(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t K,
+ const size_t norm) {
+ // Same idea as the previous version, but hoist D into a template argument
+ // so we can cache the current point in a thread-local array. We still store
+ // the current best K dists and indices in global memory, so this should work
+ // for very large K and fairly large D.
+ scalar_t cur_point[D];
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ for (int d = 0; d < D; ++d) {
+ cur_point[d] = points1[n * P1 * D + p1 * D + d];
+ }
+ int offset = n * P1 * K + p1 * K;
+ int64_t length2 = lengths2[n];
+ MinK mink(dists + offset, idxs + offset, K);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ // Find the distance between cur_point and points[n, p2]
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d];
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ }
+}
+
+// This is a shim functor to allow us to dispatch using DispatchKernel1D
+template
+struct KNearestNeighborV1Functor {
+ static void run(
+ size_t blocks,
+ size_t threads,
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t K,
+ const size_t norm) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ KNearestNeighborKernelV1<<>>(
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm);
+ }
+};
+
+template
+__global__ void KNearestNeighborKernelV2(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const int64_t N,
+ const int64_t P1,
+ const int64_t P2,
+ const size_t norm) {
+ // Same general implementation as V2, but also hoist K into a template arg.
+ scalar_t cur_point[D];
+ scalar_t min_dists[K];
+ int min_idxs[K];
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ for (int d = 0; d < D; ++d) {
+ cur_point[d] = points1[n * P1 * D + p1 * D + d];
+ }
+ int64_t length2 = lengths2[n];
+ MinK mink(min_dists, min_idxs, K);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ int offset = n * P2 * D + p2 * D + d;
+ scalar_t diff = cur_point[d] - points2[offset];
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ for (int k = 0; k < mink.size(); ++k) {
+ idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
+ dists[n * P1 * K + p1 * K + k] = min_dists[k];
+ }
+ }
+}
+
+// This is a shim so we can dispatch using DispatchKernel2D
+template
+struct KNearestNeighborKernelV2Functor {
+ static void run(
+ size_t blocks,
+ size_t threads,
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const int64_t N,
+ const int64_t P1,
+ const int64_t P2,
+ const size_t norm) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ KNearestNeighborKernelV2<<>>(
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
+ }
+};
+
+template
+__global__ void KNearestNeighborKernelV3(
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t norm) {
+ // Same idea as V2, but use register indexing for thread-local arrays.
+ // Enabling sorting for this version leads to huge slowdowns; I suspect
+ // that it forces min_dists into local memory rather than registers.
+ // As a result this version is always unsorted.
+ scalar_t cur_point[D];
+ scalar_t min_dists[K];
+ int min_idxs[K];
+ const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x);
+ const int64_t chunks_to_do = N * chunks_per_cloud;
+ for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) {
+ const int64_t n = chunk / chunks_per_cloud;
+ const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud);
+ int64_t p1 = start_point + threadIdx.x;
+ if (p1 >= lengths1[n])
+ continue;
+ for (int d = 0; d < D; ++d) {
+ cur_point[d] = points1[n * P1 * D + p1 * D + d];
+ }
+ int64_t length2 = lengths2[n];
+ RegisterMinK mink(min_dists, min_idxs);
+ for (int p2 = 0; p2 < length2; ++p2) {
+ scalar_t dist = 0;
+ for (int d = 0; d < D; ++d) {
+ int offset = n * P2 * D + p2 * D + d;
+ scalar_t diff = cur_point[d] - points2[offset];
+ scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff);
+ dist += norm_diff;
+ }
+ mink.add(dist, p2);
+ }
+ for (int k = 0; k < mink.size(); ++k) {
+ idxs[n * P1 * K + p1 * K + k] = min_idxs[k];
+ dists[n * P1 * K + p1 * K + k] = min_dists[k];
+ }
+ }
+}
+
+// This is a shim so we can dispatch using DispatchKernel2D
+template
+struct KNearestNeighborKernelV3Functor {
+ static void run(
+ size_t blocks,
+ size_t threads,
+ const scalar_t* __restrict__ points1,
+ const scalar_t* __restrict__ points2,
+ const int64_t* __restrict__ lengths1,
+ const int64_t* __restrict__ lengths2,
+ scalar_t* __restrict__ dists,
+ int64_t* __restrict__ idxs,
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t norm) {
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+ KNearestNeighborKernelV3<<>>(
+ points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm);
+ }
+};
+
+constexpr int V1_MIN_D = 1;
+constexpr int V1_MAX_D = 32;
+
+constexpr int V2_MIN_D = 1;
+constexpr int V2_MAX_D = 8;
+constexpr int V2_MIN_K = 1;
+constexpr int V2_MAX_K = 32;
+
+constexpr int V3_MIN_D = 1;
+constexpr int V3_MAX_D = 8;
+constexpr int V3_MIN_K = 1;
+constexpr int V3_MAX_K = 4;
+
+bool InBounds(const int64_t min, const int64_t x, const int64_t max) {
+ return min <= x && x <= max;
+}
+
+bool KnnCheckVersion(int version, const int64_t D, const int64_t K) {
+ if (version == 0) {
+ return true;
+ } else if (version == 1) {
+ return InBounds(V1_MIN_D, D, V1_MAX_D);
+ } else if (version == 2) {
+ return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K);
+ } else if (version == 3) {
+ return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K);
+ }
+ return false;
+}
+
+int ChooseVersion(const int64_t D, const int64_t K) {
+ for (int version = 3; version >= 1; version--) {
+ if (KnnCheckVersion(version, D, K)) {
+ return version;
+ }
+ }
+ return 0;
+}
+
+std::tuple KNearestNeighborIdxCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K,
+ int version) {
+ // Check inputs are on the same device
+ at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
+ lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4};
+ at::CheckedFrom c = "KNearestNeighborIdxCuda";
+ at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t});
+ at::checkAllSameType(c, {p1_t, p2_t});
+
+ // Set the device for the kernel launch based on the device of the input
+ at::cuda::CUDAGuard device_guard(p1.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ const auto N = p1.size(0);
+ const auto P1 = p1.size(1);
+ const auto P2 = p2.size(1);
+ const auto D = p2.size(2);
+ const int64_t K_64 = K;
+
+ TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2.");
+
+ TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
+ auto long_dtype = lengths1.options().dtype(at::kLong);
+ auto idxs = at::zeros({N, P1, K}, long_dtype);
+ auto dists = at::zeros({N, P1, K}, p1.options());
+
+ if (idxs.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(idxs, dists);
+ }
+
+ if (version < 0) {
+ version = ChooseVersion(D, K);
+ } else if (!KnnCheckVersion(version, D, K)) {
+ int new_version = ChooseVersion(D, K);
+ std::cout << "WARNING: Requested KNN version " << version
+ << " is not compatible with D = " << D << "; K = " << K
+ << ". Falling back to version = " << new_version << std::endl;
+ version = new_version;
+ }
+
+ // At this point we should have a valid version no matter what data the user
+ // gave us. But we can check once more to be sure; however this time
+ // assert fail since failing at this point means we have a bug in our version
+ // selection or checking code.
+ AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version");
+
+ const size_t threads = 256;
+ const size_t blocks = 256;
+ if (version == 0) {
+ AT_DISPATCH_FLOATING_TYPES(
+ p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ KNearestNeighborKernelV0<<>>(
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ D,
+ K,
+ norm);
+ }));
+ } else if (version == 1) {
+ AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ DispatchKernel1D<
+ KNearestNeighborV1Functor,
+ scalar_t,
+ V1_MIN_D,
+ V1_MAX_D>(
+ D,
+ blocks,
+ threads,
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ K,
+ norm);
+ }));
+ } else if (version == 2) {
+ AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ DispatchKernel2D<
+ KNearestNeighborKernelV2Functor,
+ scalar_t,
+ V2_MIN_D,
+ V2_MAX_D,
+ V2_MIN_K,
+ V2_MAX_K>(
+ D,
+ K_64,
+ blocks,
+ threads,
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ norm);
+ }));
+ } else if (version == 3) {
+ AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] {
+ DispatchKernel2D<
+ KNearestNeighborKernelV3Functor,
+ scalar_t,
+ V3_MIN_D,
+ V3_MAX_D,
+ V3_MIN_K,
+ V3_MAX_K>(
+ D,
+ K_64,
+ blocks,
+ threads,
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ dists.data_ptr(),
+ idxs.data_ptr(),
+ N,
+ P1,
+ P2,
+ norm);
+ }));
+ }
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(idxs, dists);
+}
+
+// ------------------------------------------------------------- //
+// Backward Operators //
+// ------------------------------------------------------------- //
+
+// TODO(gkioxari) support all data types once AtomicAdd supports doubles.
+// Currently, support is for floats only.
+__global__ void KNearestNeighborBackwardKernel(
+ const float* __restrict__ p1, // (N, P1, D)
+ const float* __restrict__ p2, // (N, P2, D)
+ const int64_t* __restrict__ lengths1, // (N,)
+ const int64_t* __restrict__ lengths2, // (N,)
+ const int64_t* __restrict__ idxs, // (N, P1, K)
+ const float* __restrict__ grad_dists, // (N, P1, K)
+ float* __restrict__ grad_p1, // (N, P1, D)
+ float* __restrict__ grad_p2, // (N, P2, D)
+ const size_t N,
+ const size_t P1,
+ const size_t P2,
+ const size_t K,
+ const size_t D,
+ const size_t norm) {
+ const size_t tid = blockIdx.x * blockDim.x + threadIdx.x;
+ const size_t stride = gridDim.x * blockDim.x;
+
+ for (size_t i = tid; i < N * P1 * K * D; i += stride) {
+ const size_t n = i / (P1 * K * D); // batch index
+ size_t rem = i % (P1 * K * D);
+ const size_t p1_idx = rem / (K * D); // index of point in p1
+ rem = rem % (K * D);
+ const size_t k = rem / D; // k-th nearest neighbor
+ const size_t d = rem % D; // d-th dimension in the feature vector
+
+ const size_t num1 = lengths1[n]; // number of valid points in p1 in batch
+ const size_t num2 = lengths2[n]; // number of valid points in p2 in batch
+ if ((p1_idx < num1) && (k < num2)) {
+ const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k];
+ // index of point in p2 corresponding to the k-th nearest neighbor
+ const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k];
+ // If the index is the pad value of -1 then ignore it
+ if (p2_idx == -1) {
+ continue;
+ }
+ float diff = 0.0;
+ if (norm == 1) {
+ float sign =
+ (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d])
+ ? 1.0
+ : -1.0;
+ diff = grad_dist * sign;
+ } else { // norm is 2
+ diff = 2.0 * grad_dist *
+ (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]);
+ }
+ atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff);
+ atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff);
+ }
+ }
+}
+
+std::tuple KNearestNeighborBackwardCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ int norm,
+ const at::Tensor& grad_dists) {
+ // Check inputs are on the same device
+ at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2},
+ lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4},
+ idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6};
+ at::CheckedFrom c = "KNearestNeighborBackwardCuda";
+ at::checkAllSameGPU(
+ c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t});
+ at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t});
+
+ // This is nondeterministic because atomicAdd
+ at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda");
+
+ // Set the device for the kernel launch based on the device of the input
+ at::cuda::CUDAGuard device_guard(p1.device());
+ cudaStream_t stream = at::cuda::getCurrentCUDAStream();
+
+ const auto N = p1.size(0);
+ const auto P1 = p1.size(1);
+ const auto P2 = p2.size(1);
+ const auto D = p2.size(2);
+ const auto K = idxs.size(2);
+
+ TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension");
+ TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension");
+ TORCH_CHECK(
+ idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1");
+ TORCH_CHECK(grad_dists.size(0) == N);
+ TORCH_CHECK(grad_dists.size(1) == P1);
+ TORCH_CHECK(grad_dists.size(2) == K);
+
+ auto grad_p1 = at::zeros({N, P1, D}, p1.options());
+ auto grad_p2 = at::zeros({N, P2, D}, p2.options());
+
+ if (grad_p1.numel() == 0 || grad_p2.numel() == 0) {
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(grad_p1, grad_p2);
+ }
+
+ const int blocks = 64;
+ const int threads = 512;
+
+ KNearestNeighborBackwardKernel<<>>(
+ p1.contiguous().data_ptr(),
+ p2.contiguous().data_ptr(),
+ lengths1.contiguous().data_ptr(),
+ lengths2.contiguous().data_ptr(),
+ idxs.contiguous().data_ptr(),
+ grad_dists.contiguous().data_ptr(),
+ grad_p1.data_ptr(),
+ grad_p2.data_ptr(),
+ N,
+ P1,
+ P2,
+ K,
+ D,
+ norm);
+
+ AT_CUDA_CHECK(cudaGetLastError());
+ return std::make_tuple(grad_p1, grad_p2);
+}
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/knn.h b/unik3d/ops/knn/src/knn.h
new file mode 100644
index 0000000000000000000000000000000000000000..e43a0231119bf170502e1078d77110a3cd4510aa
--- /dev/null
+++ b/unik3d/ops/knn/src/knn.h
@@ -0,0 +1,157 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+#include
+#include "utils/pytorch3d_cutils.h"
+
+// Compute indices of K nearest neighbors in pointcloud p2 to points
+// in pointcloud p1.
+//
+// Args:
+// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
+// containing P1 points of dimension D.
+// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
+// containing P2 points of dimension D.
+// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
+// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
+// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
+// K: int giving the number of nearest points to return.
+// version: Integer telling which implementation to use.
+//
+// Returns:
+// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
+// p1_neighbor_idx[n, i, k] = j means that the kth nearest
+// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
+// It is padded with zeros so that it can be used easily in a later
+// gather() operation.
+//
+// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared
+// distance from each point p1[n, p, :] to its K neighbors
+// p2[n, p1_neighbor_idx[n, p, k], :].
+
+// CPU implementation.
+std::tuple KNearestNeighborIdxCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K);
+
+// CUDA implementation
+std::tuple KNearestNeighborIdxCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K,
+ const int version);
+
+// Implementation which is exposed.
+std::tuple KNearestNeighborIdx(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K,
+ const int version) {
+ if (p1.is_cuda() || p2.is_cuda()) {
+#ifdef WITH_CUDA
+ CHECK_CUDA(p1);
+ CHECK_CUDA(p2);
+ return KNearestNeighborIdxCuda(
+ p1, p2, lengths1, lengths2, norm, K, version);
+#else
+ AT_ERROR("Not compiled with GPU support.");
+#endif
+ }
+ return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K);
+}
+
+// Compute gradients with respect to p1 and p2
+//
+// Args:
+// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each
+// containing P1 points of dimension D.
+// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each
+// containing P2 points of dimension D.
+// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud.
+// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud.
+// p1_neighbor_idx: LongTensor of shape (N, P1, K), where
+// p1_neighbor_idx[n, i, k] = j means that the kth nearest
+// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j].
+// It is padded with zeros so that it can be used easily in a later
+// gather() operation. This is computed from the forward pass.
+// norm: int specifying the norm for the distance (1 for L1, 2 for L2)
+// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input
+// gradients.
+//
+// Returns:
+// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients
+// wrt p1.
+// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients
+// wrt p2.
+
+// CPU implementation.
+std::tuple KNearestNeighborBackwardCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists);
+
+// CUDA implementation
+std::tuple KNearestNeighborBackwardCuda(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists);
+
+// Implementation which is exposed.
+std::tuple KNearestNeighborBackward(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists) {
+ if (p1.is_cuda() || p2.is_cuda()) {
+#ifdef WITH_CUDA
+ CHECK_CUDA(p1);
+ CHECK_CUDA(p2);
+ return KNearestNeighborBackwardCuda(
+ p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
+#else
+ AT_ERROR("Not compiled with GPU support.");
+#endif
+ }
+ return KNearestNeighborBackwardCpu(
+ p1, p2, lengths1, lengths2, idxs, norm, grad_dists);
+}
+
+// Utility to check whether a KNN version can be used.
+//
+// Args:
+// version: Integer in the range 0 <= version <= 3 indicating one of our
+// KNN implementations.
+// D: Number of dimensions for the input and query point clouds
+// K: Number of neighbors to be found
+//
+// Returns:
+// Whether the indicated KNN version can be used.
+bool KnnCheckVersion(int version, const int64_t D, const int64_t K);
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/knn_cpu.cpp b/unik3d/ops/knn/src/knn_cpu.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..694ab11ea62c2eb285e1a03f6e4c0862498c24ae
--- /dev/null
+++ b/unik3d/ops/knn/src/knn_cpu.cpp
@@ -0,0 +1,128 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#include
+#include
+#include
+
+std::tuple KNearestNeighborIdxCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const int norm,
+ const int K) {
+ const int N = p1.size(0);
+ const int P1 = p1.size(1);
+ const int D = p1.size(2);
+
+ auto long_opts = lengths1.options().dtype(torch::kInt64);
+ torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts);
+ torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options());
+
+ auto p1_a = p1.accessor();
+ auto p2_a = p2.accessor();
+ auto lengths1_a = lengths1.accessor();
+ auto lengths2_a = lengths2.accessor();
+ auto idxs_a = idxs.accessor();
+ auto dists_a = dists.accessor();
+
+ for (int n = 0; n < N; ++n) {
+ const int64_t length1 = lengths1_a[n];
+ const int64_t length2 = lengths2_a[n];
+ for (int64_t i1 = 0; i1 < length1; ++i1) {
+ // Use a priority queue to store (distance, index) tuples.
+ std::priority_queue> q;
+ for (int64_t i2 = 0; i2 < length2; ++i2) {
+ float dist = 0;
+ for (int d = 0; d < D; ++d) {
+ float diff = p1_a[n][i1][d] - p2_a[n][i2][d];
+ if (norm == 1) {
+ dist += abs(diff);
+ } else { // norm is 2 (default)
+ dist += diff * diff;
+ }
+ }
+ int size = static_cast(q.size());
+ if (size < K || dist < std::get<0>(q.top())) {
+ q.emplace(dist, i2);
+ if (size >= K) {
+ q.pop();
+ }
+ }
+ }
+ while (!q.empty()) {
+ auto t = q.top();
+ q.pop();
+ const int k = q.size();
+ dists_a[n][i1][k] = std::get<0>(t);
+ idxs_a[n][i1][k] = std::get<1>(t);
+ }
+ }
+ }
+ return std::make_tuple(idxs, dists);
+}
+
+// ------------------------------------------------------------- //
+// Backward Operators //
+// ------------------------------------------------------------- //
+
+std::tuple KNearestNeighborBackwardCpu(
+ const at::Tensor& p1,
+ const at::Tensor& p2,
+ const at::Tensor& lengths1,
+ const at::Tensor& lengths2,
+ const at::Tensor& idxs,
+ const int norm,
+ const at::Tensor& grad_dists) {
+ const int N = p1.size(0);
+ const int P1 = p1.size(1);
+ const int D = p1.size(2);
+ const int P2 = p2.size(1);
+ const int K = idxs.size(2);
+
+ torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options());
+ torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options());
+
+ auto p1_a = p1.accessor();
+ auto p2_a = p2.accessor();
+ auto lengths1_a = lengths1.accessor();
+ auto lengths2_a = lengths2.accessor();
+ auto idxs_a = idxs.accessor();
+ auto grad_dists_a = grad_dists.accessor();
+ auto grad_p1_a = grad_p1.accessor();
+ auto grad_p2_a = grad_p2.accessor();
+
+ for (int n = 0; n < N; ++n) {
+ const int64_t length1 = lengths1_a[n];
+ int64_t length2 = lengths2_a[n];
+ length2 = (length2 < K) ? length2 : K;
+ for (int64_t i1 = 0; i1 < length1; ++i1) {
+ for (int64_t k = 0; k < length2; ++k) {
+ const int64_t i2 = idxs_a[n][i1][k];
+ // If the index is the pad value of -1 then ignore it
+ if (i2 == -1) {
+ continue;
+ }
+ for (int64_t d = 0; d < D; ++d) {
+ float diff = 0.0;
+ if (norm == 1) {
+ float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0;
+ diff = grad_dists_a[n][i1][k] * sign;
+ } else { // norm is 2 (default)
+ diff = 2.0f * grad_dists_a[n][i1][k] *
+ (p1_a[n][i1][d] - p2_a[n][i2][d]);
+ }
+ grad_p1_a[n][i1][d] += diff;
+ grad_p2_a[n][i2][d] += -1.0f * diff;
+ }
+ }
+ }
+ }
+ return std::make_tuple(grad_p1, grad_p2);
+}
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/knn_ext.cpp b/unik3d/ops/knn/src/knn_ext.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..f2fc9b4374ea5a2196d8ddee4e3cd5c0a907ddaa
--- /dev/null
+++ b/unik3d/ops/knn/src/knn_ext.cpp
@@ -0,0 +1,10 @@
+#include
+#include "knn.h"
+
+PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
+#ifdef WITH_CUDA
+ m.def("knn_check_version", &KnnCheckVersion);
+#endif
+ m.def("knn_points_idx", &KNearestNeighborIdx);
+ m.def("knn_points_backward", &KNearestNeighborBackward);
+}
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/utils/dispatch.cuh b/unik3d/ops/knn/src/utils/dispatch.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..af197b4392106391236f9f00d999a02e6ac2defa
--- /dev/null
+++ b/unik3d/ops/knn/src/utils/dispatch.cuh
@@ -0,0 +1,357 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+// This file provides utilities for dispatching to specialized versions of
+// functions. This is especially useful for CUDA kernels, since specializing
+// them to particular input sizes can often allow the compiler to unroll loops
+// and place arrays into registers, which can give huge performance speedups.
+//
+// As an example, suppose we have the following function which is specialized
+// based on a compile-time int64_t value:
+//
+// template
+// struct SquareOffset {
+// static void run(T y) {
+// T val = x * x + y;
+// std::cout << val << std::endl;
+// }
+// }
+//
+// This function takes one compile-time argument x, and one run-time argument y.
+// We might want to compile specialized versions of this for x=0, x=1, etc and
+// then dispatch to the correct one based on the runtime value of x.
+// One simple way to achieve this is with a lookup table:
+//
+// template
+// void DispatchSquareOffset(const int64_t x, T y) {
+// if (x == 0) {
+// SquareOffset::run(y);
+// } else if (x == 1) {
+// SquareOffset::run(y);
+// } else if (x == 2) {
+// SquareOffset::run(y);
+// }
+// }
+//
+// This function takes both x and y as run-time arguments, and dispatches to
+// different specialized versions of SquareOffset based on the run-time value
+// of x. This works, but it's tedious and error-prone. If we want to change the
+// set of x values for which we provide compile-time specializations, then we
+// will need to do a lot of tedius editing of the dispatch function. Also, if we
+// want to provide compile-time specializations for another function other than
+// SquareOffset, we will need to duplicate the entire lookup table.
+//
+// To solve these problems, we can use the DispatchKernel1D function provided by
+// this file instead:
+//
+// template
+// void DispatchSquareOffset(const int64_t x, T y) {
+// constexpr int64_t xmin = 0;
+// constexpr int64_t xmax = 2;
+// DispatchKernel1D(x, y);
+// }
+//
+// DispatchKernel1D uses template metaprogramming to compile specialized
+// versions of SquareOffset for all values of x with xmin <= x <= xmax, and
+// then dispatches to the correct one based on the run-time value of x. If we
+// want to change the range of x values for which SquareOffset is specialized
+// at compile-time, then all we have to do is change the values of the
+// compile-time constants xmin and xmax.
+//
+// This file also allows us to similarly dispatch functions that depend on two
+// compile-time int64_t values, using the DispatchKernel2D function like this:
+//
+// template
+// struct Sum {
+// static void run(T z, T w) {
+// T val = x + y + z + w;
+// std::cout << val << std::endl;
+// }
+// }
+//
+// template
+// void DispatchSum(const int64_t x, const int64_t y, int z, int w) {
+// constexpr int64_t xmin = 1;
+// constexpr int64_t xmax = 3;
+// constexpr int64_t ymin = 2;
+// constexpr int64_t ymax = 5;
+// DispatchKernel2D(x, y, z, w);
+// }
+//
+// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to
+// compile specialized versions of sum for all values of (x, y) with
+// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct
+// specialized version based on the runtime values of x and y.
+
+// Define some helper structs in an anonymous namespace.
+namespace {
+
+// 1D dispatch: general case.
+// Kernel is the function we want to dispatch to; it should take a typename and
+// an int64_t as template args, and it should define a static void function
+// run which takes any number of arguments of any type.
+// In order to dispatch, we will take an additional template argument curN,
+// and increment it via template recursion until it is equal to the run-time
+// argument N.
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t curN,
+ typename... Args>
+struct DispatchKernelHelper1D {
+ static void run(const int64_t N, Args... args) {
+ if (curN == N) {
+ // The compile-time value curN is equal to the run-time value N, so we
+ // can dispatch to the run method of the Kernel.
+ Kernel::run(args...);
+ } else if (curN < N) {
+ // Increment curN via template recursion
+ DispatchKernelHelper1D::run(
+ N, args...);
+ }
+ // We shouldn't get here -- throw an error?
+ }
+};
+
+// 1D dispatch: Specialization when curN == maxN
+// We need this base case to avoid infinite template recursion.
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ typename... Args>
+struct DispatchKernelHelper1D {
+ static void run(const int64_t N, Args... args) {
+ if (N == maxN) {
+ Kernel::run(args...);
+ }
+ // We shouldn't get here -- throw an error?
+ }
+};
+
+// 2D dispatch, general case.
+// This is similar to the 1D case: we take additional template args curN and
+// curM, and increment them via template recursion until they are equal to
+// the run-time values of N and M, at which point we dispatch to the run
+// method of the kernel.
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t curN,
+ int64_t minM,
+ int64_t maxM,
+ int64_t curM,
+ typename... Args>
+struct DispatchKernelHelper2D {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (curN == N && curM == M) {
+ Kernel::run(args...);
+ } else if (curN < N && curM < M) {
+ // Increment both curN and curM. This isn't strictly necessary; we could
+ // just increment one or the other at each step. But this helps to cut
+ // on the number of recursive calls we make.
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN + 1,
+ minM,
+ maxM,
+ curM + 1,
+ Args...>::run(N, M, args...);
+ } else if (curN < N) {
+ // Increment curN only
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN + 1,
+ minM,
+ maxM,
+ curM,
+ Args...>::run(N, M, args...);
+ } else if (curM < M) {
+ // Increment curM only
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN,
+ minM,
+ maxM,
+ curM + 1,
+ Args...>::run(N, M, args...);
+ }
+ }
+};
+
+// 2D dispatch, specialization for curN == maxN
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t minM,
+ int64_t maxM,
+ int64_t curM,
+ typename... Args>
+struct DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ maxN,
+ minM,
+ maxM,
+ curM,
+ Args...> {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (maxN == N && curM == M) {
+ Kernel::run(args...);
+ } else if (curM < maxM) {
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ maxN,
+ minM,
+ maxM,
+ curM + 1,
+ Args...>::run(N, M, args...);
+ }
+ // We should not get here -- throw an error?
+ }
+};
+
+// 2D dispatch, specialization for curM == maxM
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t curN,
+ int64_t minM,
+ int64_t maxM,
+ typename... Args>
+struct DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN,
+ minM,
+ maxM,
+ maxM,
+ Args...> {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (curN == N && maxM == M) {
+ Kernel::run(args...);
+ } else if (curN < maxN) {
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ curN + 1,
+ minM,
+ maxM,
+ maxM,
+ Args...>::run(N, M, args...);
+ }
+ // We should not get here -- throw an error?
+ }
+};
+
+// 2D dispatch, specialization for curN == maxN, curM == maxM
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t minM,
+ int64_t maxM,
+ typename... Args>
+struct DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ maxN,
+ minM,
+ maxM,
+ maxM,
+ Args...> {
+ static void run(const int64_t N, const int64_t M, Args... args) {
+ if (maxN == N && maxM == M) {
+ Kernel::run(args...);
+ }
+ // We should not get here -- throw an error?
+ }
+};
+
+} // namespace
+
+// This is the function we expect users to call to dispatch to 1D functions
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ typename... Args>
+void DispatchKernel1D(const int64_t N, Args... args) {
+ if (minN <= N && N <= maxN) {
+ // Kick off the template recursion by calling the Helper with curN = minN
+ DispatchKernelHelper1D::run(
+ N, args...);
+ }
+ // Maybe throw an error if we tried to dispatch outside the allowed range?
+}
+
+// This is the function we expect users to call to dispatch to 2D functions
+template <
+ template
+ class Kernel,
+ typename T,
+ int64_t minN,
+ int64_t maxN,
+ int64_t minM,
+ int64_t maxM,
+ typename... Args>
+void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) {
+ if (minN <= N && N <= maxN && minM <= M && M <= maxM) {
+ // Kick off the template recursion by calling the Helper with curN = minN
+ // and curM = minM
+ DispatchKernelHelper2D<
+ Kernel,
+ T,
+ minN,
+ maxN,
+ minN,
+ minM,
+ maxM,
+ minM,
+ Args...>::run(N, M, args...);
+ }
+ // Maybe throw an error if we tried to dispatch outside the specified range?
+}
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/utils/index_utils.cuh b/unik3d/ops/knn/src/utils/index_utils.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..d3f7f7afa0d9fdc8c2c23187cd1c12d2ccd670f5
--- /dev/null
+++ b/unik3d/ops/knn/src/utils/index_utils.cuh
@@ -0,0 +1,224 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+// This converts dynamic array lookups into static array lookups, for small
+// arrays up to size 32.
+//
+// Suppose we have a small thread-local array:
+//
+// float vals[10];
+//
+// Ideally we should only index this array using static indices:
+//
+// for (int i = 0; i < 10; ++i) vals[i] = i * i;
+//
+// If we do so, then the CUDA compiler may be able to place the array into
+// registers, which can have a big performance improvement. However if we
+// access the array dynamically, the the compiler may force the array into
+// local memory, which has the same latency as global memory.
+//
+// These functions convert dynamic array access into static array access
+// using a brute-force lookup table. It can be used like this:
+//
+// float vals[10];
+// int idx = 3;
+// float val = 3.14f;
+// RegisterIndexUtils::set(vals, idx, val);
+// float val2 = RegisterIndexUtils::get(vals, idx);
+//
+// The implementation is based on fbcuda/RegisterUtils.cuh:
+// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh
+// To avoid depending on the entire library, we just reimplement these two
+// functions. The fbcuda implementation is a bit more sophisticated, and uses
+// the preprocessor to generate switch statements that go up to N for each
+// value of N. We are lazy and just have a giant explicit switch statement.
+//
+// We might be able to use a template metaprogramming approach similar to
+// DispatchKernel1D for this. However DispatchKernel1D is intended to be used
+// for dispatching to the correct CUDA kernel on the host, while this is
+// is intended to run on the device. I was concerned that a metaprogramming
+// approach for this might lead to extra function calls at runtime if the
+// compiler fails to optimize them away, which could be very slow on device.
+// However I didn't actually benchmark or test this.
+template
+struct RegisterIndexUtils {
+ __device__ __forceinline__ static T get(const T arr[N], int idx) {
+ if (idx < 0 || idx >= N)
+ return T();
+ switch (idx) {
+ case 0:
+ return arr[0];
+ case 1:
+ return arr[1];
+ case 2:
+ return arr[2];
+ case 3:
+ return arr[3];
+ case 4:
+ return arr[4];
+ case 5:
+ return arr[5];
+ case 6:
+ return arr[6];
+ case 7:
+ return arr[7];
+ case 8:
+ return arr[8];
+ case 9:
+ return arr[9];
+ case 10:
+ return arr[10];
+ case 11:
+ return arr[11];
+ case 12:
+ return arr[12];
+ case 13:
+ return arr[13];
+ case 14:
+ return arr[14];
+ case 15:
+ return arr[15];
+ case 16:
+ return arr[16];
+ case 17:
+ return arr[17];
+ case 18:
+ return arr[18];
+ case 19:
+ return arr[19];
+ case 20:
+ return arr[20];
+ case 21:
+ return arr[21];
+ case 22:
+ return arr[22];
+ case 23:
+ return arr[23];
+ case 24:
+ return arr[24];
+ case 25:
+ return arr[25];
+ case 26:
+ return arr[26];
+ case 27:
+ return arr[27];
+ case 28:
+ return arr[28];
+ case 29:
+ return arr[29];
+ case 30:
+ return arr[30];
+ case 31:
+ return arr[31];
+ };
+ return T();
+ }
+
+ __device__ __forceinline__ static void set(T arr[N], int idx, T val) {
+ if (idx < 0 || idx >= N)
+ return;
+ switch (idx) {
+ case 0:
+ arr[0] = val;
+ break;
+ case 1:
+ arr[1] = val;
+ break;
+ case 2:
+ arr[2] = val;
+ break;
+ case 3:
+ arr[3] = val;
+ break;
+ case 4:
+ arr[4] = val;
+ break;
+ case 5:
+ arr[5] = val;
+ break;
+ case 6:
+ arr[6] = val;
+ break;
+ case 7:
+ arr[7] = val;
+ break;
+ case 8:
+ arr[8] = val;
+ break;
+ case 9:
+ arr[9] = val;
+ break;
+ case 10:
+ arr[10] = val;
+ break;
+ case 11:
+ arr[11] = val;
+ break;
+ case 12:
+ arr[12] = val;
+ break;
+ case 13:
+ arr[13] = val;
+ break;
+ case 14:
+ arr[14] = val;
+ break;
+ case 15:
+ arr[15] = val;
+ break;
+ case 16:
+ arr[16] = val;
+ break;
+ case 17:
+ arr[17] = val;
+ break;
+ case 18:
+ arr[18] = val;
+ break;
+ case 19:
+ arr[19] = val;
+ break;
+ case 20:
+ arr[20] = val;
+ break;
+ case 21:
+ arr[21] = val;
+ break;
+ case 22:
+ arr[22] = val;
+ break;
+ case 23:
+ arr[23] = val;
+ break;
+ case 24:
+ arr[24] = val;
+ break;
+ case 25:
+ arr[25] = val;
+ break;
+ case 26:
+ arr[26] = val;
+ break;
+ case 27:
+ arr[27] = val;
+ break;
+ case 28:
+ arr[28] = val;
+ break;
+ case 29:
+ arr[29] = val;
+ break;
+ case 30:
+ arr[30] = val;
+ break;
+ case 31:
+ arr[31] = val;
+ break;
+ }
+ }
+};
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/utils/mink.cuh b/unik3d/ops/knn/src/utils/mink.cuh
new file mode 100644
index 0000000000000000000000000000000000000000..7512aabc2ca2a536eba8b4f14562ef47718cd064
--- /dev/null
+++ b/unik3d/ops/knn/src/utils/mink.cuh
@@ -0,0 +1,165 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#define MINK_H
+
+#include "index_utils.cuh"
+
+// A data structure to keep track of the smallest K keys seen so far as well
+// as their associated values, intended to be used in device code.
+// This data structure doesn't allocate any memory; keys and values are stored
+// in arrays passed to the constructor.
+//
+// The implementation is generic; it can be used for any key type that supports
+// the < operator, and can be used with any value type.
+//
+// Example usage:
+//
+// float keys[K];
+// int values[K];
+// MinK mink(keys, values, K);
+// for (...) {
+// // Produce some key and value from somewhere
+// mink.add(key, value);
+// }
+// mink.sort();
+//
+// Now keys and values store the smallest K keys seen so far and the values
+// associated to these keys:
+//
+// for (int k = 0; k < K; ++k) {
+// float key_k = keys[k];
+// int value_k = values[k];
+// }
+template
+class MinK {
+ public:
+ // Constructor.
+ //
+ // Arguments:
+ // keys: Array in which to store keys
+ // values: Array in which to store values
+ // K: How many values to keep track of
+ __device__ MinK(key_t* keys, value_t* vals, int K)
+ : keys(keys), vals(vals), K(K), _size(0) {}
+
+ // Try to add a new key and associated value to the data structure. If the key
+ // is one of the smallest K seen so far then it will be kept; otherwise it
+ // it will not be kept.
+ //
+ // This takes O(1) operations if the new key is not kept, or if the structure
+ // currently contains fewer than K elements. Otherwise this takes O(K) time.
+ //
+ // Arguments:
+ // key: The key to add
+ // val: The value associated to the key
+ __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
+ if (_size < K) {
+ keys[_size] = key;
+ vals[_size] = val;
+ if (_size == 0 || key > max_key) {
+ max_key = key;
+ max_idx = _size;
+ }
+ _size++;
+ } else if (key < max_key) {
+ keys[max_idx] = key;
+ vals[max_idx] = val;
+ max_key = key;
+ for (int k = 0; k < K; ++k) {
+ key_t cur_key = keys[k];
+ if (cur_key > max_key) {
+ max_key = cur_key;
+ max_idx = k;
+ }
+ }
+ }
+ }
+
+ // Get the number of items currently stored in the structure.
+ // This takes O(1) time.
+ __device__ __forceinline__ int size() {
+ return _size;
+ }
+
+ // Sort the items stored in the structure using bubble sort.
+ // This takes O(K^2) time.
+ __device__ __forceinline__ void sort() {
+ for (int i = 0; i < _size - 1; ++i) {
+ for (int j = 0; j < _size - i - 1; ++j) {
+ if (keys[j + 1] < keys[j]) {
+ key_t key = keys[j];
+ value_t val = vals[j];
+ keys[j] = keys[j + 1];
+ vals[j] = vals[j + 1];
+ keys[j + 1] = key;
+ vals[j + 1] = val;
+ }
+ }
+ }
+ }
+
+ private:
+ key_t* keys;
+ value_t* vals;
+ int K;
+ int _size;
+ key_t max_key;
+ int max_idx;
+};
+
+// This is a version of MinK that only touches the arrays using static indexing
+// via RegisterIndexUtils. If the keys and values are stored in thread-local
+// arrays, then this may allow the compiler to place them in registers for
+// fast access.
+//
+// This has the same API as RegisterMinK, but doesn't support sorting.
+// We found that sorting via RegisterIndexUtils gave very poor performance,
+// and suspect it may have prevented the compiler from placing the arrays
+// into registers.
+template
+class RegisterMinK {
+ public:
+ __device__ RegisterMinK(key_t* keys, value_t* vals)
+ : keys(keys), vals(vals), _size(0) {}
+
+ __device__ __forceinline__ void add(const key_t& key, const value_t& val) {
+ if (_size < K) {
+ RegisterIndexUtils::set(keys, _size, key);
+ RegisterIndexUtils::set(vals, _size, val);
+ if (_size == 0 || key > max_key) {
+ max_key = key;
+ max_idx = _size;
+ }
+ _size++;
+ } else if (key < max_key) {
+ RegisterIndexUtils::set(keys, max_idx, key);
+ RegisterIndexUtils::set(vals, max_idx, val);
+ max_key = key;
+ for (int k = 0; k < K; ++k) {
+ key_t cur_key = RegisterIndexUtils::get(keys, k);
+ if (cur_key > max_key) {
+ max_key = cur_key;
+ max_idx = k;
+ }
+ }
+ }
+ }
+
+ __device__ __forceinline__ int size() {
+ return _size;
+ }
+
+ private:
+ key_t* keys;
+ value_t* vals;
+ int _size;
+ key_t max_key;
+ int max_idx;
+};
\ No newline at end of file
diff --git a/unik3d/ops/knn/src/utils/pytorch3d_cutils.h b/unik3d/ops/knn/src/utils/pytorch3d_cutils.h
new file mode 100644
index 0000000000000000000000000000000000000000..c46b5aea5682bbef6960c7eb2cc1bf052fb6e32a
--- /dev/null
+++ b/unik3d/ops/knn/src/utils/pytorch3d_cutils.h
@@ -0,0 +1,17 @@
+/*
+ * Copyright (c) Meta Platforms, Inc. and affiliates.
+ * All rights reserved.
+ *
+ * This source code is licensed under the BSD-style license found in the
+ * LICENSE file in the root directory of this source tree.
+ */
+
+#pragma once
+#include
+
+#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.")
+#define CHECK_CONTIGUOUS(x) \
+ TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.")
+#define CHECK_CONTIGUOUS_CUDA(x) \
+ CHECK_CUDA(x); \
+ CHECK_CONTIGUOUS(x)
\ No newline at end of file
diff --git a/unik3d/ops/losses/__init__.py b/unik3d/ops/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..07032e80bee5dd6da40eafb943919a5034296c1d
--- /dev/null
+++ b/unik3d/ops/losses/__init__.py
@@ -0,0 +1,22 @@
+from .confidence import Confidence
+from .dummy import Dummy
+from .edge import SpatialGradient
+from .local_ssi import LocalSSI
+from .normals import LocalNormal
+from .regression import PolarRegression, Regression
+from .robust_loss import RobustLoss
+from .scale import Scale
+from .silog import SILog
+
+__all__ = [
+ "Confidence",
+ "Dummy",
+ "SpatialGradient",
+ "LocalSSI",
+ "Regression",
+ "LocalNormal",
+ "RobustLoss",
+ "SILog",
+ "Scale",
+ "PolarRegression",
+]
diff --git a/unik3d/ops/losses/confidence.py b/unik3d/ops/losses/confidence.py
new file mode 100644
index 0000000000000000000000000000000000000000..dca6ea4f63fadc10b3916e82ef2b82f076137219
--- /dev/null
+++ b/unik3d/ops/losses/confidence.py
@@ -0,0 +1,63 @@
+import torch
+import torch.nn as nn
+
+from .utils import FNS, masked_mean
+
+
+class Confidence(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ output_fn: str = "sqrt",
+ input_fn: str = "linear",
+ rescale: bool = True,
+ eps: float = 1e-5,
+ ):
+ super(Confidence, self).__init__()
+ self.name: str = self.__class__.__name__
+ self.weight = weight
+ self.rescale = rescale
+ self.eps = eps
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target_pred: torch.Tensor,
+ target_gt: torch.Tensor,
+ mask: torch.Tensor,
+ ):
+ B, C = target_gt.shape[:2]
+ mask = mask.bool()
+ target_gt = target_gt.float().reshape(B, C, -1)
+ target_pred = target_pred.float().reshape(B, C, -1)
+ input = input.float().reshape(B, -1)
+ mask = mask.reshape(B, -1)
+
+ if self.rescale:
+ target_pred = torch.stack(
+ [
+ p * torch.median(gt[:, m]) / torch.median(p[:, m])
+ for p, gt, m in zip(target_pred, target_gt, mask)
+ ]
+ )
+
+ error = torch.abs(
+ (self.input_fn(target_pred) - self.input_fn(target_gt)).norm(dim=1) - input
+ )
+ losses = masked_mean(error, dim=[-1], mask=mask).squeeze(dim=-1)
+ losses = self.output_fn(losses)
+
+ return losses
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ output_fn=config["output_fn"],
+ input_fn=config["input_fn"],
+ rescale=config.get("rescale", True),
+ )
+ return obj
diff --git a/unik3d/ops/losses/dummy.py b/unik3d/ops/losses/dummy.py
new file mode 100644
index 0000000000000000000000000000000000000000..77a4c99dc888bf2a5d80c7d324257db863622d28
--- /dev/null
+++ b/unik3d/ops/losses/dummy.py
@@ -0,0 +1,17 @@
+import torch
+import torch.nn as nn
+
+
+class Dummy(nn.Module):
+ def __init__(self, *args, **kwargs):
+ super().__init__()
+ self.name: str = self.__class__.__name__
+ self.weight = 1.0
+
+ def forward(self, dummy: torch.Tensor, *args, **kwargs) -> torch.Tensor:
+ return torch.tensor([0.0] * dummy.shape[0], device=dummy.device)
+
+ @classmethod
+ def build(cls, config):
+ obj = cls()
+ return obj
diff --git a/unik3d/ops/losses/edge.py b/unik3d/ops/losses/edge.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e5265532a2eb462cf94b2858483c48df6d8922
--- /dev/null
+++ b/unik3d/ops/losses/edge.py
@@ -0,0 +1,191 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from unik3d.utils.constants import VERBOSE
+from unik3d.utils.geometric import dilate, erode
+from unik3d.utils.misc import profile_method
+
+from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var,
+ masked_quantile)
+
+
+class SpatialGradient(torch.nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ input_fn: str,
+ output_fn: str,
+ fn: str,
+ scales: int = 1,
+ gamma: float = 1.0,
+ quantile: float = 0.0,
+ laplacian: bool = False,
+ canny_edge: bool = False,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name: str = self.__class__.__name__
+ self.weight = weight
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+ self.fn = REGRESSION_DICT[fn]
+ self.gamma = gamma
+ sobel_kernel_x = (
+ torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ )
+ sobel_kernel_y = (
+ torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ )
+ laplacian_kernel = (
+ torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32)
+ .unsqueeze(0)
+ .unsqueeze(0)
+ )
+ ones = torch.ones(1, 1, 3, 3, dtype=torch.float32)
+ self.sobel_kernel_x = nn.Parameter(sobel_kernel_x, requires_grad=False)
+ self.sobel_kernel_y = nn.Parameter(sobel_kernel_y, requires_grad=False)
+ self.ones = nn.Parameter(ones, requires_grad=False)
+ self.laplacian_kernel = nn.Parameter(laplacian_kernel, requires_grad=False)
+
+ self.quantile = quantile
+ self.scales = scales
+ self.laplacian = laplacian
+ self.canny_edge = canny_edge
+
+ @profile_method(verbose=VERBOSE)
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input,
+ target,
+ mask,
+ quality=None,
+ ):
+ B = input.shape[0]
+ input = self.input_fn(input.float())
+ target = self.input_fn(target.float())
+
+ # normalize to avoid scale issue, shift is not important as we are computing gradients
+ input_mean, input_var = masked_mean_var(input.detach(), mask, dim=(-3, -2, -1))
+ target_mean, target_var = masked_mean_var(target, mask, dim=(-3, -2, -1))
+ input = (input - input_mean) / (input_var + 1e-6) ** 0.5
+ target = (target - target_mean) / (target_var + 1e-6) ** 0.5
+
+ loss = 0.0
+ norm_factor = sum([(i + 1) ** 2 for i in range(self.scales)])
+ for scale in range(self.scales):
+ if scale > 0:
+ input = F.interpolate(
+ input,
+ size=(input.shape[-2] // 2, input.shape[-1] // 2),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+ target = F.interpolate(
+ target,
+ size=(target.shape[-2] // 2, target.shape[-1] // 2),
+ mode="nearest",
+ )
+ mask = (
+ F.interpolate(
+ mask.float(),
+ size=(mask.shape[-2] // 2, mask.shape[-1] // 2),
+ mode="nearest",
+ )
+ > 0.9
+ )
+ grad_loss = self.loss(input, target, mask, quality)
+ # keep per pixel same weight
+ loss = loss + grad_loss * (self.scales - scale) ** 2 / norm_factor
+
+ loss = self.output_fn(loss)
+ return loss
+
+ def loss(self, input, target, mask, quality):
+ device, dtype = input.device, input.dtype
+ B, C, H, W = input.shape
+
+ # sobel
+ input_edge_x = (
+ F.conv2d(input, self.sobel_kernel_x.repeat(C, 1, 1, 1), groups=C) / 8
+ )
+ target_edge_x = (
+ F.conv2d(target, self.sobel_kernel_x.repeat(C, 1, 1, 1), groups=C) / 8
+ )
+ input_edge_y = (
+ F.conv2d(input, self.sobel_kernel_y.repeat(C, 1, 1, 1), groups=C) / 8
+ )
+ target_edge_y = (
+ F.conv2d(target, self.sobel_kernel_y.repeat(C, 1, 1, 1), groups=C) / 8
+ )
+ input_edge = torch.stack([input_edge_x, input_edge_y], dim=-1)
+ target_edge = torch.stack([target_edge_x, target_edge_y], dim=-1)
+
+ mask = F.conv2d(mask.clone().to(input.dtype), self.ones) == 9
+ mask = mask.squeeze(1)
+
+ error = input_edge - target_edge
+ error = error.norm(dim=-1).norm(
+ dim=1
+ ) # take RMSE over xy-dir (isotropic) and over channel-dir (isotropic)
+
+ if quality is not None:
+ for quality_level in [1, 2]:
+ current_quality = quality == quality_level
+ if current_quality.sum() > 0:
+ error_qtl = error[current_quality].detach()
+ mask_qtl = error_qtl < masked_quantile(
+ error_qtl,
+ mask[current_quality],
+ dims=[1, 2],
+ q=1 - self.quantile * quality_level,
+ ).view(-1, 1, 1)
+ mask[current_quality] = mask[current_quality] & mask_qtl
+ else:
+ error_qtl = error.detach()
+ mask = mask & (
+ error_qtl
+ < masked_quantile(
+ error_qtl, mask, dims=[1, 2], q=1 - self.quantile
+ ).view(-1, 1, 1)
+ )
+
+ loss = masked_mean(error, mask, dim=(-2, -1)).squeeze(dim=(-2, -1))
+
+ if self.laplacian:
+ input_laplacian = (
+ F.conv2d(input, self.laplacian_kernel.repeat(C, 1, 1, 1), groups=C) / 8
+ )
+ target_laplacian = (
+ F.conv2d(target, self.laplacian_kernel.repeat(C, 1, 1, 1), groups=C) / 8
+ )
+ error_laplacian = self.fn(
+ input_laplacian - target_laplacian, gamma=self.gamma
+ )
+ error_laplacian = (torch.mean(error_laplacian**2, dim=1) + 1e-6) ** 0.5
+ loss_laplacian = masked_mean(error_laplacian, mask, dim=(-2, -1)).squeeze(
+ dim=(-2, -1)
+ )
+ loss = loss + 0.1 * loss_laplacian
+
+ return loss
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ input_fn=config["input_fn"],
+ output_fn=config["output_fn"],
+ fn=config["fn"],
+ gamma=config["gamma"],
+ quantile=config["quantile"],
+ scales=config["scales"],
+ laplacian=config["laplacian"],
+ )
+ return obj
diff --git a/unik3d/ops/losses/local_ssi.py b/unik3d/ops/losses/local_ssi.py
new file mode 100644
index 0000000000000000000000000000000000000000..933506c5a581d3cc7f4fb2b1021b314c08c60dfd
--- /dev/null
+++ b/unik3d/ops/losses/local_ssi.py
@@ -0,0 +1,315 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from unik3d.utils.constants import VERBOSE
+from unik3d.utils.geometric import downsample, erode
+from unik3d.utils.misc import profile_method
+
+from .utils import (FNS, REGRESSION_DICT, ind2sub, masked_mean,
+ masked_quantile, ssi, ssi_nd)
+
+
+def sample_strong_edges(edges_img, quantile=0.95, reshape=8):
+ # flat
+ edges_img = F.interpolate(
+ edges_img, scale_factor=1 / reshape, mode="bilinear", align_corners=False
+ )
+ edges_img_flat = edges_img.flatten(1)
+
+ # Find strong edges
+ edges_mask = edges_img_flat > torch.quantile(
+ edges_img_flat, quantile, dim=-1, keepdim=True
+ )
+ num_samples = edges_mask.sum(dim=-1)
+ if (num_samples < 10).any():
+ # sample random edges where num_samples < 2
+ random = torch.rand_like(edges_img_flat[num_samples < 10, :]) > quantile
+ edges_mask[num_samples < 10, :] = torch.logical_or(
+ edges_mask[num_samples < 10, :], random
+ )
+ num_samples = edges_mask.sum(dim=-1)
+
+ min_samples = num_samples.min()
+
+ # Compute the coordinates of the strong edges as B, N, 2
+ edges_coords = torch.stack(
+ [torch.nonzero(x, as_tuple=False)[:min_samples].squeeze() for x in edges_mask]
+ )
+ edges_coords = (
+ torch.stack(ind2sub(edges_coords, edges_img.shape[-1]), dim=-1) * reshape
+ )
+ return edges_coords
+
+
+@torch.jit.script
+def extract_patches(tensor, sample_coords, patch_size: tuple[int, int] = (32, 32)):
+ """
+ Extracts patches around specified edge coordinates, with zero padding.
+
+ Parameters:
+ - tensor: tenosr to be gatherd based on sampling (B, 1, H, W).
+ - sample_coords: Batch of edge coordinates as a PyTorch tensor of shape (B, num_coords, 2).
+ - patch_size: Tuple (width, height) representing the size of the patches.
+
+ Returns:
+ - Patches as a PyTorch tensor of shape (B, num_coords, patch_height, patch_width).
+ """
+
+ N, _, H, W = tensor.shape
+ device = tensor.device
+ dtype = tensor.dtype
+ patch_width, patch_height = patch_size
+ pad_width = patch_width // 2
+ pad_height = patch_height // 2
+
+ # Pad the RGB images for both sheep
+ tensor_padded = F.pad(
+ tensor,
+ (pad_width, pad_width, pad_height, pad_height),
+ mode="constant",
+ value=0.0,
+ )
+
+ # Adjust edge coordinates to account for padding
+ sample_coords_padded = sample_coords + torch.tensor(
+ [pad_height, pad_width], dtype=dtype, device=device
+ ).reshape(1, 1, 2)
+
+ # Calculate the indices for gather operation
+ x_centers = sample_coords_padded[:, :, 1].int()
+ y_centers = sample_coords_padded[:, :, 0].int()
+
+ all_patches = []
+ for tensor_i, x_centers_i, y_centers_i in zip(tensor_padded, x_centers, y_centers):
+ patches = []
+ for x_center, y_center in zip(x_centers_i, y_centers_i):
+ y_start, y_end = y_center - pad_height, y_center + pad_height + 1
+ x_start, x_end = x_center - pad_width, x_center + pad_width + 1
+ patches.append(tensor_i[..., y_start:y_end, x_start:x_end])
+ all_patches.append(torch.stack(patches, dim=0))
+
+ return torch.stack(all_patches, dim=0).reshape(N, -1, patch_height * patch_width)
+
+
+class LocalSSI(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ output_fn: str = "sqrt",
+ patch_size: tuple[int, int] = (32, 32),
+ min_samples: int = 4,
+ num_levels: int = 4,
+ fn: str = "l1",
+ rescale_fn: str = "ssi",
+ input_fn: str = "linear",
+ quantile: float = 0.1,
+ gamma: float = 1.0,
+ alpha: float = 1.0,
+ relative: bool = False,
+ eps: float = 1e-5,
+ ):
+ super(LocalSSI, self).__init__()
+ self.name: str = self.__class__.__name__
+ self.weight = weight
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+ self.fn = REGRESSION_DICT[fn]
+ self.min_samples = min_samples
+ self.eps = eps
+ patch_logrange = np.linspace(
+ start=np.log2(min(patch_size)),
+ stop=np.log2(max(patch_size)),
+ endpoint=True,
+ num=num_levels + 1,
+ )
+ self.patch_logrange = [
+ (x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:])
+ ]
+ self.rescale_fn = eval(rescale_fn)
+ self.quantile = quantile
+ self.gamma = gamma
+ self.alpha = alpha
+ self.relative = relative
+
+ @profile_method(verbose=VERBOSE)
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ quality: torch.Tensor = None,
+ down_ratio: int = 1,
+ *args,
+ **kwargs,
+ ) -> torch.Tensor:
+ mask = mask.bool()
+
+ if down_ratio > 1:
+ input = downsample(input, down_ratio)
+ target = downsample(target, down_ratio)
+ # downsample will ignore 0s in the patch "min", if there is a 1 -> set mask to 1 there
+ mask = downsample(mask.float(), down_ratio).bool()
+
+ input = self.input_fn(input.float())
+ target = self.input_fn(target.float())
+ B, C, H, W = input.shape
+ total_errors = []
+
+ # save = random() < - 0.001 and is_main_process()
+ for ii, patch_logrange in enumerate(self.patch_logrange):
+
+ log_kernel = (
+ np.random.uniform(*patch_logrange)
+ if self.training
+ else np.mean(patch_logrange)
+ )
+ kernel_size = int(
+ (2**log_kernel) * min(input.shape[-2:])
+ ) # always smaller than min_shape
+ kernel_size = (kernel_size, kernel_size)
+ stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9))
+ # stride = kernel_size
+
+ # unfold is always exceeding right/bottom, roll image only negative
+ # to have them back in the unfolding window
+ max_roll = (
+ (W - kernel_size[1]) % stride[1],
+ (H - kernel_size[0]) % stride[0],
+ )
+ roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint(
+ -max_roll[1], 1
+ )
+ input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3))
+ target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3))
+ mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3))
+
+ # unfold in patches
+ input_fold = F.unfold(
+ input_fold, kernel_size=kernel_size, stride=stride
+ ).permute(
+ 0, 2, 1
+ ) # B N C*H_p*W_p
+ target_fold = F.unfold(
+ target_fold, kernel_size=kernel_size, stride=stride
+ ).permute(0, 2, 1)
+ mask_fold = (
+ F.unfold(mask_fold, kernel_size=kernel_size, stride=stride)
+ .bool()
+ .permute(0, 2, 1)
+ )
+
+ # calculate error patchwise, then mean over patch, then over image based if sample size is significant
+ input_fold, target_fold, _ = self.rescale_fn(
+ input_fold, target_fold, mask_fold, dim=(-1,)
+ )
+ error = self.fn(
+ input_fold - target_fold, alpha=self.alpha, gamma=self.gamma
+ )
+
+ # calculate elements more then 95 percentile and lower than 5percentile of error
+ if quality is not None:
+ N_patches = mask_fold.shape[1]
+ for quality_level in [1, 2]:
+ current_quality = quality == quality_level
+ if current_quality.sum() > 0:
+ error_qtl = error[current_quality].detach()
+ mask_qtl = error_qtl < masked_quantile(
+ error_qtl,
+ mask_fold[current_quality],
+ dims=[2],
+ q=1 - self.quantile * quality_level,
+ ).view(-1, N_patches, 1)
+ mask_fold[current_quality] = (
+ mask_fold[current_quality] & mask_qtl
+ )
+ else:
+ error_qtl = error.detach()
+ mask_fold = mask_fold & (
+ error_qtl
+ < masked_quantile(
+ error_qtl, mask_fold, dims=[2], q=1 - self.quantile
+ ).view(B, -1, 1)
+ )
+
+ valid_patches = mask_fold.sum(dim=-1) >= self.min_samples
+ error_mean_patch = masked_mean(error, mask_fold, dim=(-1,)).squeeze(-1)
+ error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps))
+ error_mean_image = masked_mean(
+ error_mean_image, mask=valid_patches, dim=(-1,)
+ )
+ total_errors.append(error_mean_image.squeeze(-1))
+
+ # global
+ input_rescale = input.reshape(B, C, -1).clone()
+ target_rescale = target.reshape(B, C, -1)
+ mask = mask.reshape(B, 1, -1).clone()
+ input, target, _ = self.rescale_fn(
+ input_rescale,
+ target_rescale,
+ mask,
+ dim=(-1,),
+ target_info=target_rescale.norm(dim=1, keepdim=True),
+ input_info=input_rescale.norm(dim=1, keepdim=True),
+ )
+ error = input - target
+ error = error.norm(dim=1) if C > 1 else error.squeeze(1)
+ if self.relative:
+ error = error * torch.log(
+ 1.0 + 10.0 / target_rescale.norm(dim=1).clip(min=0.01)
+ )
+
+ error = self.fn(error, alpha=self.alpha, gamma=self.gamma).squeeze(1)
+
+ mask = mask.squeeze(1)
+ valid_patches = mask.sum(dim=-1) >= 3 * self.min_samples # 30 samples per image
+ if quality is not None:
+ for quality_level in [1, 2]:
+ current_quality = quality == quality_level
+ if current_quality.sum() > 0:
+ error_qtl = error[current_quality].detach()
+ mask_qtl = error_qtl < masked_quantile(
+ error_qtl,
+ mask[current_quality],
+ dims=[1],
+ q=1 - self.quantile * quality_level,
+ ).view(-1, 1)
+ mask[current_quality] = mask[current_quality] & mask_qtl
+ else:
+ error_qtl = error.detach()
+ mask = mask & (
+ error_qtl
+ < masked_quantile(error_qtl, mask, dims=[1], q=1 - self.quantile).view(
+ -1, 1
+ )
+ )
+
+ error_mean_image = masked_mean(error, mask, dim=(-1,)).squeeze(-1)
+ error_mean_image = (
+ self.output_fn(error_mean_image.clamp(min=self.eps)) * valid_patches.float()
+ )
+
+ total_errors.append(error_mean_image)
+
+ errors = torch.stack(total_errors).mean(dim=0)
+ return errors
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ patch_size=config["patch_size"],
+ output_fn=config["output_fn"],
+ min_samples=config["min_samples"],
+ num_levels=config["num_levels"],
+ input_fn=config["input_fn"],
+ quantile=config["quantile"],
+ gamma=config["gamma"],
+ alpha=config["alpha"],
+ rescale_fn=config["rescale_fn"],
+ fn=config["fn"],
+ relative=config["relative"],
+ )
+ return obj
diff --git a/unik3d/ops/losses/normals.py b/unik3d/ops/losses/normals.py
new file mode 100644
index 0000000000000000000000000000000000000000..18f243cb70699d635f5eaae71514239a32e18f17
--- /dev/null
+++ b/unik3d/ops/losses/normals.py
@@ -0,0 +1,229 @@
+import itertools
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from unik3d.utils.geometric import dilate, downsample, erode
+
+from .utils import FNS, masked_mean, masked_quantile
+
+
+class LocalNormal(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ output_fn: str = "sqrt",
+ min_samples: int = 4,
+ quantile: float = 0.2,
+ eps: float = 1e-5,
+ ):
+ super(LocalNormal, self).__init__()
+ self.name: str = self.__class__.__name__
+ self.weight = weight
+ self.output_fn = FNS[output_fn]
+ self.min_samples = min_samples
+ self.eps = eps
+ self.patch_weight = torch.ones(1, 1, 3, 3, device="cuda")
+ self.quantile = quantile
+
+ def bilateral_filter(self, rgb, surf, mask, patch_size=(9, 9)):
+ B, _, H, W = rgb.shape
+ sigma_surf = 0.4
+ sigma_color = 0.3
+ sigma_loc = 0.3 * max(H, W)
+
+ grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W))
+ grid = torch.stack([grid_x, grid_y], dim=0).to(rgb.device)
+ grid = grid.unsqueeze(0).repeat(B, 1, 1, 1)
+
+ paddings = [patch_size[0] // 2, patch_size[1] // 2]
+ rgbd = torch.cat([rgb, grid.float(), surf], dim=1)
+
+ # format to B,H*W,C,H_p*W_p format
+ rgbd_neigh = F.pad(rgbd, 2 * paddings, mode="constant")
+ rgbd_neigh = F.unfold(rgbd_neigh, kernel_size=patch_size)
+ rgbd_neigh = rgbd_neigh.permute(0, 2, 1).reshape(
+ B, H * W, 8, -1
+ ) # B N 8 H_p*W_p
+ mask_neigh = F.pad(mask.float(), 2 * paddings, mode="constant")
+ mask_neigh = F.unfold(mask_neigh, kernel_size=patch_size)
+ mask_neigh = mask_neigh.permute(0, 2, 1).reshape(B, H * W, -1)
+ rgbd = rgbd.permute(0, 2, 3, 1).reshape(B, H * W, 8, 1) # B H*W 8 1
+ rgb_neigh = rgbd_neigh[:, :, :3, :]
+ grid_neigh = rgbd_neigh[:, :, 3:5, :]
+ surf_neigh = rgbd_neigh[:, :, 5:, :]
+ rgb = rgbd[:, :, :3, :]
+ grid = rgbd[:, :, 3:5, :]
+ surf = rgbd[:, :, 5:, :]
+
+ # calc distance
+ rgb_dist = torch.norm(rgb - rgb_neigh, dim=-2, p=2) ** 2
+ grid_dist = torch.norm(grid - grid_neigh, dim=-2, p=2) ** 2
+ surf_dist = torch.norm(surf - surf_neigh, dim=-2, p=2) ** 2
+ rgb_sim = torch.exp(-rgb_dist / 2 / sigma_color**2)
+ grid_sim = torch.exp(-grid_dist / 2 / sigma_loc**2)
+ surf_sim = torch.exp(-surf_dist / 2 / sigma_surf**2)
+
+ weight = mask_neigh * rgb_sim * grid_sim * surf_sim # B H*W H_p*W_p
+ weight = weight / weight.sum(dim=-1, keepdim=True).clamp(min=1e-5)
+ z = (surf_neigh * weight.unsqueeze(-2)).sum(dim=-1)
+ return z.reshape(B, H, W, 3).permute(0, 3, 1, 2)
+
+ def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor):
+ P0 = xyz
+ mask = mask.float()
+ normals, masks_valid_triangle = [], []
+ combinations = list(itertools.combinations_with_replacement([-2, -1, 1, 2], 2))
+ combinations += [c[::-1] for c in combinations]
+ # combinations = [(1, 1), (-1, -1), (1, -1), (-1, 1)]
+ for shift_0, shift_1 in set(combinations):
+ P1 = torch.roll(xyz, shifts=(0, shift_0), dims=(-1, -2))
+ P2 = torch.roll(xyz, shifts=(shift_1, 0), dims=(-1, -2))
+ if (shift_0 > 0) ^ (shift_1 > 0):
+ P1, P2 = P2, P1
+ vec1, vec2 = P1 - P0, P2 - P0
+ normal = torch.cross(vec1, vec2, dim=1)
+ vec1_norm = torch.norm(vec1, dim=1, keepdim=True).clip(min=1e-8)
+ vec2_norm = torch.norm(vec2, dim=1, keepdim=True).clip(min=1e-8)
+ normal_norm = torch.norm(normal, dim=1, keepdim=True).clip(min=1e-8)
+ normals.append(normal / normal_norm)
+ is_valid = (
+ torch.roll(mask, shifts=(0, shift_0), dims=(-1, -2))
+ + torch.roll(mask, shifts=(shift_1, 0), dims=(-1, -2))
+ + mask
+ == 3
+ )
+ is_valid = (
+ (normal_norm > 1e-6)
+ & (vec1_norm > 1e-6)
+ & (vec2_norm > 1e-6)
+ & is_valid
+ )
+ masks_valid_triangle.append(is_valid)
+
+ normals = torch.stack(normals, dim=-1)
+ mask_valid_triangle = torch.stack(masks_valid_triangle, dim=-1).float()
+ mask_valid = mask_valid_triangle.sum(dim=-1)
+ normals = (normals * mask_valid_triangle).sum(dim=-1) / mask_valid.clamp(
+ min=1.0
+ )
+ normals_norm = torch.norm(normals, dim=1, keepdim=True).clip(min=1e-8)
+ normals = normals / normals_norm
+ mask_valid = (
+ (mask_valid > 0.001)
+ & (~normals.sum(dim=1, keepdim=True).isnan())
+ & (normals_norm > 1e-6)
+ )
+ return normals, mask_valid # B 3 H W, B 1 H W
+
+ # def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor):
+ # x, y, z = torch.unbind(xyz, dim=1) # B 3 H W
+ # x = x.unsqueeze(1) # B 1 H W
+ # y = y.unsqueeze(1)
+ # z = z.unsqueeze(1)
+
+ # mask_float = mask.float()
+ # paddings = [self.patch_weight.shape[-2] // 2, self.patch_weight.shape[-1] // 2]
+ # num_samples = F.conv2d(mask_float, weight=self.patch_weight, padding=paddings).clamp(min=1.0) # B 1 H W
+ # mask_invalid = num_samples < self.min_samples
+
+ # xx = x * x
+ # yy = y * y
+ # zz = z * z
+ # xy = x * y
+ # xz = x * z
+ # yz = y * z
+ # xx_patch = F.conv2d(xx * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # yy_patch = F.conv2d(yy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # zz_patch = F.conv2d(zz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # xy_patch = F.conv2d(xy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # xz_patch = F.conv2d(xz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # yz_patch = F.conv2d(yz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+
+ # x_patch = F.conv2d(x * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # y_patch = F.conv2d(y * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+ # z_patch = F.conv2d(z * mask_float, weight=self.patch_weight, padding=paddings) / num_samples
+
+ # ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], dim=-1).squeeze(1) # B H W 9
+ # ATA = torch.reshape(ATA, (ATA.shape[0], ATA.shape[1], ATA.shape[2], 3, 3)) # B H W 3 3
+ # eps_identity = torch.eye(3, device=ATA.device, dtype=ATA.dtype).unsqueeze(0) # 1 3 3
+ # ATA = ATA + 1e-6 * eps_identity
+
+ # AT1 = torch.stack([x_patch, y_patch, z_patch], dim=-1).squeeze(1).unsqueeze(-1) # B H W 3 1
+
+ # det = torch.linalg.det(ATA)
+ # mask_invalid_inverse = det.abs() < 1e-12
+ # mask_invalid = mask_invalid.squeeze(1) | mask_invalid_inverse
+ # AT1[mask_invalid, :, :] = 0
+ # ATA[mask_invalid, :, :] = eps_identity
+
+ # ATA_inv = torch.linalg.inv(ATA)
+ # normals = (ATA_inv @ AT1).squeeze(dim=-1) # B H W 3
+ # normals = normals / torch.norm(normals, dim=-1, keepdim=True).clip(min=1e-8)
+ # mask_invalid = mask_invalid | (torch.sum(normals, dim=-1) == 0.0)
+
+ # # flip normals, based if a * x + b * y + c * z < 0 -> change sign of normals
+ # mean_patch_xyz = AT1.squeeze(-1)
+ # orient_mask = torch.sum(normals * mean_patch_xyz, dim=-1).unsqueeze(-1) > 0
+ # normals = (2 * orient_mask.to(ATA.dtype) - 1) * normals
+
+ # return normals.permute(0, 3, 1, 2), ~mask_invalid.unsqueeze(1) # B 3 H W, B H W
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(self, input: torch.Tensor, target: torch.Tensor, mask, valid):
+ if not valid.any():
+ return 0.0 * input.mean(dim=(1, 2, 3))
+
+ input = input.float()
+ target = target.float()
+
+ mask = erode(mask, kernel_size=3)
+ target_normal, mask_target = self.get_surface_normal(target[valid], mask[valid])
+ input_normal, mask_input = self.get_surface_normal(
+ input[valid], torch.ones_like(mask[valid])
+ )
+
+ gt_similarity = F.cosine_similarity(input_normal, target_normal, dim=1) # B H W
+ mask_target = (
+ mask_target.squeeze(1) & (gt_similarity < 0.999) & (gt_similarity > -0.999)
+ )
+ error = F.relu((1 - gt_similarity) / 2 - 0.01)
+
+ error_full = torch.ones_like(mask.squeeze(1).float())
+ error_full[valid] = error
+ mask_full = torch.ones_like(mask.squeeze(1))
+ mask_full[valid] = mask_target
+
+ error_qtl = error_full.detach()
+ mask_full = mask_full & (
+ error_qtl
+ < masked_quantile(
+ error_qtl, mask_full, dims=[1, 2], q=1 - self.quantile
+ ).view(-1, 1, 1)
+ )
+
+ loss = masked_mean(error_full, mask=mask_full, dim=(-2, -1)).squeeze(
+ dim=(-2, -1)
+ ) # B
+ loss = self.output_fn(loss)
+ return loss
+
+ def von_mises(self, input, target, mask, kappa):
+ score = torch.cosine_similarity(input, target, dim=1).unsqueeze(1)
+ mask_cosine = torch.logical_and(
+ mask, torch.logical_and(score.detach() < 0.999, score.detach() > -0.999)
+ )
+ nll = masked_mean(
+ kappa * (1 - score), mask=mask_cosine, dim=(-1, -2, -3)
+ ).squeeze()
+ return nll
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ output_fn=config["output_fn"],
+ quantile=config.get("quantile", 0.2),
+ )
+ return obj
diff --git a/unik3d/ops/losses/regression.py b/unik3d/ops/losses/regression.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f16af625325cadadcc77744fae8508d495da372
--- /dev/null
+++ b/unik3d/ops/losses/regression.py
@@ -0,0 +1,166 @@
+import torch
+import torch.nn as nn
+
+from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile
+
+
+class Regression(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ gamma: float,
+ fn: str,
+ input_fn: str,
+ output_fn: str,
+ alpha: float = 1.0,
+ dims: tuple[int] = (-1,),
+ quantile: float = 0.0,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = self.__class__.__name__
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+ self.fn = REGRESSION_DICT[fn]
+ self.weight = weight
+ self.gamma = gamma
+ self.alpha = alpha
+ self.dims = dims
+ self.quantile = quantile
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ if mask is not None: # usually it is just repeated
+ mask = mask[:, 0]
+
+ input = self.input_fn(input.float())
+ target = self.input_fn(target.float())
+ error = self.fn(input - target, gamma=self.gamma, alpha=self.alpha).mean(dim=1)
+ if self.quantile > 0.0:
+ mask_quantile = error < masked_quantile(
+ error, mask, dims=self.dims, q=1 - self.quantile
+ ).view(-1, *((1,) * len(self.dims)))
+ mask = mask & mask_quantile if mask is not None else mask_quantile
+ mean_error = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(
+ self.dims
+ )
+ mean_error = self.output_fn(mean_error)
+ return mean_error
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ fn=config["fn"],
+ gamma=config["gamma"],
+ alpha=config.get("alpha", 1.0),
+ output_fn=config["output_fn"],
+ input_fn=config["input_fn"],
+ dims=config.get("dims", (-1,)),
+ quantile=config.get("quantile", 0.0),
+ )
+ return obj
+
+
+class PolarRegression(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ gamma: float,
+ fn: str,
+ input_fn: str,
+ output_fn: str,
+ alpha: float = 1.0,
+ dims: list[int] = [-1, -2],
+ polar_weight: float = 1.0,
+ polar_asym: float = 0.5,
+ **kwargs,
+ ):
+ super().__init__()
+ self.name = self.__class__.__name__
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+ self.fn = REGRESSION_DICT[fn]
+ self.weight = weight
+ self.gamma = gamma
+ self.alpha = alpha
+ self.dims = dims
+ self.polar_weight = polar_weight
+ self.polar_asym = polar_asym
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ if mask is not None: # usually it is just repeated
+ mask = mask.squeeze(1)
+
+ input = self.input_fn(input.float())
+ target = self.input_fn(target.float())
+ input = input / torch.norm(input, dim=1, keepdim=True).clamp(min=1e-5)
+ target = target / torch.norm(target, dim=1, keepdim=True).clamp(min=1e-5)
+
+ x_target, y_target, z_target = target.unbind(dim=1)
+ z_clipped = z_target.clip(min=-0.99999, max=0.99999)
+ x_clipped = x_target.abs().clip(min=1e-5) * (2 * (x_target > 0).float() - 1)
+ polar_target = torch.arccos(z_clipped)
+ azimuth_target = torch.atan2(y_target, x_clipped)
+
+ x_input, y_input, z_input = input.unbind(dim=1)
+ z_clipped = z_input.clip(min=-0.99999, max=0.99999)
+ x_clipped = x_input.abs().clip(min=1e-5) * (2 * (x_input > 0).float() - 1)
+ polar_input = torch.arccos(z_clipped)
+ azimuth_input = torch.atan2(y_input, x_clipped)
+
+ polar_error = self.fn(
+ polar_input - polar_target, gamma=self.gamma, alpha=self.alpha
+ )
+ azimuth_error = self.fn(
+ azimuth_input - azimuth_target, gamma=self.gamma, alpha=self.alpha
+ )
+
+ quantile_weight = torch.ones_like(polar_input)
+ quantile_weight[
+ (polar_target > polar_input) & (polar_target > torch.pi / 2)
+ ] = (2 * self.polar_asym)
+ quantile_weight[
+ (polar_target <= polar_input) & (polar_target > torch.pi / 2)
+ ] = 2 * (1 - self.polar_asym)
+
+ mean_polar_error = masked_mean(
+ data=polar_error * quantile_weight, mask=mask, dim=self.dims
+ ).squeeze(self.dims)
+ mean_azimuth_error = masked_mean(
+ data=azimuth_error, mask=mask, dim=self.dims
+ ).squeeze(self.dims)
+ mean_error = (self.polar_weight * mean_polar_error + mean_azimuth_error) / (
+ 1 + self.polar_weight
+ )
+
+ mean_error = self.output_fn(mean_error)
+ return mean_error
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ fn=config["fn"],
+ gamma=config["gamma"],
+ alpha=config.get("alpha", 1.0),
+ output_fn=config["output_fn"],
+ input_fn=config["input_fn"],
+ dims=config.get("dims", (-1,)),
+ polar_weight=config["polar_weight"],
+ polar_asym=config["polar_asym"],
+ )
+ return obj
diff --git a/unik3d/ops/losses/resources/__init__.py b/unik3d/ops/losses/resources/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/unik3d/ops/losses/resources/partition_spline.npz b/unik3d/ops/losses/resources/partition_spline.npz
new file mode 100644
index 0000000000000000000000000000000000000000..e2813d0d8397d603eeed56c3b75e045ad5bd226b
Binary files /dev/null and b/unik3d/ops/losses/resources/partition_spline.npz differ
diff --git a/unik3d/ops/losses/robust_loss.py b/unik3d/ops/losses/robust_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ef99be1a9b117e3b537dbe5fc6a3829fea7f081
--- /dev/null
+++ b/unik3d/ops/losses/robust_loss.py
@@ -0,0 +1,709 @@
+import numpy as np
+import torch
+import torch.nn as nn
+
+from .utils import FNS, masked_mean
+
+
+def log_safe(x):
+ """The same as torch.log(x), but clamps the input to prevent NaNs."""
+ x = torch.as_tensor(x)
+ return torch.log(torch.min(x, torch.tensor(33e37).to(x)))
+
+
+def log1p_safe(x):
+ """The same as torch.log1p(x), but clamps the input to prevent NaNs."""
+ x = torch.as_tensor(x)
+ return torch.log1p(torch.min(x, torch.tensor(33e37).to(x)))
+
+
+def exp_safe(x):
+ """The same as torch.exp(x), but clamps the input to prevent NaNs."""
+ x = torch.as_tensor(x)
+ return torch.exp(torch.min(x, torch.tensor(87.5).to(x)))
+
+
+def expm1_safe(x):
+ """The same as tf.math.expm1(x), but clamps the input to prevent NaNs."""
+ x = torch.as_tensor(x)
+ return torch.expm1(torch.min(x, torch.tensor(87.5).to(x)))
+
+
+def inv_softplus(y):
+ """The inverse of tf.nn.softplus()."""
+ y = torch.as_tensor(y)
+ return torch.where(y > 87.5, y, torch.log(torch.expm1(y)))
+
+
+def logit(y):
+ """The inverse of tf.nn.sigmoid()."""
+ y = torch.as_tensor(y)
+ return -torch.log(1.0 / y - 1.0)
+
+
+def affine_sigmoid(logits, lo=0, hi=1):
+ """Maps reals to (lo, hi), where 0 maps to (lo+hi)/2."""
+ if not lo < hi:
+ raise ValueError("`lo` (%g) must be < `hi` (%g)" % (lo, hi))
+ logits = torch.as_tensor(logits)
+ lo = torch.as_tensor(lo)
+ hi = torch.as_tensor(hi)
+ alpha = torch.sigmoid(logits) * (hi - lo) + lo
+ return alpha
+
+
+def inv_affine_sigmoid(probs, lo=0, hi=1):
+ """The inverse of affine_sigmoid(., lo, hi)."""
+ if not lo < hi:
+ raise ValueError("`lo` (%g) must be < `hi` (%g)" % (lo, hi))
+ probs = torch.as_tensor(probs)
+ lo = torch.as_tensor(lo)
+ hi = torch.as_tensor(hi)
+ logits = logit((probs - lo) / (hi - lo))
+ return logits
+
+
+def affine_softplus(x, lo=0, ref=1):
+ """Maps real numbers to (lo, infinity), where 0 maps to ref."""
+ if not lo < ref:
+ raise ValueError("`lo` (%g) must be < `ref` (%g)" % (lo, ref))
+ x = torch.as_tensor(x)
+ lo = torch.as_tensor(lo)
+ ref = torch.as_tensor(ref)
+ shift = inv_softplus(torch.tensor(1.0))
+ y = (ref - lo) * torch.nn.Softplus()(x + shift) + lo
+ return y
+
+
+def inv_affine_softplus(y, lo=0, ref=1):
+ """The inverse of affine_softplus(., lo, ref)."""
+ if not lo < ref:
+ raise ValueError("`lo` (%g) must be < `ref` (%g)" % (lo, ref))
+ y = torch.as_tensor(y)
+ lo = torch.as_tensor(lo)
+ ref = torch.as_tensor(ref)
+ shift = inv_softplus(torch.tensor(1.0))
+ x = inv_softplus((y - lo) / (ref - lo)) - shift
+ return x
+
+
+def students_t_nll(x, df, scale):
+ """The NLL of a Generalized Student's T distribution (w/o including TFP)."""
+ x = torch.as_tensor(x)
+ df = torch.as_tensor(df)
+ scale = torch.as_tensor(scale)
+ log_partition = (
+ torch.log(torch.abs(scale))
+ + torch.lgamma(0.5 * df)
+ - torch.lgamma(0.5 * df + torch.tensor(0.5))
+ + torch.tensor(0.5 * np.log(np.pi))
+ )
+ return (
+ 0.5 * ((df + 1.0) * torch.log1p((x / scale) ** 2.0 / df) + torch.log(df))
+ + log_partition
+ )
+
+
+def lossfun(x, alpha, scale, approximate=False, epsilon=1e-6):
+ r"""Implements the general form of the loss.
+
+ This implements the rho(x, \alpha, c) function described in "A General and
+ Adaptive Robust Loss Function", Jonathan T. Barron,
+ https://arxiv.org/abs/1701.03077.
+
+ Args:
+ x: The residual for which the loss is being computed. x can have any shape,
+ and alpha and scale will be broadcasted to match x's shape if necessary.
+ Must be a tensor of floats.
+ alpha: The shape parameter of the loss (\alpha in the paper), where more
+ negative values produce a loss with more robust behavior (outliers "cost"
+ less), and more positive values produce a loss with less robust behavior
+ (outliers are penalized more heavily). Alpha can be any value in
+ [-infinity, infinity], but the gradient of the loss with respect to alpha
+ is 0 at -infinity, infinity, 0, and 2. Must be a tensor of floats with the
+ same precision as `x`. Varying alpha allows
+ for smooth interpolation between a number of discrete robust losses:
+ alpha=-Infinity: Welsch/Leclerc Loss.
+ alpha=-2: Geman-McClure loss.
+ alpha=0: Cauchy/Lortentzian loss.
+ alpha=1: Charbonnier/pseudo-Huber loss.
+ alpha=2: L2 loss.
+ scale: The scale parameter of the loss. When |x| < scale, the loss is an
+ L2-like quadratic bowl, and when |x| > scale the loss function takes on a
+ different shape according to alpha. Must be a tensor of single-precision
+ floats.
+ approximate: a bool, where if True, this function returns an approximate and
+ faster form of the loss, as described in the appendix of the paper. This
+ approximation holds well everywhere except as x and alpha approach zero.
+ epsilon: A float that determines how inaccurate the "approximate" version of
+ the loss will be. Larger values are less accurate but more numerically
+ stable. Must be great than single-precision machine epsilon.
+
+ Returns:
+ The losses for each element of x, in the same shape and precision as x.
+ """
+ assert (scale > 0).all()
+ if approximate:
+ # Compute an approximate form of the loss which is faster, but innacurate
+ # when x and alpha are near zero.
+ b = torch.abs(alpha - 2) + epsilon
+ d = torch.where(alpha >= 0, alpha + epsilon, alpha - epsilon)
+ loss = (b / d) * (torch.pow((x / scale) ** 2 / b + 1.0, 0.5 * d) - 1.0)
+ else:
+ # This will be used repeatedly.
+ squared_scaled_x = (x / scale) ** 2
+
+ # The loss when alpha == 2.
+ loss_two = 0.5 * squared_scaled_x
+ # The loss when alpha == 0.
+ loss_zero = log1p_safe(0.5 * squared_scaled_x)
+ # The loss when alpha == -infinity.
+ loss_neginf = -torch.expm1(-0.5 * squared_scaled_x)
+ # The loss when alpha == +infinity.
+ loss_posinf = expm1_safe(0.5 * squared_scaled_x)
+
+ # The loss when not in one of the above special cases.
+ machine_epsilon = torch.tensor(torch.finfo(alpha.dtype).eps).to(x)
+ # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by.
+ beta_safe = torch.max(machine_epsilon, torch.abs(alpha - 2.0))
+ # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by.
+ alpha_safe = torch.where(
+ alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha)
+ ) * torch.max(machine_epsilon, torch.abs(alpha))
+ loss_otherwise = (beta_safe / alpha_safe) * (
+ torch.pow(squared_scaled_x / beta_safe + 1.0, 0.5 * alpha) - 1.0
+ )
+
+ # Select which of the cases of the loss to return.
+ loss = torch.where(
+ alpha == -float("inf"),
+ loss_neginf,
+ torch.where(
+ alpha == 0,
+ loss_zero,
+ torch.where(
+ alpha == 2,
+ loss_two,
+ torch.where(alpha == float("inf"), loss_posinf, loss_otherwise),
+ ),
+ ),
+ )
+
+ return loss
+
+
+from pkg_resources import resource_stream
+
+
+def interpolate1d(x, values, tangents):
+ r"""Perform cubic hermite spline interpolation on a 1D spline.
+
+ The x coordinates of the spline knots are at [0 : 1 : len(values)-1].
+ Queries outside of the range of the spline are computed using linear
+ extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline
+ for details, where "x" corresponds to `x`, "p" corresponds to `values`, and
+ "m" corresponds to `tangents`.
+
+ Args:
+ x: A tensor of any size of single or double precision floats containing the
+ set of values to be used for interpolation into the spline.
+ values: A vector of single or double precision floats containing the value
+ of each knot of the spline being interpolated into. Must be the same
+ length as `tangents` and the same type as `x`.
+ tangents: A vector of single or double precision floats containing the
+ tangent (derivative) of each knot of the spline being interpolated into.
+ Must be the same length as `values` and the same type as `x`.
+
+ Returns:
+ The result of interpolating along the spline defined by `values`, and
+ `tangents`, using `x` as the query values. Will be the same length and type
+ as `x`.
+ """
+ assert torch.is_tensor(x)
+ assert torch.is_tensor(values)
+ assert torch.is_tensor(tangents)
+ float_dtype = x.dtype
+ assert values.dtype == float_dtype
+ assert tangents.dtype == float_dtype
+ assert len(values.shape) == 1
+ assert len(tangents.shape) == 1
+ assert values.shape[0] == tangents.shape[0]
+
+ x_lo = torch.floor(torch.clamp(x, torch.as_tensor(0), values.shape[0] - 2)).type(
+ torch.int64
+ )
+ x_hi = x_lo + 1
+
+ # Compute the relative distance between each `x` and the knot below it.
+ t = x - x_lo.type(float_dtype)
+
+ # Compute the cubic hermite expansion of `t`.
+ t_sq = t**2
+ t_cu = t * t_sq
+ h01 = -2.0 * t_cu + 3.0 * t_sq
+ h00 = 1.0 - h01
+ h11 = t_cu - t_sq
+ h10 = h11 - t_sq + t
+
+ # Linearly extrapolate above and below the extents of the spline for all
+ # values.
+ value_before = tangents[0] * t + values[0]
+ value_after = tangents[-1] * (t - 1.0) + values[-1]
+
+ # Cubically interpolate between the knots below and above each query point.
+ neighbor_values_lo = values[x_lo]
+ neighbor_values_hi = values[x_hi]
+ neighbor_tangents_lo = tangents[x_lo]
+ neighbor_tangents_hi = tangents[x_hi]
+ value_mid = (
+ neighbor_values_lo * h00
+ + neighbor_values_hi * h01
+ + neighbor_tangents_lo * h10
+ + neighbor_tangents_hi * h11
+ )
+
+ # Return the interpolated or extrapolated values for each query point,
+ # depending on whether or not the query lies within the span of the spline.
+ return torch.where(
+ t < 0.0, value_before, torch.where(t > 1.0, value_after, value_mid)
+ )
+
+
+def partition_spline_curve(alpha):
+ """Applies a curve to alpha >= 0 to compress its range before interpolation.
+
+ This is a weird hand-crafted function designed to take in alpha values and
+ curve them to occupy a short finite range that works well when using spline
+ interpolation to model the partition function Z(alpha). Because Z(alpha)
+ is only varied in [0, 4] and is especially interesting around alpha=2, this
+ curve is roughly linear in [0, 4] with a slope of ~1 at alpha=0 and alpha=4
+ but a slope of ~10 at alpha=2. When alpha > 4 the curve becomes logarithmic.
+ Some (input, output) pairs for this function are:
+ [(0, 0), (1, ~1.2), (2, 4), (3, ~6.8), (4, 8), (8, ~8.8), (400000, ~12)]
+ This function is continuously differentiable.
+
+ Args:
+ alpha: A numpy array or tensor (float32 or float64) with values >= 0.
+
+ Returns:
+ An array/tensor of curved values >= 0 with the same type as `alpha`, to be
+ used as input x-coordinates for spline interpolation.
+ """
+ alpha = torch.as_tensor(alpha)
+ x = torch.where(
+ alpha < 4,
+ (2.25 * alpha - 4.5) / (torch.abs(alpha - 2) + 0.25) + alpha + 2,
+ 5.0 / 18.0 * log_safe(4 * alpha - 15) + 8,
+ )
+ return x
+
+
+def inv_partition_spline_curve(x):
+ """The inverse of partition_spline_curve()."""
+ x = torch.as_tensor(x)
+ assert (x >= 0).all()
+ alpha = torch.where(
+ x < 8,
+ 0.5 * x
+ + torch.where(
+ x <= 4,
+ 1.25 - torch.sqrt(1.5625 - x + 0.25 * x**2),
+ -1.25 + torch.sqrt(9.5625 - 3 * x + 0.25 * x**2),
+ ),
+ 3.75 + 0.25 * exp_safe(x * 3.6 - 28.8),
+ )
+ return alpha
+
+
+class Distribution:
+ # This is only a class so that we can pre-load the partition function spline.
+ def __init__(self):
+ # Load the values, tangents, and x-coordinate scaling of a spline that
+ # approximates the partition function. This was produced by running
+ # the script in fit_partition_spline.py
+ with resource_stream(__name__, "resources/partition_spline.npz") as spline_file:
+ with np.load(spline_file, allow_pickle=False) as f:
+ self._spline_x_scale = torch.tensor(f["x_scale"])
+ self._spline_values = torch.tensor(f["values"])
+ self._spline_tangents = torch.tensor(f["tangents"])
+
+ def log_base_partition_function(self, alpha):
+ r"""Approximate the distribution's log-partition function with a 1D spline.
+
+ Because the partition function (Z(\alpha) in the paper) of the distribution
+ is difficult to model analytically, we approximate it with a (transformed)
+ cubic hermite spline: Each alpha is pushed through a nonlinearity before
+ being used to interpolate into a spline, which allows us to use a relatively
+ small spline to accurately model the log partition function over the range
+ of all non-negative input values.
+
+ Args:
+ alpha: A tensor or scalar of single or double precision floats containing
+ the set of alphas for which we would like an approximate log partition
+ function. Must be non-negative, as the partition function is undefined
+ when alpha < 0.
+
+ Returns:
+ An approximation of log(Z(alpha)) accurate to within 1e-6
+ """
+ alpha = torch.as_tensor(alpha)
+ assert (alpha >= 0).all()
+ # Transform `alpha` to the form expected by the spline.
+ x = partition_spline_curve(alpha)
+ # Interpolate into the spline.
+ return interpolate1d(
+ x * self._spline_x_scale.to(x),
+ self._spline_values.to(x),
+ self._spline_tangents.to(x),
+ )
+
+ def nllfun(self, x, alpha, scale):
+ r"""Implements the negative log-likelihood (NLL).
+
+ Specifically, we implement -log(p(x | 0, \alpha, c) of Equation 16 in the
+ paper as nllfun(x, alpha, shape).
+
+ Args:
+ x: The residual for which the NLL is being computed. x can have any shape,
+ and alpha and scale will be broadcasted to match x's shape if necessary.
+ Must be a tensor or numpy array of floats.
+ alpha: The shape parameter of the NLL (\alpha in the paper), where more
+ negative values cause outliers to "cost" more and inliers to "cost"
+ less. Alpha can be any non-negative value, but the gradient of the NLL
+ with respect to alpha has singularities at 0 and 2 so you may want to
+ limit usage to (0, 2) during gradient descent. Must be a tensor or numpy
+ array of floats. Varying alpha in that range allows for smooth
+ interpolation between a Cauchy distribution (alpha = 0) and a Normal
+ distribution (alpha = 2) similar to a Student's T distribution.
+ scale: The scale parameter of the loss. When |x| < scale, the NLL is like
+ that of a (possibly unnormalized) normal distribution, and when |x| >
+ scale the NLL takes on a different shape according to alpha. Must be a
+ tensor or numpy array of floats.
+
+ Returns:
+ The NLLs for each element of x, in the same shape and precision as x.
+ """
+ # `scale` and `alpha` must have the same type as `x`.
+ # x = torch.as_tensor(x)
+ # alpha = torch.as_tensor(alpha)
+ # scale = torch.as_tensor(scale)
+ # assert (alpha >= 0).all()
+ # assert (scale >= 0).all()
+ # float_dtype = x.dtype
+ # assert alpha.dtype == float_dtype
+ # assert scale.dtype == float_dtype
+
+ loss = lossfun(x, alpha, scale, approximate=False)
+ log_partition = torch.log(scale) + self.log_base_partition_function(alpha)
+ nll = loss + log_partition
+ return nll
+
+ def draw_samples(self, alpha, scale):
+ r"""Draw samples from the robust distribution.
+
+ This function implements Algorithm 1 the paper. This code is written to
+ allow
+ for sampling from a set of different distributions, each parametrized by its
+ own alpha and scale values, as opposed to the more standard approach of
+ drawing N samples from the same distribution. This is done by repeatedly
+ performing N instances of rejection sampling for each of the N distributions
+ until at least one proposal for each of the N distributions has been
+ accepted.
+ All samples are drawn with a zero mean, to use a non-zero mean just add each
+ mean to each sample.
+
+ Args:
+ alpha: A tensor/scalar or numpy array/scalar of floats where each element
+ is the shape parameter of that element's distribution.
+ scale: A tensor/scalar or numpy array/scalar of floats where each element
+ is the scale parameter of that element's distribution. Must be the same
+ shape as `alpha`.
+
+ Returns:
+ A tensor with the same shape and precision as `alpha` and `scale` where
+ each element is a sample drawn from the distribution specified for that
+ element by `alpha` and `scale`.
+ """
+ alpha = torch.as_tensor(alpha)
+ scale = torch.as_tensor(scale)
+ assert (alpha >= 0).all()
+ assert (scale >= 0).all()
+ float_dtype = alpha.dtype
+ assert scale.dtype == float_dtype
+
+ cauchy = torch.distributions.cauchy.Cauchy(0.0, np.sqrt(2.0))
+ uniform = torch.distributions.uniform.Uniform(0, 1)
+ samples = torch.zeros_like(alpha)
+ accepted = torch.zeros(alpha.shape).type(torch.bool)
+ while not accepted.type(torch.uint8).all():
+ # Draw N samples from a Cauchy, our proposal distribution.
+ cauchy_sample = torch.reshape(
+ cauchy.sample((np.prod(alpha.shape),)), alpha.shape
+ )
+ cauchy_sample = cauchy_sample.type(alpha.dtype)
+
+ # Compute the likelihood of each sample under its target distribution.
+ nll = self.nllfun(
+ cauchy_sample,
+ torch.as_tensor(alpha).to(cauchy_sample),
+ torch.tensor(1).to(cauchy_sample),
+ )
+
+ # Bound the NLL. We don't use the approximate loss as it may cause
+ # unpredictable behavior in the context of sampling.
+ nll_bound = lossfun(
+ cauchy_sample,
+ torch.tensor(0.0, dtype=cauchy_sample.dtype),
+ torch.tensor(1.0, dtype=cauchy_sample.dtype),
+ approximate=False,
+ ) + self.log_base_partition_function(alpha)
+
+ # Draw N samples from a uniform distribution, and use each uniform sample
+ # to decide whether or not to accept each proposal sample.
+ uniform_sample = torch.reshape(
+ uniform.sample((np.prod(alpha.shape),)), alpha.shape
+ )
+ uniform_sample = uniform_sample.type(alpha.dtype)
+ accept = uniform_sample <= torch.exp(nll_bound - nll)
+
+ # If a sample is accepted, replace its element in `samples` with the
+ # proposal sample, and set its bit in `accepted` to True.
+ samples = torch.where(accept, cauchy_sample, samples)
+ accepted = accepted | accept
+
+ # Because our distribution is a location-scale family, we sample from
+ # p(x | 0, \alpha, 1) and then scale each sample by `scale`.
+ samples *= scale
+ return samples
+
+
+class AdaptiveLossFunction(nn.Module):
+ """The adaptive loss function on a matrix.
+
+ This class behaves differently from lossfun() and
+ distribution.nllfun(), which are "stateless", allow the caller to specify the
+ shape and scale of the loss, and allow for arbitrary sized inputs. This
+ class only allows for rank-2 inputs for the residual `x`, and expects that
+ `x` is of the form [batch_index, dimension_index]. This class then
+ constructs free parameters (torch Parameters) that define the alpha and scale
+ parameters for each dimension of `x`, such that all alphas are in
+ (`alpha_lo`, `alpha_hi`) and all scales are in (`scale_lo`, Infinity).
+ The assumption is that `x` is, say, a matrix where x[i,j] corresponds to a
+ pixel at location j for image i, with the idea being that all pixels at
+ location j should be modeled with the same shape and scale parameters across
+ all images in the batch. If the user wants to fix alpha or scale to be a
+ constant,
+ this can be done by setting alpha_lo=alpha_hi or scale_lo=scale_init
+ respectively.
+ """
+
+ def __init__(
+ self,
+ num_dims,
+ alpha_lo=0.001,
+ alpha_hi=1.999,
+ alpha_init=None,
+ scale_lo=1e-5,
+ scale_init=1.0,
+ ):
+ """Sets up the loss function.
+
+ Args:
+ num_dims: The number of dimensions of the input to come.
+ float_dtype: The floating point precision of the inputs to come.
+ device: The device to run on (cpu, cuda, etc).
+ alpha_lo: The lowest possible value for loss's alpha parameters, must be
+ >= 0 and a scalar. Should probably be in (0, 2).
+ alpha_hi: The highest possible value for loss's alpha parameters, must be
+ >= alpha_lo and a scalar. Should probably be in (0, 2).
+ alpha_init: The value that the loss's alpha parameters will be initialized
+ to, must be in (`alpha_lo`, `alpha_hi`), unless `alpha_lo` == `alpha_hi`
+ in which case this will be ignored. Defaults to (`alpha_lo` +
+ `alpha_hi`) / 2
+ scale_lo: The lowest possible value for the loss's scale parameters. Must
+ be > 0 and a scalar. This value may have more of an effect than you
+ think, as the loss is unbounded as scale approaches zero (say, at a
+ delta function).
+ scale_init: The initial value used for the loss's scale parameters. This
+ also defines the zero-point of the latent representation of scales, so
+ SGD may cause optimization to gravitate towards producing scales near
+ this value.
+ """
+ super(AdaptiveLossFunction, self).__init__()
+
+ if not np.isscalar(alpha_lo):
+ raise ValueError(
+ "`alpha_lo` must be a scalar, but is of type {}".format(type(alpha_lo))
+ )
+ if not np.isscalar(alpha_hi):
+ raise ValueError(
+ "`alpha_hi` must be a scalar, but is of type {}".format(type(alpha_hi))
+ )
+ if alpha_init is not None and not np.isscalar(alpha_init):
+ raise ValueError(
+ "`alpha_init` must be None or a scalar, but is of type {}".format(
+ type(alpha_init)
+ )
+ )
+ if not alpha_lo >= 0:
+ raise ValueError("`alpha_lo` must be >= 0, but is {}".format(alpha_lo))
+ if not alpha_hi >= alpha_lo:
+ raise ValueError(
+ "`alpha_hi` = {} must be >= `alpha_lo` = {}".format(alpha_hi, alpha_lo)
+ )
+ if alpha_init is not None and alpha_lo != alpha_hi:
+ if not (alpha_init > alpha_lo and alpha_init < alpha_hi):
+ raise ValueError(
+ "`alpha_init` = {} must be in (`alpha_lo`, `alpha_hi`) = ({} {})".format(
+ alpha_init, alpha_lo, alpha_hi
+ )
+ )
+ if not np.isscalar(scale_lo):
+ raise ValueError(
+ "`scale_lo` must be a scalar, but is of type {}".format(type(scale_lo))
+ )
+ if not np.isscalar(scale_init):
+ raise ValueError(
+ "`scale_init` must be a scalar, but is of type {}".format(
+ type(scale_init)
+ )
+ )
+ if not scale_lo > 0:
+ raise ValueError("`scale_lo` must be > 0, but is {}".format(scale_lo))
+ if not scale_init >= scale_lo:
+ raise ValueError(
+ "`scale_init` = {} must be >= `scale_lo` = {}".format(
+ scale_init, scale_lo
+ )
+ )
+
+ self.num_dims = num_dims
+
+ self.distribution = Distribution()
+
+ if alpha_lo == alpha_hi:
+ # If the range of alphas is a single item, then we just fix `alpha` to be a constant.
+ fixed_alpha = torch.tensor([[alpha_lo]]).repeat(1, self.num_dims)
+ self.register_parameter(
+ "fixed_alpha", torch.nn.Parameter(fixed_alpha, requires_grad=False)
+ )
+ self.alpha = lambda: self.fixed_alpha
+ else:
+ # Otherwise we construct a "latent" alpha variable and define `alpha`
+ # As an affine function of a sigmoid on that latent variable, initialized
+ # such that `alpha` starts off as `alpha_init`.
+ if alpha_init is None:
+ alpha_init = (alpha_lo + alpha_hi) / 2.0
+ latent_alpha_init = inv_affine_sigmoid(alpha_init, lo=alpha_lo, hi=alpha_hi)
+ self.register_parameter(
+ "latent_alpha",
+ torch.nn.Parameter(
+ latent_alpha_init.clone()
+ .detach()
+ .view(1, 1)
+ .repeat(1, self.num_dims),
+ requires_grad=True,
+ ),
+ )
+ self.alpha = lambda: affine_sigmoid(
+ self.latent_alpha, lo=alpha_lo, hi=alpha_hi
+ )
+
+ if scale_lo == scale_init:
+ # If the difference between the minimum and initial scale is zero, then
+ # we just fix `scale` to be a constant.
+ fixed_scale = torch.tensor([[scale_init]]).repeat(1, self.num_dims)
+ self.register_parameter(
+ "fixed_scale", torch.nn.Parameter(fixed_scale, requires_grad=False)
+ )
+
+ self.scale = lambda: self.fixed_scale
+ else:
+ # Otherwise we construct a "latent" scale variable and define `scale`
+ # As an affine function of a softplus on that latent variable.
+ self.register_parameter(
+ "latent_scale",
+ torch.nn.Parameter(torch.zeros((1, self.num_dims)), requires_grad=True),
+ )
+ self.scale = lambda: affine_softplus(
+ self.latent_scale, lo=scale_lo, ref=scale_init
+ )
+
+ def lossfun(self, x, **kwargs):
+ """Computes the loss on a matrix.
+
+ Args:
+ x: The residual for which the loss is being computed. Must be a rank-2
+ tensor, where the innermost dimension is the batch index, and the
+ outermost dimension must be equal to self.num_dims. Must be a tensor or
+ numpy array of type self.float_dtype.
+ **kwargs: Arguments to be passed to the underlying distribution.nllfun().
+
+ Returns:
+ A tensor of the same type and shape as input `x`, containing the loss at
+ each element of `x`. These "losses" are actually negative log-likelihoods
+ (as produced by distribution.nllfun()) and so they are not actually
+ bounded from below by zero. You'll probably want to minimize their sum or
+ mean.
+ """
+ assert len(x.shape) == 2
+ return self.distribution.nllfun(x, self.alpha(), self.scale(), **kwargs)
+
+
+class RobustLoss(nn.Module):
+ def __init__(
+ self,
+ weight: float = 1.0,
+ adaptive: bool = False,
+ scale: float = 1.0,
+ alpha: float = 1.0,
+ input_fn: str = "linear",
+ output_fn: str = "linear",
+ num_dims: int = 3,
+ ):
+ super().__init__()
+ self.name: str = self.__class__.__name__
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+ self.weight: float = weight
+ if not adaptive:
+ self.loss = AdaptiveLossFunction(
+ num_dims=num_dims,
+ alpha_lo=alpha,
+ alpha_hi=alpha,
+ scale_lo=scale,
+ scale_init=scale,
+ )
+ else:
+ self.loss = AdaptiveLossFunction(
+ num_dims=num_dims, alpha_init=alpha, scale_init=scale
+ )
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ error = self.input_fn(input) - self.input_fn(target)
+ loss_map = self.loss.lossfun(error.reshape(-1, error.shape[-1]))
+
+ mean_error = masked_mean(
+ data=loss_map.view(*error.shape).mean(dim=-1), mask=mask, dim=(-1,)
+ ).mean(dim=-1)
+ mean_error = self.output_fn(mean_error)
+ return mean_error
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ alpha=config["alpha"],
+ scale=config["scale"],
+ adaptive=config["adaptive"],
+ output_fn=config["output_fn"],
+ input_fn=config["input_fn"],
+ )
+ return obj
diff --git a/unik3d/ops/losses/scale.py b/unik3d/ops/losses/scale.py
new file mode 100644
index 0000000000000000000000000000000000000000..080fac9e162b3dede7a0084e7bdc8e4275dab129
--- /dev/null
+++ b/unik3d/ops/losses/scale.py
@@ -0,0 +1,87 @@
+import torch
+import torch.nn as nn
+
+from unik3d.utils.constants import VERBOSE
+from unik3d.utils.misc import profile_method
+
+from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile
+
+
+class Scale(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ output_fn: str = "sqrt",
+ input_fn: str = "disp",
+ fn: str = "l1",
+ quantile: float = 0.0,
+ gamma: float = 1.0,
+ alpha: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.name: str = self.__class__.__name__
+ self.weight: float = weight
+ self.dims = [-2, -1]
+ self.output_fn = FNS[output_fn]
+ self.input_fn = FNS[input_fn]
+ self.fn = REGRESSION_DICT[fn]
+ self.gamma = gamma
+ self.alpha = alpha
+ self.quantile = quantile
+ self.eps = eps
+
+ @profile_method(verbose=VERBOSE)
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ quality: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ mask = mask.bool()
+ input = self.input_fn(input.float())
+ target = self.input_fn(target.float())
+ error = self.fn(target - input, alpha=self.alpha, gamma=self.gamma)
+
+ if self.quantile > 0.0:
+ if quality is not None:
+ for quality_level in [1, 2]:
+ current_quality = quality == quality_level
+ if current_quality.sum() > 0:
+ error_qtl = error[current_quality].detach().abs()
+ mask_qtl = error_qtl < masked_quantile(
+ error_qtl,
+ mask[current_quality],
+ dims=[1, 2, 3],
+ q=1 - self.quantile * quality_level,
+ ).view(-1, 1, 1, 1)
+ mask[current_quality] = mask[current_quality] & mask_qtl
+ else:
+ error_qtl = error.detach().abs()
+ mask = mask & (
+ error_qtl
+ < masked_quantile(
+ error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile
+ ).view(-1, 1, 1, 1)
+ )
+
+ error_image = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(1, 2, 3)
+
+ error_image = self.output_fn(error_image)
+ return error_image
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ input_fn=config["input_fn"],
+ fn=config["fn"],
+ output_fn=config["output_fn"],
+ gamma=config["gamma"],
+ alpha=config["alpha"],
+ quantile=config.get("quantile", 0.1),
+ )
+ return obj
diff --git a/unik3d/ops/losses/silog.py b/unik3d/ops/losses/silog.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7b4d4b46ba1db34c328a507a015abbb46b4a3e2
--- /dev/null
+++ b/unik3d/ops/losses/silog.py
@@ -0,0 +1,127 @@
+import torch
+import torch.nn as nn
+
+from unik3d.utils.constants import VERBOSE
+from unik3d.utils.misc import profile_method
+
+from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var,
+ masked_quantile)
+
+
+class SILog(nn.Module):
+ def __init__(
+ self,
+ weight: float,
+ input_fn: str = "linear",
+ output_fn: str = "sqrt",
+ fn: str = "l1",
+ integrated: bool = False,
+ dims: bool = (-3, -2, -1),
+ quantile: float = 0.0,
+ alpha: float = 1.0,
+ gamma: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.name: str = self.__class__.__name__
+ self.weight: float = weight
+
+ self.dims = dims
+ self.input_fn = FNS[input_fn]
+ self.output_fn = FNS[output_fn]
+ self.fn = REGRESSION_DICT[fn]
+ self.eps: float = eps
+ self.integrated = integrated
+ self.quantile = quantile
+ self.alpha = alpha
+ self.gamma = gamma
+
+ @profile_method(verbose=VERBOSE)
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def forward(
+ self,
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ si: torch.Tensor | None = None,
+ quality: torch.Tensor | None = None,
+ **kwargs,
+ ) -> torch.Tensor:
+ mask = mask.bool()
+
+ if si.any():
+ rescale = torch.stack(
+ [x[m > 0].median() for x, m in zip(target, target)]
+ ) / torch.stack([x[m > 0].detach().median() for x, m in zip(input, target)])
+ if rescale.isnan().any():
+ print(
+ "NaN in rescale", rescale.isnan().squeeze(), mask.sum(dim=[1, 2, 3])
+ )
+ rescale = torch.nan_to_num(rescale, nan=1.0)
+ input = (1 - si.int()).view(-1, 1, 1, 1) * input + (
+ rescale * si.int()
+ ).view(-1, 1, 1, 1) * input
+
+ error = self.input_fn(input.float()) - self.input_fn(target.float())
+ if quality is not None:
+ for quality_level in [1, 2]:
+ current_quality = quality == quality_level
+ if current_quality.sum() > 0:
+ error_qtl = error[current_quality].detach().abs()
+ mask_qtl = error_qtl < masked_quantile(
+ error_qtl,
+ mask[current_quality],
+ dims=[1, 2, 3],
+ q=1 - self.quantile * quality_level,
+ ).view(-1, 1, 1, 1)
+ mask[current_quality] = mask[current_quality] & mask_qtl
+ else:
+ error_qtl = error.detach().abs()
+ mask = mask & (
+ error_qtl
+ < masked_quantile(
+ error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile
+ ).view(-1, 1, 1, 1)
+ )
+
+ mean_error, var_error = masked_mean_var(
+ data=error, mask=mask, dim=self.dims, keepdim=False
+ )
+ if var_error.ndim > 1:
+ var_error = var_error.mean(dim=-1)
+
+ if self.integrated > 0.0:
+ scale_error = masked_mean(
+ self.fn(error, alpha=self.alpha, gamma=self.gamma),
+ mask=mask,
+ dim=self.dims,
+ ).reshape(-1)
+ var_error = var_error + self.integrated * scale_error
+
+ out_loss = self.output_fn(var_error)
+ if out_loss.isnan().any():
+ print(
+ "NaN in SILog variance, input, target, mask, target>0, error",
+ var_error.isnan().squeeze(),
+ input[mask].isnan().any(),
+ target[mask].isnan().any(),
+ mask.any(dim=[1, 2, 3]),
+ (target > 0.0).any(dim=[1, 2, 3]),
+ error[mask].isnan().any(),
+ )
+ return out_loss
+
+ @classmethod
+ def build(cls, config):
+ obj = cls(
+ weight=config["weight"],
+ dims=config["dims"],
+ output_fn=config["output_fn"],
+ input_fn=config["input_fn"],
+ fn=config["fn"],
+ alpha=config["alpha"],
+ gamma=config["gamma"],
+ integrated=config.get("integrated", False),
+ quantile=config["quantile"],
+ )
+ return obj
diff --git a/unik3d/ops/losses/utils.py b/unik3d/ops/losses/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a0dc5928a0a71448ba27e5e4b19a2f0f8436fd9
--- /dev/null
+++ b/unik3d/ops/losses/utils.py
@@ -0,0 +1,314 @@
+from math import log, pi, prod
+from typing import Any, Dict, List, Optional, Tuple
+
+import torch
+
+FNS = {
+ "sqrto": lambda x: torch.sqrt(x + 1),
+ "sqrt": lambda x: torch.sqrt(x + 1e-4),
+ "log": lambda x: torch.log(x + 1e-4),
+ "log1": lambda x: torch.log(x + 1),
+ # transition from log(1/x) to 1/x at x=100
+ # if x -> 0 : log(1/x), if x -> inf : log(1+1/x) -> 1/x + hot
+ "log1i": lambda x: torch.log(1 + 50 / (1e-4 + x)),
+ "log10": lambda x: torch.log10(1e-4 + x),
+ "log2": lambda x: torch.log2(1e-4 + x),
+ "linear": lambda x: x,
+ "square": torch.square,
+ "disp": lambda x: 1 / (x + 1e-4),
+ "disp1": lambda x: 1 / (1 + x),
+}
+
+
+FNS_INV = {
+ "sqrt": torch.square,
+ "log": torch.exp,
+ "log1": lambda x: torch.exp(x) - 1,
+ "linear": lambda x: x,
+ "square": torch.sqrt,
+ "disp": lambda x: 1 / x,
+}
+
+
+def masked_mean_var(
+ data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True
+):
+ if mask is None:
+ return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim)
+ # if data[mask].isnan().any():
+ # print("Warning: NaN in masked_mean_var, valid_pixels before and after", mask.sum(dim=dim).squeeze(), (mask & ~data.isnan()).sum(dim=dim).squeeze())
+ mask = (mask & ~data.isnan().any(dim=1, keepdim=True)).float()
+ data = torch.nan_to_num(data, nan=0.0)
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+ mask_sum, min=1.0
+ )
+ mask_var = torch.sum(
+ mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
+ ) / torch.clamp(mask_sum, min=1.0)
+ if not keepdim:
+ mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim)
+ return mask_mean, mask_var
+
+
+def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]):
+ if mask is None:
+ return data.mean(dim=dim, keepdim=True)
+ mask = mask.float()
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+ mask_mean = torch.sum(
+ torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True
+ ) / mask_sum.clamp(min=1.0)
+ return mask_mean
+
+
+def masked_quantile(
+ data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float
+):
+ """
+ Compute the quantile of the data only where the mask is 1 along specified dimensions.
+
+ Args:
+ data (torch.Tensor): The input data tensor.
+ mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered.
+ dims (list of int): The dimensions to compute the quantile over.
+ q (float): The quantile to compute, must be between 0 and 1.
+
+ Returns:
+ torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values.
+ """
+ masked_data = data * mask if mask is not None else data
+
+ # Get a list of all dimensions
+ all_dims = list(range(masked_data.dim()))
+
+ # Revert negative dimensions
+ dims = [d % masked_data.dim() for d in dims]
+
+ # Find the dimensions to keep (not included in the `dims` list)
+ keep_dims = [d for d in all_dims if d not in dims]
+
+ # Permute dimensions to bring `dims` to the front
+ permute_order = dims + keep_dims
+ permuted_data = masked_data.permute(permute_order)
+
+ # Reshape into 2D: (-1, remaining_dims)
+ collapsed_shape = (
+ -1,
+ prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]),
+ )
+ reshaped_data = permuted_data.reshape(collapsed_shape)
+ if mask is None:
+ return torch.quantile(reshaped_data, q, dim=0)
+
+ permuted_mask = mask.permute(permute_order)
+ reshaped_mask = permuted_mask.reshape(collapsed_shape)
+
+ # Calculate quantile along the first dimension where mask is true
+ quantiles = []
+ for i in range(reshaped_data.shape[1]):
+ valid_data = reshaped_data[:, i][reshaped_mask[:, i]]
+ if valid_data.numel() == 0:
+ # print("Warning: No valid data found for quantile calculation.")
+ quantiles.append(reshaped_data[:, i].min() * 0.99)
+ else:
+ quantiles.append(torch.quantile(valid_data, q, dim=0))
+
+ # Stack back into a tensor with reduced dimensions
+ quantiles = torch.stack(quantiles)
+ quantiles = quantiles.reshape(
+ [permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]
+ )
+
+ return quantiles
+
+
+def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+ ndim = data.ndim
+ data = data.flatten(ndim - len(dim))
+ mask = mask.flatten(ndim - len(dim))
+ mask_median = torch.median(data[..., mask], dim=-1).values
+ return mask_median
+
+
+def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]):
+ ndim = data.ndim
+ data = data.flatten(ndim - len(dim))
+ mask = mask.flatten(ndim - len(dim))
+ mask_median = torch.median(data[mask], dim=-1, keepdim=True).values
+ mask_mad = masked_mean((data - mask_median).abs(), mask, dim=(-1,))
+ return mask_median, mask_mad
+
+
+def masked_weighted_mean_var(
+ data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...]
+):
+ if mask is None:
+ return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True)
+ mask = mask.float()
+ mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum(
+ mask * weights, dim=dim, keepdim=True
+ ).clamp(min=1.0)
+ # V1**2 - V2, V1: sum w_i, V2: sum w_i**2
+ denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum(
+ (mask * weights).square(), dim=dim, keepdim=True
+ )
+ # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd)
+ correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp(
+ min=1.0
+ )
+ mask_var = correction_factor * torch.sum(
+ weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True
+ )
+ return mask_mean, mask_var
+
+
+def stable_masked_mean_var(
+ input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int]
+):
+ # recalculate mask with points in 95% confidence interval
+ input_detach = input.detach()
+ input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim)
+ target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim)
+ input_std = (input_var).clip(min=1e-6).sqrt()
+ target_std = (target_var).clip(min=1e-6).sqrt()
+ stable_points_input = torch.logical_and(
+ input_detach > input_mean - 1.96 * input_std,
+ input_detach < input_mean + 1.96 * input_std,
+ )
+ stable_points_target = torch.logical_and(
+ target > target_mean - 1.96 * target_std,
+ target < target_mean + 1.96 * target_std,
+ )
+ stable_mask = stable_points_target & stable_points_input & mask
+
+ input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim)
+ target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim)
+ return input_mean, input_var, target_mean, target_var, stable_mask
+
+
+def ssi(
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ dim: list[int],
+ *args,
+ **kwargs,
+) -> torch.Tensor:
+ # recalculate mask with points in 95% confidence interval
+ input_mean, input_var, target_mean, target_var, stable_mask = (
+ stable_masked_mean_var(input, target, mask, dim)
+ )
+
+ # if target_var.min() < 1e-6:
+ # print(
+ # "Warning: target low",
+ # list(zip(target_var.squeeze().cpu().numpy(),
+ # target_mean.squeeze().cpu().numpy(),
+ # mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy(),
+ # stable_mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy()))
+ # )
+ # if input_var.min() < 1e-6:
+ # print("Warning: input variance is too low", input_var.squeeze(), input_mean.squeeze())
+ if input_var.isnan().any():
+ print("Warning: input variance is nan")
+ if input_var.isinf().any():
+ print("Warning: input variance is isinf")
+ if input_mean.isnan().any():
+ print("Warning: input m is nan")
+ if input_mean.isinf().any():
+ print("Warning: input m is isinf")
+ target_normalized = (target - target_mean) / FNS["sqrt"](target_var)
+ input_normalized = (input - input_mean) / FNS["sqrt"](input_var)
+ return input_normalized, target_normalized, stable_mask
+
+
+def ssi_nd(
+ input: torch.Tensor,
+ target: torch.Tensor,
+ mask: torch.Tensor,
+ dim: list[int],
+ input_info: torch.Tensor,
+ target_info: torch.Tensor,
+) -> torch.Tensor:
+ input_mean, input_var, target_mean, target_var, stable_mask = (
+ stable_masked_mean_var(input_info, target_info, mask, dim)
+ )
+ if input_var.isnan().any():
+ print("Warning: input variance is nan")
+ if input_var.isinf().any():
+ print("Warning: input variance is isinf")
+ if input_mean.isnan().any():
+ print("Warning: input m is nan")
+ if input_mean.isinf().any():
+ print("Warning: input m is isinf")
+ target_normalized = (target - target_mean) / FNS["sqrt"](target_var)
+ input_normalized = (input - input_mean) / FNS["sqrt"](input_var)
+ return input_normalized, target_normalized, stable_mask
+
+
+def stable_ssi(
+ input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int]
+) -> torch.Tensor:
+ input_mean, input_var = masked_mean_var(input, mask=mask, dim=dim)
+ target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim)
+ target_normalized = (target - target_mean) / torch.sqrt(target_var.clamp(min=1e-6))
+ input_normalized = (input - input_mean) / torch.sqrt(input_var.clamp(min=1e-6))
+ return input_normalized, target_normalized, mask
+
+
+def ind2sub(idx, cols):
+ r = idx // cols
+ c = idx % cols
+ return r, c
+
+
+def sub2ind(r, c, cols):
+ idx = r * cols + c
+ return idx
+
+
+def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor:
+ return (input_tensor / gamma) ** 2
+
+
+def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor:
+ return torch.abs(input_tensor)
+
+
+def charbonnier(
+ input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
+) -> torch.Tensor:
+ return gamma * torch.sqrt(torch.square(input_tensor / gamma) + 1) - 1
+
+
+def cauchy(
+ input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
+) -> torch.Tensor:
+ return gamma * torch.log(torch.square(input_tensor / gamma) + 1) + log(gamma * pi)
+
+
+def geman_mcclure(
+ input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs
+) -> torch.Tensor:
+ return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma)
+
+
+def robust_loss(
+ input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs
+) -> torch.Tensor:
+ coeff = abs(alpha - 2) / alpha
+ power = torch.square(input_tensor / gamma) / abs(alpha - 2) + 1
+ return (
+ gamma * coeff * (torch.pow(power, alpha / 2) - 1)
+ ) # mult gamma to keep grad magnitude invariant wrt gamma
+
+
+REGRESSION_DICT = {
+ "l2": l2,
+ "l1": l1,
+ "cauchy": cauchy,
+ "charbonnier": charbonnier,
+ "geman_mcclure": geman_mcclure,
+ "robust_loss": robust_loss,
+}
diff --git a/unik3d/ops/scheduler.py b/unik3d/ops/scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c7776b667a4c79603b14615223e5cc12857ab06
--- /dev/null
+++ b/unik3d/ops/scheduler.py
@@ -0,0 +1,128 @@
+import weakref
+
+import numpy as np
+
+
+class PlainCosineScheduler(object):
+ def __init__(
+ self,
+ klass,
+ key,
+ warmup_iters,
+ total_iters,
+ overwrite=False,
+ init_value=None,
+ base_value=None,
+ final_value=None,
+ step_init=-1,
+ ):
+ super().__init__()
+ self.iter = step_init
+ self.overwrite = overwrite
+ self.base_value = base_value
+ self.init_value = init_value if init_value is not None else base_value
+ self.final_value = final_value
+ self.total_iters = total_iters
+ self.warmup_iters = warmup_iters
+ self.key = key
+ self.klass = klass
+ self.schedulers = [self.get_scheduler()]
+
+ def get_scheduler(self):
+ init_value = self.init_value
+ base_value = self.base_value
+ final_value = self.final_value
+ warmup_iters = self.warmup_iters
+ total_iters = self.total_iters
+
+ # normalize in 0,1, then apply function (power) and denormalize
+ normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
+ normalized_schedule = np.power(normalized_schedule, 1)
+ warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
+
+ # main scheduling
+ iters = np.arange(total_iters - warmup_iters + 1)
+ schedule = final_value + 0.5 * (base_value - final_value) * (
+ 1 + np.cos(np.pi * iters / (len(iters) - 1))
+ )
+ return np.concatenate((warmup_schedule, schedule))
+
+ def step(self):
+ self.iter = self.iter + 1
+ vals = self[self.iter]
+ for i, val in enumerate(vals):
+ setattr(self.klass, self.key, val)
+
+ def __getitem__(self, it):
+ it = min(it, self.total_iters)
+ return [scheduler[it] for scheduler in self.schedulers]
+
+
+class CosineScheduler(object):
+ def __init__(
+ self,
+ optimizer,
+ warmup_iters,
+ total_iters,
+ key,
+ overwrite=False,
+ init_value=None,
+ base_value=None,
+ final_value=None,
+ flat_iters=0,
+ step_init=-1,
+ ):
+ super().__init__()
+ self.iter = step_init
+ self.overwrite = overwrite
+ self.optimizer = optimizer
+ self.base_value = base_value
+ self.init_value = init_value
+ self.final_value = final_value
+ self.total_iters = total_iters
+ self.warmup_iters = warmup_iters
+ self.flat_iters = flat_iters
+ self.key = key
+ self.schedulers = [
+ self.get_schedulers(group) for group in optimizer.param_groups
+ ]
+
+ def get_schedulers(self, group):
+ init_value = group.get(self.key + "_init", self.init_value)
+ base_value = group.get(self.key + "_base", self.base_value)
+ final_value = group.get(self.key + "_final", self.final_value)
+ warmup_iters = self.warmup_iters
+ total_iters = self.total_iters
+ flat_iters = self.flat_iters
+ if self.overwrite:
+ final_value = self.final_value
+
+ # normalize in 0,1, then apply function (power) and denormalize
+ normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True)
+ normalized_schedule = np.power(normalized_schedule, 1)
+ warmup_schedule = (base_value - init_value) * normalized_schedule + init_value
+
+ # flat scheduling]
+ flat_schedule = np.ones(flat_iters) * base_value
+
+ # decay scheduling
+ decay_iters = np.arange(total_iters - warmup_iters - flat_iters + 1)
+ decay_schedule = final_value + 0.5 * (base_value - final_value) * (
+ 1 + np.cos(np.pi * decay_iters / (len(decay_iters) - 1))
+ )
+ return np.concatenate((warmup_schedule, flat_schedule, decay_schedule))
+
+ def step(self):
+ self.iter = self.iter + 1
+ vals = self[self.iter]
+ for group, val in zip(self.optimizer.param_groups, vals):
+ if isinstance(group[self.key], (tuple, list)):
+ val = (val, *group[self.key][1:])
+ group[self.key] = val
+
+ def __getitem__(self, it):
+ it = min(it, self.total_iters)
+ return [scheduler[it] for scheduler in self.schedulers]
+
+ def get(self):
+ return [group[self.key] for group in self.optimizer.param_groups]
diff --git a/unik3d/utils/__init__.py b/unik3d/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6a7f7619dab9e9c117bc0e9348175b10f072dc7
--- /dev/null
+++ b/unik3d/utils/__init__.py
@@ -0,0 +1,41 @@
+from .camera import invert_pinhole, project_pinhole, unproject_pinhole
+from .distributed import (barrier, get_dist_info, get_rank, get_world_size,
+ is_main_process, setup_multi_processes, setup_slurm,
+ sync_tensor_across_gpus)
+from .evaluation_depth import (DICT_METRICS, DICT_METRICS_3D, eval_3d,
+ eval_depth)
+from .geometric import spherical_zbuffer_to_euclidean, unproject_points
+from .misc import (format_seconds, get_params, identity, recursive_index,
+ remove_padding, to_cpu)
+from .validation import validate
+from .visualization import colorize, image_grid, log_train_artifacts
+
+__all__ = [
+ "eval_depth",
+ "eval_3d",
+ "DICT_METRICS",
+ "DICT_METRICS_3D",
+ "colorize",
+ "image_grid",
+ "log_train_artifacts",
+ "format_seconds",
+ "remove_padding",
+ "get_params",
+ "identity",
+ "is_main_process",
+ "setup_multi_processes",
+ "setup_slurm",
+ "sync_tensor_across_gpus",
+ "barrier",
+ "get_world_size",
+ "get_rank",
+ "unproject_points",
+ "spherical_zbuffer_to_euclidean",
+ "validate",
+ "get_dist_info",
+ "to_cpu",
+ "recursive_index",
+ "invert_pinhole",
+ "unproject_pinhole",
+ "project_pinhole",
+]
diff --git a/unik3d/utils/camera.py b/unik3d/utils/camera.py
new file mode 100644
index 0000000000000000000000000000000000000000..561c3b3377237043c0f2e94451607fb35d694363
--- /dev/null
+++ b/unik3d/utils/camera.py
@@ -0,0 +1,1487 @@
+"""
+Author: Luigi Piccinelli
+Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/)
+"""
+
+from copy import deepcopy
+
+import numpy as np
+import torch
+import torch.nn.functional as F
+
+from .coordinate import coords_grid
+from .misc import recursive_apply, squeeze_list
+
+
+def invert_pinhole(K):
+ fx = K[..., 0, 0]
+ fy = K[..., 1, 1]
+ cx = K[..., 0, 2]
+ cy = K[..., 1, 2]
+ K_inv = torch.zeros_like(K)
+ K_inv[..., 0, 0] = 1.0 / fx
+ K_inv[..., 1, 1] = 1.0 / fy
+ K_inv[..., 0, 2] = -cx / fx
+ K_inv[..., 1, 2] = -cy / fy
+ K_inv[..., 2, 2] = 1.0
+ return K_inv
+
+
+def unproject_pinhole(depth, K):
+ b, _, h, w = depth.shape
+ K_inv = invert_pinhole(K)
+ grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W]
+ grid_flat = grid.reshape(b, -1, h * w) # [B, 3, H*W]
+ cam_coords = K_inv @ grid_flat
+ pcd = cam_coords.reshape(b, -1, h, w) * depth
+ return pcd
+
+
+def project_pinhole(pcd, K):
+ b, _, h, w = pcd.shape
+ pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W]
+ cam_coords = K @ pcd_flat
+ pcd_proj = cam_coords[:, :2] / cam_coords[:, 2:].clamp(min=0.01)
+ pcd_proj = pcd_proj.reshape(b, 2, h, w)
+ return pcd_proj
+
+
+class Camera:
+ def __init__(self, params=None, K=None):
+ if params.ndim == 1:
+ params = params.unsqueeze(0)
+
+ if K is None:
+ K = (
+ torch.eye(3, device=params.device, dtype=params.dtype)
+ .unsqueeze(0)
+ .repeat(params.shape[0], 1, 1)
+ )
+ K[..., 0, 0] = params[..., 0]
+ K[..., 1, 1] = params[..., 1]
+ K[..., 0, 2] = params[..., 2]
+ K[..., 1, 2] = params[..., 3]
+
+ self.params = params
+ self.K = K
+ self.overlap_mask = None
+ self.projection_mask = None
+
+ def project(self, xyz):
+ raise NotImplementedError
+
+ def unproject(self, uv):
+ raise NotImplementedError
+
+ def get_projection_mask(self):
+ return self.projection_mask
+
+ def get_overlap_mask(self):
+ return self.overlap_mask
+
+ def reconstruct(self, depth):
+ id_coords = coords_grid(
+ 1, depth.shape[-2], depth.shape[-1], device=depth.device
+ )
+ rays = self.unproject(id_coords)
+ return (
+ rays / rays[:, -1:].clamp(min=1e-4) * depth.clamp(min=1e-4)
+ ) # assumption z>0!!!
+
+ def resize(self, factor):
+ self.K[..., :2, :] *= factor
+ self.params[..., :4] *= factor
+ return self
+
+ def to(self, device, non_blocking=False):
+ self.params = self.params.to(device, non_blocking=non_blocking)
+ self.K = self.K.to(device, non_blocking=non_blocking)
+ return self
+
+ def get_rays(self, shapes, noisy=False):
+ b, h, w = shapes
+ uv = coords_grid(1, h, w, device=self.K.device, noisy=noisy)
+ rays = self.unproject(uv)
+ return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4)
+
+ def get_pinhole_rays(self, shapes, noisy=False):
+ b, h, w = shapes
+ uv = coords_grid(b, h, w, device=self.K.device, homogeneous=True, noisy=noisy)
+ rays = (invert_pinhole(self.K) @ uv.reshape(b, 3, -1)).reshape(b, 3, h, w)
+ return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4)
+
+ def flip(self, H, W, direction="horizontal"):
+ new_cx = (
+ W - self.params[:, 2] if direction == "horizontal" else self.params[:, 2]
+ )
+ new_cy = H - self.params[:, 3] if direction == "vertical" else self.params[:, 3]
+ self.params = torch.stack(
+ [self.params[:, 0], self.params[:, 1], new_cx, new_cy], dim=1
+ )
+ self.K[..., 0, 2] = new_cx
+ self.K[..., 1, 2] = new_cy
+ return self
+
+ def clone(self):
+ return deepcopy(self)
+
+ def crop(self, left, top, right=None, bottom=None):
+ self.K[..., 0, 2] -= left
+ self.K[..., 1, 2] -= top
+ self.params[..., 2] -= left
+ self.params[..., 3] -= top
+ return self
+
+ # helper function to get how fov changes based on new original size and new size
+ def get_new_fov(self, new_shape, original_shape):
+ new_hfov = 2 * torch.atan(
+ self.params[..., 2] / self.params[..., 0] * new_shape[1] / original_shape[1]
+ )
+ new_vfov = 2 * torch.atan(
+ self.params[..., 3] / self.params[..., 1] * new_shape[0] / original_shape[0]
+ )
+ return new_hfov, new_vfov
+
+ def mask_overlap_projection(self, projected):
+ B, _, H, W = projected.shape
+ id_coords = coords_grid(B, H, W, device=projected.device)
+
+ # check for mask where flow would overlap with other part of the image
+ # eleemtns coming from the border are then masked out
+ flow = projected - id_coords
+ gamma = 0.1
+ sample_grid = gamma * flow + id_coords # sample along the flow
+ sample_grid[:, 0] = sample_grid[:, 0] / (W - 1) * 2 - 1
+ sample_grid[:, 1] = sample_grid[:, 1] / (H - 1) * 2 - 1
+ sampled_flow = F.grid_sample(
+ flow,
+ sample_grid.permute(0, 2, 3, 1),
+ mode="bilinear",
+ align_corners=False,
+ padding_mode="border",
+ )
+ mask = (
+ (1 - gamma) * torch.norm(flow, dim=1, keepdim=True)
+ < torch.norm(sampled_flow, dim=1, keepdim=True)
+ ) | (torch.norm(flow, dim=1, keepdim=True) < 1)
+ return mask
+
+ def _pad_params(self):
+ # Ensure params are padded to length 16
+ if self.params.shape[1] < 16:
+ padding = torch.zeros(
+ 16 - self.params.shape[1],
+ device=self.params.device,
+ dtype=self.params.dtype,
+ )
+ padding = padding.view(*[(self.params.ndim - 1) * [1] + [-1]])
+ padding = padding.repeat(self.params.shape[:-1] + (1,))
+ return torch.cat([self.params, padding], dim=-1)
+ return self.params
+
+ @staticmethod
+ def flatten_cameras(cameras): # -> list[Camera]:
+ # Recursively flatten BatchCamera into primitive cameras
+ flattened_cameras = []
+ for camera in cameras:
+ if isinstance(camera, BatchCamera):
+ flattened_cameras.extend(BatchCamera.flatten_cameras(camera.cameras))
+ elif isinstance(camera, list):
+ flattened_cameras.extend(camera)
+ else:
+ flattened_cameras.append(camera)
+ return flattened_cameras
+
+ @staticmethod
+ def _stack_or_cat_cameras(cameras, func, **kwargs):
+ # Generalized method to handle stacking or concatenation
+ flat_cameras = BatchCamera.flatten_cameras(cameras)
+ K_matrices = [camera.K for camera in flat_cameras]
+ padded_params = [camera._pad_params() for camera in flat_cameras]
+
+ stacked_K = func(K_matrices, **kwargs)
+ stacked_params = func(padded_params, **kwargs)
+
+ # Keep track of the original classes
+ original_class = [x.__class__.__name__ for x in flat_cameras]
+ return BatchCamera(stacked_params, stacked_K, original_class, flat_cameras)
+
+ @classmethod
+ def __torch_function__(cls, func, types, args=(), kwargs=None):
+ if kwargs is None:
+ kwargs = {}
+
+ if func is torch.cat:
+ return Camera._stack_or_cat_cameras(args[0], func, **kwargs)
+
+ if func is torch.stack:
+ return Camera._stack_or_cat_cameras(args[0], func, **kwargs)
+
+ if func is torch.flatten:
+ return Camera._stack_or_cat_cameras(args[0], torch.cat, **kwargs)
+ return super().__torch_function__(func, types, args, kwargs)
+
+ @property
+ def device(self):
+ return self.K.device
+
+ # here we assume that cx,cy are more or less H/2 and W/2
+ @property
+ def hfov(self):
+ return 2 * torch.atan(self.params[..., 2] / self.params[..., 0])
+
+ @property
+ def vfov(self):
+ return 2 * torch.atan(self.params[..., 3] / self.params[..., 1])
+
+ @property
+ def max_fov(self):
+ return 150.0 / 180.0 * np.pi, 150.0 / 180.0 * np.pi
+
+
+class Pinhole(Camera):
+ def __init__(self, params=None, K=None):
+ assert params is not None or K is not None
+ # params = [fx, fy, cx, cy]
+ if params is None:
+ params = torch.stack(
+ [K[..., 0, 0], K[..., 1, 1], K[..., 0, 2], K[..., 1, 2]], dim=-1
+ )
+ super().__init__(params=params, K=K)
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, pcd):
+ b, _, h, w = pcd.shape
+ pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W]
+ cam_coords = self.K @ pcd_flat
+ pcd_proj = cam_coords[:, :2] / cam_coords[:, -1:].clamp(min=0.01)
+ pcd_proj = pcd_proj.reshape(b, 2, h, w)
+ invalid = (
+ (pcd_proj[:, 0] < 0)
+ & (pcd_proj[:, 0] >= w)
+ & (pcd_proj[:, 1] < 0)
+ & (pcd_proj[:, 1] >= h)
+ )
+ self.projection_mask = (~invalid).unsqueeze(1)
+ return pcd_proj
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, uv):
+ b, _, h, w = uv.shape
+ uv_flat = uv.reshape(b, 2, -1) # [B, 2, H*W]
+ uv_homogeneous = torch.cat(
+ [uv_flat, torch.ones(b, 1, h * w, device=uv.device)], dim=1
+ ) # [B, 3, H*W]
+ K_inv = torch.inverse(self.K.float())
+ xyz = K_inv @ uv_homogeneous
+ xyz = xyz / xyz[:, -1:].clip(min=1e-4)
+ xyz = xyz.reshape(b, 3, h, w)
+ self.unprojection_mask = xyz[:, -1:] > 1e-4
+ return xyz
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def reconstruct(self, depth):
+ b, _, h, w = depth.shape
+ uv = coords_grid(b, h, w, device=depth.device)
+ xyz = self.unproject(uv) * depth.clip(min=0.0)
+ return xyz
+
+
+class EUCM(Camera):
+ def __init__(self, params):
+ # params = [fx, fy, cx, cy, alpha, beta]
+ super().__init__(params=params, K=None)
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, xyz):
+ H, W = xyz.shape[-2:]
+ fx, fy, cx, cy, alpha, beta = self.params[:6].unbind(dim=1)
+ x, y, z = xyz.unbind(dim=1)
+ d = torch.sqrt(beta * (x**2 + y**2) + z**2)
+
+ x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
+ y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
+
+ Xnorm = fx * x + cx
+ Ynorm = fy * y + cy
+
+ coords = torch.stack([Xnorm, Ynorm], dim=1)
+
+ invalid = (
+ (coords[:, 0] < 0)
+ | (coords[:, 0] > W)
+ | (coords[:, 1] < 0)
+ | (coords[:, 1] > H)
+ | (z < 0)
+ )
+ self.projection_mask = (~invalid).unsqueeze(1)
+
+ return coords
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, uv):
+ u, v = uv.unbind(dim=1)
+ fx, fy, cx, cy, alpha, beta = self.params.unbind(dim=1)
+ mx = (u - cx) / fx
+ my = (v - cy) / fy
+ r_square = mx**2 + my**2
+ valid_mask = r_square < torch.where(
+ alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1))
+ )
+ sqrt_val = 1 - (2 * alpha - 1) * beta * r_square
+ mz = (1 - beta * (alpha**2) * r_square) / (
+ alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha)
+ )
+ coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5)
+
+ x = coeff * mx
+ y = coeff * my
+ z = coeff * mz
+ self.unprojection_mask = valid_mask & (z > 1e-3)
+
+ xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1)
+ return xnorm
+
+
+class Spherical(Camera):
+ def __init__(self, params):
+ # Hfov and Vofv are in radians and halved!
+ # params: [fx, fy, cx, cy, W, H, HFoV/2, VFoV/2]
+ # fx,fy,cx,cy = dummy values
+ super().__init__(params=params, K=None)
+
+ def resize(self, factor):
+ self.K[..., :2, :] *= factor
+ self.params[..., :6] *= factor
+ return self
+
+ def crop(self, left, top, right, bottom):
+ self.K[..., 0, 2] -= left
+ self.K[..., 1, 2] -= top
+ self.params[..., 2] -= left
+ self.params[..., 3] -= top
+ W, H = self.params[..., 4], self.params[..., 5]
+ angle_ratio_W = (W - left - right) / W
+ angle_ratio_H = (H - top - bottom) / H
+
+ self.params[..., 4] -= left + right
+ self.params[..., 5] -= top + bottom
+
+ # rescale hfov and vfov
+ self.params[..., 6] *= angle_ratio_W
+ self.params[..., 7] *= angle_ratio_H
+ return self
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, xyz):
+ width, height = self.params[..., 4], self.params[..., 5]
+ hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7]
+ longitude = torch.atan2(xyz[:, 0], xyz[:, 2])
+ latitude = torch.asin(xyz[:, 1] / torch.norm(xyz, dim=1).clamp(min=1e-5))
+
+ u = longitude / hfov * (width - 1) + (width - 1) / 2
+ v = latitude / vfov * (height - 1) + (height - 1) / 2
+
+ return torch.stack([u, v], dim=1)
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, uv):
+ u, v = uv.unbind(dim=1)
+
+ width, height = self.params[..., 4], self.params[..., 5]
+ hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7]
+ longitude = (u - (width - 1) / 2) / (width - 1) * hfov
+ latitude = (v - (height - 1) / 2) / (height - 1) * vfov
+ x = torch.cos(latitude) * torch.sin(longitude)
+ z = torch.cos(latitude) * torch.cos(longitude)
+ y = torch.sin(latitude)
+ unit_sphere = torch.stack([x, y, z], dim=1)
+ unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=1, keepdim=True).clip(
+ min=1e-5
+ )
+
+ return unit_sphere
+
+ def reconstruct(self, depth):
+ id_coords = coords_grid(
+ 1, depth.shape[-2], depth.shape[-1], device=depth.device
+ )
+ return self.unproject(id_coords) * depth
+
+ def get_new_fov(self, new_shape, original_shape):
+ new_hfov = 2 * self.params[..., 6] * new_shape[1] / original_shape[1]
+ new_vfov = 2 * self.params[..., 7] * new_shape[0] / original_shape[0]
+ return new_hfov, new_vfov
+
+ @property
+ def hfov(self):
+ return 2 * self.params[..., 6]
+
+ @property
+ def vfov(self):
+ return 2 * self.params[..., 7]
+
+ @property
+ def max_fov(self):
+ return 2 * np.pi, 0.9 * np.pi # avoid strong distortion on tops
+
+
+class OPENCV(Camera):
+ def __init__(self, params):
+ super().__init__(params=params, K=None)
+ # params: [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4]
+ self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6
+ assert (
+ self.params[..., 7:10].abs().sum() == 0.0
+ ), "Do not support poly division model"
+ self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6
+ self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, xyz):
+ eps = 1e-9
+ B, _, H, W = xyz.shape
+ N = H * W
+ xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3)
+
+ # Radial correction.
+ z = xyz[:, :, 2].reshape(B, N, 1)
+ z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
+ ab = xyz[:, :, :2] / z
+ r = torch.norm(ab, dim=-1, p=2, keepdim=True)
+ th = r
+
+ th_pow = torch.cat(
+ [torch.pow(th, 2 + i * 2) for i in range(3)], dim=-1
+ ) # Create powers of th (th^3, th^5, ...)
+ distortion_coeffs_num = self.params[:, 4:7].reshape(B, 1, 3)
+ distortion_coeffs_den = self.params[:, 7:10].reshape(B, 1, 3)
+ th_num = 1 + torch.sum(th_pow * distortion_coeffs_num, dim=-1, keepdim=True)
+ th_den = 1 + torch.sum(th_pow * distortion_coeffs_den, dim=-1, keepdim=True)
+
+ xr_yr = ab * th_num / th_den
+ uv_dist = xr_yr
+
+ # Tangential correction.
+ p0 = self.params[..., -6].reshape(B, 1)
+ p1 = self.params[..., -5].reshape(B, 1)
+ xr = xr_yr[:, :, 0].reshape(B, N)
+ yr = xr_yr[:, :, 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+ uv_dist_tu = uv_dist[:, :, 0] + (
+ (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
+ )
+ uv_dist_tv = uv_dist[:, :, 1] + (
+ (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
+ )
+ uv_dist = torch.stack(
+ [uv_dist_tu, uv_dist_tv], dim=-1
+ ) # Avoids in-place complaint.
+
+ # Thin Prism correction.
+ s0 = self.params[..., -4].reshape(B, 1)
+ s1 = self.params[..., -3].reshape(B, 1)
+ s2 = self.params[..., -2].reshape(B, 1)
+ s3 = self.params[..., -1].reshape(B, 1)
+ rd_4 = torch.square(rd_sq)
+ uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
+ uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
+
+ # Finally, apply standard terms: focal length and camera centers.
+ if self.params.shape[-1] == 15:
+ fx_fy = self.params[..., 0].reshape(B, 1, 1)
+ cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
+ else:
+ fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
+ cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
+ result = uv_dist * fx_fy + cx_cy
+
+ result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2)
+ invalid = (
+ (result[:, 0] < 0)
+ | (result[:, 0] > W)
+ | (result[:, 1] < 0)
+ | (result[:, 1] > H)
+ )
+ self.projection_mask = (~invalid).unsqueeze(1)
+ self.overlap_mask = self.mask_overlap_projection(result)
+
+ return result
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, uv, max_iters: int = 10):
+ eps = 1e-3
+ B, _, H, W = uv.shape
+ N = H * W
+ uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2)
+
+ if self.params.shape[-1] == 15:
+ fx_fy = self.params[..., 0].reshape(B, 1, 1)
+ cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
+ else:
+ fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
+ cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
+
+ uv_dist = (uv - cx_cy) / fx_fy
+
+ # Compute xr_yr using Newton's method.
+ xr_yr = uv_dist.clone() # Initial guess.
+ max_iters_tanprism = (
+ max_iters if self.use_thin_prism or self.use_tangential else 0
+ )
+
+ for _ in range(max_iters_tanprism):
+ uv_dist_est = xr_yr.clone()
+ xr = xr_yr[..., 0].reshape(B, N)
+ yr = xr_yr[..., 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[..., 0].reshape(B, N)
+ yr_sq = xr_yr_sq[..., 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+
+ if self.use_tangential:
+ # Tangential terms.
+ p0 = self.params[..., -6].reshape(B, 1)
+ p1 = self.params[..., -5].reshape(B, 1)
+ uv_dist_est[..., 0] = uv_dist_est[..., 0] + (
+ (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
+ )
+ uv_dist_est[..., 1] = uv_dist_est[..., 1] + (
+ (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
+ )
+
+ if self.use_thin_prism:
+ # Thin Prism terms.
+ s0 = self.params[..., -4].reshape(B, 1)
+ s1 = self.params[..., -3].reshape(B, 1)
+ s2 = self.params[..., -2].reshape(B, 1)
+ s3 = self.params[..., -1].reshape(B, 1)
+ rd_4 = torch.square(rd_sq)
+ uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
+ uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
+
+ # Compute the derivative of uv_dist w.r.t. xr_yr.
+ duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
+
+ if self.use_tangential:
+ duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1
+ offdiag = 2.0 * (xr * p1 + yr * p0)
+ duv_dist_dxr_yr[..., 0, 1] = offdiag
+ duv_dist_dxr_yr[..., 1, 0] = offdiag
+ duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0
+
+ if self.use_thin_prism:
+ xr_yr_sq_norm = xr_sq + yr_sq
+ temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
+ duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1)
+ duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1)
+ temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
+ duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2)
+ duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2)
+
+ mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
+ a = mat[:, 0, 0].reshape(-1, 1, 1)
+ b = mat[:, 0, 1].reshape(-1, 1, 1)
+ c = mat[:, 1, 0].reshape(-1, 1, 1)
+ d = mat[:, 1, 1].reshape(-1, 1, 1)
+ det = 1.0 / ((a * d) - (b * c))
+ top = torch.cat([d, -b], dim=-1)
+ bot = torch.cat([-c, a], dim=-1)
+ inv = det * torch.cat([top, bot], dim=-2)
+ inv = inv.reshape(B, N, 2, 2)
+ diff = uv_dist - uv_dist_est
+ a = inv[..., 0, 0]
+ b = inv[..., 0, 1]
+ c = inv[..., 1, 0]
+ d = inv[..., 1, 1]
+ e = diff[..., 0]
+ f = diff[..., 1]
+ step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
+ # Newton step.
+ xr_yr = xr_yr + step
+
+ # Compute theta using Newton's method.
+ xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
+ th = xr_yr_norm.clone()
+ max_iters_radial = max_iters if self.use_radial else 0
+ c = (
+ torch.tensor([2.0 * i + 3 for i in range(3)], device=self.device)
+ .reshape(1, 1, 3)
+ .repeat(B, 1, 1)
+ )
+ radial_params_num = self.params[..., 4:7].reshape(B, 1, 3)
+
+ # Trust region parameters
+ delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius
+ delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius
+ eta = 0.1 # Acceptable reduction threshold
+
+ for i in range(max_iters_radial):
+ th_sq = th * th # th^2
+ # Compute powers of th^2 up to th^(12)
+ theta_powers = torch.cat(
+ [th_sq ** (i + 1) for i in range(3)], dim=-1
+ ) # Shape: (B, N, 6)
+
+ # Compute th_radial: radial distortion model applied to th
+ th_radial = 1.0 + torch.sum(
+ theta_powers * radial_params_num, dim=-1, keepdim=True
+ )
+ th_radial = th_radial * th # Multiply by th at the end
+
+ # Compute derivative dthd_th
+ dthd_th = 1.0 + torch.sum(
+ c * radial_params_num * theta_powers, dim=-1, keepdim=True
+ )
+ dthd_th = dthd_th # Already includes derivative terms
+
+ # Compute residual
+ residual = th_radial - xr_yr_norm # Shape: (B, N, 1)
+ residual_norm = torch.norm(residual, dim=2, keepdim=True) # For each pixel
+
+ # Check for convergence
+ if torch.max(torch.abs(residual)) < eps:
+ break
+
+ # Avoid division by zero by adding a small epsilon
+ safe_dthd_th = dthd_th.clone()
+ zero_derivative_mask = dthd_th.abs() < eps
+ safe_dthd_th[zero_derivative_mask] = eps
+
+ # Compute Newton's step
+ step = -residual / safe_dthd_th
+
+ # Compute predicted reduction
+ predicted_reduction = -(residual * step).sum(dim=2, keepdim=True)
+
+ # Adjust step based on trust region
+ step_norm = torch.norm(step, dim=2, keepdim=True)
+ over_trust_mask = step_norm > delta
+
+ # Scale step if it exceeds trust radius
+ step_scaled = step.clone()
+ step_scaled[over_trust_mask] = step[over_trust_mask] * (
+ delta[over_trust_mask] / step_norm[over_trust_mask]
+ )
+
+ # Update theta
+ th_new = th + step_scaled
+
+ # Compute new residual
+ th_sq_new = th_new * th_new
+ theta_powers_new = torch.cat(
+ [th_sq_new ** (j + 1) for j in range(3)], dim=-1
+ )
+ th_radial_new = 1.0 + torch.sum(
+ theta_powers_new * radial_params_num, dim=-1, keepdim=True
+ )
+ th_radial_new = th_radial_new * th_new
+ residual_new = th_radial_new - xr_yr_norm
+ residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True)
+
+ # Compute actual reduction
+ actual_reduction = residual_norm - residual_new_norm
+
+ # Compute ratio of actual to predicted reduction
+ # predicted_reduction[predicted_reduction.abs() < eps] = eps #* torch.sign(predicted_reduction[predicted_reduction.abs() < eps])
+ rho = actual_reduction / predicted_reduction
+ rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0
+
+ # Update trust radius delta
+ delta_update_mask = rho > 0.5
+ delta[delta_update_mask] = torch.min(
+ 2.0 * delta[delta_update_mask], delta_max
+ )
+
+ delta_decrease_mask = rho < 0.2
+ delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask]
+
+ # Accept or reject the step
+ accept_step_mask = rho > eta
+ th = torch.where(accept_step_mask, th_new, th)
+
+ # Compute the ray direction using theta and xr_yr.
+ close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
+ ray_dir = torch.where(close_to_zero, xr_yr, th / xr_yr_norm * xr_yr)
+
+ ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
+ ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2)
+
+ return ray
+
+
+class Fisheye624(Camera):
+ def __init__(self, params):
+ super().__init__(params=params, K=None)
+ # params: [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4]
+ self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6
+ self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6
+ self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, xyz):
+ eps = 1e-9
+ B, _, H, W = xyz.shape
+ N = H * W
+ xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3)
+
+ # Radial correction.
+ z = xyz[:, :, 2].reshape(B, N, 1)
+ z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z)
+ ab = xyz[:, :, :2] / z
+ r = torch.norm(ab, dim=-1, p=2, keepdim=True)
+ th = torch.atan(r)
+ th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r)
+
+ th_pow = torch.cat(
+ [torch.pow(th, 3 + i * 2) for i in range(6)], dim=-1
+ ) # Create powers of th (th^3, th^5, ...)
+ distortion_coeffs = self.params[:, 4:10].reshape(B, 1, 6)
+ th_k = th + torch.sum(th_pow * distortion_coeffs, dim=-1, keepdim=True)
+
+ xr_yr = th_k * th_divr
+ uv_dist = xr_yr
+
+ # Tangential correction.
+ p0 = self.params[..., -6].reshape(B, 1)
+ p1 = self.params[..., -5].reshape(B, 1)
+ xr = xr_yr[:, :, 0].reshape(B, N)
+ yr = xr_yr[:, :, 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+ uv_dist_tu = uv_dist[:, :, 0] + (
+ (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
+ )
+ uv_dist_tv = uv_dist[:, :, 1] + (
+ (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
+ )
+ uv_dist = torch.stack(
+ [uv_dist_tu, uv_dist_tv], dim=-1
+ ) # Avoids in-place complaint.
+
+ # Thin Prism correction.
+ s0 = self.params[..., -4].reshape(B, 1)
+ s1 = self.params[..., -3].reshape(B, 1)
+ s2 = self.params[..., -2].reshape(B, 1)
+ s3 = self.params[..., -1].reshape(B, 1)
+ rd_4 = torch.square(rd_sq)
+ uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
+ uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
+
+ # Finally, apply standard terms: focal length and camera centers.
+ if self.params.shape[-1] == 15:
+ fx_fy = self.params[..., 0].reshape(B, 1, 1)
+ cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
+ else:
+ fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
+ cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
+ result = uv_dist * fx_fy + cx_cy
+
+ result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2)
+ invalid = (
+ (result[:, 0] < 0)
+ | (result[:, 0] > W)
+ | (result[:, 1] < 0)
+ | (result[:, 1] > H)
+ )
+ self.projection_mask = (~invalid).unsqueeze(1)
+ self.overlap_mask = self.mask_overlap_projection(result)
+
+ return result
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, uv, max_iters: int = 10):
+ eps = 1e-3
+ B, _, H, W = uv.shape
+ N = H * W
+ uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2)
+
+ if self.params.shape[-1] == 15:
+ fx_fy = self.params[..., 0].reshape(B, 1, 1)
+ cx_cy = self.params[..., 1:3].reshape(B, 1, 2)
+ else:
+ fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
+ cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
+
+ uv_dist = (uv - cx_cy) / fx_fy
+
+ # Compute xr_yr using Newton's method.
+ xr_yr = uv_dist.clone() # Initial guess.
+ max_iters_tanprism = (
+ max_iters if self.use_thin_prism or self.use_tangential else 0
+ )
+
+ for _ in range(max_iters_tanprism):
+ uv_dist_est = xr_yr.clone()
+ xr = xr_yr[..., 0].reshape(B, N)
+ yr = xr_yr[..., 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[..., 0].reshape(B, N)
+ yr_sq = xr_yr_sq[..., 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+
+ if self.use_tangential:
+ # Tangential terms.
+ p0 = self.params[..., -6].reshape(B, 1)
+ p1 = self.params[..., -5].reshape(B, 1)
+ uv_dist_est[..., 0] = uv_dist_est[..., 0] + (
+ (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
+ )
+ uv_dist_est[..., 1] = uv_dist_est[..., 1] + (
+ (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
+ )
+
+ if self.use_thin_prism:
+ # Thin Prism terms.
+ s0 = self.params[..., -4].reshape(B, 1)
+ s1 = self.params[..., -3].reshape(B, 1)
+ s2 = self.params[..., -2].reshape(B, 1)
+ s3 = self.params[..., -1].reshape(B, 1)
+ rd_4 = torch.square(rd_sq)
+ uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4)
+ uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4)
+
+ # Compute the derivative of uv_dist w.r.t. xr_yr.
+ duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2)
+
+ if self.use_tangential:
+ duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1
+ offdiag = 2.0 * (xr * p1 + yr * p0)
+ duv_dist_dxr_yr[..., 0, 1] = offdiag
+ duv_dist_dxr_yr[..., 1, 0] = offdiag
+ duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0
+
+ if self.use_thin_prism:
+ xr_yr_sq_norm = xr_sq + yr_sq
+ temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm)
+ duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1)
+ duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1)
+ temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm)
+ duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2)
+ duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2)
+ # Compute 2x2 inverse manually here since torch.inverse() is very slow.
+ # Because this is slow: inv = duv_dist_dxr_yr.inverse()
+ # About a 10x reduction in speed with above line.
+ mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
+ a = mat[:, 0, 0].reshape(-1, 1, 1)
+ b = mat[:, 0, 1].reshape(-1, 1, 1)
+ c = mat[:, 1, 0].reshape(-1, 1, 1)
+ d = mat[:, 1, 1].reshape(-1, 1, 1)
+ det = 1.0 / ((a * d) - (b * c))
+ top = torch.cat([d, -b], dim=-1)
+ bot = torch.cat([-c, a], dim=-1)
+ inv = det * torch.cat([top, bot], dim=-2)
+ inv = inv.reshape(B, N, 2, 2)
+ diff = uv_dist - uv_dist_est
+ a = inv[..., 0, 0]
+ b = inv[..., 0, 1]
+ c = inv[..., 1, 0]
+ d = inv[..., 1, 1]
+ e = diff[..., 0]
+ f = diff[..., 1]
+ step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
+ # Newton step.
+ xr_yr = xr_yr + step
+
+ # Compute theta using Newton's method.
+ xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
+ th = xr_yr_norm.clone()
+ max_iters_radial = max_iters if self.use_radial else 0
+ c = (
+ torch.tensor([2.0 * i + 3 for i in range(6)], device=self.device)
+ .reshape(1, 1, 6)
+ .repeat(B, 1, 1)
+ )
+ radial_params = self.params[..., 4:10].reshape(B, 1, 6)
+
+ # Trust region parameters
+ delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius
+ delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius
+ eta = 0.1 # Acceptable reduction threshold
+
+ for i in range(max_iters_radial):
+ th_sq = th * th
+ # Compute powers of th^2 up to th^(12)
+ theta_powers = torch.cat(
+ [th_sq ** (i + 1) for i in range(6)], dim=-1
+ ) # Shape: (B, N, 6)
+
+ # Compute th_radial: radial distortion model applied to th
+ th_radial = 1.0 + torch.sum(
+ theta_powers * radial_params, dim=-1, keepdim=True
+ )
+ th_radial = th_radial * th
+
+ # Compute derivative dthd_th
+ dthd_th = 1.0 + torch.sum(
+ c * radial_params * theta_powers, dim=-1, keepdim=True
+ )
+ dthd_th = dthd_th
+
+ # Compute residual
+ residual = th_radial - xr_yr_norm # Shape: (B, N, 1)
+ residual_norm = torch.norm(residual, dim=2, keepdim=True)
+
+ # Check for convergence
+ if torch.max(torch.abs(residual)) < eps:
+ break
+
+ # Avoid division by zero by adding a small epsilon
+ safe_dthd_th = dthd_th.clone()
+ zero_derivative_mask = dthd_th.abs() < eps
+ safe_dthd_th[zero_derivative_mask] = eps
+
+ # Compute Newton's step
+ step = -residual / safe_dthd_th
+
+ # Compute predicted reduction
+ predicted_reduction = -(residual * step).sum(dim=2, keepdim=True)
+
+ # Adjust step based on trust region
+ step_norm = torch.norm(step, dim=2, keepdim=True)
+ over_trust_mask = step_norm > delta
+
+ # Scale step if it exceeds trust radius
+ step_scaled = step.clone()
+ step_scaled[over_trust_mask] = step[over_trust_mask] * (
+ delta[over_trust_mask] / step_norm[over_trust_mask]
+ )
+
+ # Update theta
+ th_new = th + step_scaled
+
+ # Compute new residual
+ th_sq_new = th_new * th_new
+ theta_powers_new = torch.cat(
+ [th_sq_new ** (j + 1) for j in range(6)], dim=-1
+ )
+ th_radial_new = 1.0 + torch.sum(
+ theta_powers_new * radial_params, dim=-1, keepdim=True
+ )
+ th_radial_new = th_radial_new * th_new
+ residual_new = th_radial_new - xr_yr_norm
+ residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True)
+
+ # Compute actual reduction
+ actual_reduction = residual_norm - residual_new_norm
+
+ # Compute ratio of actual to predicted reduction
+ rho = actual_reduction / predicted_reduction
+ rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0
+
+ # Update trust radius delta
+ delta_update_mask = rho > 0.5
+ delta[delta_update_mask] = torch.min(
+ 2.0 * delta[delta_update_mask], delta_max
+ )
+
+ delta_decrease_mask = rho < 0.2
+ delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask]
+
+ # Accept or reject the step
+ accept_step_mask = rho > eta
+ th = torch.where(accept_step_mask, th_new, th)
+
+ # Compute the ray direction using theta and xr_yr.
+ close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps)
+ ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr)
+
+ ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2)
+ ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2)
+
+ return ray
+
+
+class MEI(Camera):
+ def __init__(self, params):
+ super().__init__(params=params, K=None)
+ # fx fy cx cy k1 k2 p1 p2 xi
+ self.use_radial = self.params[..., 4:6].abs().sum() > 1e-6
+ self.use_tangential = self.params[..., 6:8].abs().sum() > 1e-6
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, uv, max_iters: int = 20):
+ eps = 1e-6
+ B, _, H, W = uv.shape
+ N = H * W
+ uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2)
+
+ k1, k2, p0, p1, xi = self.params[..., 4:9].unbind(dim=1)
+ fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
+ cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
+
+ uv_dist = (uv - cx_cy) / fx_fy
+
+ # Compute xr_yr using Newton's method.
+ xr_yr = uv_dist.clone() # Initial guess.
+ max_iters_tangential = max_iters if self.use_tangential else 0
+ for _ in range(max_iters_tangential):
+ uv_dist_est = xr_yr.clone()
+
+ # Tangential terms.
+ xr = xr_yr[..., 0]
+ yr = xr_yr[..., 1]
+ xr_yr_sq = xr_yr**2
+ xr_sq = xr_yr_sq[..., 0]
+ yr_sq = xr_yr_sq[..., 1]
+ rd_sq = xr_sq + yr_sq
+ uv_dist_est[..., 0] = uv_dist_est[..., 0] + (
+ (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
+ )
+ uv_dist_est[..., 1] = uv_dist_est[..., 1] + (
+ (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
+ )
+
+ # Compute the derivative of uv_dist w.r.t. xr_yr.
+ duv_dist_dxr_yr = torch.ones((B, N, 2, 2), device=uv.device)
+ duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1
+ offdiag = 2.0 * (xr * p1 + yr * p0)
+ duv_dist_dxr_yr[..., 0, 1] = offdiag
+ duv_dist_dxr_yr[..., 1, 0] = offdiag
+ duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0
+
+ mat = duv_dist_dxr_yr.reshape(-1, 2, 2)
+ a = mat[:, 0, 0].reshape(-1, 1, 1)
+ b = mat[:, 0, 1].reshape(-1, 1, 1)
+ c = mat[:, 1, 0].reshape(-1, 1, 1)
+ d = mat[:, 1, 1].reshape(-1, 1, 1)
+ det = 1.0 / ((a * d) - (b * c))
+ top = torch.cat([d, -b], dim=-1)
+ bot = torch.cat([-c, a], dim=-1)
+ inv = det * torch.cat([top, bot], dim=-2)
+ inv = inv.reshape(B, N, 2, 2)
+
+ diff = uv_dist - uv_dist_est
+ a = inv[..., 0, 0]
+ b = inv[..., 0, 1]
+ c = inv[..., 1, 0]
+ d = inv[..., 1, 1]
+ e = diff[..., 0]
+ f = diff[..., 1]
+ step = torch.stack([a * e + b * f, c * e + d * f], dim=-1)
+
+ # Newton step.
+ xr_yr = xr_yr + step
+
+ # Compute theta using Newton's method.
+ xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1)
+ th = xr_yr_norm.clone()
+ max_iters_radial = max_iters if self.use_radial else 0
+ for _ in range(max_iters_radial):
+ th_radial = 1.0 + k1 * torch.pow(th, 2) + k2 * torch.pow(th, 4)
+ dthd_th = 1.0 + 3.0 * k1 * torch.pow(th, 2) + 5.0 * k2 * torch.pow(th, 4)
+ th_radial = th_radial * th
+ step = (xr_yr_norm - th_radial) / dthd_th
+ # handle dthd_th close to 0.
+ step = torch.where(
+ torch.abs(dthd_th) > eps, step, torch.sign(step) * eps * 10.0
+ )
+ th = th + step
+
+ # Compute the ray direction using theta and xr_yr.
+ close_to_zero = (torch.abs(th) < eps) & (torch.abs(xr_yr_norm) < eps)
+ ray_dir = torch.where(close_to_zero, xr_yr, th * xr_yr / xr_yr_norm)
+
+ # Compute the 3D projective ray
+ rho2_u = (
+ ray_dir.norm(p=2, dim=2, keepdim=True) ** 2
+ ) # B N 1 # x_c * x_c + y_c * y_c
+ xi = xi.reshape(B, 1, 1)
+ sqrt_term = torch.sqrt(1.0 + (1.0 - xi * xi) * rho2_u)
+ P_z = 1.0 - xi * (rho2_u + 1.0) / (xi + sqrt_term)
+
+ # Special case when xi is 1.0 (unit sphere projection ??)
+ P_z = torch.where(xi == 1.0, (1.0 - rho2_u) / 2.0, P_z)
+
+ ray = torch.cat([ray_dir, P_z], dim=-1)
+ ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2)
+
+ # remove nans
+ where_nan = ray.isnan().any(dim=1, keepdim=True).repeat(1, 3, 1, 1)
+ ray = torch.where(where_nan, torch.zeros_like(ray), ray)
+
+ return ray
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, xyz):
+ is_flat = xyz.ndim == 3
+ B, N = xyz.shape[:2]
+
+ if not is_flat:
+ B, _, H, W = xyz.shape
+ N = H * W
+ xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3)
+
+ k1, k2, p0, p1, xi = self.params[..., 4:].unbind(dim=1)
+ fx_fy = self.params[..., 0:2].reshape(B, 1, 2)
+ cx_cy = self.params[..., 2:4].reshape(B, 1, 2)
+
+ norm = xyz.norm(p=2, dim=-1, keepdim=True)
+ ab = xyz[..., :-1] / (xyz[..., -1:] + xi.reshape(B, 1, 1) * norm)
+
+ # radial correction
+ r = ab.norm(dim=-1, p=2, keepdim=True)
+ k1 = self.params[..., 4].reshape(B, 1, 1)
+ k2 = self.params[..., 5].reshape(B, 1, 1)
+ # ab / r * th * (1 + k1 * (th ** 2) + k2 * (th**4))
+ # but here r = th, no spherical distortion
+ xr_yr = ab * (1 + k1 * (r**2) + k2 * (r**4))
+
+ # Tangential correction.
+ uv_dist = xr_yr
+ p0 = self.params[:, -3].reshape(B, 1)
+ p1 = self.params[:, -2].reshape(B, 1)
+ xr = xr_yr[..., 0].reshape(B, N)
+ yr = xr_yr[..., 1].reshape(B, N)
+ xr_yr_sq = torch.square(xr_yr)
+ xr_sq = xr_yr_sq[:, :, 0].reshape(B, N)
+ yr_sq = xr_yr_sq[:, :, 1].reshape(B, N)
+ rd_sq = xr_sq + yr_sq
+ uv_dist_tu = uv_dist[:, :, 0] + (
+ (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1
+ )
+ uv_dist_tv = uv_dist[:, :, 1] + (
+ (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0
+ )
+ uv_dist = torch.stack(
+ [uv_dist_tu, uv_dist_tv], dim=-1
+ ) # Avoids in-place complaint.
+
+ result = uv_dist * fx_fy + cx_cy
+
+ if not is_flat:
+ result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2)
+ invalid = (
+ (result[:, 0] < 0)
+ | (result[:, 0] > W)
+ | (result[:, 1] < 0)
+ | (result[:, 1] > H)
+ )
+ self.projection_mask = (~invalid).unsqueeze(1)
+ # creates hole in the middle... ??
+ # self.overlap_mask = self.mask_overlap_projection(result)
+
+ return result
+
+
+class BatchCamera(Camera):
+ def __init__(self, params, K, original_class, cameras):
+ super().__init__(params, K)
+ self.original_class = original_class
+ self.cameras = cameras
+
+ # Delegate these methods to original camera
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def project(self, points_3d):
+ return torch.cat(
+ [
+ camera.project(points_3d[i : i + 1])
+ for i, camera in enumerate(self.cameras)
+ ]
+ )
+
+ @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+ def unproject(self, points_2d):
+ def recursive_unproject(cameras):
+ if isinstance(cameras, list):
+ return [recursive_unproject(camera) for camera in cameras]
+ else:
+ return cameras.unproject(points_2d)
+
+ def flatten_and_cat(nested_list):
+ if isinstance(nested_list[0], list):
+ return torch.cat(
+ [flatten_and_cat(sublist) for sublist in nested_list], dim=0
+ )
+ else:
+ return torch.cat(nested_list, dim=0)
+
+ unprojected = recursive_unproject(self.cameras)
+ return flatten_and_cat(unprojected)
+
+ def crop(self, left, top, right=None, bottom=None):
+ val = torch.cat(
+ [
+ camera.crop(left, top, right, bottom)
+ for i, camera in enumerate(self.cameras)
+ ]
+ )
+ return val
+
+ def resize(self, ratio):
+ val = torch.cat([camera.resize(ratio) for i, camera in enumerate(self.cameras)])
+ return val
+
+ def reconstruct(self, depth):
+ val = torch.cat(
+ [
+ camera.reconstruct(depth[i : i + 1])
+ for i, camera in enumerate(self.cameras)
+ ]
+ )
+ return val
+
+ def get_projection_mask(self):
+ return torch.cat(
+ [camera.projection_mask for i, camera in enumerate(self.cameras)]
+ )
+
+ def to(self, device, non_blocking=False):
+ self = super().to(device, non_blocking=non_blocking)
+ self.cameras = recursive_apply(
+ self.cameras, lambda camera: camera.to(device, non_blocking=non_blocking)
+ )
+ return self
+
+ def reshape(self, *shape):
+ # Reshape the intrinsic matrix (K) and params
+ # we know that the shape of K is (..., 3, 3) and params is (..., 16)
+ reshaped_K = self.K.reshape(*shape, 3, 3)
+ reshaped_params = self.params.reshape(*shape, self.params.shape[-1])
+
+ self.cameras = np.array(self.cameras, dtype=object).reshape(shape).tolist()
+ self.original_class = (
+ np.array(self.original_class, dtype=object).reshape(shape).tolist()
+ )
+
+ # Create a new BatchCamera with reshaped K and params
+ return BatchCamera(
+ reshaped_params, reshaped_K, self.original_class, self.cameras
+ )
+
+ def get_new_fov(self, new_shape, original_shape):
+ return [
+ camera.get_new_fov(new_shape, original_shape)
+ for i, camera in enumerate(self.cameras)
+ ]
+
+ def squeeze(self, dim):
+ return BatchCamera(
+ self.params.squeeze(dim),
+ self.K.squeeze(dim),
+ squeeze_list(self.original_class, dim=dim),
+ squeeze_list(self.cameras, dim=dim),
+ )
+
+ def __getitem__(self, idx):
+ # If it's an integer index, return a single camera
+ if isinstance(idx, int):
+ return self.cameras[idx]
+
+ # If it's a slice, return a new BatchCamera with sliced cameras
+ elif isinstance(idx, slice):
+ return BatchCamera(
+ self.params[idx],
+ self.K[idx],
+ self.original_class[idx],
+ self.cameras[idx],
+ )
+
+ raise TypeError(f"Invalid index type: {type(idx)}")
+
+ def __setitem__(self, idx, value):
+ # If it's an integer index, return a single camera
+ if isinstance(idx, int):
+ self.cameras[idx] = value
+ self.params[idx, :] = 0.0
+ self.params[idx, : value.params.shape[1]] = value.params[0]
+ self.K[idx] = value.K[0]
+
+ self.original_class[idx] = getattr(
+ value, "original_class", value.__class__.__name__
+ )
+
+ # If it's a slice, return a new BatchCamera with sliced cameras
+ elif isinstance(idx, slice):
+ # Update each internal attribute using the slice
+ self.params[idx] = value.params
+ self.K[idx] = value.K
+ self.original_class[idx] = value.original_class
+ self.cameras[idx] = value.cameras
+
+ def __len__(self):
+ return len(self.cameras)
+
+ @classmethod
+ def from_camera(cls, camera):
+ return cls(camera.params, camera.K, [camera.__class__.__name__], [camera])
+
+ @property
+ def is_perspective(self):
+ return recursive_apply(self.cameras, lambda camera: isinstance(camera, Pinhole))
+
+ @property
+ def is_spherical(self):
+ return recursive_apply(
+ self.cameras, lambda camera: isinstance(camera, Spherical)
+ )
+
+ @property
+ def is_eucm(self):
+ return recursive_apply(self.cameras, lambda camera: isinstance(camera, EUCM))
+
+ @property
+ def is_fisheye(self):
+ return recursive_apply(
+ self.cameras, lambda camera: isinstance(camera, Fisheye624)
+ )
+
+ @property
+ def is_pinhole(self):
+ return recursive_apply(self.cameras, lambda camera: isinstance(camera, Pinhole))
+
+ @property
+ def hfov(self):
+ return recursive_apply(self.cameras, lambda camera: camera.hfov)
+
+ @property
+ def vfov(self):
+ return recursive_apply(self.cameras, lambda camera: camera.vfov)
+
+ @property
+ def max_fov(self):
+ return recursive_apply(self.cameras, lambda camera: camera.max_fov)
+
+
+import json
+import random
+# sampler helpers
+from math import log
+
+import torch.nn as nn
+
+
+def eucm(boundaries, mult, batch, device, dtype):
+ alpha_min, alpha_max = boundaries[0][0] * mult, boundaries[0][1] * mult
+ beta_mean, beta_std = boundaries[1][0] * mult, boundaries[1][1] * mult
+ alpha = (
+ torch.rand(batch, device=device, dtype=dtype) * (alpha_max - alpha_min)
+ + alpha_min
+ )
+ beta = F.softplus(
+ torch.randn(batch, device=device, dtype=dtype) * beta_std + beta_mean,
+ beta=log(2),
+ )
+ return alpha, beta
+
+
+def free_fisheye(boundaries, mult, batch, device, dtype):
+ k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult
+ k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult
+ k3_min, k3_max = boundaries[2][0] * mult, boundaries[2][1] * mult
+ k4_min, k4_max = boundaries[3][0] * mult, boundaries[3][1] * mult
+ k5_min, k5_max = boundaries[4][0] * mult, boundaries[4][1] * mult
+ k6_min, k6_max = boundaries[5][0] * mult, boundaries[5][1] * mult
+ p1_min, p1_max = boundaries[6][0] * mult, boundaries[6][1] * mult
+ p2_min, p2_max = boundaries[7][0] * mult, boundaries[7][1] * mult
+ s1_min, s1_max = boundaries[8][0] * mult, boundaries[8][1] * mult
+ s2_min, s2_max = boundaries[9][0] * mult, boundaries[9][1] * mult
+ s3_min, s3_max = boundaries[10][0] * mult, boundaries[10][1] * mult
+ s4_min, s4_max = boundaries[11][0] * mult, boundaries[11][1] * mult
+ k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min
+ k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min
+ k3 = torch.rand(batch, device=device, dtype=dtype) * (k3_max - k3_min) + k3_min
+ k4 = torch.rand(batch, device=device, dtype=dtype) * (k4_max - k4_min) + k4_min
+ k5 = torch.rand(batch, device=device, dtype=dtype) * (k5_max - k5_min) + k5_min
+ k6 = torch.rand(batch, device=device, dtype=dtype) * (k6_max - k6_min) + k6_min
+ p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min
+ p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min
+ s1 = torch.rand(batch, device=device, dtype=dtype) * (s1_max - s1_min) + s1_min
+ s2 = torch.rand(batch, device=device, dtype=dtype) * (s2_max - s2_min) + s2_min
+ s3 = torch.rand(batch, device=device, dtype=dtype) * (s3_max - s3_min) + s3_min
+ s4 = torch.rand(batch, device=device, dtype=dtype) * (s4_max - s4_min) + s4_min
+ return k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4
+
+
+def mei(boundaries, mult, batch, device, dtype):
+ k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult
+ k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult
+ p1_min, p1_max = boundaries[2][0] * mult, boundaries[2][1] * mult
+ p2_min, p2_max = boundaries[3][0] * mult, boundaries[3][1] * mult
+ xi_min, xi_max = boundaries[4][0] * mult, boundaries[4][1] * mult
+ k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min
+ k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min
+ p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min
+ p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min
+ xi = torch.rand(batch, device=device, dtype=dtype) * (xi_max - xi_min) + xi_min
+ return k1, k2, p1, p2, xi
+
+
+def consistent_fisheye(boundaries, mult, batch, device, dtype):
+ sign = random.choice([-1, 1])
+ return free_fisheye(boundaries, sign * mult, batch, device, dtype)
+
+
+def invert_fisheye(boundaries, mult, batch, device, dtype):
+ k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult
+ k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult
+ k3_min, k3_max = boundaries[2][0] * mult, boundaries[2][1] * mult
+ k4_min, k4_max = boundaries[3][0] * mult, boundaries[3][1] * mult
+ k5_min, k5_max = boundaries[4][0] * mult, boundaries[4][1] * mult
+ k6_min, k6_max = boundaries[5][0] * mult, boundaries[5][1] * mult
+ p1_min, p1_max = boundaries[6][0] * mult, boundaries[6][1] * mult
+ p2_min, p2_max = boundaries[7][0] * mult, boundaries[7][1] * mult
+ s1_min, s1_max = boundaries[8][0] * mult, boundaries[8][1] * mult
+ s2_min, s2_max = boundaries[9][0] * mult, boundaries[9][1] * mult
+ s3_min, s3_max = boundaries[10][0] * mult, boundaries[10][1] * mult
+ s4_min, s4_max = boundaries[11][0] * mult, boundaries[11][1] * mult
+
+ sign = random.choice([-1, 1])
+ k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min
+ k1 = sign * k1
+ k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min
+ k2 = -1 * sign * k2
+ k3 = torch.rand(batch, device=device, dtype=dtype) * (k3_max - k3_min) + k3_min
+ k3 = sign * k3
+ k4 = torch.rand(batch, device=device, dtype=dtype) * (k4_max - k4_min) + k4_min
+ k4 = -1 * sign * k4
+ k5 = torch.rand(batch, device=device, dtype=dtype) * (k5_max - k5_min) + k5_min
+ k5 = sign * k5
+ k6 = torch.rand(batch, device=device, dtype=dtype) * (k6_max - k6_min) + k6_min
+ k6 = -1 * sign * k6
+
+ p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min
+ p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min
+ s1 = torch.rand(batch, device=device, dtype=dtype) * (s1_max - s1_min) + s1_min
+ s2 = torch.rand(batch, device=device, dtype=dtype) * (s2_max - s2_min) + s2_min
+ s3 = torch.rand(batch, device=device, dtype=dtype) * (s3_max - s3_min) + s3_min
+ s4 = torch.rand(batch, device=device, dtype=dtype) * (s4_max - s4_min) + s4_min
+ return k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4
+
+
+class CameraSampler(nn.Module):
+ def __init__(self):
+ super().__init__()
+ with open("camera_sampler.json", "r") as f:
+ config = json.load(f)
+ self.camera_type = config["type"]
+ self.sampling_fn = config["fn"]
+ self.boundaries = nn.ParameterList(
+ [
+ nn.Parameter(torch.tensor(x), requires_grad=False)
+ for x in config["boundaries"]
+ ]
+ )
+ self.probs = nn.Parameter(torch.tensor(config["probs"]), requires_grad=False)
+
+ def forward(self, fx, fy, cx, cy, mult, ratio, H):
+ selected_idx = torch.multinomial(self.probs, num_samples=1)
+ device, dtype = fx.device, fx.dtype
+
+ selected_camera = self.camera_type[selected_idx]
+ selected_sampling_fn = self.sampling_fn[selected_idx]
+ selected_boundaries = self.boundaries[selected_idx]
+ if "Fisheye" in selected_camera or "OPENCV" in selected_camera:
+ mult = mult * ratio
+
+ params = eval(selected_sampling_fn)(
+ selected_boundaries, mult, len(fx), device, dtype
+ )
+ params = torch.stack([fx, fy, cx, cy, *params], dim=1)
+ camera = eval(selected_camera)(params=params)
+ return camera
diff --git a/unik3d/utils/chamfer_distance.py b/unik3d/utils/chamfer_distance.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f0068f9bfef86acd62d86c8774be3464f8e5636
--- /dev/null
+++ b/unik3d/utils/chamfer_distance.py
@@ -0,0 +1,158 @@
+import warnings
+from typing import Union
+
+import torch
+
+try:
+ from unik3d.ops.knn import knn_points
+except ImportError as e:
+ warnings.warn(
+ "!! To run evaluation you need KNN. Please compile KNN: "
+ "`cd unik3d/ops/knn && bash compile.sh`."
+ )
+ knn_points = lambda x: x
+
+
+def _validate_chamfer_reduction_inputs(
+ batch_reduction: Union[str, None], point_reduction: str
+):
+ """Check the requested reductions are valid.
+
+ Args:
+ batch_reduction: Reduction operation to apply for the loss across the
+ batch, can be one of ["mean", "sum"] or None.
+ point_reduction: Reduction operation to apply for the loss across the
+ points, can be one of ["mean", "sum"].
+ """
+ if batch_reduction is not None and batch_reduction not in ["mean", "sum"]:
+ raise ValueError('batch_reduction must be one of ["mean", "sum"] or None')
+ if point_reduction not in ["mean", "sum"]:
+ raise ValueError('point_reduction must be one of ["mean", "sum"]')
+
+
+def _handle_pointcloud_input(
+ points: torch.Tensor,
+ lengths: Union[torch.Tensor, None],
+ normals: Union[torch.Tensor, None],
+):
+ """
+ If points is an instance of Pointclouds, retrieve the padded points tensor
+ along with the number of points per batch and the padded normals.
+ Otherwise, return the input points (and normals) with the number of points per cloud
+ set to the size of the second dimension of `points`.
+ """
+ if points.ndim != 3:
+ raise ValueError("Expected points to be of shape (N, P, D)")
+ X = points
+ if lengths is not None and (lengths.ndim != 1 or lengths.shape[0] != X.shape[0]):
+ raise ValueError("Expected lengths to be of shape (N,)")
+ if lengths is None:
+ lengths = torch.full(
+ (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device
+ )
+ if normals is not None and normals.ndim != 3:
+ raise ValueError("Expected normals to be of shape (N, P, 3")
+
+ return X, lengths, normals
+
+
+class ChamferDistance(torch.nn.Module):
+ def forward(
+ self,
+ x,
+ y,
+ x_lengths=None,
+ y_lengths=None,
+ x_normals=None,
+ y_normals=None,
+ weights=None,
+ batch_reduction: Union[str, None] = "mean",
+ point_reduction: str = "mean",
+ ):
+ """
+ Chamfer distance between two pointclouds x and y.
+
+ Args:
+ x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing
+ a batch of point clouds with at most P1 points in each batch element,
+ batch size N and feature dimension D.
+ y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing
+ a batch of point clouds with at most P2 points in each batch element,
+ batch size N and feature dimension D.
+ x_lengths: Optional LongTensor of shape (N,) giving the number of points in each
+ cloud in x.
+ y_lengths: Optional LongTensor of shape (N,) giving the number of points in each
+ cloud in x.
+ x_normals: Optional FloatTensor of shape (N, P1, D).
+ y_normals: Optional FloatTensor of shape (N, P2, D).
+ weights: Optional FloatTensor of shape (N,) giving weights for
+ batch elements for reduction operation.
+ batch_reduction: Reduction operation to apply for the loss across the
+ batch, can be one of ["mean", "sum"] or None.
+ point_reduction: Reduction operation to apply for the loss across the
+ points, can be one of ["mean", "sum"].
+
+ Returns:
+ 2-element tuple containing
+
+ - **loss**: Tensor giving the reduced distance between the pointclouds
+ in x and the pointclouds in y.
+ - **loss_normals**: Tensor giving the reduced cosine distance of normals
+ between pointclouds in x and pointclouds in y. Returns None if
+ x_normals and y_normals are None.
+ """
+ _validate_chamfer_reduction_inputs(batch_reduction, point_reduction)
+
+ x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals)
+ y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals)
+
+ return_normals = x_normals is not None and y_normals is not None
+
+ N, P1, D = x.shape
+ P2 = y.shape[1]
+
+ # Check if inputs are heterogeneous and create a lengths mask.
+ is_x_heterogeneous = (x_lengths != P1).any()
+ is_y_heterogeneous = (y_lengths != P2).any()
+ x_mask = (
+ torch.arange(P1, device=x.device)[None] >= x_lengths[:, None]
+ ) # shape [N, P1]
+ y_mask = (
+ torch.arange(P2, device=y.device)[None] >= y_lengths[:, None]
+ ) # shape [N, P2]
+
+ if y.shape[0] != N or y.shape[2] != D:
+ raise ValueError("y does not have the correct shape.")
+ if weights is not None:
+ if weights.size(0) != N:
+ raise ValueError("weights must be of shape (N,).")
+ if not (weights >= 0).all():
+ raise ValueError("weights cannot be negative.")
+ if weights.sum() == 0.0:
+ weights = weights.view(N, 1)
+ if batch_reduction in ["mean", "sum"]:
+ return (
+ (x.sum((1, 2)) * weights).sum() * 0.0,
+ (x.sum((1, 2)) * weights).sum() * 0.0,
+ )
+ return (
+ (x.sum((1, 2)) * weights) * 0.0,
+ (x.sum((1, 2)) * weights) * 0.0,
+ )
+
+ x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1)
+ y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1)
+
+ cham_x = x_nn.dists[..., 0] # (N, P1)
+ cham_y = y_nn.dists[..., 0] # (N, P2)
+
+ if is_x_heterogeneous:
+ cham_x[x_mask] = 0.0
+ if is_y_heterogeneous:
+ cham_y[y_mask] = 0.0
+
+ if weights is not None:
+ cham_x *= weights.view(N, 1)
+ cham_y *= weights.view(N, 1)
+
+ return cham_x, cham_y, x_nn.idx[..., -1], y_nn.idx[..., -1]
diff --git a/unik3d/utils/constants.py b/unik3d/utils/constants.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8c70b011e18eb9c53ffd4c075afab8d98b7123
--- /dev/null
+++ b/unik3d/utils/constants.py
@@ -0,0 +1,42 @@
+import math
+
+import torch
+
+NAME_PAGE = "submission3"
+OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073)
+OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711)
+IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DATASET_STD = (0.229, 0.224, 0.225)
+DEPTH_BINS = torch.cat(
+ (
+ torch.logspace(math.log10(0.1), math.log10(180.0), steps=512),
+ torch.tensor([260.0]),
+ ),
+ dim=0,
+)
+LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1)
+LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1)
+
+VERBOSE = False
+OUTDOOR_DATASETS = __all__ = [
+ "Argoverse",
+ "DDAD",
+ "DrivingStereo",
+ "Mapillary",
+ "BDD",
+ "A2D2",
+ "Nuscenes",
+ "Cityscape",
+ "KITTI",
+ "DENSE",
+ "DIML",
+ "NianticMapFree",
+ "DL3DV",
+ "KITTIMulti",
+ "Waymo",
+ "Argoverse2",
+ "BEDLAM",
+ "NeRDS360",
+ "BlendedMVG",
+ "MegaDepthS",
+]
diff --git a/unik3d/utils/coordinate.py b/unik3d/utils/coordinate.py
new file mode 100644
index 0000000000000000000000000000000000000000..a09129982737d288db26a3717337af136d163c32
--- /dev/null
+++ b/unik3d/utils/coordinate.py
@@ -0,0 +1,27 @@
+import torch
+
+
+def coords_grid(b, h, w, homogeneous=False, device=None, noisy=False):
+ pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device)
+ pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device)
+ if noisy: # \pm 0.5px noise
+ pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
+ pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
+
+ stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()]
+ if homogeneous:
+ ones = torch.ones_like(stacks[0]) # [H, W]
+ stacks.append(ones)
+ grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W]
+ grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W]
+ if device is not None:
+ grid = grid.to(device)
+
+ return grid
+
+
+def normalize_coords(coords, h, w):
+ c = torch.tensor([(w - 1) / 2.0, (h - 1) / 2.0], device=coords.device).view(
+ 1, 2, 1, 1
+ )
+ return (coords - c) / c
diff --git a/unik3d/utils/distributed.py b/unik3d/utils/distributed.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bf45f02fb7e0a6420373bf07d926303d62b7e50
--- /dev/null
+++ b/unik3d/utils/distributed.py
@@ -0,0 +1,247 @@
+import os
+import pickle
+import platform
+import subprocess
+import warnings
+
+import cv2
+import torch
+import torch.utils.data.distributed
+from torch import distributed as dist
+from torch import multiprocessing as mp
+
+_LOCAL_PROCESS_GROUP = None
+
+
+def is_dist_avail_and_initialized():
+ if not dist.is_available():
+ return False
+ if not dist.is_initialized():
+ return False
+ return True
+
+
+def get_rank():
+ if not is_dist_avail_and_initialized():
+ return 0
+ return dist.get_rank()
+
+
+def get_local_rank() -> int:
+ """
+ Returns:
+ The rank of the current process within the local (per-machine) process group.
+ """
+ if not is_dist_avail_and_initialized():
+ return 0
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_rank(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_local_size() -> int:
+ """
+ Returns:
+ The size of the per-machine process group,
+ i.e. the number of processes per machine.
+ """
+ if not is_dist_avail_and_initialized():
+ return 1
+ assert _LOCAL_PROCESS_GROUP is not None
+ return dist.get_world_size(group=_LOCAL_PROCESS_GROUP)
+
+
+def get_world_size():
+ if not is_dist_avail_and_initialized():
+ return 1
+ return dist.get_world_size()
+
+
+def barrier():
+ if not is_dist_avail_and_initialized():
+ return
+ dist.barrier()
+
+
+def is_main_process():
+ return get_rank() == 0
+
+
+def is_rank_zero(args):
+ return args.rank == 0
+
+
+def get_dist_info():
+ if dist.is_available() and dist.is_initialized():
+ rank = dist.get_rank()
+ world_size = dist.get_world_size()
+ else:
+ rank = 0
+ world_size = 1
+ return rank, world_size
+
+
+def setup_multi_processes(cfg):
+ """Setup multi-processing environment variables."""
+ # set multi-process start method as `fork` to speed up the training
+ if platform.system() != "Windows":
+ mp_start_method = cfg.get("mp_start_method", "fork")
+ current_method = mp.get_start_method(allow_none=True)
+ if current_method is not None and current_method != mp_start_method:
+ warnings.warn(
+ f"Multi-processing start method `{mp_start_method}` is "
+ f"different from the previous setting `{current_method}`."
+ f"It will be force set to `{mp_start_method}`. You can change "
+ f"this behavior by changing `mp_start_method` in your config."
+ )
+ mp.set_start_method(mp_start_method, force=True)
+
+ # disable opencv multithreading to avoid system being overloaded
+ # opencv_num_threads = cfg.get('opencv_num_threads', 0)
+ # cv2.setNumThreads(opencv_num_threads)
+
+ # setup OMP threads
+ # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa
+ # workers_per_gpu = cfg.get('workers_per_gpu', 4)
+
+ # if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
+ # omp_num_threads = 1
+ # warnings.warn(
+ # f'Setting OMP_NUM_THREADS environment variable for each process '
+ # f'to be {omp_num_threads} in default, to avoid your system being '
+ # f'overloaded, please further tune the variable for optimal '
+ # f'performance in your application as needed.')
+ # os.environ['OMP_NUM_THREADS'] = str(omp_num_threads)
+
+ # setup MKL threads
+ # if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1:
+ # mkl_num_threads = os.environ.get('OMP_NUM_THREADS', 1)
+ # warnings.warn(
+ # f'Setting MKL_NUM_THREADS environment variable for each process '
+ # f'to be {mkl_num_threads} in default, to avoid your system being '
+ # f'overloaded, please further tune the variable for optimal '
+ # f'performance in your application as needed.')
+ # os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads)
+
+
+def setup_slurm(backend: str, port: str) -> None:
+ proc_id = int(os.environ["SLURM_PROCID"])
+ ntasks = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+
+ num_gpus = torch.cuda.device_count()
+
+ torch.cuda.set_device(proc_id % num_gpus)
+ if "MASTER_ADDR" not in os.environ:
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+ os.environ["MASTER_PORT"] = str(port)
+ os.environ["MASTER_ADDR"] = addr
+ else:
+ addr = os.environ["MASTER_ADDR"]
+ os.environ["WORLD_SIZE"] = str(ntasks)
+ os.environ["LOCAL_RANK"] = str(proc_id % num_gpus)
+ os.environ["RANK"] = str(proc_id)
+ print(
+ proc_id,
+ ntasks,
+ num_gpus,
+ proc_id % num_gpus,
+ node_list,
+ addr,
+ os.environ["MASTER_PORT"],
+ os.system("nvidia-smi -L"),
+ )
+ dist.init_process_group(backend, rank=proc_id, world_size=ntasks)
+
+
+def sync_tensor_across_gpus(t, dim=0, cat=True):
+ if t is None or not (dist.is_available() and dist.is_initialized()):
+ return t
+ t = torch.atleast_1d(t)
+ group = dist.group.WORLD
+ group_size = torch.distributed.get_world_size(group)
+
+ local_size = torch.tensor(t.size(dim), device=t.device)
+ all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)]
+ dist.all_gather(all_sizes, local_size)
+ max_size = max(all_sizes)
+ size_diff = max_size.item() - local_size.item()
+ if size_diff:
+ padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype)
+ t = torch.cat((t, padding))
+
+ gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)]
+ dist.all_gather(gather_t_tensor, t)
+ all_ts = []
+ for t, size in zip(gather_t_tensor, all_sizes):
+ all_ts.append(t[:size])
+ if cat:
+ return torch.cat(all_ts, dim=0)
+ return all_ts
+
+
+def sync_string_across_gpus(keys: list[str], device, dim=0):
+ keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL)
+ keys_serialized_tensor = (
+ torch.frombuffer(keys_serialized, dtype=torch.uint8).clone().to(device)
+ )
+ keys_serialized_tensor = sync_tensor_across_gpus(
+ keys_serialized_tensor, dim=0, cat=False
+ )
+ keys = [
+ key
+ for keys in keys_serialized_tensor
+ for key in pickle.loads(bytes(keys.cpu().tolist()))
+ ]
+ return keys
+
+
+def create_local_process_group() -> None:
+ num_workers_per_machine = torch.cuda.device_count()
+ global _LOCAL_PROCESS_GROUP
+ assert _LOCAL_PROCESS_GROUP is None
+ assert get_world_size() % num_workers_per_machine == 0
+ num_machines = get_world_size() // num_workers_per_machine
+ machine_rank = get_rank() // num_workers_per_machine
+ for i in range(num_machines):
+ ranks_on_i = list(
+ range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine)
+ )
+ pg = dist.new_group(ranks_on_i)
+ if i == machine_rank:
+ _LOCAL_PROCESS_GROUP = pg
+
+
+def _get_global_gloo_group():
+ if dist.get_backend() == "nccl":
+ return dist.new_group(backend="gloo")
+ else:
+ return dist.group.WORLD
+
+
+def all_gather(data, group=None):
+ if get_world_size() == 1:
+ return [data]
+ if group is None:
+ group = (
+ _get_global_gloo_group()
+ ) # use CPU group by default, to reduce GPU RAM usage.
+ world_size = dist.get_world_size(group)
+ if world_size == 1:
+ return [data]
+
+ output = [None for _ in range(world_size)]
+ dist.all_gather_object(output, data, group=group)
+ return output
+
+
+def local_broadcast_process_authkey():
+ if get_local_size() == 1:
+ return
+ local_rank = get_local_rank()
+ authkey = bytes(mp.current_process().authkey)
+ all_keys = all_gather(authkey)
+ local_leader_key = all_keys[get_rank() - local_rank]
+ if authkey != local_leader_key:
+ # print("Process authkey is different from the key of local leader! workers are launched independently ??")
+ # print("Overwriting local authkey ...")
+ mp.current_process().authkey = local_leader_key
diff --git a/unik3d/utils/ema_torch.py b/unik3d/utils/ema_torch.py
new file mode 100644
index 0000000000000000000000000000000000000000..a50685b841a3d86bf32f2fdd73c64700233f53a8
--- /dev/null
+++ b/unik3d/utils/ema_torch.py
@@ -0,0 +1,340 @@
+from __future__ import division, unicode_literals
+
+import contextlib
+import copy
+import weakref
+from math import tanh
+from typing import Iterable, Optional
+
+import torch
+
+
+class DummyExponentialMovingAverage:
+ def __init__(self, *args, **kwargs):
+ pass
+
+ def _get_parameters(self, *args, **kwargs):
+ pass
+
+ def get_current_decay(self, *args, **kwargs):
+ pass
+
+ def update(self, *args, **kwargs):
+ pass
+
+ def copy_to(self, *args, **kwargs):
+ pass
+
+ def store(self, *args, **kwargs):
+ return
+
+ def restore(self, *args, **kwargs):
+ return
+
+ @contextlib.contextmanager
+ def average_parameters(self, *args, **kwargs):
+ try:
+ yield
+ finally:
+ pass
+
+ def to(self, *args, **kwargs):
+ pass
+
+ def state_dict(self, *args, **kwargs):
+ pass
+
+ def load_state_dict(self, *args, **kwargs):
+ pass
+
+
+class ExponentialMovingAverage:
+ """
+ Maintains (exponential) moving average of a set of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter` (typically from
+ `model.parameters()`).
+ Note that EMA is computed on *all* provided parameters,
+ regardless of whether or not they have `requires_grad = True`;
+ this allows a single EMA object to be consistantly used even
+ if which parameters are trainable changes step to step.
+
+ If you want to some parameters in the EMA, do not pass them
+ to the object in the first place. For example:
+
+ ExponentialMovingAverage(
+ parameters=[p for p in model.parameters() if p.requires_grad],
+ decay=0.9
+ )
+
+ will ignore parameters that do not require grad.
+
+ decay: The exponential decay.
+
+ use_num_updates: Whether to use number of updates when computing
+ averages.
+ """
+
+ def __init__(
+ self,
+ parameters: Iterable[torch.nn.Parameter],
+ decay: float,
+ use_num_updates: bool = True,
+ update_after_step: int = 10000,
+ tau: int = 20000,
+ switch: bool = False,
+ save_memory: bool = True,
+ ):
+ if decay < 0.0 or decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+ self.decay = decay
+ self.switch = switch # fi keeping EMA params in model after epochs
+ self.num_updates = 0 if use_num_updates else None
+ parameters = list(parameters)
+ self.shadow_params = [p.clone().detach() for p in parameters]
+ self.collected_params = None
+ # By maintaining only a weakref to each parameter,
+ # we maintain the old GC behaviour of ExponentialMovingAverage:
+ # if the model goes out of scope but the ExponentialMovingAverage
+ # is kept, no references to the model or its parameters will be
+ # maintained, and the model will be cleaned up.
+ self._params_refs = [weakref.ref(p) for p in parameters]
+ self.update_after_step = update_after_step
+ self.tau = tau
+ self.save_memory = save_memory
+
+ def _get_parameters(
+ self, parameters: Optional[Iterable[torch.nn.Parameter]]
+ ) -> Iterable[torch.nn.Parameter]:
+ if parameters is None:
+ parameters = [p() for p in self._params_refs]
+ if any(p is None for p in parameters):
+ raise ValueError(
+ "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);"
+ " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected."
+ )
+ return parameters
+ else:
+ parameters = list(parameters)
+ if len(parameters) != len(self.shadow_params):
+ raise ValueError(
+ "Number of parameters passed as argument is different "
+ "from number of shadow parameters maintained by this "
+ "ExponentialMovingAverage"
+ )
+ return parameters
+
+ def get_current_decay(self):
+ epoch = max(self.num_updates - self.update_after_step - 1, 0.0)
+ if epoch <= 0:
+ return 0.0
+ value = tanh(epoch / self.tau) * self.decay
+ return value
+
+ def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
+ """
+ Update currently maintained parameters.
+
+ Call this every time the parameters are updated, such as the result of
+ the `optimizer.step()` call.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; usually the same set of
+ parameters used to initialize this object. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ decay = self.get_current_decay()
+ if self.num_updates is not None:
+ self.num_updates += 1
+
+ one_minus_decay = 1.0 - decay
+ with torch.no_grad():
+ for s_param, param in zip(self.shadow_params, parameters):
+ tmp = s_param - param
+ # tmp will be a new tensor so we can do in-place
+ tmp.mul_(one_minus_decay)
+ s_param.sub_(tmp)
+
+ def copy_to(
+ self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
+ """
+ Copy current averaged parameters into given collection of parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored moving averages. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ for s_param, param in zip(self.shadow_params, parameters):
+ param.data.copy_(s_param.data)
+
+ def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None:
+ """
+ Save the current parameters for restoring later.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ temporarily stored. If `None`, the parameters of with which this
+ `ExponentialMovingAverage` was initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.collected_params = [param.detach().clone() for param in parameters]
+
+ def restore(
+ self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ) -> None:
+ """
+ Restore the parameters stored with the `store` method.
+ Useful to validate the model with EMA parameters without affecting the
+ original optimization process. Store the parameters before the
+ `copy_to` method. After validation (or model saving), use this to
+ restore the former parameters.
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ if self.collected_params is None:
+ raise RuntimeError(
+ "This ExponentialMovingAverage has no `store()`ed weights "
+ "to `restore()`"
+ )
+ parameters = self._get_parameters(parameters)
+ for c_param, param in zip(self.collected_params, parameters):
+ param.data.copy_(c_param.data)
+
+ @contextlib.contextmanager
+ def average_parameters(
+ self, parameters: Optional[Iterable[torch.nn.Parameter]] = None
+ ):
+ r"""
+ Context manager for validation/inference with averaged parameters.
+
+ Equivalent to:
+
+ ema.store()
+ ema.copy_to()
+ try:
+ ...
+ finally:
+ ema.restore()
+
+ Args:
+ parameters: Iterable of `torch.nn.Parameter`; the parameters to be
+ updated with the stored parameters. If `None`, the
+ parameters with which this `ExponentialMovingAverage` was
+ initialized will be used.
+ """
+ parameters = self._get_parameters(parameters)
+ self.store(parameters)
+ self.copy_to(parameters)
+ try:
+ yield
+ finally:
+ if not self.switch:
+ self.restore(parameters)
+ if self.save_memory:
+ self.collected_params = None
+
+ def to(self, device=None, dtype=None) -> None:
+ r"""Move internal buffers of the ExponentialMovingAverage to `device`.
+
+ Args:
+ device: like `device` argument to `torch.Tensor.to`
+ """
+ # .to() on the tensors handles None correctly
+ self.shadow_params = [
+ (
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ )
+ for p in self.shadow_params
+ ]
+ if self.collected_params is not None:
+ self.collected_params = [
+ (
+ p.to(device=device, dtype=dtype)
+ if p.is_floating_point()
+ else p.to(device=device)
+ )
+ for p in self.collected_params
+ ]
+ return
+
+ def state_dict(self) -> dict:
+ r"""Returns the state of the ExponentialMovingAverage as a dict."""
+ # Following PyTorch conventions, references to tensors are returned:
+ # "returns a reference to the state and not its copy!" -
+ # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
+ return {
+ "decay": self.decay,
+ "num_updates": self.num_updates,
+ "shadow_params": self.shadow_params,
+ "collected_params": self.collected_params,
+ }
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""Loads the ExponentialMovingAverage state.
+
+ Args:
+ state_dict (dict): EMA state. Should be an object returned
+ from a call to :meth:`state_dict`.
+ """
+ # deepcopy, to be consistent with module API
+ state_dict = copy.deepcopy(state_dict)
+ self.decay = state_dict["decay"]
+ if self.decay < 0.0 or self.decay > 1.0:
+ raise ValueError("Decay must be between 0 and 1")
+ self.num_updates = state_dict["num_updates"]
+ assert self.num_updates is None or isinstance(
+ self.num_updates, int
+ ), "Invalid num_updates"
+
+ self.shadow_params = state_dict["shadow_params"]
+ assert isinstance(self.shadow_params, list), "shadow_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.shadow_params
+ ), "shadow_params must all be Tensors"
+
+ self.collected_params = state_dict["collected_params"]
+ if self.collected_params is not None:
+ assert isinstance(
+ self.collected_params, list
+ ), "collected_params must be a list"
+ assert all(
+ isinstance(p, torch.Tensor) for p in self.collected_params
+ ), "collected_params must all be Tensors"
+ assert len(self.collected_params) == len(
+ self.shadow_params
+ ), "collected_params and shadow_params had different lengths"
+
+ if len(self.shadow_params) == len(self._params_refs):
+ # Consistant with torch.optim.Optimizer, cast things to consistant
+ # device and dtype with the parameters
+ params = [p() for p in self._params_refs]
+ # If parameters have been garbage collected, just load the state
+ # we were given without change.
+ if not any(p is None for p in params):
+ # ^ parameter references are still good
+ for i, p in enumerate(params):
+ self.shadow_params[i] = self.shadow_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ if self.collected_params is not None:
+ self.collected_params[i] = self.collected_params[i].to(
+ device=p.device, dtype=p.dtype
+ )
+ else:
+ raise ValueError(
+ "Tried to `load_state_dict()` with the wrong number of "
+ "parameters in the saved state."
+ )
diff --git a/unik3d/utils/evaluation_depth.py b/unik3d/utils/evaluation_depth.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ae9feb66384b73832ecc20471513be1a897e80d
--- /dev/null
+++ b/unik3d/utils/evaluation_depth.py
@@ -0,0 +1,337 @@
+from collections import defaultdict
+from functools import partial
+
+import torch
+import torch.nn.functional as F
+import torchvision.transforms.v2.functional as TF
+from PIL import Image
+
+from unik3d.utils.chamfer_distance import ChamferDistance
+from unik3d.utils.constants import DEPTH_BINS
+
+chamfer_cls = ChamferDistance()
+
+
+def kl_div(gt, pred, eps: float = 1e-6):
+ depth_bins = DEPTH_BINS.to(gt.device)
+ gt, pred = torch.bucketize(
+ gt, boundaries=depth_bins, out_int32=True
+ ), torch.bucketize(pred, boundaries=depth_bins, out_int32=True)
+ gt = torch.bincount(gt, minlength=len(depth_bins) + 1)
+ pred = torch.bincount(pred, minlength=len(depth_bins) + 1)
+ gt = gt / gt.sum()
+ pred = pred / pred.sum()
+ return torch.sum(gt * (torch.log(gt + eps) - torch.log(pred + eps)))
+
+
+def chamfer_dist(tensor1, tensor2):
+ x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
+ y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
+ dist1, dist2, idx1, idx2 = chamfer_cls(
+ tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
+ )
+ return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2
+
+
+def auc(tensor1, tensor2, thresholds):
+ x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
+ y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
+ dist1, dist2, idx1, idx2 = chamfer_cls(
+ tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
+ )
+ # compute precision recall
+ precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
+ recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
+ auc_value = torch.trapz(
+ torch.tensor(precisions, device=tensor1.device),
+ torch.tensor(recalls, device=tensor1.device),
+ )
+ return auc_value
+
+
+def delta(tensor1, tensor2, exponent):
+ inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
+ return (inlier < 1.25**exponent).to(torch.float32).mean()
+
+
+def rho(tensor1, tensor2):
+ min_deg = 0.5
+ tensor1_norm = tensor1 / torch.norm(tensor1, dim=-1, p=2, keepdim=True).clip(
+ min=1e-6
+ )
+ tensor2_norm = tensor2 / torch.norm(tensor2, dim=-1, p=2, keepdim=True).clip(
+ min=1e-6
+ )
+ max_polar_angle = torch.arccos(tensor1_norm[..., 2]).max() * 180.0 / torch.pi
+
+ if max_polar_angle < 100.0:
+ threshold = 15.0
+ elif max_polar_angle < 190.0:
+ threshold = 20.0
+ else:
+ threshold = 30.0
+
+ acos_clip = 1 - 1e-6
+ # inner prod of norm vector -> cosine
+ angular_error = (
+ torch.arccos(
+ (tensor1_norm * tensor2_norm)
+ .sum(dim=-1)
+ .clip(min=-acos_clip, max=acos_clip)
+ )
+ * 180.0
+ / torch.pi
+ )
+ thresholds = torch.linspace(min_deg, threshold, steps=100, device=tensor1.device)
+ y_values = [
+ (angular_error.abs() <= th).to(torch.float32).mean() for th in thresholds
+ ]
+ auc_value = torch.trapz(
+ torch.tensor(y_values, device=tensor1.device), thresholds
+ ) / (threshold - min_deg)
+ return auc_value
+
+
+def tau(tensor1, tensor2, perc):
+ inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1))
+ return (inlier < (1.0 + perc)).to(torch.float32).mean()
+
+
+@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+def ssi(tensor1, tensor2, qtl=0.05):
+ stability_mat = 1e-9 * torch.eye(2, device=tensor1.device)
+ error = (tensor1 - tensor2).abs()
+ mask = error < torch.quantile(error, 1 - qtl)
+ tensor1_mask = tensor1.to(torch.float32)[mask]
+ tensor2_mask = tensor2.to(torch.float32)[mask]
+ stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
+ tensor2_one = torch.stack([tensor2_mask, torch.ones_like(tensor2_mask)], dim=1)
+ A = torch.matmul(tensor2_one.T, tensor2_one) + stability_mat
+ det_A = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0]
+ A_inv = (1.0 / det_A) * torch.tensor(
+ [[A[1, 1], -A[0, 1]], [-A[1, 0], A[0, 0]]], device=tensor1.device
+ )
+ b = tensor2_one.T @ tensor1_mask.unsqueeze(1)
+ scale_shift = A_inv @ b
+ scale, shift = scale_shift.squeeze().chunk(2, dim=0)
+ return tensor2 * scale + shift
+
+
+def si(tensor1, tensor2):
+ return tensor2 * torch.median(tensor1) / torch.median(tensor2)
+
+
+def arel(tensor1, tensor2):
+ tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2)
+ return (torch.abs(tensor1 - tensor2) / tensor1).mean()
+
+
+def d_auc(tensor1, tensor2):
+ exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device)
+ deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents]
+ return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0
+
+
+def f1_score(tensor1, tensor2, thresholds):
+ x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device)
+ y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device)
+ dist1, dist2, idx1, idx2 = chamfer_cls(
+ tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths
+ )
+ # compute precision recall
+ precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds]
+ recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds]
+ precisions = torch.tensor(precisions, device=tensor1.device)
+ recalls = torch.tensor(recalls, device=tensor1.device)
+ f1_thresholds = 2 * precisions * recalls / (precisions + recalls)
+ f1_thresholds = torch.where(
+ torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds
+ )
+ f1_value = torch.trapz(f1_thresholds) / len(thresholds)
+ return f1_value
+
+
+def f1_score_si(tensor1, tensor2, thresholds):
+ tensor2 = (
+ tensor2
+ * torch.median(tensor1.norm(dim=-1))
+ / torch.median(tensor2.norm(dim=-1))
+ )
+ f1_value = f1_score(tensor1, tensor2, thresholds)
+ return f1_value
+
+
+DICT_METRICS = {
+ "d1": partial(delta, exponent=1.0),
+ "d2": partial(delta, exponent=2.0),
+ "d3": partial(delta, exponent=3.0),
+ "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()),
+ "rmselog": lambda gt, pred: torch.sqrt(
+ ((torch.log(gt) - torch.log(pred)) ** 2).mean()
+ ),
+ "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(),
+ "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(),
+ "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(),
+ "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(),
+ "medianlog": lambda gt, pred: 100
+ * (torch.log(pred) - torch.log(gt)).median().abs(),
+ "d_auc": d_auc,
+ "tau": partial(tau, perc=0.03),
+}
+
+
+DICT_METRICS_3D = {
+ "MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2),
+ "arel_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2)
+ / torch.norm(gt, dim=0, p=2),
+ "tau_3d": lambda gt, pred, thresholds: (
+ (torch.norm(pred, dim=0, p=2) / torch.norm(gt, dim=0, p=2)).log().abs().exp()
+ < 1.25
+ )
+ .float()
+ .mean(),
+ "chamfer": lambda gt, pred, thresholds: chamfer_dist(
+ gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1)
+ ),
+ "F1": lambda gt, pred, thresholds: f1_score(
+ gt.unsqueeze(0).permute(0, 2, 1),
+ pred.unsqueeze(0).permute(0, 2, 1),
+ thresholds=thresholds,
+ ),
+ "F1_si": lambda gt, pred, thresholds: f1_score_si(
+ gt.unsqueeze(0).permute(0, 2, 1),
+ pred.unsqueeze(0).permute(0, 2, 1),
+ thresholds=thresholds,
+ ),
+ "rays": lambda gt, pred, thresholds: rho(
+ gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1)
+ ),
+}
+
+
+DICT_METRICS_FLOW = {
+ "epe": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)),
+ "epe1": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 1,
+ "epe3": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 3,
+ "epe5": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 5,
+}
+
+
+DICT_METRICS_D = {
+ "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to(
+ torch.float32
+ ),
+ "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt),
+}
+
+
+def eval_depth(
+ gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None
+):
+ summary_metrics = defaultdict(list)
+ # preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear")
+ for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
+ if max_depth is not None:
+ mask = mask & (gt <= max_depth)
+ for name, fn in DICT_METRICS.items():
+ if name in ["tau", "d1", "arel"]:
+ for rescale_fn in ["ssi", "si"]:
+ summary_metrics[f"{name}_{rescale_fn}"].append(
+ fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask]))
+ )
+ summary_metrics[name].append(fn(gt[mask], pred[mask]).mean())
+ return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
+
+
+def eval_3d(
+ gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None
+):
+ summary_metrics = defaultdict(list)
+ MAX_PIXELS = 75_000 # 300_000
+ ratio = min(1.0, (MAX_PIXELS / masks[0].sum()) ** 0.5)
+ h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio)
+ gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact")
+ preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact")
+ masks = F.interpolate(
+ masks.float(), size=(h_max, w_max), mode="nearest-exact"
+ ).bool()
+
+ for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)):
+ if not torch.any(mask):
+ continue
+ for name, fn in DICT_METRICS_3D.items():
+ summary_metrics[name].append(
+ fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean()
+ )
+ return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
+
+
+def compute_aucs(gt, pred, mask, uncertainties, steps=50, metrics=["abs_rel"]):
+ dict_ = {}
+ x_axis = torch.linspace(0, 1, steps=steps + 1, device=gt.device)
+ quantiles = torch.linspace(0, 1 - 1 / steps, steps=steps, device=gt.device)
+ zer = torch.tensor(0.0, device=gt.device)
+ # revert order (high uncertainty first)
+ uncertainties = -uncertainties[mask]
+ gt = gt[mask]
+ pred = pred[mask]
+ true_uncert = {metric: -DICT_METRICS_D[metric](gt, pred) for metric in metrics}
+ # get percentiles for sampling and corresponding subsets
+ thresholds = torch.quantile(uncertainties, quantiles)
+ subs = [(uncertainties >= t) for t in thresholds]
+
+ # compute sparsification curves for each metric (add 0 for final sampling)
+ for metric in metrics:
+ opt_thresholds = torch.quantile(true_uncert[metric], quantiles)
+ opt_subs = [(true_uncert[metric] >= t) for t in opt_thresholds]
+ sparse_curve = torch.stack(
+ [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in subs] + [zer], dim=0
+ )
+ opt_curve = torch.stack(
+ [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in opt_subs] + [zer],
+ dim=0,
+ )
+ rnd_curve = DICT_METRICS[metric](gt, pred)
+
+ dict_[f"AUSE_{metric}"] = torch.trapz(sparse_curve - opt_curve, x=x_axis)
+ dict_[f"AURG_{metric}"] = rnd_curve - torch.trapz(sparse_curve, x=x_axis)
+
+ return dict_
+
+
+def eval_depth_uncertainties(
+ gts: torch.Tensor,
+ preds: torch.Tensor,
+ uncertainties: torch.Tensor,
+ masks: torch.Tensor,
+ max_depth=None,
+):
+ summary_metrics = defaultdict(list)
+ preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear")
+ for i, (gt, pred, mask, uncertainty) in enumerate(
+ zip(gts, preds, masks, uncertainties)
+ ):
+ if max_depth is not None:
+ mask = torch.logical_and(mask, gt < max_depth)
+ for name, fn in DICT_METRICS.items():
+ summary_metrics[name].append(fn(gt[mask], pred[mask]))
+ for name, val in compute_aucs(gt, pred, mask, uncertainty).items():
+ summary_metrics[name].append(val)
+ return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()}
+
+
+def lazy_eval_depth(
+ gt_fns, pred_fns, min_depth=1e-2, max_depth=None, depth_scale=256.0
+):
+ summary_metrics = defaultdict(list)
+ for i, (gt_fn, pred_fn) in enumerate(zip(gt_fns, pred_fns)):
+ gt = TF.pil_to_tensor(Image.open(gt_fn)).to(torch.float32) / depth_scale
+ pred = TF.pil_to_tensor(Image.open(pred_fn)).to(torch.float32) / depth_scale
+ mask = gt > min_depth
+ if max_depth is not None:
+ mask_2 = gt < max_depth
+ mask = torch.logical_and(mask, mask_2)
+ for name, fn in DICT_METRICS.items():
+ summary_metrics[name].append(fn(gt[mask], pred[mask]))
+
+ return {name: torch.mean(vals).item() for name, vals in summary_metrics.items()}
diff --git a/unik3d/utils/geometric.py b/unik3d/utils/geometric.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c61d8e5e7ef954b2423063057eba4dbe72f6704
--- /dev/null
+++ b/unik3d/utils/geometric.py
@@ -0,0 +1,479 @@
+from typing import Tuple
+
+import torch
+from torch.nn import functional as F
+
+
+# @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+def generate_rays(
+ camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False
+):
+ batch_size, device, dtype = (
+ camera_intrinsics.shape[0],
+ camera_intrinsics.device,
+ camera_intrinsics.dtype,
+ )
+ # print("CAMERA DTYPE", dtype)
+ height, width = image_shape
+ # Generate grid of pixel coordinates
+ pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype)
+ pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype)
+ if noisy:
+ pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5
+ pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5
+ pixel_coords = torch.stack(
+ [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2
+ ) # (H, W, 2)
+ pixel_coords = pixel_coords + 0.5
+
+ # Calculate ray directions
+ intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype)
+ homogeneous_coords = torch.cat(
+ [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2
+ ) # (H, W, 3)
+ ray_directions = torch.matmul(
+ intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1)
+ ) # (3, H*W)
+
+ # unstable normalization, need float32?
+ ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W)
+ ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3)
+
+ theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1])
+ phi = torch.acos(ray_directions[..., 1])
+ # pitch = torch.asin(ray_directions[..., 1])
+ # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1])
+ angles = torch.stack([theta, phi], dim=-1)
+ return ray_directions, angles
+
+
+@torch.jit.script
+def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
+ theta = spherical_tensor[..., 0] # Extract polar angle
+ phi = spherical_tensor[..., 1] # Extract azimuthal angle
+ z = spherical_tensor[..., 2] # Extract zbuffer depth
+
+ # y = r * cos(phi)
+ # x = r * sin(phi) * sin(theta)
+ # z = r * sin(phi) * cos(theta)
+ # =>
+ # r = z / sin(phi) / cos(theta)
+ # y = z / (sin(phi) / cos(phi)) / cos(theta)
+ # x = z * sin(theta) / cos(theta)
+ x = z * torch.tan(theta)
+ y = z / torch.tan(phi) / torch.cos(theta)
+
+ euclidean_tensor = torch.stack((x, y, z), dim=-1)
+ return euclidean_tensor
+
+
+@torch.jit.script
+def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor:
+ theta = spherical_tensor[..., 0] # Extract polar angle
+ phi = spherical_tensor[..., 1] # Extract azimuthal angle
+ r = spherical_tensor[..., 2] # Extract radius
+ # x = r * torch.sin(theta) * torch.sin(phi)
+ # y = r * torch.cos(theta)
+ # z = r * torch.cos(phi) * torch.sin(theta)
+ x = r * torch.sin(theta) * torch.cos(phi)
+ y = r * torch.sin(theta) * torch.sin(phi)
+ z = r * torch.cos(theta)
+ euclidean_tensor = torch.stack((x, y, z), dim=-1)
+ return euclidean_tensor
+
+
+@torch.jit.script
+def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor:
+ x = spherical_tensor[..., 0] # Extract polar angle
+ y = spherical_tensor[..., 1] # Extract azimuthal angle
+ z = spherical_tensor[..., 2] # Extract radius
+ # y = r * cos(phi)
+ # x = r * sin(phi) * sin(theta)
+ # z = r * sin(phi) * cos(theta)
+ r = torch.sqrt(x**2 + y**2 + z**2)
+ theta = torch.atan2(x / r, z / r)
+ phi = torch.acos(y / r)
+
+ euclidean_tensor = torch.stack((theta, phi, r), dim=-1)
+ return euclidean_tensor
+
+
+@torch.jit.script
+def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor:
+ pitch = torch.asin(euclidean_tensor[..., 1])
+ yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1])
+ z = euclidean_tensor[..., 2] # Extract zbuffer depth
+ euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1)
+ return euclidean_tensor
+
+
+@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+def unproject_points(
+ depth: torch.Tensor, camera_intrinsics: torch.Tensor
+) -> torch.Tensor:
+ """
+ Unprojects a batch of depth maps to 3D point clouds using camera intrinsics.
+
+ Args:
+ depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W).
+ camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3).
+
+ Returns:
+ torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W).
+ """
+ batch_size, _, height, width = depth.shape
+ device = depth.device
+
+ # Create pixel grid
+ y_coords, x_coords = torch.meshgrid(
+ torch.arange(height, device=device),
+ torch.arange(width, device=device),
+ indexing="ij",
+ )
+ pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2)
+
+ # Get homogeneous coords (u v 1)
+ pixel_coords_homogeneous = torch.cat(
+ (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1
+ )
+ pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten(
+ 1
+ ) # (3, H*W)
+ # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W]
+ camera_intrinsics_inv = camera_intrinsics.clone()
+ # invert camera intrinsics
+ camera_intrinsics_inv[:, 0, 0] = 1 / camera_intrinsics_inv[:, 0, 0]
+ camera_intrinsics_inv[:, 1, 1] = 1 / camera_intrinsics_inv[:, 1, 1]
+
+ unprojected_points = camera_intrinsics_inv @ pixel_coords_homogeneous # (B, 3, H*W)
+ unprojected_points = unprojected_points.view(
+ batch_size, 3, height, width
+ ) # (B, 3, H, W)
+ unprojected_points = unprojected_points * depth # (B, 3, H, W)
+ return unprojected_points
+
+
+@torch.jit.script
+def project_points(
+ points_3d: torch.Tensor,
+ intrinsic_matrix: torch.Tensor,
+ image_shape: Tuple[int, int],
+) -> torch.Tensor:
+ # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T
+ points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2))
+
+ # Normalize projected points: (u v w) -> (u / w, v / w, 1)
+ points_2d = points_2d[..., :2] / points_2d[..., 2:]
+
+ # To pixels (rounding!!!), no int as it breaks gradient
+ points_2d = points_2d.round()
+
+ # pointa need to be inside the image (can it diverge onto all points out???)
+ valid_mask = (
+ (points_2d[..., 0] >= 0)
+ & (points_2d[..., 0] < image_shape[1])
+ & (points_2d[..., 1] >= 0)
+ & (points_2d[..., 1] < image_shape[0])
+ )
+
+ # Calculate the flat indices of the valid pixels
+ flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1]
+ flat_indices = flat_points_2d.long()
+
+ # Create depth maps and counts using scatter_add, (B, H, W)
+ depth_maps = torch.zeros(
+ [points_3d.shape[0], *image_shape], device=points_3d.device
+ )
+ counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device)
+
+ # Loop over batches to apply masks and accumulate depth/count values
+ for i in range(points_3d.shape[0]):
+ valid_indices = flat_indices[i, valid_mask[i]]
+ depth_maps[i].view(-1).scatter_add_(
+ 0, valid_indices, points_3d[i, valid_mask[i], 2]
+ )
+ counts[i].view(-1).scatter_add_(
+ 0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2])
+ )
+
+ # Calculate mean depth for each pixel in each batch
+ mean_depth_maps = depth_maps / counts.clamp(min=1.0)
+ return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W)
+
+
+@torch.jit.script
+def downsample(data: torch.Tensor, downsample_factor: int = 2):
+ N, _, H, W = data.shape
+ data = data.view(
+ N,
+ H // downsample_factor,
+ downsample_factor,
+ W // downsample_factor,
+ downsample_factor,
+ 1,
+ )
+ data = data.permute(0, 1, 3, 5, 2, 4).contiguous()
+ data = data.view(-1, downsample_factor * downsample_factor)
+ data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data)
+ data = torch.min(data_tmp, dim=-1).values
+ data = data.view(N, 1, H // downsample_factor, W // downsample_factor)
+ data = torch.where(data > 1000, torch.zeros_like(data), data)
+ return data
+
+
+@torch.jit.script
+def flat_interpolate(
+ flat_tensor: torch.Tensor,
+ old: Tuple[int, int],
+ new: Tuple[int, int],
+ antialias: bool = False,
+ mode: str = "bilinear",
+) -> torch.Tensor:
+ if old[0] == new[0] and old[1] == new[1]:
+ return flat_tensor
+ tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute(
+ 0, 3, 1, 2
+ ) # b c h w
+ tensor_interp = F.interpolate(
+ tensor,
+ size=(new[0], new[1]),
+ mode=mode,
+ align_corners=False,
+ antialias=antialias,
+ )
+ flat_tensor_interp = tensor_interp.view(
+ flat_tensor.shape[0], -1, new[0] * new[1]
+ ).permute(
+ 0, 2, 1
+ ) # b (h w) c
+ return flat_tensor_interp.contiguous()
+
+
+# # @torch.jit.script
+# def displacement_relative_neighbour(gt: torch.Tensor, mask: torch.Tensor = None, kernel_size: int = 7, ndim: int =4):
+# pad = kernel_size // 2
+# n_neighbours = int(kernel_size**2)
+
+# # when torchscipt will support nested generators in listcomp or usage of range
+# # in product(range_, range_), then use listcomp, so far speedup ~5% wrt std python
+# if mask is None:
+# mask = torch.ones_like(gt).bool()
+
+# lst_gts, lst_masks = [], []
+# for i in range(-kernel_size//2 + 1, kernel_size//2 + 1):
+# for j in range(-kernel_size//2 + 1, kernel_size//2 + 1):
+# if i != 0 or j != 0:
+# lst_gts.append(torch.roll(gt, shifts=(i, j), dims=(-2, -1)))
+# lst_masks.append(torch.roll(F.pad(mask, (pad,) * 4), shifts=(i, j), dims=(-2, -1)))
+# gts = torch.cat(lst_gts, dim=-3)
+# masks = torch.cat(lst_masks, dim=-3)
+
+# masks = masks[..., pad:-pad, pad:-pad]
+# masks[~mask.repeat(*(1,) * (ndim - 3), n_neighbours-1, 1, 1,)] = False # No displacement known if seed is missing
+# log_gts = gts.clamp(min=1e-6).log() - gt.repeat(*(1,) * (ndim - 3), n_neighbours-1, 1, 1).clamp(min=1e-6).log()
+# return log_gts, masks
+
+
+# @torch.jit.script
+# def antidisplacement_relative_neighbour(preds: torch.Tensor, kernel_size: int = 7):
+# lst_preds, lst_masks = [], []
+# cnt = 0
+# pad = kernel_size // 2
+# mask = F.pad(torch.ones((preds.shape[0], 1, preds.shape[-2], preds.shape[-1]), device=preds.device), (pad,) * 4)
+# for i in range(-kernel_size//2 + 1, kernel_size//2 + 1):
+# for j in range(-kernel_size//2 + 1, kernel_size//2 + 1):
+# if i != 0 or j !=0:
+# lst_preds.append(torch.roll(preds[:, cnt], shifts=(-i, -j), dims=(-2, -1)))
+# lst_masks.append(torch.roll(mask, shifts=(-i, -j), dims=(-2, -1)))
+# cnt += 1
+# preds_ensamble = torch.stack(lst_preds, dim=1)
+# masks = torch.cat(lst_masks, dim=1)
+# masks = masks[..., pad:-pad, pad:-pad]
+# return preds_ensamble, masks
+
+
+# def unproject(uv, fx, fy, cx, cy, xi=0, alpha=0):
+# u, v = uv.unbind(dim=1)
+# mx = (u - cx) / fx
+# my = (v - cy) / fy
+# r_square = mx ** 2 + my ** 2
+# root = 1 - (2 * alpha - 1) * r_square
+# valid_mask = root >= 0
+# root[~valid_mask] = 0.0
+# mz = (1 - (alpha ** 2) * r_square) / (alpha * torch.sqrt(root) + (1 - alpha))
+# coeff = (mz * xi + torch.sqrt(mz ** 2 + (1 - xi ** 2) * r_square)) / (mz ** 2 + r_square)
+
+# x = coeff * mx
+# y = coeff * my
+# z = coeff * mz - xi
+# # z = z.clamp(min=1e-7)
+
+# x_norm = x / z
+# y_norm = y / z
+# z_norm = z / z
+# xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1)
+# # print("unproj", xnorm.shape, xnorm[:, -1].mean())
+
+# return xnorm, valid_mask.unsqueeze(1).repeat(1, 3, 1, 1)
+
+
+# def project(point3D, fx, fy, cx, cy, xi=0, alpha=0):
+# B, C, H, W = point3D.shape
+# x, y, z = point3D.unbind(dim=1)
+# z = z.clamp(min=1e-7)
+# d_1 = torch.sqrt( x ** 2 + y ** 2 + z ** 2 )
+# d_2 = torch.sqrt( x ** 2 + y ** 2 + (xi * d_1 + z) ** 2 )
+
+# div = alpha * d_2 + (1 - alpha) * (xi * d_1 + z)
+# Xnorm = fx * x / div + cx
+# Ynorm = fy * y / div + cy
+
+# coords = torch.stack([Xnorm, Ynorm], dim=1)
+# w1 = torch.where(alpha <= 0.5, alpha / (1 - alpha), (1 - alpha) / alpha)
+# w2 = w1 + xi / ((2 * w1 * xi + xi ** 2 + 1) ** 0.5)
+# valid_mask = z > - w2 * d_1
+
+# # Return pixel coordinates
+# return coords, valid_mask.unsqueeze(1).repeat(1, 2, 1, 1)
+
+
+@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+def unproject(uv, fx, fy, cx, cy, alpha=None, beta=None):
+ uv = uv.float()
+ fx = fx.float()
+ fy = fy.float()
+ cx = cx.float()
+ cy = cy.float()
+ u, v = uv.unbind(dim=1)
+ alpha = torch.zeros_like(fx) if alpha is None else alpha.float()
+ beta = torch.ones_like(fx) if beta is None else beta.float()
+ mx = (u - cx) / fx
+ my = (v - cy) / fy
+ r_square = mx**2 + my**2
+ valid_mask = r_square < torch.where(alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1)))
+ sqrt_val = 1 - (2 * alpha - 1) * beta * r_square
+ mz = (1 - beta * (alpha**2) * r_square) / (
+ alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha)
+ )
+ coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5)
+
+ x = coeff * mx
+ y = coeff * my
+ z = coeff * mz
+ valid_mask = valid_mask & (z > 1e-3)
+
+ xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1)
+ return xnorm, valid_mask.unsqueeze(1)
+
+
+@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32)
+def project(point3D, fx, fy, cx, cy, alpha=None, beta=None):
+ H, W = point3D.shape[-2:]
+ alpha = torch.zeros_like(fx) if alpha is None else alpha
+ beta = torch.ones_like(fx) if beta is None else beta
+ x, y, z = point3D.unbind(dim=1)
+ d = torch.sqrt(beta * (x**2 + y**2) + z**2)
+
+ x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
+ y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3)
+
+ Xnorm = fx * x + cx
+ Ynorm = fy * y + cy
+
+ coords = torch.stack([Xnorm, Ynorm], dim=1)
+
+ invalid = (
+ (coords[:, 0] < 0)
+ | (coords[:, 0] > W)
+ | (coords[:, 1] < 0)
+ | (coords[:, 1] > H)
+ | (z < 0)
+ )
+
+ # Return pixel coordinates
+ return coords, (~invalid).unsqueeze(1)
+
+
+def rays2angles(rays: torch.Tensor) -> torch.Tensor:
+ theta = torch.atan2(rays[..., 0], rays[..., -1])
+ phi = torch.acos(rays[..., 1])
+ angles = torch.stack([theta, phi], dim=-1)
+ return angles
+
+
+@torch.jit.script
+def dilate(image, kernel_size: int | tuple[int, int]):
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ device, dtype = image.device, image.dtype
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
+ kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device)
+ dilated_image = F.conv2d(image.float(), kernel, padding=padding, stride=1)
+ dilated_image = torch.where(
+ dilated_image > 0,
+ torch.tensor(1.0, device=device),
+ torch.tensor(0.0, device=device),
+ )
+ return dilated_image.to(dtype)
+
+
+@torch.jit.script
+def erode(image, kernel_size: int | tuple[int, int]):
+ if isinstance(kernel_size, int):
+ kernel_size = (kernel_size, kernel_size)
+ device, dtype = image.device, image.dtype
+ padding = (kernel_size[0] // 2, kernel_size[1] // 2)
+ kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device)
+ eroded_image = F.conv2d(image.float(), kernel, padding=padding, stride=1)
+ eroded_image = torch.where(
+ eroded_image == (kernel_size[0] * kernel_size[1]),
+ torch.tensor(1.0, device=device),
+ torch.tensor(0.0, device=device),
+ )
+ return eroded_image.to(dtype)
+
+
+@torch.jit.script
+def iou(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor:
+ device = mask1.device
+
+ # Ensure the masks are binary (0 or 1)
+ mask1 = mask1.to(torch.bool)
+ mask2 = mask2.to(torch.bool)
+
+ # Compute intersection and union
+ intersection = torch.sum(mask1 & mask2).to(torch.float32)
+ union = torch.sum(mask1 | mask2).to(torch.float32)
+
+ # Compute IoU
+ iou = intersection / union.clip(min=1.0)
+
+ return iou
+
+
+if __name__ == "__main__":
+ kernel_size = 3
+ image = torch.tensor(
+ [
+ [
+ [
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ [1, 1, 1, 1, 1],
+ ]
+ ]
+ ],
+ dtype=torch.bool,
+ )
+
+ print("testing dilate and erode, with image:\n", image, image.shape)
+
+ # Perform dilation
+ dilated_image = dilate(image, kernel_size)
+ print("Dilated Image:\n", dilated_image)
+
+ # Perform erosion
+ eroded_image = erode(image, kernel_size)
+ print("Eroded Image:\n", eroded_image)
diff --git a/unik3d/utils/knn.py b/unik3d/utils/knn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e12787497268f37cc022ef14207e084d4d5eb352
--- /dev/null
+++ b/unik3d/utils/knn.py
@@ -0,0 +1,248 @@
+# Copyright (c) Meta Platforms, Inc. and affiliates.
+# All rights reserved.
+#
+# This source code is licensed under the BSD-style license found in the
+# LICENSE file in the root directory of this source tree.
+
+# pyre-unsafe
+
+from collections import namedtuple
+from typing import Union
+
+import torch
+from torch.autograd import Function
+from torch.autograd.function import once_differentiable
+
+_KNN = namedtuple("KNN", "dists idx knn")
+
+
+class _knn_points(Function):
+ """
+ Torch autograd Function wrapper for KNN C++/CUDA implementations.
+ """
+
+ @staticmethod
+ # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently.
+ def forward(
+ ctx,
+ p1,
+ p2,
+ lengths1,
+ lengths2,
+ K,
+ version,
+ norm: int = 2,
+ return_sorted: bool = True,
+ ):
+ """
+ K-Nearest neighbors on point clouds.
+
+ Args:
+ p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
+ containing up to P1 points of dimension D.
+ p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
+ containing up to P2 points of dimension D.
+ lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
+ length of each pointcloud in p1. Or None to indicate that every cloud has
+ length P1.
+ lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
+ length of each pointcloud in p2. Or None to indicate that every cloud has
+ length P2.
+ K: Integer giving the number of nearest neighbors to return.
+ version: Which KNN implementation to use in the backend. If version=-1,
+ the correct implementation is selected based on the shapes of the inputs.
+ norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2).
+ return_sorted: (bool) whether to return the nearest neighbors sorted in
+ ascending order of distance.
+
+ Returns:
+ p1_dists: Tensor of shape (N, P1, K) giving the squared distances to
+ the nearest neighbors. This is padded with zeros both where a cloud in p2
+ has fewer than K points and where a cloud in p1 has fewer than P1 points.
+
+ p1_idx: LongTensor of shape (N, P1, K) giving the indices of the
+ K nearest neighbors from points in p1 to points in p2.
+ Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
+ neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
+ in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points.
+ """
+ if not ((norm == 1) or (norm == 2)):
+ raise ValueError("Support for 1 or 2 norm.")
+
+ idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version)
+
+ # sort KNN in ascending order if K > 1
+ if K > 1 and return_sorted:
+ if lengths2.min() < K:
+ P1 = p1.shape[1]
+ mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None]
+ # mask has shape [N, K], true where dists irrelevant
+ mask = mask[:, None].expand(-1, P1, -1)
+ # mask has shape [N, P1, K], true where dists irrelevant
+ dists[mask] = float("inf")
+ dists, sort_idx = dists.sort(dim=2)
+ dists[mask] = 0
+ else:
+ dists, sort_idx = dists.sort(dim=2)
+ idx = idx.gather(2, sort_idx)
+
+ ctx.save_for_backward(p1, p2, lengths1, lengths2, idx)
+ ctx.mark_non_differentiable(idx)
+ ctx.norm = norm
+ return dists, idx
+
+ @staticmethod
+ @once_differentiable
+ def backward(ctx, grad_dists, grad_idx):
+ p1, p2, lengths1, lengths2, idx = ctx.saved_tensors
+ norm = ctx.norm
+ # TODO(gkioxari) Change cast to floats once we add support for doubles.
+ if not (grad_dists.dtype == torch.float32):
+ grad_dists = grad_dists.float()
+ if not (p1.dtype == torch.float32):
+ p1 = p1.float()
+ if not (p2.dtype == torch.float32):
+ p2 = p2.float()
+ grad_p1, grad_p2 = _C.knn_points_backward(
+ p1, p2, lengths1, lengths2, idx, norm, grad_dists
+ )
+ return grad_p1, grad_p2, None, None, None, None, None, None
+
+
+def knn_points(
+ p1: torch.Tensor,
+ p2: torch.Tensor,
+ lengths1: Union[torch.Tensor, None] = None,
+ lengths2: Union[torch.Tensor, None] = None,
+ norm: int = 2,
+ K: int = 1,
+ version: int = -1,
+ return_nn: bool = False,
+ return_sorted: bool = True,
+) -> _KNN:
+ """
+ K-Nearest neighbors on point clouds.
+
+ Args:
+ p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each
+ containing up to P1 points of dimension D.
+ p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each
+ containing up to P2 points of dimension D.
+ lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the
+ length of each pointcloud in p1. Or None to indicate that every cloud has
+ length P1.
+ lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the
+ length of each pointcloud in p2. Or None to indicate that every cloud has
+ length P2.
+ norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2.
+ K: Integer giving the number of nearest neighbors to return.
+ version: Which KNN implementation to use in the backend. If version=-1,
+ the correct implementation is selected based on the shapes of the inputs.
+ return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1.
+ return_sorted: (bool) whether to return the nearest neighbors sorted in
+ ascending order of distance.
+
+ Returns:
+ dists: Tensor of shape (N, P1, K) giving the squared distances to
+ the nearest neighbors. This is padded with zeros both where a cloud in p2
+ has fewer than K points and where a cloud in p1 has fewer than P1 points.
+
+ idx: LongTensor of shape (N, P1, K) giving the indices of the
+ K nearest neighbors from points in p1 to points in p2.
+ Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest
+ neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud
+ in p2 has fewer than K points and where a cloud in p1 has fewer than P1
+ points.
+
+ nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for
+ each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor
+ for `p1[n, i]`. Returned if `return_nn` is True.
+ The nearest neighbors are collected using `knn_gather`
+
+ .. code-block::
+
+ p2_nn = knn_gather(p2, p1_idx, lengths2)
+
+ which is a helper function that allows indexing any tensor of shape (N, P2, U) with
+ the indices `p1_idx` returned by `knn_points`. The output is a tensor
+ of shape (N, P1, K, U).
+
+ """
+ if p1.shape[0] != p2.shape[0]:
+ raise ValueError("pts1 and pts2 must have the same batch dimension.")
+ if p1.shape[2] != p2.shape[2]:
+ raise ValueError("pts1 and pts2 must have the same point dimension.")
+
+ p1 = p1.contiguous()
+ p2 = p2.contiguous()
+
+ P1 = p1.shape[1]
+ P2 = p2.shape[1]
+
+ if lengths1 is None:
+ lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device)
+ if lengths2 is None:
+ lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device)
+
+ p1_dists, p1_idx = _knn_points.apply(
+ p1, p2, lengths1, lengths2, K, version, norm, return_sorted
+ )
+
+ p2_nn = None
+ if return_nn:
+ p2_nn = knn_gather(p2, p1_idx, lengths2)
+
+ return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None)
+
+
+def knn_gather(
+ x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None
+):
+ """
+ A helper function for knn that allows indexing a tensor x with the indices `idx`
+ returned by `knn_points`.
+
+ For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)`
+ where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D),
+ then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`.
+ It can also be applied for any tensor x of shape (N, M, U) where U != D.
+
+ Args:
+ x: Tensor of shape (N, M, U) containing U-dimensional features to
+ be gathered.
+ idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`.
+ lengths: LongTensor of shape (N,) of values in the range [0, M], giving the
+ length of each example in the batch in x. Or None to indicate that every
+ example has length M.
+ Returns:
+ x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x
+ with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`.
+ If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0.
+ """
+ N, M, U = x.shape
+ _N, L, K = idx.shape
+
+ if N != _N:
+ raise ValueError("x and idx must have same batch dimension.")
+
+ if lengths is None:
+ lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device)
+
+ idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U)
+ # idx_expanded has shape [N, L, K, U]
+
+ x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded)
+ # p2_nn has shape [N, L, K, U]
+
+ needs_mask = lengths.min() < K
+ if needs_mask:
+ # mask has shape [N, K], true where idx is irrelevant because
+ # there is less number of points in p2 than K
+ mask = lengths[:, None] <= torch.arange(K, device=x.device)[None]
+
+ # expand mask to shape [N, L, K, U]
+ mask = mask[:, None].expand(-1, L, -1)
+ mask = mask[:, :, :, None].expand(-1, -1, -1, U)
+ x_out[mask] = 0.0
+
+ return x_out
diff --git a/unik3d/utils/misc.py b/unik3d/utils/misc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d144b0cf119a781c2e13998db842717b0141ed9
--- /dev/null
+++ b/unik3d/utils/misc.py
@@ -0,0 +1,630 @@
+from functools import wraps
+from time import time
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from einops import rearrange, reduce, repeat
+from scipy import interpolate
+
+
+@torch.jit.script
+def max_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
+ if len(tensors) == 1:
+ return tensors[0]
+ return torch.stack(tensors, dim=-1).max(dim=-1).values
+
+
+def last_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
+ return tensors[-1]
+
+
+def first_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
+ return tensors[0]
+
+
+@torch.jit.script
+def softmax_stack(
+ tensors: list[torch.Tensor], temperature: float = 1.0
+) -> torch.Tensor:
+ if len(tensors) == 1:
+ return tensors[0]
+ return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1)
+
+
+@torch.jit.script
+def mean_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
+ if len(tensors) == 1:
+ return tensors[0]
+ return torch.stack(tensors, dim=-1).mean(dim=-1)
+
+
+@torch.jit.script
+def sum_stack(tensors: list[torch.Tensor]) -> torch.Tensor:
+ if len(tensors) == 1:
+ return tensors[0]
+ return torch.stack(tensors, dim=-1).sum(dim=-1)
+
+
+def convert_module_to_f16(l):
+ """
+ Convert primitive modules to float16.
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.half()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.half()
+
+
+def convert_module_to_f32(l):
+ """
+ Convert primitive modules to float32, undoing convert_module_to_f16().
+ """
+ if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)):
+ l.weight.data = l.weight.data.float()
+ if l.bias is not None:
+ l.bias.data = l.bias.data.float()
+
+
+def format_seconds(seconds):
+ minutes, seconds = divmod(seconds, 60)
+ hours, minutes = divmod(minutes, 60)
+ return f"{hours:d}:{minutes:02d}:{seconds:02d}"
+
+
+def get_params(module, lr, wd):
+ skip_list = {}
+ skip_keywords = {}
+ if hasattr(module, "no_weight_decay"):
+ skip_list = module.no_weight_decay()
+ if hasattr(module, "no_weight_decay_keywords"):
+ skip_keywords = module.no_weight_decay_keywords()
+ has_decay = []
+ no_decay = []
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ continue # frozen weights
+ if (
+ (name in skip_list)
+ or any((kw in name for kw in skip_keywords))
+ or len(param.shape) == 1
+ or name.endswith(".gamma")
+ or name.endswith(".beta")
+ or name.endswith(".bias")
+ ):
+ no_decay.append(param)
+ else:
+ has_decay.append(param)
+
+ group1 = {
+ "params": has_decay,
+ "weight_decay": wd,
+ "lr": lr,
+ "weight_decay_init": wd,
+ "weight_decay_base": wd,
+ "lr_base": lr,
+ }
+ group2 = {
+ "params": no_decay,
+ "weight_decay": 0.0,
+ "lr": lr,
+ "weight_decay_init": 0.0,
+ "weight_decay_base": 0.0,
+ "weight_decay_final": 0.0,
+ "lr_base": lr,
+ }
+ return [group1, group2], [lr, lr]
+
+
+def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage):
+ if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"):
+ return 0
+ elif var_name.startswith("patch_embed"):
+ return 0
+ elif var_name.startswith("layers"):
+ if var_name.split(".")[2] == "blocks":
+ stage_id = int(var_name.split(".")[1])
+ layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id])
+ return layer_id + 1
+ elif var_name.split(".")[2] == "downsample":
+ stage_id = int(var_name.split(".")[1])
+ layer_id = sum(layers_per_stage[: stage_id + 1])
+ return layer_id
+ else:
+ return num_max_layer - 1
+
+
+def get_params_layerdecayswin(module, lr, wd, ld):
+ skip_list = {}
+ skip_keywords = {}
+ if hasattr(module, "no_weight_decay"):
+ skip_list = module.no_weight_decay()
+ if hasattr(module, "no_weight_decay_keywords"):
+ skip_keywords = module.no_weight_decay_keywords()
+ layers_per_stage = module.depths
+ num_layers = sum(layers_per_stage) + 1
+ lrs = []
+ params = []
+ for name, param in module.named_parameters():
+ if not param.requires_grad:
+ print(f"{name} frozen")
+ continue # frozen weights
+ layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage)
+ lr_cur = lr * ld ** (num_layers - layer_id - 1)
+ # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"):
+ if (name in skip_list) or any((kw in name for kw in skip_keywords)):
+ wd_cur = 0.0
+ else:
+ wd_cur = wd
+ params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur})
+ lrs.append(lr_cur)
+ return params, lrs
+
+
+def log(t, eps: float = 1e-5):
+ return torch.log(t.clamp(min=eps))
+
+
+def l2norm(t):
+ return F.normalize(t, dim=-1)
+
+
+def exists(val):
+ return val is not None
+
+
+def identity(t, *args, **kwargs):
+ return t
+
+
+def divisible_by(numer, denom):
+ return (numer % denom) == 0
+
+
+def first(arr, d=None):
+ if len(arr) == 0:
+ return d
+ return arr[0]
+
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if callable(d) else d
+
+
+def maybe(fn):
+ @wraps(fn)
+ def inner(x):
+ if not exists(x):
+ return x
+ return fn(x)
+
+ return inner
+
+
+def once(fn):
+ called = False
+
+ @wraps(fn)
+ def inner(x):
+ nonlocal called
+ if called:
+ return
+ called = True
+ return fn(x)
+
+ return inner
+
+
+def _many(fn):
+ @wraps(fn)
+ def inner(tensors, pattern, **kwargs):
+ return (fn(tensor, pattern, **kwargs) for tensor in tensors)
+
+ return inner
+
+
+rearrange_many = _many(rearrange)
+repeat_many = _many(repeat)
+reduce_many = _many(reduce)
+
+
+def load_pretrained(state_dict, checkpoint):
+ checkpoint_model = checkpoint["model"]
+ if any([True if "encoder." in k else False for k in checkpoint_model.keys()]):
+ checkpoint_model = {
+ k.replace("encoder.", ""): v
+ for k, v in checkpoint_model.items()
+ if k.startswith("encoder.")
+ }
+ print("Detect pre-trained model, remove [encoder.] prefix.")
+ else:
+ print("Detect non-pre-trained model, pass without doing anything.")
+ print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........")
+ checkpoint = load_checkpoint_swin(state_dict, checkpoint_model)
+
+
+def load_checkpoint_swin(model, checkpoint_model):
+ state_dict = model.state_dict()
+ # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size
+ all_keys = list(checkpoint_model.keys())
+ for key in all_keys:
+ if "relative_position_bias_table" in key:
+ relative_position_bias_table_pretrained = checkpoint_model[key]
+ relative_position_bias_table_current = state_dict[key]
+ L1, nH1 = relative_position_bias_table_pretrained.size()
+ L2, nH2 = relative_position_bias_table_current.size()
+ if nH1 != nH2:
+ print(f"Error in loading {key}, passing......")
+ else:
+ if L1 != L2:
+ print(f"{key}: Interpolate relative_position_bias_table using geo.")
+ src_size = int(L1**0.5)
+ dst_size = int(L2**0.5)
+
+ def geometric_progression(a, r, n):
+ return a * (1.0 - r**n) / (1.0 - r)
+
+ left, right = 1.01, 1.5
+ while right - left > 1e-6:
+ q = (left + right) / 2.0
+ gp = geometric_progression(1, q, src_size // 2)
+ if gp > dst_size // 2:
+ right = q
+ else:
+ left = q
+
+ # if q > 1.090307:
+ # q = 1.090307
+
+ dis = []
+ cur = 1
+ for i in range(src_size // 2):
+ dis.append(cur)
+ cur += q ** (i + 1)
+
+ r_ids = [-_ for _ in reversed(dis)]
+
+ x = r_ids + [0] + dis
+ y = r_ids + [0] + dis
+
+ t = dst_size // 2.0
+ dx = np.arange(-t, t + 0.1, 1.0)
+ dy = np.arange(-t, t + 0.1, 1.0)
+
+ print("Original positions = %s" % str(x))
+ print("Target positions = %s" % str(dx))
+
+ all_rel_pos_bias = []
+
+ for i in range(nH1):
+ z = (
+ relative_position_bias_table_pretrained[:, i]
+ .view(src_size, src_size)
+ .float()
+ .numpy()
+ )
+ f_cubic = interpolate.interp2d(x, y, z, kind="cubic")
+ all_rel_pos_bias.append(
+ torch.Tensor(f_cubic(dx, dy))
+ .contiguous()
+ .view(-1, 1)
+ .to(relative_position_bias_table_pretrained.device)
+ )
+
+ new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1)
+ checkpoint_model[key] = new_rel_pos_bias
+
+ # delete relative_position_index since we always re-init it
+ relative_position_index_keys = [
+ k for k in checkpoint_model.keys() if "relative_position_index" in k
+ ]
+ for k in relative_position_index_keys:
+ del checkpoint_model[k]
+
+ # delete relative_coords_table since we always re-init it
+ relative_coords_table_keys = [
+ k for k in checkpoint_model.keys() if "relative_coords_table" in k
+ ]
+ for k in relative_coords_table_keys:
+ del checkpoint_model[k]
+
+ # # re-map keys due to name change
+ rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k]
+ for k in rpe_mlp_keys:
+ checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k)
+
+ # delete attn_mask since we always re-init it
+ attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k]
+ for k in attn_mask_keys:
+ del checkpoint_model[k]
+
+ encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")]
+ for k in encoder_keys:
+ checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k)
+
+ return checkpoint_model
+
+
+def add_padding_metas(out, image_metas):
+ device = out.device
+ # left, right, top, bottom
+ paddings = [img_meta.get("paddings", [0] * 4) for img_meta in image_metas]
+ paddings = torch.stack(paddings).to(device)
+ outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)]
+ return torch.stack(outs)
+
+
+# left, right, top, bottom
+def remove_padding(out, paddings):
+ H, W = out.shape[-2:]
+ outs = [
+ o[..., padding[2] : H - padding[3], padding[0] : W - padding[1]]
+ for padding, o in zip(paddings, out)
+ ]
+ return torch.stack(outs)
+
+
+def remove_padding_metas(out, image_metas):
+ B, C, H, W = out.shape
+ device = out.device
+ # left, right, top, bottom
+ paddings = [
+ torch.tensor(img_meta.get("paddings", [0] * 4)) for img_meta in image_metas
+ ]
+ return remove_padding(out, paddings)
+
+
+def ssi_helper(tensor1, tensor2):
+ stability_mat = 1e-4 * torch.eye(2, device=tensor1.device)
+ tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1)
+ scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ (
+ tensor2_one.T @ tensor1.unsqueeze(1)
+ )
+ scale, shift = scale_shift.squeeze().chunk(2, dim=0)
+ return scale, shift
+
+
+def calculate_mean_values(names, values):
+ # Create a defaultdict to store sum and count for each name
+ name_values = {name: {} for name in names}
+
+ # Iterate through the lists and accumulate values for each name
+ for name, value in zip(names, values):
+ name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value
+ name_values[name]["count"] = name_values[name].get("count", 0.0) + 1
+
+ # Calculate mean values and create the output dictionary
+ output_dict = {
+ name: name_values[name]["sum"] / name_values[name]["count"]
+ for name in name_values
+ }
+
+ return output_dict
+
+
+def remove_leading_dim(infos):
+ if isinstance(infos, dict):
+ return {k: remove_leading_dim(v) for k, v in infos.items()}
+ elif isinstance(infos, torch.Tensor):
+ return infos.squeeze(0)
+ else:
+ return infos
+
+
+def recursive_index(infos, index):
+ if isinstance(infos, dict):
+ return {k: recursive_index(v, index) for k, v in infos.items()}
+ elif isinstance(infos, torch.Tensor):
+ return infos[index]
+ else:
+ return infos
+
+
+def to_cpu(infos):
+ if isinstance(infos, dict):
+ return {k: to_cpu(v) for k, v in infos.items()}
+ elif isinstance(infos, torch.Tensor):
+ return infos.detach()
+ else:
+ return infos
+
+
+def masked_mean(
+ data: torch.Tensor,
+ mask: torch.Tensor | None = None,
+ dim: list[int] | None = None,
+ keepdim: bool = False,
+) -> torch.Tensor:
+ dim = dim if dim is not None else list(range(data.dim()))
+ if mask is None:
+ return data.mean(dim=dim, keepdim=keepdim)
+ mask = mask.float()
+ mask_sum = torch.sum(mask, dim=dim, keepdim=True)
+ mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp(
+ mask_sum, min=1.0
+ )
+ return mask_mean.squeeze(dim) if not keepdim else mask_mean
+
+
+class ProfileMethod:
+ def __init__(self, model, func_name, track_statistics=True, verbose=False):
+ self.model = model
+ self.func_name = func_name
+ self.verbose = verbose
+ self.track_statistics = track_statistics
+ self.timings = []
+
+ def __enter__(self):
+ # Start timing
+ if self.verbose:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ self.start_time = time()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.verbose:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ self.end_time = time()
+
+ elapsed_time = self.end_time - self.start_time
+
+ self.timings.append(elapsed_time)
+ if self.track_statistics and len(self.timings) > 25:
+
+ # Compute statistics if tracking
+ timings_array = np.array(self.timings)
+ mean_time = np.mean(timings_array)
+ std_time = np.std(timings_array)
+ quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100])
+ print(
+ f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds"
+ )
+ print(f"Mean Time: {mean_time:.4f} seconds")
+ print(f"Std Time: {std_time:.4f} seconds")
+ print(
+ f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}"
+ )
+
+ else:
+ print(
+ f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds"
+ )
+
+
+def profile_method(track_statistics=True, verbose=False):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with ProfileMethod(self, func.__name__, track_statistics, verbose):
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+class ProfileFunction:
+ def __init__(self, func_name, track_statistics=True, verbose=False):
+ self.func_name = func_name
+ self.verbose = verbose
+ self.track_statistics = track_statistics
+ self.timings = []
+
+ def __enter__(self):
+ # Start timing
+ if self.verbose:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+ self.start_time = time()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ if self.verbose:
+ if torch.cuda.is_available():
+ torch.cuda.synchronize()
+
+ self.end_time = time()
+
+ elapsed_time = self.end_time - self.start_time
+
+ self.timings.append(elapsed_time)
+ if self.track_statistics and len(self.timings) > 25:
+
+ # Compute statistics if tracking
+ timings_array = np.array(self.timings)
+ mean_time = np.mean(timings_array)
+ std_time = np.std(timings_array)
+ quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100])
+ print(f"{self.func_name} took {elapsed_time:.4f} seconds")
+ print(f"Mean Time: {mean_time:.4f} seconds")
+ print(f"Std Time: {std_time:.4f} seconds")
+ print(
+ f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}"
+ )
+
+ else:
+ print(f"{self.func_name} took {elapsed_time:.4f} seconds")
+
+
+def profile_function(track_statistics=True, verbose=False):
+ def decorator(func):
+ @wraps(func)
+ def wrapper(self, *args, **kwargs):
+ with ProfileFunction(func.__name__, track_statistics, verbose):
+ return func(self, *args, **kwargs)
+
+ return wrapper
+
+ return decorator
+
+
+def recursive_apply(inputs, func):
+ if isinstance(inputs, list):
+ return [recursive_apply(camera, func) for camera in inputs]
+ else:
+ return func(inputs)
+
+
+def squeeze_list(nested_list, dim, current_dim=0):
+ # If the current dimension is in the list of indices to squeeze
+ if isinstance(nested_list, list) and len(nested_list) == 1 and current_dim == dim:
+ return squeeze_list(nested_list[0], dim, current_dim + 1)
+ elif isinstance(nested_list, list):
+ return [squeeze_list(item, dim, current_dim + 1) for item in nested_list]
+ else:
+ return nested_list
+
+
+def match_gt(tensor1, tensor2, padding1, padding2, mode: str = "bilinear"):
+ """
+ Transform each item in tensor1 batch to match tensor2's dimensions and padding.
+
+ Args:
+ tensor1 (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width).
+ tensor2 (torch.Tensor): The target tensor to match, with shape (batch_size, channels, height, width).
+ padding1 (tuple): Padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom).
+ padding2 (tuple): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom).
+
+ Returns:
+ torch.Tensor: The batch of transformed tensors matching tensor2's size and padding.
+ """
+ # Get batch size
+ batch_size = len(tensor1)
+ src_dtype = tensor1[0].dtype
+ tgt_dtype = tensor2[0].dtype
+
+ # List to store transformed tensors
+ transformed_tensors = []
+
+ for i in range(batch_size):
+ item1 = tensor1[i]
+ item2 = tensor2[i]
+
+ h1, w1 = item1.shape[1], item1.shape[2]
+ pad1_l, pad1_r, pad1_t, pad1_b = (
+ padding1[i] if padding1 is not None else (0, 0, 0, 0)
+ )
+ pad2_l, pad2_r, pad2_t, pad2_b = (
+ padding2[i] if padding2 is not None else (0, 0, 0, 0)
+ )
+ item1_unpadded = item1[:, pad1_t : h1 - pad1_b, pad1_l : w1 - pad1_r]
+
+ h2, w2 = (
+ item2.shape[1] - pad2_t - pad2_b,
+ item2.shape[2] - pad2_l - pad2_r,
+ )
+
+ item1_resized = F.interpolate(
+ item1_unpadded.unsqueeze(0).to(tgt_dtype), size=(h2, w2), mode=mode
+ )
+ item1_padded = F.pad(item1_resized, (pad2_l, pad2_r, pad2_t, pad2_b))
+ transformed_tensors.append(item1_padded)
+
+ transformed_batch = torch.cat(transformed_tensors)
+ return transformed_batch.to(src_dtype)
diff --git a/unik3d/utils/pose.py b/unik3d/utils/pose.py
new file mode 100644
index 0000000000000000000000000000000000000000..f878bb257bcde1f98d2ce611ed85eeac1a8fcb17
--- /dev/null
+++ b/unik3d/utils/pose.py
@@ -0,0 +1,225 @@
+import torch
+from torch.nn import functional as F
+
+
+def quaternion_to_R(quaternions):
+ """
+ Convert rotations given as quaternions to rotation matrices.
+
+ Args:
+ quaternions: quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Rotation matrices as tensor of shape (..., 3, 3).
+ """
+ r, i, j, k = torch.unbind(quaternions, -1)
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
+
+ o = torch.stack(
+ (
+ 1 - two_s * (j * j + k * k),
+ two_s * (i * j - k * r),
+ two_s * (i * k + j * r),
+ two_s * (i * j + k * r),
+ 1 - two_s * (i * i + k * k),
+ two_s * (j * k - i * r),
+ two_s * (i * k - j * r),
+ two_s * (j * k + i * r),
+ 1 - two_s * (i * i + j * j),
+ ),
+ -1,
+ )
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
+
+
+def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
+ """
+ Convert a unit quaternion to a standard form: one in which the real
+ part is non negative.
+
+ Args:
+ quaternions: Quaternions with real part first,
+ as tensor of shape (..., 4).
+
+ Returns:
+ Standardized quaternions as tensor of shape (..., 4).
+ """
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
+
+
+def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
+ """
+ Returns torch.sqrt(torch.max(0, x))
+ but with a zero subgradient where x is 0.
+ """
+ ret = torch.zeros_like(x)
+ positive_mask = x > 0
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
+ return ret
+
+
+def R_to_quaternion(matrix: torch.Tensor) -> torch.Tensor:
+ """
+ Convert rotations given as rotation matrices to quaternions.
+
+ Args:
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
+
+ Returns:
+ quaternions with real part first, as tensor of shape (..., 4).
+ """
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
+ raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
+
+ batch_dim = matrix.shape[:-2]
+ m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
+ matrix.reshape(batch_dim + (9,)), dim=-1
+ )
+
+ q_abs = _sqrt_positive_part(
+ torch.stack(
+ [
+ 1.0 + m00 + m11 + m22,
+ 1.0 + m00 - m11 - m22,
+ 1.0 - m00 + m11 - m22,
+ 1.0 - m00 - m11 + m22,
+ ],
+ dim=-1,
+ )
+ )
+ # we produce the desired quaternion multiplied by each of r, i, j, k
+ quat_by_rijk = torch.stack(
+ [
+ torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
+ torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
+ torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
+ torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
+ ],
+ dim=-2,
+ )
+
+ # We floor here at 0.1 but the exact level is not important; if q_abs is small,
+ # the candidate won't be picked.
+ flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
+ quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
+
+ # if not for numerical problems, quat_candidates[i] should be same (up to a sign),
+ # forall i; we pick the best-conditioned one (with the largest denominator)
+ out = quat_candidates[
+ F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :
+ ].reshape(batch_dim + (4,))
+ out = standardize_quaternion(out)
+ return out
+
+
+def Rt_to_pose(R, t):
+ assert R.shape[-2:] == (3, 3), "The last two dimensions of R must be 3x3"
+ assert t.shape[-2:] == (3, 1), "The last dimension of t must be 3"
+
+ # Create the pose matrix
+ pose = torch.cat([R, t], dim=-1)
+ pose = F.pad(pose, (0, 0, 0, 1), value=0)
+ pose[..., 3, 3] = 1
+
+ return pose
+
+
+def pose_to_Rt(pose):
+ assert pose.shape[-2:] == (4, 4), "The last two dimensions of pose must be 4x4"
+
+ # Extract the rotation matrix and translation vector
+ R = pose[..., :3, :3]
+ t = pose[..., :3, 3:]
+
+ return R, t
+
+
+def relative_pose(pose1, pose2):
+ # Compute world_to_cam for pose1
+ pose1_inv = invert_pose(pose1)
+
+ # Relative pose as cam_to_world_2 -> world_to_cam_1 => cam2_to_cam1
+ relative_pose = pose1_inv @ pose2
+
+ return relative_pose
+
+
+@torch.autocast(device_type="cuda", dtype=torch.float32)
+def invert_pose(pose):
+ R, t = pose_to_Rt(pose)
+ R_inv = R.transpose(-2, -1)
+ t_inv = -torch.matmul(R_inv, t)
+ pose_inv = Rt_to_pose(R_inv, t_inv)
+ return pose_inv
+
+
+def apply_pose_transformation(point_cloud, pose):
+ reshape = point_cloud.ndim > 3
+ shapes = point_cloud.shape
+ # Extract rotation and translation from pose
+ R, t = pose_to_Rt(pose)
+
+ # Apply the pose transformation
+ if reshape:
+ point_cloud = point_cloud.reshape(shapes[0], -1, shapes[-1])
+ transformed_points = torch.matmul(point_cloud, R.transpose(-2, -1)) + t.transpose(
+ -2, -1
+ )
+ if reshape:
+ transformed_points = transformed_points.reshape(shapes)
+ return transformed_points
+
+
+def euler2mat(roll, pitch, yaw) -> torch.Tensor:
+ """
+ Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix.
+
+ Args:
+ euler_angles (torch.Tensor): Tensor of shape (N, 3) representing roll, pitch, yaw in radians.
+ - roll: rotation around z-axis
+ - pitch: rotation around x-axis
+ - yaw: rotation around y-axis
+ Returns:
+ torch.Tensor: Tensor of shape (N, 3, 3) representing the rotation matrices.
+ """
+
+ cos_r, sin_r = torch.cos(roll), torch.sin(roll) # Roll
+ cos_p, sin_p = torch.cos(pitch), torch.sin(pitch) # Pitch
+ cos_y, sin_y = torch.cos(yaw), torch.sin(yaw) # Yaw
+
+ # Rotation matrices
+ R_z = torch.zeros((roll.shape[0], 3, 3), device=roll.device)
+ R_y = torch.zeros_like(R_z)
+ R_x = torch.zeros_like(R_z)
+
+ # Z-axis (roll)
+ R_z[:, 0, 0], R_z[:, 0, 1], R_z[:, 1, 0], R_z[:, 1, 1], R_z[:, 2, 2] = (
+ cos_y,
+ -sin_y,
+ sin_y,
+ cos_y,
+ 1.0,
+ )
+
+ # Y-axis (yaw)
+ R_y[:, 0, 0], R_y[:, 0, 2], R_y[:, 2, 0], R_y[:, 2, 2], R_y[:, 1, 1] = (
+ cos_p,
+ sin_p,
+ -sin_p,
+ cos_p,
+ 1.0,
+ )
+
+ # X-axis (pitch)
+ R_x[:, 1, 1], R_x[:, 1, 2], R_x[:, 2, 1], R_x[:, 2, 2], R_x[:, 0, 0] = (
+ cos_r,
+ -sin_r,
+ sin_r,
+ cos_r,
+ 1.0,
+ )
+
+ # Combine rotations: R = R_z * R_y * R_x
+ rotation_matrix = torch.matmul(torch.matmul(R_z, R_y), R_x)
+ return rotation_matrix
diff --git a/unik3d/utils/positional_embedding.py b/unik3d/utils/positional_embedding.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e3ac5727f8e12c2f155de5da66352326ce9bdbd
--- /dev/null
+++ b/unik3d/utils/positional_embedding.py
@@ -0,0 +1,269 @@
+from math import log, pi
+from typing import Optional
+
+import torch
+import torch.nn as nn
+from einops import rearrange, repeat
+
+
+class PositionEmbeddingSine(nn.Module):
+
+ def __init__(
+ self, num_pos_feats=64, temperature=10000, normalize=False, scale=None
+ ):
+ super().__init__()
+ self.num_pos_feats = num_pos_feats
+ self.temperature = temperature
+ self.normalize = normalize
+ if scale is not None and normalize is False:
+ raise ValueError("normalize should be True if scale is passed")
+ if scale is None:
+ scale = 2 * pi
+ self.scale = scale
+
+ def forward(
+ self, x: torch.Tensor, mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ if mask is None:
+ mask = torch.zeros(
+ (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool
+ )
+ not_mask = ~mask
+ y_embed = not_mask.cumsum(1, dtype=torch.float32)
+ x_embed = not_mask.cumsum(2, dtype=torch.float32)
+ if self.normalize:
+ eps = 1e-6
+ y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale
+ x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale
+
+ dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device)
+ dim_t = self.temperature ** (
+ 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats
+ )
+
+ pos_x = x_embed[:, :, :, None] / dim_t
+ pos_y = y_embed[:, :, :, None] / dim_t
+ pos_x = torch.stack(
+ (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos_y = torch.stack(
+ (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4
+ ).flatten(3)
+ pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2)
+ return pos
+
+ def __repr__(self, _repr_indent=4):
+ head = "Positional encoding " + self.__class__.__name__
+ body = [
+ "num_pos_feats: {}".format(self.num_pos_feats),
+ "temperature: {}".format(self.temperature),
+ "normalize: {}".format(self.normalize),
+ "scale: {}".format(self.scale),
+ ]
+ # _repr_indent = 4
+ lines = [head] + [" " * _repr_indent + line for line in body]
+ return "\n".join(lines)
+
+
+class LearnedSinusoidalPosEmb(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ assert (dim % 2) == 0
+ half_dim = dim // 2
+ self.weights = nn.Parameter(torch.randn(half_dim))
+
+ def forward(self, x):
+ x = rearrange(x, "b -> b 1")
+ freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi
+ fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
+ fouriered = torch.cat((x, fouriered), dim=-1)
+ return fouriered
+
+
+def broadcat(tensors, dim=-1):
+ num_tensors = len(tensors)
+ shape_lens = set(list(map(lambda t: len(t.shape), tensors)))
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
+ shape_len = list(shape_lens)[0]
+ dim = (dim + shape_len) if dim < 0 else dim
+ dims = list(zip(*map(lambda t: list(t.shape), tensors)))
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
+ assert all(
+ [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)]
+ ), "invalid dimensions for broadcastable concatentation"
+ max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims))
+ expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims))
+ expanded_dims.insert(dim, (dim, dims[dim]))
+ expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims)))
+ tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes)))
+ return torch.cat(tensors, dim=dim)
+
+
+def rotate_half(x):
+ x = rearrange(x, "... (d r) -> ... d r", r=2)
+ x1, x2 = x.unbind(dim=-1)
+ x = torch.stack((-x2, x1), dim=-1)
+ return rearrange(x, "... d r -> ... (d r)")
+
+
+class VisionRotaryEmbedding(nn.Module):
+ def __init__(
+ self,
+ dim,
+ pt_seq_len,
+ ft_seq_len=None,
+ custom_freqs=None,
+ freqs_for="lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ ):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ )
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs_h = torch.einsum("..., f -> ... f", t, freqs)
+ freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2)
+
+ freqs_w = torch.einsum("..., f -> ... f", t, freqs)
+ freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2)
+
+ freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1)
+
+ self.register_buffer("freqs_cos", freqs.cos())
+ self.register_buffer("freqs_sin", freqs.sin())
+
+ print("======== shape of rope freq", self.freqs_cos.shape, "========")
+
+ def forward(self, t, start_index=0):
+ rot_dim = self.freqs_cos.shape[-1]
+ end_index = start_index + rot_dim
+ assert (
+ rot_dim <= t.shape[-1]
+ ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}"
+ t_left, t, t_right = (
+ t[..., :start_index],
+ t[..., start_index:end_index],
+ t[..., end_index:],
+ )
+ t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin)
+ return torch.cat((t_left, t, t_right), dim=-1)
+
+
+class VisionRotaryEmbeddingFast(nn.Module):
+ def __init__(
+ self,
+ dim,
+ pt_seq_len,
+ ft_seq_len=None,
+ custom_freqs=None,
+ freqs_for="lang",
+ theta=10000,
+ max_freq=10,
+ num_freqs=1,
+ ):
+ super().__init__()
+ if custom_freqs:
+ freqs = custom_freqs
+ elif freqs_for == "lang":
+ freqs = 1.0 / (
+ theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)
+ )
+ elif freqs_for == "pixel":
+ freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi
+ elif freqs_for == "constant":
+ freqs = torch.ones(num_freqs).float()
+ else:
+ raise ValueError(f"unknown modality {freqs_for}")
+
+ if ft_seq_len is None:
+ ft_seq_len = pt_seq_len
+ t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len
+
+ freqs = torch.einsum("..., f -> ... f", t, freqs)
+ freqs = repeat(freqs, "... n -> ... (n r)", r=2)
+ freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1)
+
+ freqs_cos = freqs.cos().view(-1, freqs.shape[-1])
+ freqs_sin = freqs.sin().view(-1, freqs.shape[-1])
+
+ self.register_buffer("freqs_cos", freqs_cos)
+ self.register_buffer("freqs_sin", freqs_sin)
+
+ def forward(self, t):
+ return t * self.freqs_cos + rotate_half(t) * self.freqs_sin
+
+
+from math import log2
+
+
+def generate_fourier_features(
+ x: torch.Tensor,
+ dim: int = 512,
+ max_freq: int = 64,
+ use_cos: bool = False,
+ use_log: bool = False,
+ cat_orig: bool = False,
+):
+ x_orig = x
+ device, dtype, input_dim = x.device, x.dtype, x.shape[-1]
+ num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim
+
+ if use_log:
+ scales = 2.0 ** torch.linspace(
+ 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype
+ )
+ else:
+ scales = torch.linspace(
+ 1.0, max_freq / 2, num_bands, device=device, dtype=dtype
+ )
+
+ x = x.unsqueeze(-1)
+ scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)]
+
+ x = x * scales * pi
+ x = torch.cat(
+ (
+ [x.sin(), x.cos()]
+ if use_cos
+ else [
+ x.sin(),
+ ]
+ ),
+ dim=-1,
+ )
+ x = x.flatten(-2)
+ if cat_orig:
+ return torch.cat((x, x_orig), dim=-1)
+ return x
+
+
+# from PIL import Image
+# from unik3d.utils import image_grid, colorize
+# if __name__ == "__main__":
+# H, W = 512, 512
+# resolution = 128
+# mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W))
+# mesh = torch.stack(mesh, dim=0).unsqueeze(0)
+# mesh = mesh.view(1, 2, -1).permute(0, 2, 1)
+
+# features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True)
+# channels = features.shape[-1]
+# print(features.shape)
+
+# features = features[0].view(H, W, channels).permute(2, 0, 1).numpy()
+# Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png")
diff --git a/unik3d/utils/sht.py b/unik3d/utils/sht.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba8d0cf5708ba4b87397158ebb5181e8911c1f62
--- /dev/null
+++ b/unik3d/utils/sht.py
@@ -0,0 +1,1639 @@
+"""Real spherical harmonics in Cartesian form for PyTorch.
+
+This is an autogenerated file. See
+https://github.com/cheind/torch-spherical-harmonics
+for more information.
+"""
+
+import torch
+
+
+def rsh_cart_0(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 0.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,1) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_1(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 1.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,4) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_2(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 2.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,9) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ ],
+ -1,
+ )
+
+
+@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float32)
+def rsh_cart_3(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 3.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,16) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_4(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 4.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,25) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_5(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 5.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,36) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ ],
+ -1,
+ )
+
+
+def rsh_cart_6(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 6.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,49) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ ],
+ -1,
+ )
+
+
+def rsh_cart_7(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 7.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,64) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ z4 = z2**2
+
+ return torch.stack(
+ [
+ xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ -0.707162732524596
+ * y
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 9.98394571852353e-5
+ * y
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00239614697244565
+ * xy
+ * (x2 - y2)
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+ 0.00397356022507413
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.0561946276120613
+ * xy
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.206472245902897
+ * y
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
+ - 1.68564615005635
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 2.02901851395672
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.499450711127808 * z,
+ 0.206472245902897
+ * x
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 0.0280973138060306
+ * (x2 - y2)
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.00397356022507413
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.000599036743111412
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ * (-6.0 * x2 * y2 + x4 + y4),
+ 9.98394571852353e-5
+ * x
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -0.707162732524596
+ * x
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ ],
+ -1,
+ )
+
+
+# @torch.jit.script
+def rsh_cart_8(xyz: torch.Tensor):
+ """Computes all real spherical harmonics up to degree 8.
+
+ This is an autogenerated method. See
+ https://github.com/cheind/torch-spherical-harmonics
+ for more information.
+
+ Params:
+ xyz: (N,...,3) tensor of points on the unit sphere
+
+ Returns:
+ rsh: (N,...,81) real spherical harmonics
+ projections of input. Ynm is found at index
+ `n*(n+1) + m`, with `0 <= n <= degree` and
+ `-n <= m <= n`.
+ """
+ x = xyz[..., 0]
+ y = xyz[..., 1]
+ z = xyz[..., 2]
+
+ x2 = x**2
+ y2 = y**2
+ z2 = z**2
+ xy = x * y
+ xz = x * z
+ yz = y * z
+ x4 = x2**2
+ y4 = y2**2
+ # z4 = z2**2
+ return torch.stack(
+ [
+ 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]),
+ -0.48860251190292 * y,
+ 0.48860251190292 * z,
+ -0.48860251190292 * x,
+ 1.09254843059208 * xy,
+ -1.09254843059208 * yz,
+ 0.94617469575756 * z2 - 0.31539156525252,
+ -1.09254843059208 * xz,
+ 0.54627421529604 * x2 - 0.54627421529604 * y2,
+ -0.590043589926644 * y * (3.0 * x2 - y2),
+ 2.89061144264055 * xy * z,
+ 0.304697199642977 * y * (1.5 - 7.5 * z2),
+ 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z,
+ 0.304697199642977 * x * (1.5 - 7.5 * z2),
+ 1.44530572132028 * z * (x2 - y2),
+ -0.590043589926644 * x * (x2 - 3.0 * y2),
+ 2.5033429417967 * xy * (x2 - y2),
+ -1.77013076977993 * yz * (3.0 * x2 - y2),
+ 0.126156626101008 * xy * (52.5 * z2 - 7.5),
+ 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 1.48099765681286
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 0.952069922236839 * z2
+ + 0.317356640745613,
+ 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z),
+ 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5),
+ -1.77013076977993 * xz * (x2 - 3.0 * y2),
+ -3.75501441269506 * x2 * y2
+ + 0.625835735449176 * x4
+ + 0.625835735449176 * y4,
+ -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 8.30264925952416 * xy * z * (x2 - y2),
+ 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2),
+ 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.241571547304372
+ * y
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ -1.24747010616985 * z * (1.5 * z2 - 0.5)
+ + 1.6840846433293
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.498988042467941 * z,
+ 0.241571547304372
+ * x
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ ),
+ 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z),
+ 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2),
+ 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4),
+ -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 4.09910463115149 * x**4 * xy
+ - 13.6636821038383 * xy**3
+ + 4.09910463115149 * xy * y**4,
+ -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5),
+ 0.00584892228263444
+ * y
+ * (3.0 * x2 - y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0701870673916132
+ * xy
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.221950995245231
+ * y
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ -1.48328138624466
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.86469659985043
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.953538034014426 * z2
+ - 0.317846011338142,
+ 0.221950995245231
+ * x
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ ),
+ 0.0350935336958066
+ * (x2 - y2)
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ ),
+ 0.00584892228263444
+ * x
+ * (x2 - 3.0 * y2)
+ * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z),
+ 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4),
+ -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 0.683184105191914 * x2**3
+ + 10.2477615778787 * x2 * y4
+ - 10.2477615778787 * x4 * y2
+ - 0.683184105191914 * y2**3,
+ -0.707162732524596
+ * y
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 9.98394571852353e-5
+ * y
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00239614697244565
+ * xy
+ * (x2 - y2)
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z),
+ 0.00397356022507413
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.0561946276120613
+ * xy
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.206472245902897
+ * y
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 1.24862677781952 * z * (1.5 * z2 - 0.5)
+ - 1.68564615005635
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 2.02901851395672
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.499450711127808 * z,
+ 0.206472245902897
+ * x
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ ),
+ 0.0280973138060306
+ * (x2 - y2)
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ ),
+ 0.00397356022507413
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ ),
+ 0.000599036743111412
+ * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ * (-6.0 * x2 * y2 + x4 + y4),
+ 9.98394571852353e-5
+ * x
+ * (5197.5 - 67567.5 * z2)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -0.707162732524596
+ * x
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3),
+ -2.91570664069932
+ * yz
+ * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3),
+ 7.87853281621404e-6
+ * (1013512.5 * z2 - 67567.5)
+ * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4),
+ 5.10587282657803e-5
+ * y
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+ * (-10.0 * x2 * y2 + 5.0 * x4 + y4),
+ 0.00147275890257803
+ * xy
+ * (x2 - y2)
+ * (
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ - 14293.125 * z2
+ + 1299.375
+ ),
+ 0.0028519853513317
+ * y
+ * (3.0 * x2 - y2)
+ * (
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
+ + 3.0
+ * z
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ )
+ - 560.0 * z
+ ),
+ 0.0463392770473559
+ * xy
+ * (
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ + 2.5
+ * z
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ )
+ + 137.8125 * z2
+ - 19.6875
+ ),
+ 0.193851103820053
+ * y
+ * (
+ 3.2 * z * (1.5 - 7.5 * z2)
+ - 2.51428571428571
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ + 2.14285714285714
+ * z
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ )
+ + 5.48571428571429 * z
+ ),
+ 1.48417251362228
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.86581687426801
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 2.1808249179756
+ * z
+ * (
+ 1.14285714285714 * z * (1.5 * z2 - 0.5)
+ - 1.54285714285714
+ * z
+ * (
+ 1.75
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ - 1.125 * z2
+ + 0.375
+ )
+ + 1.85714285714286
+ * z
+ * (
+ -1.45833333333333
+ * z
+ * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z)
+ + 1.83333333333333
+ * z
+ * (
+ -1.33333333333333 * z * (1.5 * z2 - 0.5)
+ + 1.8
+ * z
+ * (
+ 1.75
+ * z
+ * (
+ 1.66666666666667 * z * (1.5 * z2 - 0.5)
+ - 0.666666666666667 * z
+ )
+ - 1.125 * z2
+ + 0.375
+ )
+ + 0.533333333333333 * z
+ )
+ + 0.9375 * z2
+ - 0.3125
+ )
+ - 0.457142857142857 * z
+ )
+ - 0.954110901614325 * z2
+ + 0.318036967204775,
+ 0.193851103820053
+ * x
+ * (
+ 3.2 * z * (1.5 - 7.5 * z2)
+ - 2.51428571428571
+ * z
+ * (
+ 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ + 2.14285714285714
+ * z
+ * (
+ -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 2.16666666666667
+ * z
+ * (
+ -2.8 * z * (1.5 - 7.5 * z2)
+ + 2.2
+ * z
+ * (
+ 2.25
+ * z
+ * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z)
+ + 9.375 * z2
+ - 1.875
+ )
+ - 4.8 * z
+ )
+ - 10.9375 * z2
+ + 2.1875
+ )
+ + 5.48571428571429 * z
+ ),
+ 0.0231696385236779
+ * (x2 - y2)
+ * (
+ -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ + 2.5
+ * z
+ * (
+ -4.8 * z * (52.5 * z2 - 7.5)
+ + 2.6
+ * z
+ * (
+ 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z)
+ - 91.875 * z2
+ + 13.125
+ )
+ + 48.0 * z
+ )
+ + 137.8125 * z2
+ - 19.6875
+ ),
+ 0.0028519853513317
+ * x
+ * (x2 - 3.0 * y2)
+ * (
+ -7.33333333333333 * z * (52.5 - 472.5 * z2)
+ + 3.0
+ * z
+ * (
+ 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z)
+ + 1063.125 * z2
+ - 118.125
+ )
+ - 560.0 * z
+ ),
+ 0.000368189725644507
+ * (-6.0 * x2 * y2 + x4 + y4)
+ * (
+ 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z)
+ - 14293.125 * z2
+ + 1299.375
+ ),
+ 5.10587282657803e-5
+ * x
+ * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z)
+ * (-10.0 * x2 * y2 + x4 + 5.0 * y4),
+ 7.87853281621404e-6
+ * (1013512.5 * z2 - 67567.5)
+ * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3),
+ -2.91570664069932
+ * xz
+ * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3),
+ -20.4099464848952 * x2**3 * y2
+ - 20.4099464848952 * x2 * y2**3
+ + 0.72892666017483 * x4**2
+ + 51.0248662122381 * x4 * y4
+ + 0.72892666017483 * y4**2,
+ ],
+ -1,
+ )
+
+
+__all__ = [
+ "rsh_cart_0",
+ "rsh_cart_1",
+ "rsh_cart_2",
+ "rsh_cart_3",
+ "rsh_cart_4",
+ "rsh_cart_5",
+ "rsh_cart_6",
+ "rsh_cart_7",
+ "rsh_cart_8",
+]
+
+
+from typing import Optional
+
+import torch
+
+
+class SphHarm(torch.nn.Module):
+ def __init__(self, m, n, dtype=torch.float32) -> None:
+ super().__init__()
+ self.dtype = dtype
+ m = torch.tensor(list(range(-m + 1, m)))
+ n = torch.tensor(list(range(n)))
+ self.is_normalized = False
+ vals = torch.cartesian_prod(m, n).T
+ vals = vals[:, vals[0] <= vals[1]]
+ m, n = vals.unbind(0)
+
+ self.register_buffer("m", tensor=m)
+ self.register_buffer("n", tensor=n)
+ self.register_buffer("l_max", tensor=torch.max(self.n))
+
+ f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre()
+ self.register_buffer("f_a", tensor=f_a)
+ self.register_buffer("f_b", tensor=f_b)
+ self.register_buffer("d0_mask_3d", tensor=d0_mask_3d)
+ self.register_buffer("d1_mask_3d", tensor=d1_mask_3d)
+ self.register_buffer("initial_value", tensor=initial_value)
+
+ @property
+ def device(self):
+ return next(self.buffers()).device
+
+ def forward(self, points: torch.Tensor) -> torch.Tensor:
+ """Computes the spherical harmonics."""
+ # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi)
+ B, N, D = points.shape
+ dtype = points.dtype
+ theta, phi = points.view(-1, D).to(self.dtype).unbind(-1)
+ cos_colatitude = torch.cos(phi)
+ legendre = self._gen_associated_legendre(cos_colatitude)
+ vals = torch.stack([self.m.abs(), self.n], dim=0)
+ vals = torch.cat(
+ [
+ vals.repeat(1, theta.shape[0]),
+ torch.arange(theta.shape[0], device=theta.device)
+ .unsqueeze(0)
+ .repeat_interleave(vals.shape[1], dim=1),
+ ],
+ dim=0,
+ )
+ legendre_vals = legendre[vals[0], vals[1], vals[2]]
+ legendre_vals = legendre_vals.reshape(-1, theta.shape[0])
+ angle = torch.outer(self.m.abs(), theta)
+ vandermonde = torch.complex(torch.cos(angle), torch.sin(angle))
+ harmonics = torch.complex(
+ legendre_vals * torch.real(vandermonde),
+ legendre_vals * torch.imag(vandermonde),
+ )
+
+ # Negative order.
+ m = self.m.unsqueeze(-1)
+ harmonics = torch.where(
+ m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics
+ )
+ harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype)
+ return harmonics
+
+ def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]:
+ """Generates mask for recurrence relation on the remaining entries.
+
+ The remaining entries are with respect to the diagonal and offdiagonal
+ entries.
+
+ Args:
+ l_max: see `gen_normalized_legendre`.
+ Returns:
+ torch.Tensors representing the mask used by the recurrence relations.
+ """
+
+ # Computes all coefficients.
+ m_mat, l_mat = torch.meshgrid(
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
+ torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype),
+ indexing="ij",
+ )
+ if self.is_normalized:
+ c0 = l_mat * l_mat
+ c1 = m_mat * m_mat
+ c2 = 2.0 * l_mat
+ c3 = (l_mat - 1.0) * (l_mat - 1.0)
+ d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1))
+ d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1)))
+ else:
+ d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat)
+ d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat)
+
+ d0_mask_indices = torch.triu_indices(self.l_max + 1, 1)
+ d1_mask_indices = torch.triu_indices(self.l_max + 1, 2)
+
+ d_zeros = torch.zeros(
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
+ )
+ d_zeros[d0_mask_indices] = d0[d0_mask_indices]
+ d0_mask = d_zeros
+
+ d_zeros = torch.zeros(
+ (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device
+ )
+ d_zeros[d1_mask_indices] = d1[d1_mask_indices]
+ d1_mask = d_zeros
+
+ # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere.
+ i = torch.arange(self.l_max + 1, device=self.device)[:, None, None]
+ j = torch.arange(self.l_max + 1, device=self.device)[None, :, None]
+ k = torch.arange(self.l_max + 1, device=self.device)[None, None, :]
+ mask = (i + j - k == 0).to(self.dtype)
+ d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask)
+ d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask)
+ return (d0_mask_3d, d1_mask_3d)
+
+ def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
+ coeff_0 = self.d0_mask_3d[i]
+ coeff_1 = self.d1_mask_3d[i]
+ h = torch.einsum(
+ "ij,ijk->ijk",
+ coeff_0,
+ torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x),
+ ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1))
+ p_val = p_val + h
+ return p_val
+
+ def _init_legendre(self):
+ a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device)
+ b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device)
+ if self.is_normalized:
+ # The initial value p(0,0).
+ initial_value: torch.Tensor = torch.tensor(
+ 0.5 / (torch.pi**0.5), device=self.device
+ )
+ f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0)
+ f_b = torch.sqrt(2.0 * b_idx + 3.0)
+ else:
+ # The initial value p(0,0).
+ initial_value = torch.tensor(1.0, device=self.device)
+ f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0)
+ f_b = 2.0 * b_idx + 1.0
+
+ d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask()
+ return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d
+
+ def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor:
+ r"""Computes associated Legendre functions (ALFs) of the first kind.
+
+ The ALFs of the first kind are used in spherical harmonics. The spherical
+ harmonic of degree `l` and order `m` can be written as
+ `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the
+ normalization factor and θ and φ are the colatitude and longitude,
+ repectively. `N_l^m` is chosen in the way that the spherical harmonics form
+ a set of orthonormal basis function of L^2(S^2). For the computational
+ efficiency of spherical harmonics transform, the normalization factor is
+ used in the computation of the ALFs. In addition, normalizing `P_l^m`
+ avoids overflow/underflow and achieves better numerical stability. Three
+ recurrence relations are used in the computation.
+
+ Args:
+ l_max: The maximum degree of the associated Legendre function. Both the
+ degrees and orders are `[0, 1, 2, ..., l_max]`.
+ x: A vector of type `float32`, `float64` containing the sampled points in
+ spherical coordinates, at which the ALFs are computed; `x` is essentially
+ `cos(θ)`. For the numerical integration used by the spherical harmonics
+ transforms, `x` contains the quadrature points in the interval of
+ `[-1, 1]`. There are several approaches to provide the quadrature points:
+ Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev
+ method (`scipy.special.roots_chebyu`), and Driscoll & Healy
+ method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier
+ transforms and convolutions on the 2-sphere." Advances in applied
+ mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature
+ points are nearly equal-spaced along θ and provide exact discrete
+ orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose
+ operation, `W` is a diagonal matrix containing the quadrature weights,
+ and `I` is the identity matrix. The Gauss-Chebyshev points are equally
+ spaced, which only provide approximate discrete orthogonality. The
+ Driscoll & Healy qudarture points are equally spaced and provide the
+ exact discrete orthogonality. The number of sampling points is required to
+ be twice as the number of frequency points (modes) in the Driscoll & Healy
+ approach, which enables FFT and achieves a fast spherical harmonics
+ transform.
+ is_normalized: True if the associated Legendre functions are normalized.
+ With normalization, `N_l^m` is applied such that the spherical harmonics
+ form a set of orthonormal basis functions of L^2(S^2).
+
+ Returns:
+ The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values
+ of the ALFs at `x`; the dimensions in the sequence of order, degree, and
+ evalution points.
+ """
+ p = torch.zeros(
+ (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device
+ )
+ p[0, 0] = self.initial_value
+
+ # Compute the diagonal entries p(l,l) with recurrence.
+ y = torch.cumprod(
+ torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0
+ )
+ p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y)
+ # torch.diag_indices(l_max + 1)
+ diag_indices = torch.stack(
+ [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0
+ )
+ p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag
+
+ diag_indices = torch.stack(
+ [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0
+ )
+
+ # Compute the off-diagonal entries with recurrence.
+ p_offdiag = torch.einsum(
+ "ij,ij->ij",
+ torch.einsum("i,j->ij", self.f_b, x),
+ p[(diag_indices[0], diag_indices[1])],
+ ) # p[torch.diag_indices(l_max)])
+ p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = (
+ p_offdiag
+ )
+
+ # Compute the remaining entries with recurrence.
+ if self.l_max > 1:
+ for i in range(2, self.l_max + 1):
+ p = self._recursive(i, p, x)
+ return p
diff --git a/unik3d/utils/validation.py b/unik3d/utils/validation.py
new file mode 100644
index 0000000000000000000000000000000000000000..972ff00f58a12571197f8d133ea3b17df0adb12d
--- /dev/null
+++ b/unik3d/utils/validation.py
@@ -0,0 +1,329 @@
+import json
+import os
+from collections import defaultdict
+from typing import Any, Dict, List, Optional
+
+import numpy as np
+import torch
+import torch.distributed as dist
+import torch.utils.data.distributed
+import wandb
+from PIL import Image
+from torch.nn import functional as F
+from torch.utils.data import DataLoader
+from tqdm import tqdm
+
+from unik3d.utils.distributed import barrier, get_world_size, is_main_process
+from unik3d.utils.misc import remove_leading_dim, remove_padding, ssi_helper
+from unik3d.utils.visualization import colorize, image_grid
+
+
+def stack_mixedshape_numpy(tensor_list, dim=0):
+ max_rows = max(tensor.shape[0] for tensor in tensor_list)
+ max_columns = max(tensor.shape[1] for tensor in tensor_list)
+
+ padded_tensors = []
+ for tensor in tensor_list:
+ rows, columns, *_ = tensor.shape
+ pad_rows = max_rows - rows
+ pad_columns = max_columns - columns
+
+ padded_tensor = np.pad(
+ tensor, ((0, pad_rows), (0, pad_columns), (0, 0)), mode="constant"
+ )
+ padded_tensors.append(padded_tensor)
+
+ return np.stack(padded_tensors, axis=dim)
+
+
+def original_image(batch):
+ paddings = [
+ torch.tensor(pads)
+ for img_meta in batch["img_metas"]
+ for pads in img_meta.get("paddings", [[0] * 4])
+ ]
+ paddings = torch.stack(paddings).to(batch["data"]["image"].device)[
+ ..., [0, 2, 1, 3]
+ ] # lrtb
+
+ T, _, H, W = batch["data"]["depth"].shape
+ batch["data"]["image"] = F.interpolate(
+ batch["data"]["image"],
+ (H + paddings[0][2] + paddings[0][3], W + paddings[0][1] + paddings[0][2]),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+ batch["data"]["image"] = remove_padding(
+ batch["data"]["image"], paddings.repeat(T, 1)
+ )
+ return batch
+
+
+def original_image_inv(batch, preds=None):
+ paddings = [
+ torch.tensor(pads)
+ for img_meta in batch["img_metas"]
+ for pads in img_meta.get("padding_size", [[0] * 4])
+ ]
+ T, _, H, W = batch["data"]["depth"].shape
+ batch["data"]["image"] = remove_padding(batch["data"]["image"], paddings * T)
+ batch["data"]["image"] = F.interpolate(
+ batch["data"]["image"],
+ (H, W),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ if preds is not None:
+ for key in ["depth"]:
+ if key in preds:
+ preds[key] = remove_padding(preds[key], paddings * T)
+ preds[key] = F.interpolate(
+ preds[key],
+ (H, W),
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+
+ return batch, preds
+
+
+def aggregate_metrics(metrics_all, exclude_fn=lambda name: False):
+ aggregate_name = "".join(
+ [name_ds[:3] for name_ds in metrics_all.keys() if not exclude_fn(name_ds)]
+ )
+ metrics_aggregate = defaultdict(list)
+ for name_ds, metrics in metrics_all.items():
+ if exclude_fn(name_ds):
+ continue
+ for metrics_name, metrics_value in metrics.items():
+ metrics_aggregate[metrics_name].append(metrics_value)
+ return {
+ **{aggregate_name: {k: sum(v) / len(v) for k, v in metrics_aggregate.items()}},
+ **metrics_all,
+ }
+
+
+GROUPS = {
+ "SFoV": ["KITTI", "NYUv2Depth", "DiodeIndoor", "ETH3D", "IBims"],
+ "SFoVDi": ["DiodeIndoor_F", "ETH3D_F", "IBims_F"],
+ "LFoV": ["ADT", "KITTI360", "ScanNetpp_F"],
+}
+
+
+def aggregate_metrics_camera(metrics_all):
+ available_groups = {
+ k: v for k, v in GROUPS.items() if any([name in metrics_all for name in v])
+ }
+ for group_name, group_datasets in available_groups.items():
+ metrics_aggregate = defaultdict(list)
+ for dataset_name in group_datasets:
+ if dataset_name not in metrics_all:
+ print(
+ f"Dataset {dataset_name} not used for aggregation of {group_name}"
+ )
+ continue
+ for metrics_name, metrics_value in metrics_all[dataset_name].items():
+ metrics_aggregate[metrics_name].append(metrics_value)
+ metrics_all[group_name] = {
+ k: sum(v) / len(v) for k, v in metrics_aggregate.items()
+ }
+ return metrics_all
+
+
+def log_metrics(metrics_all, step):
+ for name_ds, metrics in metrics_all.items():
+ for metrics_name, metrics_value in metrics.items():
+ try:
+ wandb.log(
+ {f"Metrics/{name_ds}/{metrics_name}": metrics_value}, step=step
+ )
+ except:
+ print(f"Metrics/{name_ds}/{metrics_name} {round(metrics_value, 4)}")
+
+
+def log_artifacts(artifacts_all, step, run_id):
+ for ds_name, artifacts in artifacts_all.items():
+ rgbs, gts = artifacts["rgbs"], artifacts["gts"]
+ logging_imgs = [
+ *rgbs,
+ *gts,
+ *[
+ x
+ for k, v in artifacts.items()
+ if ("rgbs" not in k and "gts" not in k)
+ for x in v
+ ],
+ ]
+ artifacts_grid = image_grid(logging_imgs, len(artifacts), len(rgbs))
+ try:
+ wandb.log({f"{ds_name}_test": [wandb.Image(artifacts_grid)]}, step=step)
+ except:
+ print(f"Error while saving artifacts at step {step}")
+
+
+def show(vals, dataset, ssi_depth=False):
+ output_artifacts, additionals = {}, {}
+ predictions, gts, errors, images = [], [], [], []
+ for v in vals:
+ image = v["image"][0].unsqueeze(0)
+ gt = v["depth"][0].unsqueeze(0)
+ prediction = v["depth_pred"][0].unsqueeze(0)
+ # Downsample for memory and viz
+ # if any([x in dataset.__class__.__name__ for x in ["DDAD", "Argoverse", "Waymo", "DrivingStereo"]]):
+ # gt = F.interpolate(gt, scale_factor=0.5, mode="nearest-exact")
+ # # Dilate for a better visualization
+ # gt[gt < 1e-4] = dilate(gt)[gt < 1e-4]
+ H, W = gt.shape[-2:]
+ aspect_ratio = H / W
+ new_W = int((300_000 / aspect_ratio) ** 0.5)
+ new_H = int(aspect_ratio * new_W)
+ gt = F.interpolate(gt, (new_H, new_W), mode="nearest-exact")
+
+ # Format predictions and errors for every metrics used
+ prediction = F.interpolate(
+ prediction,
+ gt.shape[-2:],
+ mode="bilinear",
+ align_corners=False,
+ antialias=True,
+ )
+ error = torch.zeros_like(prediction)
+ error[gt > dataset.min_depth] = (
+ 4
+ * dataset.max_depth
+ * torch.abs(gt - prediction)[gt > dataset.min_depth]
+ / gt[gt > dataset.min_depth]
+ )
+ if ssi_depth:
+ scale, shift = ssi_helper(gt[gt > 0], prediction[gt > 0])
+ prediction = (prediction * scale + shift).clip(0.0, dataset.max_depth)
+ prediction = colorize(
+ prediction.squeeze().cpu().detach().numpy(),
+ vmin=dataset.min_depth,
+ vmax=dataset.max_depth,
+ cmap="magma_r",
+ )
+ error = error.clip(0.0, dataset.max_depth).cpu().detach().numpy()
+ error = colorize(error.squeeze(), vmin=0.001, vmax=1.0, cmap="coolwarm")
+ errors.append(error)
+ predictions.append(prediction)
+
+ image = F.interpolate(
+ image, gt.shape[-2:], mode="bilinear", align_corners=False, antialias=True
+ )
+ image = image.cpu().detach() * dataset.normalization_stats["std"].view(
+ 1, -1, 1, 1
+ ) + dataset.normalization_stats["mean"].view(1, -1, 1, 1)
+ image = (
+ (255 * image)
+ .clip(0.0, 255.0)
+ .to(torch.uint8)
+ .permute(0, 2, 3, 1)
+ .numpy()
+ .squeeze()
+ )
+ gt = gt.clip(0.0, dataset.max_depth).cpu().detach().numpy()
+ gt = colorize(
+ gt.squeeze(), vmin=dataset.min_depth, vmax=dataset.max_depth, cmap="magma_r"
+ )
+ gts.append(gt)
+ images.append(image)
+
+ for name, additional in v.get("infos", {}).items():
+ if name not in additionals:
+ additionals[name] = []
+ if additional[0].shape[0] == 3:
+ val = (
+ (127.5 * (additional[0] + 1))
+ .clip(0, 255)
+ .to(torch.uint8)
+ .cpu()
+ .detach()
+ .permute(1, 2, 0)
+ .numpy()
+ )
+ else:
+ val = colorize(
+ additional[0].cpu().detach().squeeze().numpy(),
+ 0.0,
+ dataset.max_depth,
+ )
+ additionals[name].append(val)
+
+ output_artifacts.update(
+ {
+ f"predictions": stack_mixedshape_numpy(predictions),
+ f"errors": stack_mixedshape_numpy(errors),
+ "rgbs": stack_mixedshape_numpy(images),
+ "gts": stack_mixedshape_numpy(gts),
+ **{k: stack_mixedshape_numpy(v) for k, v in additionals.items()},
+ }
+ )
+ return output_artifacts
+
+
+METRIC_B = "F1"
+INVERT = True
+SSI_VISUALIZATION = True
+
+
+def validate(
+ model,
+ test_loaders: Dict[str, DataLoader],
+ step,
+ run_id,
+ context,
+ idxs=(1, 100, 150, 1000),
+):
+
+ metrics_all, predictions_select = {}, {}
+ world_size = get_world_size()
+ for name_ds, test_loader in test_loaders.items():
+ idxs = [idx % len(test_loader.dataset) for idx in idxs]
+ ds_show = []
+ for i, batch in enumerate(test_loader):
+ with context:
+ batch["data"] = {
+ k: v.to(model.device) for k, v in batch["data"].items()
+ }
+ preds = model(batch["data"], batch["img_metas"])
+
+ if batch["data"]["image"].ndim == 5:
+ batch["data"] = remove_leading_dim(batch["data"])
+ if preds["depth"].ndim == 5:
+ preds = remove_leading_dim(preds)
+ batch = original_image(batch)
+ test_loader.dataset.accumulate_metrics(
+ inputs=batch["data"],
+ preds=preds,
+ keyframe_idx=batch["img_metas"][0].get("keyframe_idx"),
+ )
+
+ # for prediction images logging
+ if i * world_size in idxs:
+ ii = (len(preds["depth"]) + 1) // 2 - 1
+ slice_ = slice(ii, ii + 1)
+ batch["data"] = {k: v[slice_] for k, v in batch["data"].items()}
+ preds["depth"] = preds["depth"][slice_]
+ ds_show.append({**batch["data"], **{"depth_pred": preds["depth"]}})
+
+ barrier()
+
+ metrics_all[name_ds] = test_loader.dataset.get_evaluation()
+ predictions_select[name_ds] = show(
+ ds_show, test_loader.dataset, ssi_depth=SSI_VISUALIZATION
+ )
+
+ barrier()
+ if is_main_process():
+ log_artifacts(artifacts_all=predictions_select, step=step, run_id=run_id)
+ metrics_all = aggregate_metrics(
+ metrics_all, exclude_fn=lambda name: "mono" in name
+ )
+ metrics_all = aggregate_metrics_camera(metrics_all)
+ log_metrics(metrics_all=metrics_all, step=step)
+ return metrics_all
diff --git a/unik3d/utils/visualization.py b/unik3d/utils/visualization.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b04ea0b29255d3169963a1d381edaea19a6618
--- /dev/null
+++ b/unik3d/utils/visualization.py
@@ -0,0 +1,251 @@
+import os
+
+import matplotlib.pyplot as plt
+import numpy as np
+import torch
+import torch.nn.functional as F
+import wandb
+from PIL import Image
+
+from unik3d.utils.distributed import get_rank
+from unik3d.utils.misc import ssi_helper
+
+
+def colorize(
+ value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r"
+):
+ # if already RGB, do nothing
+ if value.ndim > 2:
+ if value.shape[-1] > 1:
+ return value
+ value = value[..., 0]
+ invalid_mask = value < 0.0001
+ # normalize
+ vmin = value.min() if vmin is None else vmin
+ vmax = value.max() if vmax is None else vmax
+ value = (value - vmin) / (vmax - vmin) # vmin..vmax
+
+ # set color
+ cmapper = plt.get_cmap(cmap)
+ value = cmapper(value, bytes=True) # (nxmx4)
+ value[invalid_mask] = 0
+ img = value[..., :3]
+ return img
+
+
+def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray:
+ if not len(imgs):
+ return None
+ assert len(imgs) == rows * cols
+ h, w = imgs[0].shape[:2]
+ grid = Image.new("RGB", size=(cols * w, rows * h))
+
+ for i, img in enumerate(imgs):
+ grid.paste(
+ Image.fromarray(img.astype(np.uint8)).resize(
+ (w, h), resample=Image.BILINEAR
+ ),
+ box=(i % cols * w, i // cols * h),
+ )
+
+ return np.array(grid)
+
+
+def get_pointcloud_from_rgbd(
+ image: np.array,
+ depth: np.array,
+ mask: np.ndarray,
+ intrinsic_matrix: np.array,
+ extrinsic_matrix: np.array = None,
+):
+ depth = np.array(depth).squeeze()
+ mask = np.array(mask).squeeze()
+ # Mask the depth array
+ masked_depth = np.ma.masked_where(mask == False, depth)
+ # masked_depth = np.ma.masked_greater(masked_depth, 8000)
+ # Create idx array
+ idxs = np.indices(masked_depth.shape)
+ u_idxs = idxs[1]
+ v_idxs = idxs[0]
+ # Get only non-masked depth and idxs
+ z = masked_depth[~masked_depth.mask]
+ compressed_u_idxs = u_idxs[~masked_depth.mask]
+ compressed_v_idxs = v_idxs[~masked_depth.mask]
+ image = np.stack(
+ [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1
+ )
+
+ # Calculate local position of each point
+ # Apply vectorized math to depth using compressed arrays
+ cx = intrinsic_matrix[0, 2]
+ fx = intrinsic_matrix[0, 0]
+ x = (compressed_u_idxs - cx) * z / fx
+ cy = intrinsic_matrix[1, 2]
+ fy = intrinsic_matrix[1, 1]
+ # Flip y as we want +y pointing up not down
+ y = -((compressed_v_idxs - cy) * z / fy)
+
+ # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords
+ # if extrinsic_matrix is not None:
+ # # Calculate camera pose from extrinsic matrix
+ # camera_matrix = np.linalg.inv(extrinsic_matrix)
+ # # Create homogenous array of vectors by adding 4th entry of 1
+ # # At the same time flip z as for eye space the camera is looking down the -z axis
+ # w = np.ones(z.shape)
+ # x_y_z_eye_hom = np.vstack((x, y, -z, w))
+ # # Transform the points from eye space to world space
+ # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3]
+ # return x_y_z_world.T
+ # else:
+ x_y_z_local = np.stack((x, y, z), axis=-1)
+ return np.concatenate([x_y_z_local, image], axis=-1)
+
+
+def save_file_ply(xyz, rgb, pc_file):
+ if rgb.max() < 1.001:
+ rgb = rgb * 255.0
+ rgb = rgb.astype(np.uint8)
+ # print(rgb)
+ with open(pc_file, "w") as f:
+ # headers
+ f.writelines(
+ [
+ "ply\n" "format ascii 1.0\n",
+ "element vertex {}\n".format(xyz.shape[0]),
+ "property float x\n",
+ "property float y\n",
+ "property float z\n",
+ "property uchar red\n",
+ "property uchar green\n",
+ "property uchar blue\n",
+ "end_header\n",
+ ]
+ )
+
+ for i in range(xyz.shape[0]):
+ str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format(
+ xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2]
+ )
+ f.write(str_v)
+
+
+# really awful fct... FIXME
+
+
+def train_artifacts(rgbs, gts, preds, infos={}):
+ # interpolate to same shape, will be distorted! FIXME TODO
+ shape = rgbs[0].shape[-2:]
+ gts = F.interpolate(gts, shape, mode="nearest-exact")
+
+ rgbs = [
+ (127.5 * (rgb + 1))
+ .clip(0, 255)
+ .to(torch.uint8)
+ .cpu()
+ .detach()
+ .permute(1, 2, 0)
+ .numpy()
+ for rgb in rgbs
+ ]
+ new_gts, new_preds = [], []
+ num_additional, additionals = 0, []
+
+ if len(gts) > 0:
+ for i, gt in enumerate(gts):
+ # scale, shift = ssi_helper(gts[i][gts[i]>0].cpu().detach(), preds[i][gts[i]>0].cpu().detach())
+ scale, shift = 1, 0
+ up = torch.quantile(
+ torch.log(1 + gts[i][gts[i] > 0]).float().cpu().detach(), 0.98
+ ).item()
+ down = torch.quantile(
+ torch.log(1 + gts[i][gts[i] > 0]).float().cpu().detach(), 0.02
+ ).item()
+ gt = gts[i].cpu().detach().squeeze().numpy()
+ pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy()
+ new_gts.append(
+ colorize(np.log(1.0 + gt), vmin=down, vmax=up)
+ ) # , vmin=vmin, vmax=vmax))
+ new_preds.append(
+ colorize(np.log(1.0 + pred), vmin=down, vmax=up)
+ ) # , vmin=vmin, vmax=vmax))
+
+ gts, preds = new_gts, new_preds
+ else:
+ preds = [
+ colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0)
+ for i, pred in enumerate(preds)
+ ]
+
+ for name, info in infos.items():
+ num_additional += 1
+ if info.shape[1] == 3:
+ additionals.extend(
+ [
+ (127.5 * (x + 1))
+ .clip(0, 255)
+ .to(torch.uint8)
+ .cpu()
+ .detach()
+ .permute(1, 2, 0)
+ .numpy()
+ for x in info
+ ]
+ )
+ else: # must be depth!
+ additionals.extend(
+ [
+ colorize(x.cpu().detach().squeeze().numpy())
+ for i, x in enumerate(info)
+ ]
+ )
+
+ num_rows = 2 + int(len(gts) > 0) + num_additional
+
+ artifacts_grid = image_grid(
+ [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs)
+ )
+ return artifacts_grid
+
+
+def log_train_artifacts(rgbs, gts, preds, step, infos={}):
+ artifacts_grid = train_artifacts(rgbs, gts, preds, infos)
+ try:
+ wandb.log({f"training": [wandb.Image(artifacts_grid)]}, step=step)
+ except:
+ Image.fromarray(artifacts_grid).save(
+ os.path.join(
+ os.environ.get("TMPDIR", "/tmp"),
+ f"{get_rank()}_art_grid{step}.png",
+ )
+ )
+ print("Logging training images failed")
+
+
+def plot_quiver(flow, spacing, margin=0, **kwargs):
+ """Plots less dense quiver field.
+
+ Args:
+ ax: Matplotlib axis
+ flow: motion vectors
+ spacing: space (px) between each arrow in grid
+ margin: width (px) of enclosing region without arrows
+ kwargs: quiver kwargs (default: angles="xy", scale_units="xy")
+ """
+ h, w, *_ = flow.shape
+
+ nx = int((w - 2 * margin) / spacing)
+ ny = int((h - 2 * margin) / spacing)
+
+ x = np.linspace(margin, w - margin - 1, nx, dtype=np.int64)
+ y = np.linspace(margin, h - margin - 1, ny, dtype=np.int64)
+
+ flow = flow[np.ix_(y, x)]
+ u = flow[:, :, 0]
+ v = flow[:, :, 1]
+
+ kwargs = {**dict(angles="xy", scale_units="xy"), **kwargs}
+ fig, ax = plt.subplots(figsize=(10, 10))
+ ax.quiver(x, y, u, v, **kwargs)
+
+ # ax.set_ylim(sorted(ax.get_ylim(), reverse=True))
+ return fig, ax