diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..a1ca0df48306fa3795b32723d9f6d6d76e2fc4d3 --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +**/__pycache__/ +**/build/ +**/dist/ +**/*egg-info +.gradio/ + +# ignore scripts +_*.sh +__*.png +__*.jpg +__*.webp +___*.py +**/___*.py + +# ignore pcds +*.ply diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..12d4ca74d98cfb36269464dced75e9441b7bf645 --- /dev/null +++ b/LICENSE @@ -0,0 +1,407 @@ +Attribution-NonCommercial 4.0 International + +======================================================================= + +Creative Commons Corporation ("Creative Commons") is not a law firm and +does not provide legal services or legal advice. Distribution of +Creative Commons public licenses does not create a lawyer-client or +other relationship. Creative Commons makes its licenses and related +information available on an "as-is" basis. Creative Commons gives no +warranties regarding its licenses, any material licensed under their +terms and conditions, or any related information. Creative Commons +disclaims all liability for damages resulting from their use to the +fullest extent possible. + +Using Creative Commons Public Licenses + +Creative Commons public licenses provide a standard set of terms and +conditions that creators and other rights holders may use to share +original works of authorship and other material subject to copyright +and certain other rights specified in the public license below. The +following considerations are for informational purposes only, are not +exhaustive, and do not form part of our licenses. + + Considerations for licensors: Our public licenses are + intended for use by those authorized to give the public + permission to use material in ways otherwise restricted by + copyright and certain other rights. Our licenses are + irrevocable. Licensors should read and understand the terms + and conditions of the license they choose before applying it. + Licensors should also secure all rights necessary before + applying our licenses so that the public can reuse the + material as expected. Licensors should clearly mark any + material not subject to the license. This includes other CC- + licensed material, or material used under an exception or + limitation to copyright. More considerations for licensors: + wiki.creativecommons.org/Considerations_for_licensors + + Considerations for the public: By using one of our public + licenses, a licensor grants the public permission to use the + licensed material under specified terms and conditions. If + the licensor's permission is not necessary for any reason--for + example, because of any applicable exception or limitation to + copyright--then that use is not regulated by the license. Our + licenses grant only permissions under copyright and certain + other rights that a licensor has authority to grant. Use of + the licensed material may still be restricted for other + reasons, including because others have copyright or other + rights in the material. A licensor may make special requests, + such as asking that all changes be marked or described. + Although not required by our licenses, you are encouraged to + respect those requests where reasonable. More considerations + for the public: + wiki.creativecommons.org/Considerations_for_licensees + +======================================================================= + +Creative Commons Attribution-NonCommercial 4.0 International Public +License + +By exercising the Licensed Rights (defined below), You accept and agree +to be bound by the terms and conditions of this Creative Commons +Attribution-NonCommercial 4.0 International Public License ("Public +License"). To the extent this Public License may be interpreted as a +contract, You are granted the Licensed Rights in consideration of Your +acceptance of these terms and conditions, and the Licensor grants You +such rights in consideration of benefits the Licensor receives from +making the Licensed Material available under these terms and +conditions. + + +Section 1 -- Definitions. + + a. Adapted Material means material subject to Copyright and Similar + Rights that is derived from or based upon the Licensed Material + and in which the Licensed Material is translated, altered, + arranged, transformed, or otherwise modified in a manner requiring + permission under the Copyright and Similar Rights held by the + Licensor. For purposes of this Public License, where the Licensed + Material is a musical work, performance, or sound recording, + Adapted Material is always produced where the Licensed Material is + synched in timed relation with a moving image. + + b. Adapter's License means the license You apply to Your Copyright + and Similar Rights in Your contributions to Adapted Material in + accordance with the terms and conditions of this Public License. + + c. Copyright and Similar Rights means copyright and/or similar rights + closely related to copyright including, without limitation, + performance, broadcast, sound recording, and Sui Generis Database + Rights, without regard to how the rights are labeled or + categorized. For purposes of this Public License, the rights + specified in Section 2(b)(1)-(2) are not Copyright and Similar + Rights. + d. Effective Technological Measures means those measures that, in the + absence of proper authority, may not be circumvented under laws + fulfilling obligations under Article 11 of the WIPO Copyright + Treaty adopted on December 20, 1996, and/or similar international + agreements. + + e. Exceptions and Limitations means fair use, fair dealing, and/or + any other exception or limitation to Copyright and Similar Rights + that applies to Your use of the Licensed Material. + + f. Licensed Material means the artistic or literary work, database, + or other material to which the Licensor applied this Public + License. + + g. Licensed Rights means the rights granted to You subject to the + terms and conditions of this Public License, which are limited to + all Copyright and Similar Rights that apply to Your use of the + Licensed Material and that the Licensor has authority to license. + + h. Licensor means the individual(s) or entity(ies) granting rights + under this Public License. + + i. NonCommercial means not primarily intended for or directed towards + commercial advantage or monetary compensation. For purposes of + this Public License, the exchange of the Licensed Material for + other material subject to Copyright and Similar Rights by digital + file-sharing or similar means is NonCommercial provided there is + no payment of monetary compensation in connection with the + exchange. + + j. Share means to provide material to the public by any means or + process that requires permission under the Licensed Rights, such + as reproduction, public display, public performance, distribution, + dissemination, communication, or importation, and to make material + available to the public including in ways that members of the + public may access the material from a place and at a time + individually chosen by them. + + k. Sui Generis Database Rights means rights other than copyright + resulting from Directive 96/9/EC of the European Parliament and of + the Council of 11 March 1996 on the legal protection of databases, + as amended and/or succeeded, as well as other essentially + equivalent rights anywhere in the world. + + l. You means the individual or entity exercising the Licensed Rights + under this Public License. Your has a corresponding meaning. + + +Section 2 -- Scope. + + a. License grant. + + 1. Subject to the terms and conditions of this Public License, + the Licensor hereby grants You a worldwide, royalty-free, + non-sublicensable, non-exclusive, irrevocable license to + exercise the Licensed Rights in the Licensed Material to: + + a. reproduce and Share the Licensed Material, in whole or + in part, for NonCommercial purposes only; and + + b. produce, reproduce, and Share Adapted Material for + NonCommercial purposes only. + + 2. Exceptions and Limitations. For the avoidance of doubt, where + Exceptions and Limitations apply to Your use, this Public + License does not apply, and You do not need to comply with + its terms and conditions. + + 3. Term. The term of this Public License is specified in Section + 6(a). + + 4. Media and formats; technical modifications allowed. The + Licensor authorizes You to exercise the Licensed Rights in + all media and formats whether now known or hereafter created, + and to make technical modifications necessary to do so. The + Licensor waives and/or agrees not to assert any right or + authority to forbid You from making technical modifications + necessary to exercise the Licensed Rights, including + technical modifications necessary to circumvent Effective + Technological Measures. For purposes of this Public License, + simply making modifications authorized by this Section 2(a) + (4) never produces Adapted Material. + + 5. Downstream recipients. + + a. Offer from the Licensor -- Licensed Material. Every + recipient of the Licensed Material automatically + receives an offer from the Licensor to exercise the + Licensed Rights under the terms and conditions of this + Public License. + + b. No downstream restrictions. You may not offer or impose + any additional or different terms or conditions on, or + apply any Effective Technological Measures to, the + Licensed Material if doing so restricts exercise of the + Licensed Rights by any recipient of the Licensed + Material. + + 6. No endorsement. Nothing in this Public License constitutes or + may be construed as permission to assert or imply that You + are, or that Your use of the Licensed Material is, connected + with, or sponsored, endorsed, or granted official status by, + the Licensor or others designated to receive attribution as + provided in Section 3(a)(1)(A)(i). + + b. Other rights. + + 1. Moral rights, such as the right of integrity, are not + licensed under this Public License, nor are publicity, + privacy, and/or other similar personality rights; however, to + the extent possible, the Licensor waives and/or agrees not to + assert any such rights held by the Licensor to the limited + extent necessary to allow You to exercise the Licensed + Rights, but not otherwise. + + 2. Patent and trademark rights are not licensed under this + Public License. + + 3. To the extent possible, the Licensor waives any right to + collect royalties from You for the exercise of the Licensed + Rights, whether directly or through a collecting society + under any voluntary or waivable statutory or compulsory + licensing scheme. In all other cases the Licensor expressly + reserves any right to collect such royalties, including when + the Licensed Material is used other than for NonCommercial + purposes. + + +Section 3 -- License Conditions. + +Your exercise of the Licensed Rights is expressly made subject to the +following conditions. + + a. Attribution. + + 1. If You Share the Licensed Material (including in modified + form), You must: + + a. retain the following if it is supplied by the Licensor + with the Licensed Material: + + i. identification of the creator(s) of the Licensed + Material and any others designated to receive + attribution, in any reasonable manner requested by + the Licensor (including by pseudonym if + designated); + + ii. a copyright notice; + + iii. a notice that refers to this Public License; + + iv. a notice that refers to the disclaimer of + warranties; + + v. a URI or hyperlink to the Licensed Material to the + extent reasonably practicable; + + b. indicate if You modified the Licensed Material and + retain an indication of any previous modifications; and + + c. indicate the Licensed Material is licensed under this + Public License, and include the text of, or the URI or + hyperlink to, this Public License. + + 2. You may satisfy the conditions in Section 3(a)(1) in any + reasonable manner based on the medium, means, and context in + which You Share the Licensed Material. For example, it may be + reasonable to satisfy the conditions by providing a URI or + hyperlink to a resource that includes the required + information. + + 3. If requested by the Licensor, You must remove any of the + information required by Section 3(a)(1)(A) to the extent + reasonably practicable. + + 4. If You Share Adapted Material You produce, the Adapter's + License You apply must not prevent recipients of the Adapted + Material from complying with this Public License. + + +Section 4 -- Sui Generis Database Rights. + +Where the Licensed Rights include Sui Generis Database Rights that +apply to Your use of the Licensed Material: + + a. for the avoidance of doubt, Section 2(a)(1) grants You the right + to extract, reuse, reproduce, and Share all or a substantial + portion of the contents of the database for NonCommercial purposes + only; + + b. if You include all or a substantial portion of the database + contents in a database in which You have Sui Generis Database + Rights, then the database in which You have Sui Generis Database + Rights (but not its individual contents) is Adapted Material; and + + c. You must comply with the conditions in Section 3(a) if You Share + all or a substantial portion of the contents of the database. + +For the avoidance of doubt, this Section 4 supplements and does not +replace Your obligations under this Public License where the Licensed +Rights include other Copyright and Similar Rights. + + +Section 5 -- Disclaimer of Warranties and Limitation of Liability. + + a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE + EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS + AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF + ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS, + IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION, + WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR + PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS, + ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT + KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT + ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU. + + b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE + TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION, + NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT, + INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES, + COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR + USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN + ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR + DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR + IN PART, THIS LIMITATION MAY NOT APPLY TO YOU. + + c. The disclaimer of warranties and limitation of liability provided + above shall be interpreted in a manner that, to the extent + possible, most closely approximates an absolute disclaimer and + waiver of all liability. + + +Section 6 -- Term and Termination. + + a. This Public License applies for the term of the Copyright and + Similar Rights licensed here. However, if You fail to comply with + this Public License, then Your rights under this Public License + terminate automatically. + + b. Where Your right to use the Licensed Material has terminated under + Section 6(a), it reinstates: + + 1. automatically as of the date the violation is cured, provided + it is cured within 30 days of Your discovery of the + violation; or + + 2. upon express reinstatement by the Licensor. + + For the avoidance of doubt, this Section 6(b) does not affect any + right the Licensor may have to seek remedies for Your violations + of this Public License. + + c. For the avoidance of doubt, the Licensor may also offer the + Licensed Material under separate terms or conditions or stop + distributing the Licensed Material at any time; however, doing so + will not terminate this Public License. + + d. Sections 1, 5, 6, 7, and 8 survive termination of this Public + License. + + +Section 7 -- Other Terms and Conditions. + + a. The Licensor shall not be bound by any additional or different + terms or conditions communicated by You unless expressly agreed. + + b. Any arrangements, understandings, or agreements regarding the + Licensed Material not stated herein are separate from and + independent of the terms and conditions of this Public License. + + +Section 8 -- Interpretation. + + a. For the avoidance of doubt, this Public License does not, and + shall not be interpreted to, reduce, limit, restrict, or impose + conditions on any use of the Licensed Material that could lawfully + be made without permission under this Public License. + + b. To the extent possible, if any provision of this Public License is + deemed unenforceable, it shall be automatically reformed to the + minimum extent necessary to make it enforceable. If the provision + cannot be reformed, it shall be severed from this Public License + without affecting the enforceability of the remaining terms and + conditions. + + c. No term or condition of this Public License will be waived and no + failure to comply consented to unless expressly agreed to by the + Licensor. + + d. Nothing in this Public License constitutes or may be interpreted + as a limitation upon, or waiver of, any privileges and immunities + that apply to the Licensor or You, including from the legal + processes of any jurisdiction or authority. + +======================================================================= + +Creative Commons is not a party to its public +licenses. Notwithstanding, Creative Commons may elect to apply one of +its public licenses to material it publishes and in those instances +will be considered the “Licensor.” The text of the Creative Commons +public licenses is dedicated to the public domain under the CC0 Public +Domain Dedication. Except for the limited purpose of indicating that +material is shared under a Creative Commons public license or as +otherwise permitted by the Creative Commons policies published at +creativecommons.org/policies, Creative Commons does not authorize the +use of the trademark "Creative Commons" or any other trademark or logo +of Creative Commons without its prior written consent including, +without limitation, in connection with any unauthorized modifications +to any of its public licenses or any other arrangements, +understandings, or agreements concerning use of licensed material. For +the avoidance of doubt, this paragraph does not form part of the +public licenses. + +Creative Commons may be contacted at creativecommons.org. \ No newline at end of file diff --git a/README.md b/README.md index f0442905a8d4cb9e6d0020015731ee82a9fe7c3b..bf3e83af9af16301aaf295e155560c7916b0596c 100644 --- a/README.md +++ b/README.md @@ -8,6 +8,7 @@ sdk_version: 5.22.0 app_file: app.py pinned: false license: cc-by-nc-4.0 +short_description: UniK3D (CVPR 2025) --- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..950e0d92ffe20c33dbab3d03dd2a06f14fb86652 --- /dev/null +++ b/app.py @@ -0,0 +1,800 @@ +import gc +import os +import shutil +import time +from datetime import datetime +from math import pi +import sys + +import gradio as gr +import numpy as np +import torch +import trimesh +from PIL import Image + + +sys.path.append("unik3d/") + +from unik3d.models import UniK3D +from unik3d.utils.camera import OPENCV, Fisheye624, Pinhole, Spherical +from unik3d.utils.visualization import colorize + + +def predictions_to_glb( + predictions, + mask_black_bg=False, + mask_far_points=False, +) -> trimesh.Scene: + print("Building GLB scene") + images = predictions["image"].squeeze().permute(1, 2, 0).cpu().numpy() + world_points = predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy() + + vertices_3d = world_points.reshape(-1, 3) + # flip x and y + vertices_3d[:, 1] *= -1 + vertices_3d[:, 0] *= -1 + colors_rgb = (images.reshape(-1, 3)).astype(np.uint8) + + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + vertices_3d = vertices_3d[black_bg_mask] + colors_rgb = colors_rgb[black_bg_mask] + + if mask_far_points: + far_points_mask = np.linalg.norm(vertices_3d, axis=-1) < 100.0 + vertices_3d = vertices_3d[far_points_mask] + colors_rgb = colors_rgb[far_points_mask] + + scene_3d = trimesh.Scene() + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + scene_3d.add_geometry(point_cloud_data) + + return scene_3d + + +def instantiate_model(model_name): + type_ = model_name[0].lower() + + name = f"unik3d-vit{type_}" + model = UniK3D.from_pretrained(f"lpiccinelli/{name}") + + # Set resolution level and interpolation mode as specified. + model.resolution_level = 9 + model.interpolation_mode = "bilinear" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device).eval() + return model + + +def instantiate_camera(camera_name, params, device): + if camera_name == "Predicted": + return None + fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov, H, W = params + if camera_name == "Pinhole": + params = [fx, fy, cx, cy] + elif camera_name == "Fisheye624": + params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2] + elif camera_name == "OPENCV": + params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2] + elif camera_name == "Equirectangular": + # dummy intrinsics for spherical camera, assume hfov -> vfov based on input shapes + hfov2 = hfov * pi / 180.0 / 2 + params = [fx, fy, cx, cy, W, H, hfov2, H / W * hfov2] + camera_name = "Spherical" + + return eval(camera_name)(params=torch.tensor(params).float()).to(device) + + +def run_model(target_dir, model_name, camera_name, params): + + print("Instantiating model and camera...") + model = instantiate_model(model_name) + + image_names = [x for x in os.listdir(target_dir) if x.endswith(".png")] + input_image = np.array(Image.open(os.path.join(target_dir, image_names[-1]))) + image_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float() + device = next(model.parameters()).device + image_tensor = image_tensor.to(device) + H, W = image_tensor.shape[-2:] + params = params + [H, W] + camera = instantiate_camera(camera_name, params=params, device=device) + + # Perform inference with the model. + print("Running inference...") + outputs = model.infer(image_tensor, camera=camera, normalize=True) + outputs["image"] = image_tensor + + return outputs + + +def gradio_demo( + target_dir, + model_name, + camera_name, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, +): + print(target_dir) + if not os.path.isdir(target_dir) or target_dir == "None": + return None, "No valid target directory found. Please upload first.", None + + start_time = time.time() + gc.collect() + + print("Running run_model...") + params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov] + with torch.no_grad(): + outputs = run_model(target_dir, model_name, camera_name, params) + + # Save predictions + points = outputs["points"].squeeze().permute(1, 2, 0).cpu().numpy() + rgb = outputs["image"].squeeze().permute(1, 2, 0).cpu().numpy() + + prediction_save_path = os.path.join(target_dir, "predictions.npz") + np.savez(prediction_save_path, {"points": points, "image": rgb}) + + # Build a GLB file name + glbfile = os.path.join( + target_dir, + f"glbscene.glb", + ) + + # Convert predictions to GLB + glbscene = predictions_to_glb( + outputs, + mask_black_bg=mask_black_bg, + mask_far_points=mask_far_points, + ) + glbscene.export(file_obj=glbfile) + + # Cleanup + del outputs + gc.collect() + + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds") + log_msg = f"Success. Waiting for visualization." + + return glbfile, log_msg, prediction_save_path + + +def handle_uploads(input_image): + gc.collect() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + tmpdir = os.environ.get("TMPDIR", "/tmp") + target_dir = os.path.join(tmpdir, f"input_images_{timestamp}") + + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir) + + dst_path = os.path.join(target_dir, "image.png") + Image.fromarray(input_image).save(dst_path) + image_paths = [dst_path] + + print(f"Files uploaded.") + return target_dir, image_paths + + +def update_gallery_on_upload(input_images): + if input_images is None: + return None, None + target_dir, image_path = handle_uploads(input_images) + return target_dir, "Upload complete. Click 'Run UniK3D' to get 3D pointcloud." + + +def update_parameters(camera): + if camera == "Pinhole": + return ( + gr.update(visible=True), # fx + gr.update(visible=True), # fy + gr.update(visible=True), # cx + gr.update(visible=True), # cy + gr.update(visible=False), # k1 + gr.update(visible=False), # k2 + gr.update(visible=False), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=False), # t1 + gr.update(visible=False), # t2 + gr.update(visible=False), # hfov + ) + elif camera == "OPENCV": + return ( + gr.update(visible=True), # fx + gr.update(visible=True), # fy + gr.update(visible=True), # cx + gr.update(visible=True), # cy + gr.update(visible=True), # k1 + gr.update(visible=True), # k2 + gr.update(visible=True), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=True), # t1 + gr.update(visible=True), # t2 + gr.update(visible=False), # hfov + ) + elif camera == "Fisheye624": + return ( + gr.update(visible=True), # fx + gr.update(visible=True), # fy + gr.update(visible=True), # cx + gr.update(visible=True), # cy + gr.update(visible=True), # k1 + gr.update(visible=True), # k2 + gr.update(visible=True), # k3 + gr.update(visible=True), # k4 + gr.update(visible=True), # k5 + gr.update(visible=True), # k6 + gr.update(visible=True), # t1 + gr.update(visible=True), # t2 + gr.update(visible=False), # hfov + ) + elif camera == "Equirectangular": + return ( + gr.update(visible=False), # fx + gr.update(visible=False), # fy + gr.update(visible=False), # cx + gr.update(visible=False), # cy + gr.update(visible=False), # k1 + gr.update(visible=False), # k2 + gr.update(visible=False), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=False), # t1 + gr.update(visible=False), # t2 + gr.update(visible=True), # hfov + ) + elif camera == "Predicted": + return ( + gr.update(visible=False), # fx + gr.update(visible=False), # fy + gr.update(visible=False), # cx + gr.update(visible=False), # cy + gr.update(visible=False), # k1 + gr.update(visible=False), # k2 + gr.update(visible=False), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=False), # t1 + gr.update(visible=False), # t2 + gr.update(visible=False), # hfov + ) + else: + raise ValueError(f"Invalid camera type: {camera}") + + +def clear_fields(): + return None + + +def update_log(): + return "Loading Model and Running Inference..." + + +def update_visualization(target_dir, mask_black_bg, mask_far_points, is_example): + + if is_example == "True": + return ( + None, + "No reconstruction available. Please click the Reconstruct button first.", + ) + + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return ( + None, + "No reconstruction available. Please click the Reconstruct button first.", + ) + + predictions_path = os.path.join(target_dir, "predictions.npz") + if not os.path.exists(predictions_path): + return ( + None, + f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", + ) + + loaded = np.load(predictions_path, allow_pickle=True) + predictions = {key: loaded[key] for key in loaded.keys()} + + glbfile = os.path.join( + target_dir, + f"glbscene.glb", + ) + + if not os.path.exists(glbfile): + glbscene = predictions_to_glb( + predictions, + mask_black_bg=mask_black_bg, + mask_far_points=mask_far_points, + ) + glbscene.export(file_obj=glbfile) + + return glbfile, "Updating Visualization" + + +if __name__ == "__main__": + theme = gr.themes.Citrus() + theme.set( + checkbox_label_background_fill_selected="*button_primary_background_fill", + checkbox_label_text_color_selected="*button_primary_text_color", + ) + + with gr.Blocks( + theme=theme, + css=""" + .custom-log * { + font-style: italic; + font-size: 22px !important; + background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%); + -webkit-background-clip: text; + background-clip: text; + font-weight: bold !important; + color: transparent !important; + text-align: center !important; + } + + .example-log * { + font-style: italic; + font-size: 16px !important; + background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%); + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + } + + #my_radio .wrap { + display: flex; + flex-wrap: nowrap; + justify-content: center; + align-items: center; + } + + #my_radio .wrap label { + display: flex; + width: 50%; + justify-content: center; + align-items: center; + margin: 0; + padding: 10px 0; + box-sizing: border-box; + } + """, + ) as demo: + + # Instead of gr.State, we use a hidden Textbox: + is_example = gr.Textbox(label="is_example", visible=False, value="None") + + gr.HTML( + """ +

UniK3D: Universal Camera Monocular 3D Estimation

+

+ 🌟 GitHub Repository | + 🚀 Project Page +

+ +
+

Upload one image to create a 3D estimation of a scene or object. UniK3D allows to predict directly 3D of any camera and scene.

+ +

Getting Started:

+
    +
  1. Upload Your Image: Use the "Upload Images" panel to provide your input.
  2. +
  3. Run: Click the "Run UniK3D" button to start the 3D estimation process.
  4. +
  5. Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.
  6. +
+

Please note: Our model runs on CPU on HuggingFace Space. Actual inference is less than 100ms second per image on consumer-level GPUs. Web-based 3D pointcloud visualization may be slow due to Gradio's rendering. For faster visualization, use a local machine to run our demo from our GitHub repository.

+
+ """ + ) + + target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") + + with gr.Row(): + with gr.Column(): + camera_dropdown = gr.Dropdown( + choices=[ + "Predicted", + "Pinhole", + "Fisheye624", + "OPENCV", + "Equirectangular", + ], + label="Input Camera", + ) + model_dropdown = gr.Dropdown( + choices=["Large", "Base", "Small"], label="Utilized Model" + ) + mask_black_bg = gr.Checkbox( + label="Filter Black Background", value=False + ) + mask_far_points = gr.Checkbox(label="Filter Far Points", value=False) + + with gr.Column(): + fx = gr.Number(label="Focal length x", value=500.0, visible=False) + fy = gr.Number(label="Focal length y", value=500.0, visible=False) + cx = gr.Number(label="Center projection x", value=320.0, visible=False) + cy = gr.Number(label="Center projection y", value=240.0, visible=False) + hfov = gr.Number( + label="Horizontal FoV (degree)", value=0.0, visible=False + ) + + with gr.Column(): + k1 = gr.Number(label="Radial 1", value=0.0, visible=False) + k2 = gr.Number(label="Radial 2", value=0.0, visible=False) + k3 = gr.Number(label="Radial 3", value=0.0, visible=False) + k4 = gr.Number(label="Radial 4", value=0.0, visible=False) + + with gr.Column(): + k5 = gr.Number(label="Radial 5", value=0.0, visible=False) + k6 = gr.Number(label="Radial 6", value=0.0, visible=False) + t1 = gr.Number(label="Tangential 1", value=0.0, visible=False) + t2 = gr.Number(label="Tangential 2", value=0.0, visible=False) + + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image(label="Upload Images") + gr.Markdown("**3D Estimation**") + with gr.Row(): + log_output = gr.Markdown( + "Please upload one image at a time, then click `Run UniK3D`.", + elem_classes=["custom-log"], + ) + reconstruction_npy = gr.File( + label="Download 3D Pointcloud", type="filepath" + ) + + with gr.Column(scale=2): + reconstruction_output = gr.Model3D( + height=520, zoom_speed=0.5, pan_speed=0.5 + ) + with gr.Row(): + submit_btn = gr.Button("Run UniK3D", scale=1, variant="primary") + clear_btn = gr.ClearButton( + [ + input_image, + reconstruction_output, + log_output, + target_dir_output, + reconstruction_npy, + ], + scale=1, + ) + + examples = [ + [ + "assets/demo/poorthings.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + True, + False, + ], + [ + "assets/demo/naruto.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + False, + False, + ], + [ + "assets/demo/bears.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + True, + False, + ], + [ + "assets/demo/berzirk.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + True, + False, + ], + [ + "assets/demo/luke.webp", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + False, + False, + ], + [ + "assets/demo/equirectangular.jpg", + "Large", + "Equirectangular", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 360.0, + False, + False, + ], + [ + "assets/demo/venice.jpg", + "Large", + "Equirectangular", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 360.0, + False, + True, + ], + [ + "assets/demo/dl3dv.png", + "Large", + "OPENCV", + 429.57611083984375, + 429.6898193359375, + 479.5, + 269.5, + -0.0014844092074781656, + 0.0007422995404340327, + 0.0, + 0.0, + 0.0, + 0.0, + 0.00012013866944471374, + 0.001125041046179831, + 0.0, + False, + False, + ], + [ + "assets/demo/scannet.jpg", + "Large", + "Fisheye624", + 791.90869140625, + 792.7230834960938, + 878.16796875, + 585.045166015625, + -0.029167557135224342, + -0.006803446915000677, + -0.0012682401575148106, + -4.6094228309812024e-05, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + False, + False, + ], + ] + + def example_pipeline( + input_image, + model_name, + camera_name, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ): + target_dir, image_path = handle_uploads(input_image) + glbfile, log_msg, prediction_save_path = gradio_demo( + target_dir, + model_name, + camera_name, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ) + return ( + glbfile, + log_msg, + prediction_save_path, + target_dir, + image_path, + ) + + gr.Markdown("Click any row to load an example.", elem_classes=["example-log"]) + + gr.Examples( + examples=examples, + inputs=[ + input_image, + model_dropdown, + camera_dropdown, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ], + outputs=[reconstruction_output, log_output, reconstruction_npy], + fn=example_pipeline, + cache_examples=False, + examples_per_page=50, + ) + + submit_btn.click( + fn=clear_fields, inputs=[], outputs=[reconstruction_output] + ).then(fn=update_log, inputs=[], outputs=[log_output]).then( + fn=gradio_demo, + inputs=[ + target_dir_output, + model_dropdown, + camera_dropdown, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ], + outputs=[reconstruction_output, log_output, reconstruction_npy], + ).then( + fn=lambda: "False", inputs=[], outputs=[is_example] + ) + + mask_black_bg.change( + update_visualization, + [target_dir_output, mask_black_bg, mask_far_points, is_example], + [reconstruction_output, log_output], + ) + + mask_far_points.change( + update_visualization, + [target_dir_output, mask_black_bg, mask_far_points, is_example], + [reconstruction_output, log_output], + ) + + input_image.change( + fn=update_gallery_on_upload, + inputs=[input_image], + outputs=[target_dir_output, log_output], + ) + + # Dynamically update intrinsic parameter visibility when camera selection changes. + camera_dropdown.change( + fn=update_parameters, + inputs=camera_dropdown, + outputs=[fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov], + ) + + # demo.queue(max_size=20).launch(show_error=True, share=False, ssr_mode=False) + demo.launch( + show_error=True, + ) diff --git a/assets/demo/bears.jpg b/assets/demo/bears.jpg new file mode 100644 index 0000000000000000000000000000000000000000..99f6ec57a49b31f946b034a0a3831324abf91e25 Binary files /dev/null and b/assets/demo/bears.jpg differ diff --git a/assets/demo/berzirk.jpg b/assets/demo/berzirk.jpg new file mode 100644 index 0000000000000000000000000000000000000000..e5fa22911064487993c6a70297a544655fb172a2 Binary files /dev/null and b/assets/demo/berzirk.jpg differ diff --git a/assets/demo/dl3dv.json b/assets/demo/dl3dv.json new file mode 100644 index 0000000000000000000000000000000000000000..6bf2bbd8a0a952e9a8462a975d1e1e63c5a0a0cc --- /dev/null +++ b/assets/demo/dl3dv.json @@ -0,0 +1,4 @@ +{ + "name": "OPENCV", + "params": [429.57611083984375, 429.6898193359375, 479.5, 269.5, -0.0014844092074781656, 0.0007422995404340327, 0.0, 0.0, 0.0, 0.0, 0.00012013866944471374, 0.001125041046179831, 0.0, 0.0, 0.0, 0.0] +} \ No newline at end of file diff --git a/assets/demo/dl3dv.png b/assets/demo/dl3dv.png new file mode 100644 index 0000000000000000000000000000000000000000..a3f2bb1302251b09c20f99a452bae29d6251daa6 Binary files /dev/null and b/assets/demo/dl3dv.png differ diff --git a/assets/demo/equirectangular.jpg b/assets/demo/equirectangular.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ef669756640bf3b9f71775ae078c116502e9184d Binary files /dev/null and b/assets/demo/equirectangular.jpg differ diff --git a/assets/demo/kitti360.json b/assets/demo/kitti360.json new file mode 100644 index 0000000000000000000000000000000000000000..cdacd45482f9052ca9a10047a6db44c5a7a670ce --- /dev/null +++ b/assets/demo/kitti360.json @@ -0,0 +1,14 @@ +{ + "params": [ + 890.8814086914062, + 890.5255737304688, + 477.7955017089844, + 470.34332275390625, + 0.016798235476017, + 1.6548773050308228, + 0.000422239420004189, + 0.000424621335696429, + 2.213404655456543 + ], + "name": "MEI" +} \ No newline at end of file diff --git a/assets/demo/kitti360.png b/assets/demo/kitti360.png new file mode 100644 index 0000000000000000000000000000000000000000..4a58026a71cfbc3787dc2e91c11d29d54ef90138 Binary files /dev/null and b/assets/demo/kitti360.png differ diff --git a/assets/demo/luke.webp b/assets/demo/luke.webp new file mode 100644 index 0000000000000000000000000000000000000000..b60384b4bb9bf67a626b3f84e226bfa39b2e561c Binary files /dev/null and b/assets/demo/luke.webp differ diff --git a/assets/demo/naruto.jpg b/assets/demo/naruto.jpg new file mode 100644 index 0000000000000000000000000000000000000000..6bb0aad52c5aeaa9a1ac062fede8d1313017f062 Binary files /dev/null and b/assets/demo/naruto.jpg differ diff --git a/assets/demo/poorthings.jpg b/assets/demo/poorthings.jpg new file mode 100644 index 0000000000000000000000000000000000000000..12b018df8c45097d3cb3c2ff2c6725327c0897e9 Binary files /dev/null and b/assets/demo/poorthings.jpg differ diff --git a/assets/demo/scannet.jpg b/assets/demo/scannet.jpg new file mode 100644 index 0000000000000000000000000000000000000000..ad3b4a68bbf4e5a232de75a2bf19ac29d84687f8 Binary files /dev/null and b/assets/demo/scannet.jpg differ diff --git a/assets/demo/scannet.json b/assets/demo/scannet.json new file mode 100644 index 0000000000000000000000000000000000000000..4df47995b7af91b18360192537e2aa3aad80a527 --- /dev/null +++ b/assets/demo/scannet.json @@ -0,0 +1,21 @@ +{ + "params": [ + 791.90869140625, + 792.7230834960938, + 878.16796875, + 585.045166015625, + -0.029167557135224342, + -0.006803446915000677, + -0.0012682401575148106, + -4.6094228309812024e-05, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0 + ], + "name": "Fisheye624" +} \ No newline at end of file diff --git a/assets/demo/venice.jpg b/assets/demo/venice.jpg new file mode 100644 index 0000000000000000000000000000000000000000..2c6c7dd8629c9887b51786b92a1c8e110014070b Binary files /dev/null and b/assets/demo/venice.jpg differ diff --git a/assets/docs/unik3d-banner.png b/assets/docs/unik3d-banner.png new file mode 100644 index 0000000000000000000000000000000000000000..598ac6713d629653e18850bf493967cb07800acb Binary files /dev/null and b/assets/docs/unik3d-banner.png differ diff --git a/assets/docs/unik3d-teaser.png b/assets/docs/unik3d-teaser.png new file mode 100644 index 0000000000000000000000000000000000000000..a60145bf44be0cc2091ef94471608de1a98378d3 Binary files /dev/null and b/assets/docs/unik3d-teaser.png differ diff --git a/configs/config_vitb.json b/configs/config_vitb.json new file mode 100644 index 0000000000000000000000000000000000000000..ad6b813bb4f43744b45bd5ba7626f28b03e48302 --- /dev/null +++ b/configs/config_vitb.json @@ -0,0 +1,159 @@ +{ + "generic": { + "seed": 42, + "deterministic": true, + "name_page": "ufish" + }, + "training": { + "n_iters": 250000, + "batch_size": 8, + "validation_interval": 2500, + "nsteps_accumulation_gradient": 4, + "lr": 5e-05, + "lr_final": 1e-06, + "lr_warmup": 1.0, + "cycle_beta": true, + "wd": 0.1, + "wd_final": 0.1, + "warmup_iters": 75000, + "ld": 1.0, + "drop_path": 0.0, + "ema": 0.9995, + "f16": "f16", + "clipping": 1.0, + "losses": { + "depth": { + "name": "Scale", + "weight": 1.0, + "fn": "l1", + "gamma": 1.0, + "alpha": 1.0, + "output_fn": "sqrt", + "input_fn": "log" + }, + "camera": { + "name": "PolarRegression", + "weight": 1.0, + "gamma": 1.0, + "alpha": 1.0, + "fn": "l1", + "output_fn": "sqrt", + "input_fn": "linear", + "dims": [ + 1, + 2 + ], + "polar_weight": 3.0, + "polar_asym": 0.7 + }, + "confidence": { + "name": "Confidence", + "weight": 0.1, + "input_fn": "log", + "output_fn": "sqrt" + } + } + }, + "data": { + "image_shape": [ + 518, + 518 + ], + "resize_method": "contextcrop", + "normalization": "imagenet", + "pair": 1, + "mini": 1.0, + "num_frames": 1, + "sampling": { + "KITTI": 1.0 + }, + "train_datasets": [ + "KITTI" + ], + "val_datasets": [ + "KITTI" + ], + "data_root": "datasets", + "crop": "garg", + "augmentations": { + "random_scale": 4.0, + "random_translate_x": 0.04, + "random_translate_y": 0.01, + "scale_p": 0.0, + "translate_p": 0.0, + "random_rotation": 0.0, + "rotation_p": 0.0, + "random_shear": 0.0, + "affine_p": 0.0, + "random_jitter": 0.5, + "jitter_p": 1.0, + "random_blur": 2.0, + "blur_p": 0.5, + "random_gamma": 0.5, + "gamma_p": 1.0, + "grayscale_p": 0.2, + "flip_p": 0.5, + "cut_p": 0.0, + "invert_p": 0.0, + "shape_mult": 14, + "noise_pad": 1.0, + "test_context": 1.0 + }, + "shape_constraints": { + "ratio_bounds": [ + 0.5, + 2.5 + ], + "pixels_max": 600000.0, + "pixels_min": 200000.0, + "height_min": 15, + "width_min": 15, + "shape_mult": 14, + "sample": true + } + }, + "model": { + "name": "UniK3D", + "num_heads": 8, + "expansion": 4, + "num_steps": 100000, + "layer_scale": 1e-4, + "camera": { + "augment": true, + "weak_ratio": 0.9, + "tau": 50000 + }, + "pixel_decoder": { + "name": "Decoder", + "hidden_dim": 384, + "dropout": 0.0, + "depths": [ + 2, + 2, + 2 + ], + "detach": 0.1, + "out_dim": 48, + "kernel_size": 3, + "num_prompt_blocks": 1, + "use_norm": false + }, + "pixel_encoder": { + "lr": 3e-06, + "wd": 0.1, + "name": "dinov2_vitb14", + "frozen_stages": 0, + "num_register_tokens": 0, + "use_norm": true, + "freeze_norm": true, + "pretrained": null, + "stacking_fn": "last", + "output_idx": [ + 3, + 6, + 9, + 12 + ] + } + } +} \ No newline at end of file diff --git a/configs/config_vitl.json b/configs/config_vitl.json new file mode 100644 index 0000000000000000000000000000000000000000..ca6865b908e9c8a0319ed0e6f3e5bc3f7450f362 --- /dev/null +++ b/configs/config_vitl.json @@ -0,0 +1,159 @@ +{ + "generic": { + "seed": 42, + "deterministic": true, + "name_page": "ufish" + }, + "training": { + "n_iters": 250000, + "batch_size": 8, + "validation_interval": 2500, + "nsteps_accumulation_gradient": 4, + "lr": 5e-05, + "lr_final": 1e-06, + "lr_warmup": 1.0, + "cycle_beta": true, + "wd": 0.1, + "wd_final": 0.1, + "warmup_iters": 75000, + "ld": 1.0, + "drop_path": 0.0, + "ema": 0.9995, + "f16": "f16", + "clipping": 1.0, + "losses": { + "depth": { + "name": "Scale", + "weight": 1.0, + "fn": "l1", + "gamma": 1.0, + "alpha": 1.0, + "output_fn": "sqrt", + "input_fn": "log" + }, + "camera": { + "name": "PolarRegression", + "weight": 1.0, + "gamma": 1.0, + "alpha": 1.0, + "fn": "l1", + "output_fn": "sqrt", + "input_fn": "linear", + "dims": [ + 1, + 2 + ], + "polar_weight": 3.0, + "polar_asym": 0.7 + }, + "confidence": { + "name": "Confidence", + "weight": 0.1, + "input_fn": "log", + "output_fn": "sqrt" + } + } + }, + "data": { + "image_shape": [ + 518, + 518 + ], + "resize_method": "contextcrop", + "normalization": "imagenet", + "pair": 1, + "mini": 1.0, + "num_frames": 1, + "sampling": { + "KITTI": 1.0 + }, + "train_datasets": [ + "KITTI" + ], + "val_datasets": [ + "KITTI" + ], + "data_root": "datasets", + "crop": "garg", + "augmentations": { + "random_scale": 4.0, + "random_translate_x": 0.04, + "random_translate_y": 0.01, + "scale_p": 0.0, + "translate_p": 0.0, + "random_rotation": 0.0, + "rotation_p": 0.0, + "random_shear": 0.0, + "affine_p": 0.0, + "random_jitter": 0.5, + "jitter_p": 1.0, + "random_blur": 2.0, + "blur_p": 0.5, + "random_gamma": 0.5, + "gamma_p": 1.0, + "grayscale_p": 0.2, + "flip_p": 0.5, + "cut_p": 0.0, + "invert_p": 0.0, + "shape_mult": 14, + "noise_pad": 1.0, + "test_context": 1.0 + }, + "shape_constraints": { + "ratio_bounds": [ + 0.5, + 2.5 + ], + "pixels_max": 600000.0, + "pixels_min": 200000.0, + "height_min": 15, + "width_min": 15, + "shape_mult": 14, + "sample": true + } + }, + "model": { + "name": "UniK3D", + "num_heads": 8, + "expansion": 4, + "num_steps": 100000, + "layer_scale": 1e-4, + "camera": { + "augment": true, + "weak_ratio": 0.9, + "tau": 50000 + }, + "pixel_decoder": { + "name": "Decoder", + "hidden_dim": 512, + "dropout": 0.0, + "depths": [ + 2, + 2, + 2 + ], + "detach": 0.1, + "out_dim": 64, + "kernel_size": 3, + "num_prompt_blocks": 1, + "use_norm": false + }, + "pixel_encoder": { + "lr": 3e-06, + "wd": 0.1, + "name": "dinov2_vitl14", + "frozen_stages": 0, + "num_register_tokens": 0, + "use_norm": true, + "freeze_norm": true, + "pretrained": null, + "stacking_fn": "last", + "output_idx": [ + 6, + 12, + 18, + 24 + ] + } + } +} \ No newline at end of file diff --git a/configs/config_vits.json b/configs/config_vits.json new file mode 100644 index 0000000000000000000000000000000000000000..ca37e00100b51e681f233f9b638aa0c4beb7b734 --- /dev/null +++ b/configs/config_vits.json @@ -0,0 +1,159 @@ +{ + "generic": { + "seed": 42, + "deterministic": true, + "name_page": "ufish" + }, + "training": { + "n_iters": 250000, + "batch_size": 8, + "validation_interval": 2500, + "nsteps_accumulation_gradient": 4, + "lr": 5e-05, + "lr_final": 1e-06, + "lr_warmup": 1.0, + "cycle_beta": true, + "wd": 0.1, + "wd_final": 0.1, + "warmup_iters": 75000, + "ld": 1.0, + "drop_path": 0.0, + "ema": 0.9995, + "f16": "f16", + "clipping": 1.0, + "losses": { + "depth": { + "name": "Scale", + "weight": 1.0, + "fn": "l1", + "gamma": 1.0, + "alpha": 1.0, + "output_fn": "sqrt", + "input_fn": "log" + }, + "camera": { + "name": "PolarRegression", + "weight": 1.0, + "gamma": 1.0, + "alpha": 1.0, + "fn": "l1", + "output_fn": "sqrt", + "input_fn": "linear", + "dims": [ + 1, + 2 + ], + "polar_weight": 3.0, + "polar_asym": 0.7 + }, + "confidence": { + "name": "Confidence", + "weight": 0.1, + "input_fn": "log", + "output_fn": "sqrt" + } + } + }, + "data": { + "image_shape": [ + 518, + 518 + ], + "resize_method": "contextcrop", + "normalization": "imagenet", + "pair": 1, + "mini": 1.0, + "num_frames": 1, + "sampling": { + "KITTI": 1.0 + }, + "train_datasets": [ + "KITTI" + ], + "val_datasets": [ + "KITTI" + ], + "data_root": "datasets", + "crop": "garg", + "augmentations": { + "random_scale": 4.0, + "random_translate_x": 0.04, + "random_translate_y": 0.01, + "scale_p": 0.0, + "translate_p": 0.0, + "random_rotation": 0.0, + "rotation_p": 0.0, + "random_shear": 0.0, + "affine_p": 0.0, + "random_jitter": 0.5, + "jitter_p": 1.0, + "random_blur": 2.0, + "blur_p": 0.5, + "random_gamma": 0.5, + "gamma_p": 1.0, + "grayscale_p": 0.2, + "flip_p": 0.5, + "cut_p": 0.0, + "invert_p": 0.0, + "shape_mult": 14, + "noise_pad": 1.0, + "test_context": 1.0 + }, + "shape_constraints": { + "ratio_bounds": [ + 0.5, + 2.5 + ], + "pixels_max": 600000.0, + "pixels_min": 200000.0, + "height_min": 15, + "width_min": 15, + "shape_mult": 14, + "sample": true + } + }, + "model": { + "name": "UniK3D", + "num_heads": 8, + "expansion": 4, + "num_steps": 100000, + "layer_scale": 1e-4, + "camera": { + "augment": true, + "weak_ratio": 0.9, + "tau": 50000 + }, + "pixel_decoder": { + "name": "Decoder", + "hidden_dim": 256, + "dropout": 0.0, + "depths": [ + 2, + 2, + 2 + ], + "detach": 0.1, + "out_dim": 32, + "kernel_size": 3, + "num_prompt_blocks": 1, + "use_norm": false + }, + "pixel_encoder": { + "lr": 3e-06, + "wd": 0.1, + "name": "dinov2_vits14", + "frozen_stages": 0, + "num_register_tokens": 0, + "use_norm": true, + "freeze_norm": true, + "pretrained": null, + "stacking_fn": "last", + "output_idx": [ + 3, + 6, + 9, + 12 + ] + } + } +} \ No newline at end of file diff --git a/gradio_demo.py b/gradio_demo.py new file mode 100644 index 0000000000000000000000000000000000000000..7f7c654cda3d22989cf54fa0a25c91d5f46341dd --- /dev/null +++ b/gradio_demo.py @@ -0,0 +1,796 @@ +import gc +import os +import shutil +import time +from datetime import datetime +from math import pi + +import gradio as gr +import numpy as np +import torch +import trimesh +from PIL import Image + +from unik3d.models import UniK3D +from unik3d.utils.camera import OPENCV, Fisheye624, Pinhole, Spherical +from unik3d.utils.visualization import colorize + + +def predictions_to_glb( + predictions, + mask_black_bg=False, + mask_far_points=False, +) -> trimesh.Scene: + print("Building GLB scene") + images = predictions["image"].squeeze().permute(1, 2, 0).cpu().numpy() + world_points = predictions["points"].squeeze().permute(1, 2, 0).cpu().numpy() + + vertices_3d = world_points.reshape(-1, 3) + # flip x and y + vertices_3d[:, 1] *= -1 + vertices_3d[:, 0] *= -1 + colors_rgb = (images.reshape(-1, 3)).astype(np.uint8) + + if mask_black_bg: + black_bg_mask = colors_rgb.sum(axis=1) >= 16 + vertices_3d = vertices_3d[black_bg_mask] + colors_rgb = colors_rgb[black_bg_mask] + + if mask_far_points: + far_points_mask = np.linalg.norm(vertices_3d, axis=-1) < 100.0 + vertices_3d = vertices_3d[far_points_mask] + colors_rgb = colors_rgb[far_points_mask] + + scene_3d = trimesh.Scene() + point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb) + scene_3d.add_geometry(point_cloud_data) + + return scene_3d + + +def instantiate_model(model_name): + type_ = model_name[0].lower() + + name = f"unik3d-vit{type_}" + model = UniK3D.from_pretrained(f"lpiccinelli/{name}") + + # Set resolution level and interpolation mode as specified. + model.resolution_level = 9 + model.interpolation_mode = "bilinear" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device).eval() + return model + + +def instantiate_camera(camera_name, params, device): + if camera_name == "Predicted": + return None + fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov, H, W = params + if camera_name == "Pinhole": + params = [fx, fy, cx, cy] + elif camera_name == "Fisheye624": + params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2] + elif camera_name == "OPENCV": + params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2] + elif camera_name == "Equirectangular": + # dummy intrinsics for spherical camera, assume hfov -> vfov based on input shapes + hfov2 = hfov * pi / 180.0 / 2 + params = [fx, fy, cx, cy, W, H, hfov2, H / W * hfov2] + camera_name = "Spherical" + + return eval(camera_name)(params=torch.tensor(params).float()).to(device) + + +def run_model(target_dir, model_name, camera_name, params): + + print("Instantiating model and camera...") + model = instantiate_model(model_name) + + image_names = [x for x in os.listdir(target_dir) if x.endswith(".png")] + input_image = np.array(Image.open(os.path.join(target_dir, image_names[-1]))) + image_tensor = torch.from_numpy(input_image).permute(2, 0, 1).unsqueeze(0).float() + device = next(model.parameters()).device + image_tensor = image_tensor.to(device) + H, W = image_tensor.shape[-2:] + params = params + [H, W] + camera = instantiate_camera(camera_name, params=params, device=device) + + # Perform inference with the model. + print("Running inference...") + outputs = model.infer(image_tensor, camera=camera, normalize=True) + outputs["image"] = image_tensor + + return outputs + + +def gradio_demo( + target_dir, + model_name, + camera_name, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, +): + print(target_dir) + if not os.path.isdir(target_dir) or target_dir == "None": + return None, "No valid target directory found. Please upload first.", None + + start_time = time.time() + gc.collect() + + print("Running run_model...") + params = [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov] + with torch.no_grad(): + outputs = run_model(target_dir, model_name, camera_name, params) + + # Save predictions + points = outputs["points"].squeeze().permute(1, 2, 0).cpu().numpy() + rgb = outputs["image"].squeeze().permute(1, 2, 0).cpu().numpy() + + prediction_save_path = os.path.join(target_dir, "predictions.npz") + np.savez(prediction_save_path, {"points": points, "image": rgb}) + + # Build a GLB file name + glbfile = os.path.join( + target_dir, + f"glbscene.glb", + ) + + # Convert predictions to GLB + glbscene = predictions_to_glb( + outputs, + mask_black_bg=mask_black_bg, + mask_far_points=mask_far_points, + ) + glbscene.export(file_obj=glbfile) + + # Cleanup + del outputs + gc.collect() + + end_time = time.time() + print(f"Total time: {end_time - start_time:.2f} seconds") + log_msg = f"Success. Waiting for visualization." + + return glbfile, log_msg, prediction_save_path + + +def handle_uploads(input_image): + gc.collect() + + timestamp = datetime.now().strftime("%Y%m%d_%H%M%S_%f") + tmpdir = os.environ.get("TMPDIR", "/tmp") + target_dir = os.path.join(tmpdir, f"input_images_{timestamp}") + + if os.path.exists(target_dir): + shutil.rmtree(target_dir) + os.makedirs(target_dir) + + dst_path = os.path.join(target_dir, "image.png") + Image.fromarray(input_image).save(dst_path) + image_paths = [dst_path] + + print(f"Files uploaded.") + return target_dir, image_paths + + +def update_gallery_on_upload(input_images): + if input_images is None: + return None, None + target_dir, image_path = handle_uploads(input_images) + return target_dir, "Upload complete. Click 'Run UniK3D' to get 3D pointcloud." + + +def update_parameters(camera): + if camera == "Pinhole": + return ( + gr.update(visible=True), # fx + gr.update(visible=True), # fy + gr.update(visible=True), # cx + gr.update(visible=True), # cy + gr.update(visible=False), # k1 + gr.update(visible=False), # k2 + gr.update(visible=False), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=False), # t1 + gr.update(visible=False), # t2 + gr.update(visible=False), # hfov + ) + elif camera == "OPENCV": + return ( + gr.update(visible=True), # fx + gr.update(visible=True), # fy + gr.update(visible=True), # cx + gr.update(visible=True), # cy + gr.update(visible=True), # k1 + gr.update(visible=True), # k2 + gr.update(visible=True), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=True), # t1 + gr.update(visible=True), # t2 + gr.update(visible=False), # hfov + ) + elif camera == "Fisheye624": + return ( + gr.update(visible=True), # fx + gr.update(visible=True), # fy + gr.update(visible=True), # cx + gr.update(visible=True), # cy + gr.update(visible=True), # k1 + gr.update(visible=True), # k2 + gr.update(visible=True), # k3 + gr.update(visible=True), # k4 + gr.update(visible=True), # k5 + gr.update(visible=True), # k6 + gr.update(visible=True), # t1 + gr.update(visible=True), # t2 + gr.update(visible=False), # hfov + ) + elif camera == "Equirectangular": + return ( + gr.update(visible=False), # fx + gr.update(visible=False), # fy + gr.update(visible=False), # cx + gr.update(visible=False), # cy + gr.update(visible=False), # k1 + gr.update(visible=False), # k2 + gr.update(visible=False), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=False), # t1 + gr.update(visible=False), # t2 + gr.update(visible=True), # hfov + ) + elif camera == "Predicted": + return ( + gr.update(visible=False), # fx + gr.update(visible=False), # fy + gr.update(visible=False), # cx + gr.update(visible=False), # cy + gr.update(visible=False), # k1 + gr.update(visible=False), # k2 + gr.update(visible=False), # k3 + gr.update(visible=False), # k4 + gr.update(visible=False), # k5 + gr.update(visible=False), # k6 + gr.update(visible=False), # t1 + gr.update(visible=False), # t2 + gr.update(visible=False), # hfov + ) + else: + raise ValueError(f"Invalid camera type: {camera}") + + +def clear_fields(): + return None + + +def update_log(): + return "Loading Model and Running Inference..." + + +def update_visualization(target_dir, mask_black_bg, mask_far_points, is_example): + + if is_example == "True": + return ( + None, + "No reconstruction available. Please click the Reconstruct button first.", + ) + + if not target_dir or target_dir == "None" or not os.path.isdir(target_dir): + return ( + None, + "No reconstruction available. Please click the Reconstruct button first.", + ) + + predictions_path = os.path.join(target_dir, "predictions.npz") + if not os.path.exists(predictions_path): + return ( + None, + f"No reconstruction available at {predictions_path}. Please run 'Reconstruct' first.", + ) + + loaded = np.load(predictions_path, allow_pickle=True) + predictions = {key: loaded[key] for key in loaded.keys()} + + glbfile = os.path.join( + target_dir, + f"glbscene.glb", + ) + + if not os.path.exists(glbfile): + glbscene = predictions_to_glb( + predictions, + mask_black_bg=mask_black_bg, + mask_far_points=mask_far_points, + ) + glbscene.export(file_obj=glbfile) + + return glbfile, "Updating Visualization" + + +if __name__ == "__main__": + theme = gr.themes.Citrus() + theme.set( + checkbox_label_background_fill_selected="*button_primary_background_fill", + checkbox_label_text_color_selected="*button_primary_text_color", + ) + + with gr.Blocks( + theme=theme, + css=""" + .custom-log * { + font-style: italic; + font-size: 22px !important; + background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%); + -webkit-background-clip: text; + background-clip: text; + font-weight: bold !important; + color: transparent !important; + text-align: center !important; + } + + .example-log * { + font-style: italic; + font-size: 16px !important; + background-image: linear-gradient(120deg, #ff7e26 0%, #ff9c59 60%, #fff4d6 100%); + -webkit-background-clip: text; + background-clip: text; + color: transparent !important; + } + + #my_radio .wrap { + display: flex; + flex-wrap: nowrap; + justify-content: center; + align-items: center; + } + + #my_radio .wrap label { + display: flex; + width: 50%; + justify-content: center; + align-items: center; + margin: 0; + padding: 10px 0; + box-sizing: border-box; + } + """, + ) as demo: + + # Instead of gr.State, we use a hidden Textbox: + is_example = gr.Textbox(label="is_example", visible=False, value="None") + + gr.HTML( + """ +

UniK3D: Universal Camera Monocular 3D Estimation

+

+ 🌟 GitHub Repository | + 🚀 Project Page +

+ +
+

Upload one image to create a 3D estimation of a scene or object. UniK3D allows to predict directly 3D of any camera and scene.

+ +

Getting Started:

+
    +
  1. Upload Your Image: Use the "Upload Images" panel to provide your input.
  2. +
  3. Run: Click the "Run UniK3D" button to start the 3D estimation process.
  4. +
  5. Visualize: The 3D reconstruction will appear in the viewer on the right. You can rotate, pan, and zoom to explore the model, and download the GLB file.
  6. +
+

Please note: Our model runs on CPU on HuggingFace Space. Actual inference is less than 100ms second per image on consumer-level GPUs. Web-based 3D pointcloud visualization may be slow due to Gradio's rendering. For faster visualization, use a local machine to run our demo from our GitHub repository.

+
+ """ + ) + + target_dir_output = gr.Textbox(label="Target Dir", visible=False, value="None") + + with gr.Row(): + with gr.Column(): + camera_dropdown = gr.Dropdown( + choices=[ + "Predicted", + "Pinhole", + "Fisheye624", + "OPENCV", + "Equirectangular", + ], + label="Input Camera", + ) + model_dropdown = gr.Dropdown( + choices=["Large", "Base", "Small"], label="Utilized Model" + ) + mask_black_bg = gr.Checkbox( + label="Filter Black Background", value=False + ) + mask_far_points = gr.Checkbox(label="Filter Far Points", value=False) + + with gr.Column(): + fx = gr.Number(label="Focal length x", value=500.0, visible=False) + fy = gr.Number(label="Focal length y", value=500.0, visible=False) + cx = gr.Number(label="Center projection x", value=320.0, visible=False) + cy = gr.Number(label="Center projection y", value=240.0, visible=False) + hfov = gr.Number( + label="Horizontal FoV (degree)", value=0.0, visible=False + ) + + with gr.Column(): + k1 = gr.Number(label="Radial 1", value=0.0, visible=False) + k2 = gr.Number(label="Radial 2", value=0.0, visible=False) + k3 = gr.Number(label="Radial 3", value=0.0, visible=False) + k4 = gr.Number(label="Radial 4", value=0.0, visible=False) + + with gr.Column(): + k5 = gr.Number(label="Radial 5", value=0.0, visible=False) + k6 = gr.Number(label="Radial 6", value=0.0, visible=False) + t1 = gr.Number(label="Tangential 1", value=0.0, visible=False) + t2 = gr.Number(label="Tangential 2", value=0.0, visible=False) + + with gr.Row(): + with gr.Column(scale=1): + input_image = gr.Image(label="Upload Images") + gr.Markdown("**3D Estimation**") + with gr.Row(): + log_output = gr.Markdown( + "Please upload one image at a time, then click `Run UniK3D`.", + elem_classes=["custom-log"], + ) + reconstruction_npy = gr.File( + label="Download 3D Pointcloud", type="filepath" + ) + + with gr.Column(scale=2): + reconstruction_output = gr.Model3D( + height=520, zoom_speed=0.5, pan_speed=0.5 + ) + with gr.Row(): + submit_btn = gr.Button("Run UniK3D", scale=1, variant="primary") + clear_btn = gr.ClearButton( + [ + input_image, + reconstruction_output, + log_output, + target_dir_output, + reconstruction_npy, + ], + scale=1, + ) + + examples = [ + [ + "assets/demo/poorthings.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + True, + False, + ], + [ + "assets/demo/naruto.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + False, + False, + ], + [ + "assets/demo/bears.png", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + True, + False, + ], + [ + "assets/demo/berzirk.jpg", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + True, + False, + ], + [ + "assets/demo/luke.webp", + "Large", + "Predicted", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + False, + False, + ], + [ + "assets/demo/equirectangular.jpg", + "Large", + "Equirectangular", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 360.0, + False, + False, + ], + [ + "assets/demo/venice.jpg", + "Large", + "Equirectangular", + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 360.0, + False, + True, + ], + [ + "assets/demo/dl3dv.png", + "Large", + "OPENCV", + 429.57611083984375, + 429.6898193359375, + 479.5, + 269.5, + -0.0014844092074781656, + 0.0007422995404340327, + 0.0, + 0.0, + 0.0, + 0.0, + 0.00012013866944471374, + 0.001125041046179831, + 0.0, + False, + False, + ], + [ + "assets/demo/scannet.png", + "Large", + "Fisheye624", + 791.90869140625, + 792.7230834960938, + 878.16796875, + 585.045166015625, + -0.029167557135224342, + -0.006803446915000677, + -0.0012682401575148106, + -4.6094228309812024e-05, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + False, + False, + ], + ] + + def example_pipeline( + input_image, + model_name, + camera_name, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ): + target_dir, image_path = handle_uploads(input_image) + glbfile, log_msg, prediction_save_path = gradio_demo( + target_dir, + model_name, + camera_name, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ) + return ( + glbfile, + log_msg, + prediction_save_path, + target_dir, + image_path, + ) + + gr.Markdown("Click any row to load an example.", elem_classes=["example-log"]) + + gr.Examples( + examples=examples, + inputs=[ + input_image, + model_dropdown, + camera_dropdown, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ], + outputs=[reconstruction_output, log_output, reconstruction_npy], + fn=example_pipeline, + cache_examples=False, + examples_per_page=50, + ) + + submit_btn.click( + fn=clear_fields, inputs=[], outputs=[reconstruction_output] + ).then(fn=update_log, inputs=[], outputs=[log_output]).then( + fn=gradio_demo, + inputs=[ + target_dir_output, + model_dropdown, + camera_dropdown, + fx, + fy, + cx, + cy, + k1, + k2, + k3, + k4, + k5, + k6, + t1, + t2, + hfov, + mask_black_bg, + mask_far_points, + ], + outputs=[reconstruction_output, log_output, reconstruction_npy], + ).then( + fn=lambda: "False", inputs=[], outputs=[is_example] + ) + + mask_black_bg.change( + update_visualization, + [target_dir_output, mask_black_bg, mask_far_points, is_example], + [reconstruction_output, log_output], + ) + + mask_far_points.change( + update_visualization, + [target_dir_output, mask_black_bg, mask_far_points, is_example], + [reconstruction_output, log_output], + ) + + input_image.change( + fn=update_gallery_on_upload, + inputs=[input_image], + outputs=[target_dir_output, log_output], + ) + + # Dynamically update intrinsic parameter visibility when camera selection changes. + camera_dropdown.change( + fn=update_parameters, + inputs=camera_dropdown, + outputs=[fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, t1, t2, hfov], + ) + + # demo.queue(max_size=20).launch(show_error=True, share=False, ssr_mode=False) + demo.launch( + show_error=True, + ) diff --git a/hubconf.py b/hubconf.py new file mode 100644 index 0000000000000000000000000000000000000000..1f66ecb5b1cb48e8169d25578656f3137cf4498b --- /dev/null +++ b/hubconf.py @@ -0,0 +1,29 @@ +dependencies = ["torch", "huggingface_hub"] + +import os +import json + +import torch +import huggingface_hub + +from unik3d.models import UniK3D as UniK3D_ + +BACKBONES = ["vitl", "vitb", "vits"] + + +def UniK3D(backbone="vitl", pretrained=True): + assert backbone in BACKBONES, f"backbone must be one of {BACKBONES}" + repo_dir = os.path.dirname(os.path.realpath(__file__)) + with open(os.path.join(repo_dir, "configs", f"config_{backbone}.json")) as f: + config = json.load(f) + + model = UniK3D_(config) + if pretrained: + path = huggingface_hub.hf_hub_download(repo_id=f"lpiccinelli/unik3d-{backbone}", filename=f"pytorch_model.bin", repo_type="model") + info = model.load_state_dict(torch.load(path), strict=False) + print(f"UniK3D-{backbone} is loaded with:") + print(f"\t missing keys: {info.missing_keys}") + print(f"\t additional keys: {info.unexpected_keys}") + + return model + diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 0000000000000000000000000000000000000000..fd34b34b16cc63a01cd357e6eea96433888acda0 --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,25 @@ +[build-system] +requires = ["setuptools"] +build-backend = "setuptools.build_meta" + +[tool.pyright] +include = ["unik3d"] + +[project] +name = "unik3d" +version = "0.1" +authors = [{name = "Luigi Piccinelli", email = "lpiccinelli@ethz.ch"}] +description = "UniK3D: Universal Monocular Metric Depth Estimation" +readme = "README.md" +license = { text="Creatives Common BY-NC 4.0 license"} +requires-python = ">=3.11.0" +dynamic = ["dependencies"] + +[tool.setuptools.dynamic] +dependencies = {file = ["requirements.txt"]} + +[tool.setuptools.package-data] +"*" = ["py.typed"] + +[tool.setuptools.packages.find] +include = ["unik3d*"] diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..536fd8919111c6731776590ebd8e821b04f73bd7 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,84 @@ +appdirs +attrs +black +blosc2 +botocore>=1.34.54 +certifi>=2022.12.7 +charset-normalizer +click +contourpy +cycler +docker-pycreds +einops>=0.7.0 +filelock +flake8>=7.0.0 +flake8-bugbear>=24.2.6 +flake8-comprehensions>=3.14.0 +fonttools +fsspec +fvcore>=0.1.5.post20221221 +gitdb +GitPython +gradio +h5py>=3.10.0 +huggingface-hub>=0.22.0 +idna +imageio +imath +iopath +isort +Jinja2 +jmespath +kiwisolver +MarkupSafe +matplotlib +mccabe +mpmath +msgpack +mypy-extensions +ndindex +networkx +ninja +numexpr +numpy<2.0.0 +opencv-python +OpenEXR +packaging +pandas +pathspec +pillow>=10.2.0 +platformdirs +portalocker +protobuf>=4.25.3 +psutil +py-cpuinfo +pycodestyle +pyflakes +pyparsing +python-dateutil +pytz +PyYAML +requests +safetensors +scipy +sentry-sdk +setproctitle +six +smmap +sympy +tables +tabulate +termcolor +timm +tqdm +trimesh +triton>=2.4.0 +typing_extensions +tzdata==2024.1 +urllib3==1.26.13 +wandb +yacs +torch>=2.4.0 +torchvision>=0.19.0 +torchaudio>=2.4.0 +xformers>=0.0.26 \ No newline at end of file diff --git a/requirements_demo.txt b/requirements_demo.txt new file mode 100644 index 0000000000000000000000000000000000000000..536fd8919111c6731776590ebd8e821b04f73bd7 --- /dev/null +++ b/requirements_demo.txt @@ -0,0 +1,84 @@ +appdirs +attrs +black +blosc2 +botocore>=1.34.54 +certifi>=2022.12.7 +charset-normalizer +click +contourpy +cycler +docker-pycreds +einops>=0.7.0 +filelock +flake8>=7.0.0 +flake8-bugbear>=24.2.6 +flake8-comprehensions>=3.14.0 +fonttools +fsspec +fvcore>=0.1.5.post20221221 +gitdb +GitPython +gradio +h5py>=3.10.0 +huggingface-hub>=0.22.0 +idna +imageio +imath +iopath +isort +Jinja2 +jmespath +kiwisolver +MarkupSafe +matplotlib +mccabe +mpmath +msgpack +mypy-extensions +ndindex +networkx +ninja +numexpr +numpy<2.0.0 +opencv-python +OpenEXR +packaging +pandas +pathspec +pillow>=10.2.0 +platformdirs +portalocker +protobuf>=4.25.3 +psutil +py-cpuinfo +pycodestyle +pyflakes +pyparsing +python-dateutil +pytz +PyYAML +requests +safetensors +scipy +sentry-sdk +setproctitle +six +smmap +sympy +tables +tabulate +termcolor +timm +tqdm +trimesh +triton>=2.4.0 +typing_extensions +tzdata==2024.1 +urllib3==1.26.13 +wandb +yacs +torch>=2.4.0 +torchvision>=0.19.0 +torchaudio>=2.4.0 +xformers>=0.0.26 \ No newline at end of file diff --git a/scripts/README.md b/scripts/README.md new file mode 100644 index 0000000000000000000000000000000000000000..f35f660e7ee343c9d7a962d4d366a81c610f786c --- /dev/null +++ b/scripts/README.md @@ -0,0 +1,55 @@ +## Training + +We provide the `train.py` script that allows to load the dataset, initialize and start the training. From the root of the repo: + +```bash +export REPO=`pwd` +export PYTHONPATH=${REPO}:${PYTHONPATH} + +# Adapt all this to your setup +export TMPDIR="/tmp" +export TORCH_HOME=${TMPDIR} +export HUGGINGFACE_HUB_CACHE=${TMPDIR} +export WANDB_HOME=${TMPDIR} +export DATAROOT= + + +export MASTER_PORT=$((( RANDOM % 600 ) + 29400 )) +if [ $NNODES -gt 1 ]; then + export MASTER_PORT=29400 +fi + +# this is the config will be used +export CFG="config_vitl.json" +``` + +If you are on a machine without SLURM you can run the following: +```bash +# make the following input-dependent for multi-node +export NNODES=1 +export RANK=0 +export MASTER_ADDR=127.0.0.1 +export CUDA_VISIBLE_DEVICES="0" # set yours + +export GPUS=$(echo ${CUDA_VISIBLE_DEVICES} | tr ',' '\n' | wc -l) +echo "Start script with python from: `which python`" +torchrun --rdzv-backend=c10d --nnodes=${NNODES} --nproc_per_node=${GPUS} --rdzv-endpoint ${MASTER_ADDR}:${MASTER_PORT} ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --distributed +``` + +If you system has SLURM, all the information will be set by the scheduler and you have to run just: +```bash +srun -c ${SLURM_CPUS_PER_TASK} --kill-on-bad-exit=1 python -u ${REPO}/scripts/train.py --config-file ${REPO}/configs/${CFG} --master-port ${MASTER_PORT} --distributed +``` + + +### Datasets + +We used both image-based and sequence-based dataset. The `ImageDataset` class is actually for legacy only as we moved image-based dataset to be "dummy" single-frame sequences.
+We [provide two example dataset to get familiar to the pipeline and structure, namely iBims-1 and Sintel](https://drive.google.com/drive/folders/1FKsa5-b3EX0ukZq7bxord5fC5OfUiy16?usp=sharing), image- and sequence-based, respectively.
+You can adapt the data loading and processing to your example; however, you will need to keep the same interface for the model to be consisten and train "out-of-the-box" the model.
+ + +### Additional dependencies + +We require chamfer distance for the evaluation, you can compile the knn operation under `ops/knn`: `bash compile.sh` from the directory `$REPO/unik3d/ops/knn`. Set the correct `export TORCH_CUDA_ARCH_LIST`, according to the hardware you are working on. +For training and to perform augmentation, you can use `camera_augmenter.py`; however the splatting requires you to install operations by cloning and installing from `github.com/hperrot/splatting`. \ No newline at end of file diff --git a/scripts/demo.py b/scripts/demo.py new file mode 100644 index 0000000000000000000000000000000000000000..fd6765b53e06b3d0d24f8a65561d0021a0d9b4bf --- /dev/null +++ b/scripts/demo.py @@ -0,0 +1,150 @@ +import argparse +import json +import os + +import numpy as np +import torch +import torch.nn as nn +from PIL import Image + +from unik3d.models import UniK3D +from unik3d.utils.camera import (MEI, OPENCV, BatchCamera, Fisheye624, Pinhole, + Spherical) +from unik3d.utils.visualization import colorize, save_file_ply + +SAVE = False +BASE_PATH = os.path.join( + os.path.dirname(os.path.dirname(os.path.abspath(__file__))), "assets", "demo" +) + + +def infer(model, rgb_path, camera_path, rays=None): + rgb = np.array(Image.open(rgb_path)) + rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1) + + camera = None + if camera_path is not None: + with open(camera_path, "r") as f: + camera_dict = json.load(f) + + params = torch.tensor(camera_dict["params"]) + name = camera_dict["name"] + assert name in ["Fisheye624", "Spherical", "OPENCV", "Pinhole", "MEI"] + camera = eval(name)(params=params) + + outputs = model.infer(rgb=rgb_torch, camera=camera, normalize=True, rays=rays) + + return rgb_torch, outputs + + +def infer_equirectangular(model, rgb_path): + rgb = np.array(Image.open(rgb_path)) + rgb_torch = torch.from_numpy(rgb).permute(2, 0, 1) + + # assuming full equirectangular image horizontally + H, W = rgb.shape[:2] + hfov_half = np.pi + vfov_half = np.pi * H / W + assert vfov_half <= np.pi / 2 + + params = [W, H, hfov_half, vfov_half] + camera = Spherical(params=torch.tensor([1.0] * 4 + params)) + + outputs = model.infer(rgb=rgb_torch, camera=camera, normalize=True) + return rgb_torch, outputs + + +def save(rgb, outputs, name, base_path, save_pointcloud=False): + depth = outputs["depth"] + rays = outputs["rays"] + points = outputs["points"] + + depth = depth.cpu().numpy() + rays = ((rays + 1) * 127.5).clip(0, 255) + + Image.fromarray(colorize(depth.squeeze())).save( + os.path.join(base_path, f"{name}_depth.png") + ) + Image.fromarray(rgb.squeeze().permute(1, 2, 0).cpu().numpy()).save( + os.path.join(base_path, f"{name}_rgb.png") + ) + Image.fromarray(rays.squeeze().permute(1, 2, 0).byte().cpu().numpy()).save( + os.path.join(base_path, f"{name}_rays.png") + ) + + if save_pointcloud: + predictions_3d = points.permute(0, 2, 3, 1).reshape(-1, 3).cpu().numpy() + rgb = rgb.permute(1, 2, 0).reshape(-1, 3).cpu().numpy() + save_file_ply(predictions_3d, rgb, os.path.join(base_path, f"{name}.ply")) + + +def demo(model): + # RGB + CAMERA + rgb, outputs = infer( + model, + os.path.join(BASE_PATH, f"scannet.png"), + os.path.join(BASE_PATH, "scannet.json"), + ) + if SAVE: + save(rgb, outputs, name="scannet", base_path=BASE_PATH) + + # get GT and pred + pts_pred = outputs["points"].squeeze().cpu().permute(1, 2, 0).numpy() + pts_gt = np.load("./assets/demo/scannet.npy").astype(float) + mask = np.linalg.norm(pts_gt, axis=-1) > 0 + error = np.linalg.norm(pts_pred - pts_gt, axis=-1) + error = np.mean(error[mask] ** 2) ** 0.5 + + # Trade-off between speed and resolution + model.resolution_level = 1 + rgb, outputs = infer( + model, + os.path.join(BASE_PATH, f"scannet.png"), + os.path.join(BASE_PATH, "scannet.json"), + ) + if SAVE: + save(rgb, outputs, name="scannet_lowres", base_path=BASE_PATH) + + # RGB + rgb, outputs = infer(model, os.path.join(BASE_PATH, f"poorthings.jpg"), None) + if SAVE: + save(rgb, outputs, name="poorthings", base_path=BASE_PATH) + + # RGB + CAMERA + rgb, outputs = infer( + model, + os.path.join(BASE_PATH, f"dl3dv.png"), + os.path.join(BASE_PATH, "dl3dv.json"), + ) + if SAVE: + save(rgb, outputs, name="dl3dv", base_path=BASE_PATH) + + # EQUIRECTANGULAR + rgb, outputs = infer_equirectangular( + model, os.path.join(BASE_PATH, f"equirectangular.jpg") + ) + if SAVE: + save(rgb, outputs, name="equirectangular", base_path=BASE_PATH) + + print("Output keys are", outputs.keys()) + + if SAVE: + print("Done! Results saved in", BASE_PATH) + + print(f"RMSE on 3D clouds for ScanNet++ sample: {100*error:.1f}cm") + + +if __name__ == "__main__": + print("Torch version:", torch.__version__) + type_ = "l" # available types: s, b, l + name = f"unik3d-vit{type_}" + model = UniK3D.from_pretrained(f"lpiccinelli/{name}") + + # set resolution level in [0,10) and output interpolation + model.resolution_level = 9 + model.interpolation_mode = "bilinear" + + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + model = model.to(device).eval() + + demo(model) diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..95d25552fa45af91d42e5e89c8462f0a4205b545 --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,630 @@ +import argparse +import json +import os +import random +import uuid +from contextlib import nullcontext +from copy import deepcopy +from datetime import datetime as dt +from functools import partial +from math import log2 +from time import sleep, time +from typing import Any, Dict + +import git +import numpy as np +import psutil +import torch +import torch.nn as nn +import torch.utils.data.distributed +import wandb +from PIL import Image +from torch import distributed as dist +from torch import optim +from torch.nn.parallel.distributed import DistributedDataParallel +from torch.utils.data import DataLoader, RandomSampler, SequentialSampler +from tqdm import tqdm + +import unik3d.datasets as datasets +from unik3d.datasets import (ConcatDataset, DistributedSamplerNoDuplicate, + collate_fn, get_weights) +from unik3d.models import UniK3D +from unik3d.ops.scheduler import CosineScheduler +from unik3d.utils import (barrier, format_seconds, is_main_process, + log_train_artifacts, validate) +from unik3d.utils.distributed import (create_local_process_group, + local_broadcast_process_authkey, + setup_multi_processes, setup_slurm, + sync_string_across_gpus, + sync_tensor_across_gpus) +from unik3d.utils.ema_torch import (DummyExponentialMovingAverage, + ExponentialMovingAverage) +from unik3d.utils.misc import calculate_mean_values + +EMA_INTERVAL = 10 +EMA_TAU = 10000 +EMA_START = 50000 + + +MAP_DTYPE = { + "f16": torch.float16, + "bf16": torch.bfloat16, + "f32": torch.float32, +} + + +def aggregate_sync_losses(dict_: dict[str, torch.Tensor], device): + keys = list(dict_.keys()) + values = torch.tensor(list(dict_.values()), device=device) + keys = sync_string_across_gpus(keys, device) + values = sync_tensor_across_gpus(values, dim=0).cpu().tolist() + dict_ = calculate_mean_values(keys, values) + return dict_ + + +def main_worker(config: Dict[str, Any], args: argparse.Namespace): + + current_process = psutil.Process(os.getpid()) + device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") + seed = config["generic"]["seed"] + + if not args.distributed: + args.rank = 0 + args.local_rank = 0 + args.world_size = 1 + else: + # initializes the distributed backend which will take care of synchronizing nodes/GPUs + setup_multi_processes(config) + is_slurm = "SLURM_PROCID" in os.environ + if is_slurm: + setup_slurm("nccl", port=args.master_port) + args.rank = int(os.environ["RANK"]) + args.world_size = int(os.environ["WORLD_SIZE"]) + args.local_rank = device = int(os.environ["LOCAL_RANK"]) + if not is_slurm: + import datetime + + dist.init_process_group( + "nccl", + rank=args.rank, + world_size=args.world_size, + timeout=datetime.timedelta(seconds=30 * 60), + ) + torch.cuda.set_device(device) + create_local_process_group() + local_broadcast_process_authkey() + print( + f"Start running DDP on: {args.rank} (local: {args.local_rank}) with seed {seed + args.rank}." + ) + config["training"]["batch_size"] = int( + config["training"]["batch_size"] / args.world_size + ) + dist.barrier() + + # Fix seed + # Different for every machine to avoid sampling + # the same element across machines + seed = seed + args.rank + random.seed(seed) + np.random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + os.environ["PYTHONHASHSEED"] = str(seed) + + batch_size = config["training"]["batch_size"] + if is_main_process(): + print("Config: ", args.config_file) + print( + f"Torch version:{torch.__version__}, cuda:{torch.version.cuda}, cudnn:{torch.backends.cudnn.version()}, threads:{torch.get_num_threads()}" + ) + print("BatchSize per GPU: ", batch_size) + print( + f"Divided into {config['training']['nsteps_accumulation_gradient']} accumulation step" + ) + + ############################## + ########### MODEL ############ + ############################## + # Build model + model = UniK3D(config).to(device) + model.eval() + print(f"MODEL: {model.__class__.__name__} at {model.device}") + torch.cuda.empty_cache() + + if args.distributed: + model = nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DistributedDataParallel( + model, + find_unused_parameters=False, + device_ids=[device], + output_device=device, + ) + + ############################## + ######### OPTIMIZER ########## + ############################## + dtype_16bit = config["training"]["f16"] + is_16bit = dtype_16bit != "f32" + clipping = config["training"].get("clipping", None) + + # Optimize + ddp_model = model.module if args.distributed else model + params = ddp_model.get_params(config) + optimizer = optim.AdamW( + params, + eps=6e-8 if is_16bit else 1e-8, # smallest subnormal fp16 number is 5.96e-8 + # amsgrad=is_16bit, # use max instead of avg v_hat, avoid small number divisions? + ) + + # Load Model: + step = 0 + if config["training"].get("pretrained", None) is not None: + ddp_model.load_pretrained(config["training"]["pretrained"]) + pretrained = torch.load( + config["training"]["pretrained"], map_location="cpu", weights_only=False + ) + try: + optimizer.load_state_dict(pretrained["optimizer"]) + except Exception as e: + if is_main_process(): + print("Could not load optimizer state dict:", e) + step = pretrained.get("step", 0) + ddp_model.pixel_decoder.steps = step + + # EMA + ema_class = ( + ExponentialMovingAverage + if config["training"]["ema"] > 0.0 + else DummyExponentialMovingAverage + ) + ema_handle = ema_class( + ddp_model.parameters_grad(), + 1 - (1 - config["training"]["ema"]) * EMA_INTERVAL, + update_after_step=config["training"]["warmup_iters"] / EMA_INTERVAL, + switch=True, + tau=EMA_TAU // EMA_INTERVAL, + ) + setattr(ema_handle, "num_updates", step // EMA_INTERVAL) + + ############################## + ######### GENERICS ########### + ############################## + resize_method = config["data"].get("resize_method", "hard") + crop = config["data"].get("crop", "garg") + augmentations_db = config["data"].get("augmentations", {}) + shape_constraints = config["data"].get("shape_constraints", {}) + image_shape = config["data"]["image_shape"] + mini = config["data"]["mini"] + nsteps_accumulation_gradient = config["training"]["nsteps_accumulation_gradient"] + batch_size = config["training"]["batch_size"] + clipping_fn = torch.nn.utils.clip_grad_norm_ + + is_shell = int(os.environ.get("SHELL_JOB", 0)) + run_id = sync_string_across_gpus( + [f"{dt.now().strftime('%d-%h_%H-%M')}-{uuid.uuid4()}"], device + )[0] + + if not is_shell and is_main_process(): + repo_folder = os.path.dirname(os.path.realpath(__file__)) + try: + repo = git.Repo(repo_folder) + current_head = repo.head if repo.head.is_detached else repo.active_branch + notes = f"MESSAGE: {current_head.commit.message} HASH:{current_head.commit.hexsha} BRANCH:{current_head.name}" + except: + print(f"problem with {repo_folder}, does it exist?") + notes = "" + + # restore the original batchsize, not acquired by other calls from now on + if args.distributed: + config["training"]["batch_size"] = ( + config["training"]["batch_size"] * args.world_size + ) + wandb.init( + project="UniK3D", + name=run_id, + config=config, + tags=None, + notes=notes, + dir=os.environ.get("WANDB_HOME", os.environ.get("TMPDIR", "/tmp")), + ) + wandb.watch(model) + + ############################## + ########## DATASET ########### + ############################## + # Datasets loading + train_datasets, val_datasets = {}, {} + if is_main_process(): + print("Loading training datasets...") + dims = 0 + + for dataset in config["data"]["train_datasets"]: + assert hasattr(datasets, dataset), f"{dataset} not a custom dataset" + train_dataset: datasets.BaseDataset = getattr(datasets, dataset) + train_datasets[dataset] = train_dataset( + image_shape=image_shape, + split_file=train_dataset.train_split, + test_mode=False, + crop=crop, + augmentations_db=augmentations_db, + shape_constraints=shape_constraints, + normalize=config["data"].get("normalization", "imagenet"), + resize_method=resize_method, + mini=mini, + num_frames=config["data"].get("num_frames", 1), + fps_range=[1, 5], + num_copies=config["data"]["pair"], + ) + dim = ( + train_datasets[dataset].dataset._addr.numel() * 8 + + train_datasets[dataset].dataset._lst.numel() + ) / (2**20) + if hasattr(train_datasets[dataset], "sequences"): + dim += ( + train_datasets[dataset].sequences._addr.numel() * 8 + + train_datasets[dataset].sequences._lst.numel() + ) / (2**20) + dims = dims + dim + if is_main_process(): + print(f"{dataset}: {dim:.1f}MB") + + print(f"All training datasets loaded, with total size: {dims:.1f}MB") + + barrier() + + assert batch_size % config["data"]["pair"] == 0 + batch_size = batch_size // config["data"]["pair"] + assert batch_size % nsteps_accumulation_gradient == 0 + batch_chunk = batch_size // nsteps_accumulation_gradient + + train_dataset = ConcatDataset( + list(train_datasets.values()), + shape_constraints=shape_constraints, + ) + + if is_main_process(): + print("Loading validation datasets...") + for dataset in config["data"]["val_datasets"]: + val_dataset: datasets.BaseDataset = getattr(datasets, dataset) + val_datasets[dataset] = val_dataset( + image_shape=image_shape, + split_file=val_dataset.test_split, + test_mode=True, + crop=crop, + shape_constraints=shape_constraints, + augmentations_db=augmentations_db, + normalize=config["data"].get("normalization", "imagenet"), + resize_method=resize_method, + num_frames=1, + mini=1.0, + num_copies=1, + ) + + # Dataset samplers, create distributed sampler pinned to rank + if args.distributed: + sampling = deepcopy(config["data"]["sampling"]) + weights, num_samples = get_weights(train_datasets, sampling) + train_sampler = torch.utils.data.WeightedRandomSampler( + weights, num_samples, replacement=True + ) + valid_samplers = { + k: DistributedSamplerNoDuplicate( + v, + num_replicas=args.world_size, + rank=args.rank, + shuffle=False, + drop_last=False, + ) + for k, v in val_datasets.items() + } + else: + train_sampler = RandomSampler(train_dataset) + valid_samplers = {k: SequentialSampler(v) for k, v in val_datasets.items()} + + train_sampler = torch.utils.data.BatchSampler( + train_sampler, batch_size=batch_size, drop_last=True + ) + + # Dataset loader + val_batch_size = 1 + num_workers = int(os.environ.get("SLURM_CPUS_PER_TASK", 4)) + train_loader = DataLoader( + train_dataset, + num_workers=num_workers, + sampler=train_sampler, + pin_memory=True, + collate_fn=partial(collate_fn, is_batched=True), + persistent_workers=True if num_workers else None, + ) + val_loaders = { + name_dataset: DataLoader( + dataset, + batch_size=val_batch_size, + shuffle=False, + num_workers=num_workers, + sampler=valid_samplers[name_dataset], + pin_memory=True, + drop_last=False, + collate_fn=partial(collate_fn, is_batched=False), + ) + for name_dataset, dataset in val_datasets.items() + } + + # SCHEDULERS! + scheduler_wd = CosineScheduler( + optimizer, + key="weight_decay", + init_value=config["training"]["wd"], + base_value=config["training"]["wd"], + final_value=config["training"]["wd_final"], + warmup_iters=0, + total_iters=config["training"]["n_iters"], + flat_iters=config["training"]["warmup_iters"], + step_init=step - 1, + ) + scheduler_lr = CosineScheduler( + optimizer, + key="lr", + init_value=config["training"]["lr"] * config["training"].get("lr_warmup", 1.0), + final_value=config["training"]["lr_final"], + warmup_iters=5000, + flat_iters=config["training"]["warmup_iters"], + total_iters=config["training"]["n_iters"], + step_init=step - 1, + ) + scheduler_betas = CosineScheduler( + optimizer, + key="betas", + init_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, + base_value=0.85 if config["training"].get("cycle_betas", True) else 0.9, + final_value=0.95 if config["training"].get("cycle_betas", True) else 0.9, + warmup_iters=config["training"]["warmup_iters"], + total_iters=config["training"]["n_iters"], + step_init=step - 1, + ) + + # Set loss scaler for half precision training + sanity zeroing grads + dtype = MAP_DTYPE[dtype_16bit] + if not torch.cuda.is_bf16_supported() and is_16bit: + dtype = torch.float16 + + context = torch.autocast(device_type="cuda", dtype=dtype, enabled=is_16bit) + # use float16 to check for instability at inference an avoid bfloat16 for coarseness + context_val = torch.autocast( + device_type="cuda", dtype=torch.float16, enabled=is_16bit + ) + optimizer.zero_grad(set_to_none=True) + + ############################## + ########## TRAINING ########## + ############################## + # Remember that if i-th layer is frozen, this will break gradient checkpointing + # in layer i+1-th. This is because CheckpointFunction treats the i+1-th input as + # without gradient, thus the i+1-th layer does not have grads (?). To solve it, + # just add requires_grad_() to the inputs coming from the frozen layer + ddp_model.train() + + start = time() + n_steps = config["training"]["n_iters"] + init_steps = int(step) + track_pbar = is_shell + + if is_main_process(): + print("Is a shell job?", is_shell) + print("Use dtype:", dtype if is_16bit else torch.float32) + print( + f'Train for {config["training"]["n_iters"]} steps, validate every {config["training"]["validation_interval"]} steps' + ) + print(f"START with {num_workers} workers") + if track_pbar: + pbar = tqdm(total=n_steps - init_steps) + + scaler = torch.amp.GradScaler( + "cuda", + init_scale=2**14 if dtype_16bit == "f16" else 2**40, + enabled=is_16bit, + growth_factor=1.2, + backoff_factor=0.8, + growth_interval=500, + ) + track_losses, track_grad = {}, {} + system_memory = dict(psutil.virtual_memory()._asdict())["available"] / 2**30 + cpid_memory = current_process.memory_info()[0] / 2.0**30 + gpu_mem = (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) / 2**30 + while True: + for j, batches in enumerate(train_loader): + system_memory = ( + 0.99 * system_memory + + 0.01 * dict(psutil.virtual_memory()._asdict())["available"] / 2**30 + ) + cpid_memory = ( + 0.99 * cpid_memory + 0.01 * current_process.memory_info()[0] / 2.0**30 + ) + gpu_mem = ( + 0.99 * gpu_mem + + 0.01 + * (torch.cuda.mem_get_info()[1] - torch.cuda.mem_get_info()[0]) + / 2**30 + ) + if j % 1000 == 0 and is_main_process(): + print(f"System information at step {j}") + print(f"System-wide RAM available: {system_memory:.2f}GB") + print(f"CPU utilization: {psutil.cpu_percent(interval=None)}%") + print(f"GPU memory utilized: {gpu_mem:.2f}GB") + + batches["data"] = { + k: v.to(model.device, non_blocking=True) + for k, v in batches["data"].items() + } + for idx in range(nsteps_accumulation_gradient): + batch = {} + batch_slice = slice(idx * batch_chunk, (idx + 1) * batch_chunk) + batch["data"] = {k: v[batch_slice] for k, v in batches["data"].items()} + batch["img_metas"] = batches["img_metas"][batch_slice] + with ( + model.no_sync() + if idx < nsteps_accumulation_gradient - 1 + else nullcontext() + ): + with context: + preds, losses = model(batch["data"], batch["img_metas"]) + loss = sum(losses["opt"].values()) + scaler.scale(loss).backward() + + losses_dict = { + k: v.detach() for loss in losses.values() for k, v in loss.items() + } + track_losses.update( + { + k: track_losses.get(k, 0.0) + + torch.nan_to_num(v, nan=1e5, posinf=1e5, neginf=1e5) + for k, v in losses_dict.items() + } + ) + ddp_model.loss_history = track_losses + + if clipping is not None: + scaler.unscale_(optimizer) + grad_norm = clipping_fn(ddp_model.parameters_grad(), clipping) + if torch.isfinite(grad_norm): + track_losses.update( + {"Grad_Norm": track_losses.get("Grad_Norm", 0.0) + grad_norm} + ) + + # there is a deeper issue, either log/sqrt of negative loss + # or the inputs create large values and destroy model weights + if is_16bit and scaler.get_scale() < 1: + raise ValueError("Scale went less than 1, ISSUE!!!") + + scaler.step(optimizer) + scaler.update() + + scheduler_wd.step() + scheduler_lr.step() + scheduler_betas.step() + model.module.step() + optimizer.zero_grad(set_to_none=True) + if step % EMA_INTERVAL == 0: + ema_handle.update() + + if is_main_process() and track_pbar: + pbar.update(1) + + step += 1 + + # LOGGING + if step % 100 == 0 and is_main_process(): + log_num = min(10, preds["depth"].shape[0]) + log_train_artifacts( + batch["data"]["image"][-log_num:, 0].float(), + ( + batch["data"]["depth"][-log_num:, 0].float() + if "depth" in batch["data"] + else [] + ), + preds["depth"][-log_num:, 0].detach().float(), + infos={ + k: v[-log_num:, 0] for k, v in preds.get("infos", {}).items() + }, + step=step, + ) + + if step % 50 == 0: + track_losses = { + k: v / (50 * nsteps_accumulation_gradient) + for k, v in track_losses.items() + } + # grad norm is for every step! + track_losses["Grad_Norm"] = ( + track_losses["Grad_Norm"] * nsteps_accumulation_gradient + ) + track_losses = aggregate_sync_losses(track_losses, device=model.device) + if is_main_process(): + elapsed = int(time() - start) + eta = int(elapsed * (n_steps - step) / max(1, step - init_steps)) + print( + f"Step {step}/{n_steps} [{format_seconds(elapsed)}<{format_seconds(eta)}]" + ) + try: + wandb.log( + { + **{f"Train/{k}": v for k, v in track_losses.items()}, + **{f"Train/lr": scheduler_lr.get()[-1]}, + **{f"Train/wd": scheduler_wd.get()[-2]}, + **{f"Train/scale_f16": log2(scaler.get_scale())}, + }, + step=step, + ) + except Exception as e: + print("Not logging loss because of:", e) + if step % 100 == 0: + log_loss_dict = { + f"Train/{k}": v for k, v in track_losses.items() + } + print( + ", ".join( + [f"{k}: {v:.5f}" for k, v in log_loss_dict.items()] + ) + ) + track_losses = {} # reinit every 50 steps, average the current 50 steps + + # Validation + is_last_step = step >= config["training"]["n_iters"] + is_validation = step % config["training"]["validation_interval"] == 0 + if is_last_step or is_validation: + torch.cuda.empty_cache() + barrier() + if is_main_process(): + print(f"Validation at {step}th step...") + ddp_model.eval() + start_validation = time() + with torch.no_grad(), ema_handle.average_parameters(): + validate( + model, + test_loaders=val_loaders, + step=step, + run_id=run_id, + idxs=(64, 96, 224, 256), # random + context=context_val, + ) + + if is_main_process(): + print(f"Elapsed: {format_seconds(int(time() - start_validation))}") + ddp_model.train() + torch.cuda.empty_cache() + + if step >= config["training"]["n_iters"]: + if is_main_process() and track_pbar: + pbar.close() + wandb.finish(0) + dist.destroy_process_group() + return 0 + + +if __name__ == "__main__": + if "SLURM_PROCID" in os.environ: + os.environ["TRITON_CACHE_DIR"] = "/tmp" + # Arguments + parser = argparse.ArgumentParser( + description="Training script", conflict_handler="resolve" + ) + parser.add_argument("--config-file", type=str, required=True) + parser.add_argument("--master-port", type=str) + parser.add_argument("--distributed", action="store_true") + parser.add_argument("--local_rank", type=int, default=0) + + args = parser.parse_args() + with open(args.config_file, "r") as f: + config = json.load(f) + + deterministic = config["generic"].get("deterministic", True) + torch.backends.cudnn.deterministic = deterministic + torch.backends.cudnn.benchmark = not deterministic + + torch.backends.cudnn.allow_tf32 = True + torch.backends.cuda.matmul.allow_tf32 = True + torch.set_float32_matmul_precision("high") + torch.backends.cuda.enable_mem_efficient_sdp(False) + torch.set_num_threads(1) + main_worker(config, args) diff --git a/unik3d/__init__.py b/unik3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4526bd9a1b4b0b508458724e3157d4d8dd8ebc30 --- /dev/null +++ b/unik3d/__init__.py @@ -0,0 +1 @@ +from .models import UniK3D diff --git a/unik3d/datasets/_2d3ds.py b/unik3d/datasets/_2d3ds.py new file mode 100644 index 0000000000000000000000000000000000000000..021e86c11a55eba9c9d5ff6754f6ab2d1db34b74 --- /dev/null +++ b/unik3d/datasets/_2d3ds.py @@ -0,0 +1,67 @@ +from typing import Any + +import torch + +from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class d2D3DS(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 512.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"2D3DS.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["cam2w", "camera_params"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + self.resizer = Compose( + [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] + ) + + def preprocess(self, results): + self.resizer.ctx = None + if self.test_mode: + for i, seq in enumerate(results["sequence_fields"]): + results[seq]["points"] = results[seq]["camera"].reconstruct( + results[seq]["depth"] + ) + results[seq]["depth"] = results[seq]["points"][:, -1:] + results[seq]["gt_fields"].add("points") + return super().preprocess(results) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/_4dor.py b/unik3d/datasets/_4dor.py new file mode 100644 index 0000000000000000000000000000000000000000..f9db5355112799eef5ae8bd13cb0fb52144ca064 --- /dev/null +++ b/unik3d/datasets/_4dor.py @@ -0,0 +1,52 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class d4DOR(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + default_fps = 10 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["4DOR.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["si"] = [False] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/__init__.py b/unik3d/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..2e6a339b51c69a6f612615136dc2429ca36fac95 --- /dev/null +++ b/unik3d/datasets/__init__.py @@ -0,0 +1,161 @@ +from ._2d3ds import d2D3DS +from ._4dor import d4DOR +from .a2d2 import A2D2 +from .adt import ADT +from .aimotive import aiMotive +from .argoverse import Argoverse +from .argoverse2 import Argoverse2 +from .arkit import ARKit +from .ase import ASE +from .base_dataset import BaseDataset +from .bdd import BDD +from .bedlam import BEDLAM +from .behave import Behave +from .blendedmvg import BlendedMVG +from .cityscape import Cityscape +from .ddad import DDAD +from .deep360 import Deep360 +from .dense import DENSE +from .diml import DIML +from .diode import DiodeIndoor, DiodeIndoor_F +from .dl3dv import DL3DV +from .driving_stereo import DrivingStereo +from .dtu_rmvd import DTURMVD +from .dummy import Dummy +from .dynamic_replica import DynReplica +from .eden import EDEN +from .eth3d import ETH3D, ETH3D_F, ETH3DRMVD +from .facedepth import FaceDepth +from .flsea import FLSea +from .futurehouse import FutureHouse +from .gibson import Gibson +from .hammer import HAMMER +from .hm3d import HM3D +from .hoi4d import HOI4D +from .hypersim import HyperSim +from .ibims import IBims, IBims_F +from .ken_burns import KenBurns +from .kitti import KITTI, KITTIRMVD, KITTIBenchmark +from .kitti360 import KITTI360 +from .lyft import Lyft +from .mapillary import Mapillary +from .matrix_city import MatrixCity +from .matterport3d import Matterport3D +from .megadepth import MegaDepth +from .megadepth_s import MegaDepthS +from .midair import MidAir +from .mip import MIP +from .ms2 import MS2 +from .mvimgnet import MVImgNet +from .mvsynth import MVSynth +from .nerds360 import NeRDS360 +from .niantic_mapfree import NianticMapFree +from .nuscenes import Nuscenes +from .nyuv2 import NYUv2Depth +from .point_odyssey import PointOdyssey +from .proteus import Proteus +from .samplers import (DistributedSamplerNoDuplicate, + DistributedSamplerWrapper, ShardedInfiniteSampler) +from .scannet import ScanNet +from .scannetpp import ScanNetpp, ScanNetpp_F +from .sintel import Sintel +from .sunrgbd import SUNRGBD +from .synscapes import Synscapes +from .tartanair import TartanAir +from .taskonomy import Taskonomy +from .tat_rmvd import TATRMVD +from .theo import Theo +from .unrealstereo4k import UnrealStereo4K +from .urbansyn import UrbanSyn +from .utils import ConcatDataset, collate_fn, get_weights +from .vkitti import VKITTI +from .void import VOID +from .waymo import Waymo +from .wildrgbd import WildRGBD + +__all__ = [ + "Dummy", + "BaseDataset", + "get_weights" "DistributedSamplerNoDuplicate", + "ShardedInfiniteSampler", + "DistributedSamplerWrapper", + "ConcatDataset", + "PairDataset", + "collate_fn", + # additional, do not count + "WaymoImage", + "MegaDepth", + "COCO2017", + "ImageNet", + "OASISv2", + # image based + "Argoverse", + "DDAD", + "IBims", + "NYUv2Depth", + "DrivingStereo", + "VOID", + "Mapillary", + "ScanNet", + "Taskonomy", + "BDD", + "A2D2", + "Nuscenes", + "SUNRGBD", + "ETH3D", + "HAMMER", + "Cityscape", + "KITTI", + "DENSE", + "DIML", + "DiodeIndoor", + "FLSea", + "ARKitScenes", + "Lyft", + "HyperSim", + "KenBurns", + "HRWSI", + "UrbanSyn", + "Synscapes", + "Gibson", + "Matterport3D", + "_2D3DS", + # sequence based + "TartanAir", + "WildRGBD", + "ScanNetS", + "ScanNetpp", + "MVImgNet", + "NianticMapFree", + "DL3DV", + "PointOdyssey", + "KITTIMulti", + "Waymo", + "Argoverse2", + "UnrealStereo4K", + "MatrixCity", + "HM3D", + "MVSynth", + "EDEN", + # sequence based, but not usable for seq, only image + "BEDLAM", + "NeRDS360", + "BlendedMVG", + "DynReplica", + "ARKitS", + "Sintel", + "VKITTI", + "MegaDepthS", + # benchmarks + "KITTIBenchmark", + "ETH3DRMVD", + "DTURMVD", + "KITTIRMVD", + "TATRMVD", + "DiodeIndoor_F", + "IBims_F", + "ETH3D_F", + "KITTI360", + "ScanNetpp_F", + "ADT", +] diff --git a/unik3d/datasets/a2d2.py b/unik3d/datasets/a2d2.py new file mode 100644 index 0000000000000000000000000000000000000000..4a01b2037dd27f88137161b4d1a850df69f71d90 --- /dev/null +++ b/unik3d/datasets/a2d2.py @@ -0,0 +1,78 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class A2D2(ImageDataset): + min_depth = 0.01 + max_depth = 120.0 + depth_scale = 256.0 + train_split = "train_clean.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["a2d2.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor( + intrinsics[os.path.join(*image_filename.split("/")[:2])] + ).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_copies + results["quality"] = [1] * self.num_copies + return results diff --git a/unik3d/datasets/adt.py b/unik3d/datasets/adt.py new file mode 100644 index 0000000000000000000000000000000000000000..8883efd6d8e2590467ea4314ab2f13d837da684e --- /dev/null +++ b/unik3d/datasets/adt.py @@ -0,0 +1,68 @@ +from typing import Any + +import torch + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class ADT(SequenceDataset): + min_depth = 0.01 + max_depth = 20.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"ADT.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, # if not test_mode else [*decode_fields, "points"], + inplace_fields=inplace_fields, + **kwargs, + ) + + def preprocess(self, results): + self.resizer.ctx = None + for i, seq in enumerate(results["sequence_fields"]): + # Create a mask where the distance from the center is less than H/2 + H, W = results[seq]["image"].shape[-2:] + x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W) + y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H) + xv, yv = torch.meshgrid(x, y, indexing="xy") + distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) + results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20 + results[seq]["depth_mask"] = results[seq]["validity_mask"].clone() + results[seq]["mask_fields"].add("depth_mask") + results[seq]["mask_fields"].add("validity_mask") + + return super().preprocess(results) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/aimotive.py b/unik3d/datasets/aimotive.py new file mode 100644 index 0000000000000000000000000000000000000000..38557bc68a9f1aeeff0003eae46f12b46ec197ba --- /dev/null +++ b/unik3d/datasets/aimotive.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class aiMotive(SequenceDataset): + min_depth = 0.01 + max_depth = 100.0 + depth_scale = 256.0 + default_fps = 10 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["aiMotive.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/argoverse.py b/unik3d/datasets/argoverse.py new file mode 100644 index 0000000000000000000000000000000000000000..b4f4ea6612b3a29f537e0690ccb3b30b04d06d24 --- /dev/null +++ b/unik3d/datasets/argoverse.py @@ -0,0 +1,73 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class Argoverse(ImageDataset): + min_depth = 0.05 + max_depth = 120.0 + depth_scale = 256.0 + test_split = "argo_val.txt" + train_split = "argo_train.txt" + intrisics_file = "argo_intrinsics.json" + hdf5_paths = ["argoverse11.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() diff --git a/unik3d/datasets/argoverse2.py b/unik3d/datasets/argoverse2.py new file mode 100644 index 0000000000000000000000000000000000000000..5a49f49d18979cb39b1be7385b4ec4c09b47a244 --- /dev/null +++ b/unik3d/datasets/argoverse2.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Argoverse2(SequenceDataset): + min_depth = 0.05 + max_depth = 120.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences_clean.json" + hdf5_paths = [f"AV2_viz.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/arkit.py b/unik3d/datasets/arkit.py new file mode 100644 index 0000000000000000000000000000000000000000..96225cd8b248cad6d1a5072cfc39adcf880c2297 --- /dev/null +++ b/unik3d/datasets/arkit.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class ARKit(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "Training.txt" + train_split = "Training.txt" + sequences_file = "sequences.json" + hdf5_paths = ["ARKitS.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/ase.py b/unik3d/datasets/ase.py new file mode 100644 index 0000000000000000000000000000000000000000..5f57fd2119dd589ca62800d907fc4c235dc1ed1b --- /dev/null +++ b/unik3d/datasets/ase.py @@ -0,0 +1,66 @@ +from typing import Any + +import torch + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class ASE(SequenceDataset): + min_depth = 0.01 + max_depth = 20.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"ASE.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def preprocess(self, results): + self.resizer.ctx = None + for i, seq in enumerate(results["sequence_fields"]): + # Create a mask where the distance from the center is less than H/2 + H, W = results[seq]["image"].shape[-2:] + x = torch.linspace(-W / 2 - 0.5, W / 2 + 0.5, W) + y = torch.linspace(-H / 2 - 0.5, H / 2 + 0.5, H) + xv, yv = torch.meshgrid(x, y, indexing="xy") + distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) + results[seq]["validity_mask"] = distance_from_center < (H / 2) + 20 + results[seq]["mask_fields"].add("validity_mask") + + return super().preprocess(results) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/base_dataset.py b/unik3d/datasets/base_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..9474e767c569cc75b89da2d0510de808ed79208b --- /dev/null +++ b/unik3d/datasets/base_dataset.py @@ -0,0 +1,344 @@ +import os +from abc import abstractmethod +from copy import deepcopy +from math import ceil, log +from typing import Any, Dict, Tuple + +import numpy as np +import torch +from torch.utils.data import Dataset + +import unik3d.datasets.pipelines as pipelines +from unik3d.utils import (eval_3d, eval_depth, identity, is_main_process, + recursive_index, sync_tensor_across_gpus) +from unik3d.utils.constants import (IMAGENET_DATASET_MEAN, + IMAGENET_DATASET_STD, OPENAI_DATASET_MEAN, + OPENAI_DATASET_STD) + + +class BaseDataset(Dataset): + min_depth = 0.01 + max_depth = 1000.0 + + def __init__( + self, + image_shape: Tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: Dict[str, Any], + shape_constraints: Dict[str, Any], + resize_method: str, + mini: float, + num_copies: int = 1, + **kwargs, + ) -> None: + super().__init__() + assert normalize in [None, "imagenet", "openai"] + + self.split_file = split_file + self.test_mode = test_mode + self.data_root = os.environ["DATAROOT"] + self.image_shape = image_shape + self.resize_method = resize_method + self.mini = mini + self.num_frames = 1 + self.num_copies = num_copies + self.metrics_store = {} + self.metrics_count = {} + + if normalize == "imagenet": + self.normalization_stats = { + "mean": torch.tensor(IMAGENET_DATASET_MEAN), + "std": torch.tensor(IMAGENET_DATASET_STD), + } + elif normalize == "openai": + self.normalization_stats = { + "mean": torch.tensor(OPENAI_DATASET_MEAN), + "std": torch.tensor(OPENAI_DATASET_STD), + } + else: + self.normalization_stats = { + "mean": torch.tensor([0.0, 0.0, 0.0]), + "std": torch.tensor([1.0, 1.0, 1.0]), + } + + for k, v in augmentations_db.items(): + setattr(self, k, v) + self.shape_constraints = shape_constraints + if not self.test_mode: + self._augmentation_space() + + self.masker = pipelines.AnnotationMask( + min_value=0.0, + max_value=self.max_depth if test_mode else None, + custom_fn=identity, + ) + self.filler = pipelines.RandomFiller(test_mode=test_mode) + + shape_mult = self.shape_constraints["shape_mult"] + self.image_shape = [ + ceil(self.image_shape[0] / shape_mult) * shape_mult, + ceil(self.image_shape[1] / shape_mult) * shape_mult, + ] + self.resizer = pipelines.ContextCrop( + image_shape=self.image_shape, + train_ctx_range=(1.0 / self.random_scale, 1.0 * self.random_scale), + test_min_ctx=self.test_context, + keep_original=test_mode, + shape_constraints=self.shape_constraints, + ) + + self.collecter = pipelines.Collect( + keys=["image_fields", "mask_fields", "gt_fields", "camera_fields"] + ) + + def __len__(self): + return len(self.dataset) + + def pack_batch(self, results): + results["paddings"] = [ + results[x]["paddings"][0] for x in results["sequence_fields"] + ] + for fields_name in [ + "image_fields", + "gt_fields", + "mask_fields", + "camera_fields", + ]: + fields = results.get(fields_name) + packed = { + field: torch.cat( + [results[seq][field] for seq in results["sequence_fields"]] + ) + for field in fields + } + results.update(packed) + return results + + def unpack_batch(self, results): + for fields_name in [ + "image_fields", + "gt_fields", + "mask_fields", + "camera_fields", + ]: + fields = results.get(fields_name) + unpacked = { + field: { + seq: results[field][idx : idx + 1] + for idx, seq in enumerate(results["sequence_fields"]) + } + for field in fields + } + results.update(unpacked) + return results + + def _augmentation_space(self): + self.augmentations_dict = { + "Flip": pipelines.RandomFlip(prob=self.flip_p), + "Jitter": pipelines.RandomColorJitter( + (-self.random_jitter, self.random_jitter), prob=self.jitter_p + ), + "Gamma": pipelines.RandomGamma( + (-self.random_gamma, self.random_gamma), prob=self.gamma_p + ), + "Blur": pipelines.GaussianBlur( + kernel_size=13, sigma=(0.1, self.random_blur), prob=self.blur_p + ), + "Grayscale": pipelines.RandomGrayscale(prob=self.grayscale_p), + } + + def augment(self, results): + for name, aug in self.augmentations_dict.items(): + results = aug(results) + return results + + def prepare_depth_eval(self, inputs, preds): + new_preds = {} + keyframe_idx = getattr(self, "keyframe_idx", None) + slice_idx = slice( + keyframe_idx, keyframe_idx + 1 if keyframe_idx is not None else None + ) + new_gts = inputs["depth"][slice_idx] + new_masks = inputs["depth_mask"][slice_idx].bool() + for key, val in preds.items(): + if "depth" in key: + new_preds[key] = val[slice_idx] + return new_gts, new_preds, new_masks + + def prepare_points_eval(self, inputs, preds): + new_preds = {} + new_gts = inputs["points"] + new_masks = inputs["depth_mask"].bool() + if "points_mask" in inputs: + new_masks = inputs["points_mask"].bool() + for key, val in preds.items(): + if "points" in key: + new_preds[key] = val + return new_gts, new_preds, new_masks + + def add_points(self, inputs): + inputs["points"] = inputs.get("camera_original", inputs["camera"]).reconstruct( + inputs["depth"] + ) + return inputs + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def accumulate_metrics( + self, + inputs, + preds, + keyframe_idx=None, + metrics=["depth", "points", "flow_fwd", "pairwise"], + ): + if "depth" in inputs and "points" not in inputs: + inputs = self.add_points(inputs) + + available_metrics = [] + for metric in metrics: + metric_in_gt = any((metric in k for k in inputs.keys())) + metric_in_pred = any((metric in k for k in preds.keys())) + if metric_in_gt and metric_in_pred: + available_metrics.append(metric) + + if keyframe_idx is not None: + inputs = recursive_index(inputs, slice(keyframe_idx, keyframe_idx + 1)) + preds = recursive_index(preds, slice(keyframe_idx, keyframe_idx + 1)) + + if "depth" in available_metrics: + depth_gt, depth_pred, depth_masks = self.prepare_depth_eval(inputs, preds) + self.accumulate_metrics_depth(depth_gt, depth_pred, depth_masks) + + if "points" in available_metrics: + points_gt, points_pred, points_masks = self.prepare_points_eval( + inputs, preds + ) + self.accumulate_metrics_3d(points_gt, points_pred, points_masks) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def accumulate_metrics_depth(self, gts, preds, masks): + for eval_type, pred in preds.items(): + log_name = eval_type.replace("depth", "").strip("-").strip("_") + if log_name not in self.metrics_store: + self.metrics_store[log_name] = {} + current_count = self.metrics_count.get( + log_name, torch.tensor([], device=gts.device) + ) + new_count = masks.view(gts.shape[0], -1).sum(dim=-1) + self.metrics_count[log_name] = torch.cat([current_count, new_count]) + for k, v in eval_depth(gts, pred, masks, max_depth=self.max_depth).items(): + current_metric = self.metrics_store[log_name].get( + k, torch.tensor([], device=gts.device) + ) + self.metrics_store[log_name][k] = torch.cat([current_metric, v]) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def accumulate_metrics_3d(self, gts, preds, masks): + thresholds = torch.linspace( + log(self.min_depth), + log(self.max_depth / 20), + steps=100, + device=gts.device, + ).exp() + for eval_type, pred in preds.items(): + log_name = eval_type.replace("points", "").strip("-").strip("_") + if log_name not in self.metrics_store: + self.metrics_store[log_name] = {} + current_count = self.metrics_count.get( + log_name, torch.tensor([], device=gts.device) + ) + new_count = masks.view(gts.shape[0], -1).sum(dim=-1) + self.metrics_count[log_name] = torch.cat([current_count, new_count]) + for k, v in eval_3d(gts, pred, masks, thresholds=thresholds).items(): + current_metric = self.metrics_store[log_name].get( + k, torch.tensor([], device=gts.device) + ) + self.metrics_store[log_name][k] = torch.cat([current_metric, v]) + + def get_evaluation(self, metrics=None): + metric_vals = {} + for eval_type in metrics if metrics is not None else self.metrics_store.keys(): + assert self.metrics_store[eval_type] + cnts = sync_tensor_across_gpus(self.metrics_count[eval_type]) + for name, val in self.metrics_store[eval_type].items(): + # vals_r = (sync_tensor_across_gpus(val) * cnts / cnts.sum()).sum() + vals_r = sync_tensor_across_gpus(val).mean() + metric_vals[f"{eval_type}_{name}".strip("_")] = np.round( + vals_r.cpu().item(), 5 + ) + self.metrics_store[eval_type] = {} + self.metrics_count = {} + return metric_vals + + def replicate(self, results): + for i in range(1, self.num_copies): + results[(0, i)] = {k: deepcopy(v) for k, v in results[(0, 0)].items()} + results["sequence_fields"].append((0, i)) + return results + + def log_load_dataset(self): + if is_main_process(): + info = f"Loaded {self.__class__.__name__} with {len(self)} images." + print(info) + + def pre_pipeline(self, results): + results["image_fields"] = results.get("image_fields", set()) + results["gt_fields"] = results.get("gt_fields", set()) + results["mask_fields"] = results.get("mask_fields", set()) + results["sequence_fields"] = results.get("sequence_fields", set()) + results["camera_fields"] = results.get("camera_fields", set()) + results["dataset_name"] = ( + [self.__class__.__name__] * self.num_frames * self.num_copies + ) + results["depth_scale"] = [self.depth_scale] * self.num_frames * self.num_copies + results["si"] = [False] * self.num_frames * self.num_copies + results["dense"] = [False] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + results["valid_camera"] = [True] * self.num_frames * self.num_copies + results["valid_pose"] = [True] * self.num_frames * self.num_copies + return results + + def eval_mask(self, valid_mask): + return valid_mask + + def chunk(self, dataset, chunk_dim=1, pct=1.0): + subsampled_datasets = [ + x + for i in range(0, len(dataset), int(1 / pct * chunk_dim)) + for x in dataset[i : i + chunk_dim] + ] + return subsampled_datasets + + @abstractmethod + def preprocess(self, results): + raise NotImplementedError + + @abstractmethod + def postprocess(self, results): + raise NotImplementedError + + @abstractmethod + def get_mapper(self): + raise NotImplementedError + + @abstractmethod + def get_intrinsics(self, idx, image_name): + raise NotImplementedError + + @abstractmethod + def get_extrinsics(self, idx, image_name): + raise NotImplementedError + + @abstractmethod + def load_dataset(self): + raise NotImplementedError + + @abstractmethod + def get_single_item(self, idx, sample=None, mapper=None): + raise NotImplementedError + + @abstractmethod + def __getitem__(self, idx): + raise NotImplementedError diff --git a/unik3d/datasets/bdd.py b/unik3d/datasets/bdd.py new file mode 100644 index 0000000000000000000000000000000000000000..2f8956ea4ea77cf0db86b1be744a3daaacd04616 --- /dev/null +++ b/unik3d/datasets/bdd.py @@ -0,0 +1,82 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class BDD(ImageDataset): + min_depth = 0.01 + max_depth = 70.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train_clean.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["BDD.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor( + intrinsics[os.path.join(*image_filename.split("/")[:2])] + ).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + h5file.close() + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.1) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["si"] = [True] * self.num_copies + results["valid_camera"] = [False] * self.num_copies + results["dense"] = [False] * self.num_copies + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/bedlam.py b/unik3d/datasets/bedlam.py new file mode 100644 index 0000000000000000000000000000000000000000..e9e9cfdf46e5ae57bcc5b3becdf6043dbe1f7f8b --- /dev/null +++ b/unik3d/datasets/bedlam.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class BEDLAM(SequenceDataset): + min_depth = 0.01 + max_depth = 256.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "val.txt" + sequences_file = "sequences.json" + hdf5_paths = ["BEDLAM.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/behave.py b/unik3d/datasets/behave.py new file mode 100644 index 0000000000000000000000000000000000000000..155afc72800367f09133505032dea8d94eb80083 --- /dev/null +++ b/unik3d/datasets/behave.py @@ -0,0 +1,52 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Behave(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + default_fps = 10 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["Behave.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["si"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/blendedmvg.py b/unik3d/datasets/blendedmvg.py new file mode 100644 index 0000000000000000000000000000000000000000..b3fa5f52b63785d0a4b02addb347506c51048ab9 --- /dev/null +++ b/unik3d/datasets/blendedmvg.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class BlendedMVG(SequenceDataset): + min_depth = 0.01 + max_depth = 5000.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences_clean.json" + hdf5_paths = ["BlendedMVG_.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["si"] = [False] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/cityscape.py b/unik3d/datasets/cityscape.py new file mode 100644 index 0000000000000000000000000000000000000000..0f462d6571fc24acc47925550f4331bb1053b651 --- /dev/null +++ b/unik3d/datasets/cityscape.py @@ -0,0 +1,78 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class Cityscape(ImageDataset): + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["cityscape.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/ddad.py b/unik3d/datasets/ddad.py new file mode 100644 index 0000000000000000000000000000000000000000..322089c73c88c7ac2208c6b2b5f5b682effb7c95 --- /dev/null +++ b/unik3d/datasets/ddad.py @@ -0,0 +1,84 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class DDAD(ImageDataset): + min_depth = 0.05 + max_depth = 120.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = [f"ddad/ddad_{i}.hdf5" for i in range(8)] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii").strip("\n") + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, chunk_idx = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + "chunk_idx": 3, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_copies + results["quality"] = [1] * self.num_copies + return results diff --git a/unik3d/datasets/deep360.py b/unik3d/datasets/deep360.py new file mode 100644 index 0000000000000000000000000000000000000000..20bf72aeacade6166201fac523137900d9ca1eba --- /dev/null +++ b/unik3d/datasets/deep360.py @@ -0,0 +1,56 @@ +from typing import Any + +import torch + +from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Deep360(SequenceDataset): + min_depth = 0.1 + max_depth = 1000.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"Deep360.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["cam2w", "camera_params"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + self.resizer = Compose( + [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/dense.py b/unik3d/datasets/dense.py new file mode 100644 index 0000000000000000000000000000000000000000..69e31b74a2d48a43067f2112a9e5d339620d2ced --- /dev/null +++ b/unik3d/datasets/dense.py @@ -0,0 +1,91 @@ +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class DENSE(ImageDataset): + CAM_INTRINSIC = { + "ALL": torch.tensor( + [ + [1177.8614, 0.0, 474.319027], + [0.0, 1177.8614, 224.275919], + [0.0, 0.0, 1.0], + ] + ) + } + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 255.0 + test_split = "train.txt" + train_split = "train.txt" + hdf5_paths = ["DENSE.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.intrisics = {} + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + sample = [image_filename, depth_filename] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_intrinsics(self, idx, image_name): + return self.CAM_INTRINSIC["ALL"].clone() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_copies + results["quality"] = [1] * self.num_copies + return results diff --git a/unik3d/datasets/diml.py b/unik3d/datasets/diml.py new file mode 100644 index 0000000000000000000000000000000000000000..4a5b38411611655b7219bfacbe8c6fb528331084 --- /dev/null +++ b/unik3d/datasets/diml.py @@ -0,0 +1,79 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class DIML(ImageDataset): + min_depth = 0.01 + max_depth = 100.0 + depth_scale = 256.0 + test_split = "test.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["DIML.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.intrisics = {} + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor( + intrinsics[image_filename.split("/")[0]] + ).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + h5file.close() + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/diode.py b/unik3d/datasets/diode.py new file mode 100644 index 0000000000000000000000000000000000000000..9dc180eb096c52e50c237e8555cb952927f537c8 --- /dev/null +++ b/unik3d/datasets/diode.py @@ -0,0 +1,278 @@ +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.sequence_dataset import SequenceDataset +from unik3d.datasets.utils import DatasetFromList + + +class DiodeIndoor(ImageDataset): + CAM_INTRINSIC = { + "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]]) + } + min_depth = 0.01 + max_depth = 25.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + hdf5_paths = ["DiodeIndoor.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + # load annotations + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + sample = [ + image_filename, + depth_filename, + ] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_intrinsics(self, *args, **kwargs): + return self.CAM_INTRINSIC["ALL"].clone() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["quality"] = [1] * self.num_copies + return results + + +class DiodeIndoor_F(SequenceDataset): + min_depth = 0.01 + max_depth = 25.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["DiodeIndoor-F.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, float], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=( + decode_fields if not test_mode else [*decode_fields, "points"] + ), + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results + + +class DiodeOutdoor(ImageDataset): + CAM_INTRINSIC = { + "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]]) + } + min_depth = 0.1 + max_depth = 80.0 + log_mean = 0 + log_std = 1 + test_split = "diode_outdoor_val.txt" + train_split = "diode_outdoor_train.txt" + hdf5_paths = ["diode.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + depth_scale=256, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.depth_scale = depth_scale + + self.masker = AnnotationMask( + min_value=self.min_depth, + max_value=self.max_depth if test_mode else None, + custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, + ) + # load annotations + self.load_dataset() + + def load_dataset(self): + self.h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_path), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(self.h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] + dataset = {"depth_filename": [], "image_filename": []} + for line in txt_string.split("\n"): + depth_filename = line.strip().split(" ")[1] + img_name = line.strip().split(" ")[0] + image_filename = img_name + dataset["depth_filename"].append(depth_filename) + dataset["image_filename"].append(image_filename) + + self.dataset = pl.from_dict(dataset) + + if not self.test_mode and self.mini: + self.dataset = self.dataset[::2] + + +class Diode(ImageDataset): + CAM_INTRINSIC = { + "ALL": torch.tensor([[886.81, 0, 512], [0, 927.06, 384], [0, 0, 1]]) + } + log_mean = 0 + log_std = 1 + min_depth = 0.6 + max_depth = 80.0 + test_split = "diode_val.txt" + train_split = "diode_train.txt" + hdf5_paths = ["diode.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + depth_scale=256, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.depth_scale = depth_scale + + self.masker = AnnotationMask( + min_value=self.min_depth, + max_value=self.max_depth if test_mode else None, + custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, + ) + # load annotations + self.load_dataset() + + def load_dataset(self): + self.h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_path), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(self.h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] + dataset = {"depth_filename": [], "image_filename": []} + for line in txt_string.split("\n"): + depth_filename = line.strip().split(" ")[1] + image_filename = line.strip().split(" ")[0] + dataset["depth_filename"].append(depth_filename) + dataset["image_filename"].append(image_filename) + + self.dataset = pl.from_dict(dataset) + + if not self.test_mode and self.mini: + self.dataset = self.dataset[::2] + + def get_intrinsics(self, *args, **kwargs): + return self.CAM_INTRINSIC["ALL"].clone() diff --git a/unik3d/datasets/dl3dv.py b/unik3d/datasets/dl3dv.py new file mode 100644 index 0000000000000000000000000000000000000000..fed077b27393389f38054175eed101d76922269e --- /dev/null +++ b/unik3d/datasets/dl3dv.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class DL3DV(SequenceDataset): + min_depth = 0.001 + max_depth = 250.0 + depth_scale = 512.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"DL3DVcv.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["si"] = [True] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/driving_stereo.py b/unik3d/datasets/driving_stereo.py new file mode 100644 index 0000000000000000000000000000000000000000..512c3bd00658609538790c3893a4a3d367ca4629 --- /dev/null +++ b/unik3d/datasets/driving_stereo.py @@ -0,0 +1,82 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class DrivingStereo(ImageDataset): + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + test_split = "drivingstereo_val.txt" + train_split = "drivingstereo_train.txt" + intrisics_file = "drivingstereo_intrinsics.json" + hdf5_paths = ["DrivingStereo.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=1.0) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_copies + results["quality"] = [1] * self.num_copies + return results diff --git a/unik3d/datasets/dtu_rmvd.py b/unik3d/datasets/dtu_rmvd.py new file mode 100644 index 0000000000000000000000000000000000000000..80ee8d54944b7029a5eb5fcd59f499df0910adfb --- /dev/null +++ b/unik3d/datasets/dtu_rmvd.py @@ -0,0 +1,62 @@ +import json +import os +from typing import Any + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.pipelines import AnnotationMask, KittiCrop +from unik3d.datasets.sequence_dataset import SequenceDataset +from unik3d.datasets.utils import DatasetFromList +from unik3d.utils import identity + + +class DTURMVD(SequenceDataset): + min_depth = 0.05 + max_depth = 3.0 + depth_scale = 1000.0 + default_fps = 6 + test_split = "test.txt" + train_split = "test.txt" + sequences_file = "sequences.json" + hdf5_paths = ["dtu_rmvd.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["si"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/dummy.py b/unik3d/datasets/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..2e0157c7e2c60631d16603aa7b5616fa32d07cf2 --- /dev/null +++ b/unik3d/datasets/dummy.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +from torch.utils.data import Dataset + + +class Dummy(Dataset): + train_split = None + test_split = None + + def __init__(self, *args, **kwargs): + super().__init__() + self.dataset = np.arange(1_000_000) + + def get_single_item(self, idx): + # results = {} + # results["cam2w"] = torch.eye(4).unsqueeze(0) + # results["K"] = torch.eye(3).unsqueeze(0) + # results["image"] = torch.zeros(1, 3, 1024, 1024).to(torch.uint8) + # results["depth"] = torch.zeros(1, 1, 1024, 1024).to(torch.float32) + return { + "x": {(0, 0): torch.rand(1, 3, 1024, 1024, dtype=torch.float32)}, + "img_metas": {"val": torch.rand(1, 1024, dtype=torch.float32)}, + } + + def __getitem__(self, idx): + if isinstance(idx, (list, tuple)): + results = [self.get_single_item(i) for i in idx] + else: + results = self.get_single_item(idx) + return results + + def __len__(self): + return len(self.dataset) diff --git a/unik3d/datasets/dynamic_replica.py b/unik3d/datasets/dynamic_replica.py new file mode 100644 index 0000000000000000000000000000000000000000..eb6211d476cd8012dfa6be2bccf57e3a37ba1e5f --- /dev/null +++ b/unik3d/datasets/dynamic_replica.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class DynReplica(SequenceDataset): + min_depth = 0.01 + max_depth = 20.0 + default_fps = 30.0 + depth_scale = 512.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences_clean.json" + hdf5_paths = ["DynReplica.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/eden.py b/unik3d/datasets/eden.py new file mode 100644 index 0000000000000000000000000000000000000000..3e12daeaa1a94cc8a3ed98aa5facb51268e5db28 --- /dev/null +++ b/unik3d/datasets/eden.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class EDEN(SequenceDataset): + min_depth = 0.1 + max_depth = 100.0 + depth_scale = 256.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"EDEN.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/eth3d.py b/unik3d/datasets/eth3d.py new file mode 100644 index 0000000000000000000000000000000000000000..20c1110e27124b803458677cbf96592a13df4170 --- /dev/null +++ b/unik3d/datasets/eth3d.py @@ -0,0 +1,164 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.sequence_dataset import SequenceDataset +from unik3d.datasets.utils import DatasetFromList + + +class ETH3D(ImageDataset): + min_depth = 0.01 + max_depth = 50.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["ETH3D.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results + + +class ETH3D_F(SequenceDataset): + min_depth = 0.05 + max_depth = 60.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["ETH3D-F.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, float], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=( + decode_fields if not test_mode else [*decode_fields, "points"] + ), + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results + + +class ETH3DRMVD(SequenceDataset): + min_depth = 0.01 + max_depth = 50.0 + depth_scale = 1000.0 + default_fps = 6 + test_split = "test.txt" + train_split = "test.txt" + sequences_file = "sequences.json" + hdf5_paths = ["eth3d_rmvd.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) diff --git a/unik3d/datasets/facedepth.py b/unik3d/datasets/facedepth.py new file mode 100644 index 0000000000000000000000000000000000000000..108afbcde8c332f4b1736646893fddb93e0c5020 --- /dev/null +++ b/unik3d/datasets/facedepth.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class FaceDepth(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + default_fps = 10 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["FaceDepth.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/flsea.py b/unik3d/datasets/flsea.py new file mode 100644 index 0000000000000000000000000000000000000000..394a5dd1ddfbe4b1154922032e35103ca6d9e62a --- /dev/null +++ b/unik3d/datasets/flsea.py @@ -0,0 +1,100 @@ +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class FLSea(ImageDataset): + CAM_INTRINSIC = { + "canyons": torch.tensor( + [ + [1175.3913431656817, 0.0, 466.2595428966926], + [0.0, 1174.2805075232263, 271.2116633091501], + [0.0, 0.0, 1.0], + ] + ), + "red_sea": torch.tensor( + [ + [1296.666758476217, 0.0, 501.50386149846], + [0.0, 1300.831316354508, 276.161712082695], + [0.0, 0.0, 1.0], + ] + ), + } + min_depth = 0.05 + max_depth = 20.0 + depth_scale = 1000.0 + train_split = "train.txt" + hdf5_paths = ["FLSea.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=False, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + sample = [image_filename, depth_filename] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.33) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_intrinsics(self, idx, image_name): + return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/futurehouse.py b/unik3d/datasets/futurehouse.py new file mode 100644 index 0000000000000000000000000000000000000000..d0346177f183735676cfc7f941ed3db09d6a3cb8 --- /dev/null +++ b/unik3d/datasets/futurehouse.py @@ -0,0 +1,56 @@ +from typing import Any + +import torch + +from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class FutureHouse(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"FutureHouse.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["cam2w", "camera_params"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + self.resizer = Compose( + [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/gibson.py b/unik3d/datasets/gibson.py new file mode 100644 index 0000000000000000000000000000000000000000..da662fffee527d53adfe3bb3467e316186dfb8fd --- /dev/null +++ b/unik3d/datasets/gibson.py @@ -0,0 +1,56 @@ +from typing import Any + +import torch + +from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Gibson(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"Gibson.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["cam2w", "camera_params"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + self.resizer = Compose( + [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/hammer.py b/unik3d/datasets/hammer.py new file mode 100644 index 0000000000000000000000000000000000000000..90b8a8594747fd7cbad04ef6f683f7709605db7c --- /dev/null +++ b/unik3d/datasets/hammer.py @@ -0,0 +1,76 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class HAMMER(ImageDataset): + min_depth = 0.005 + max_depth = 10.0 + depth_scale = 1000.0 + train_split = "test.txt" + test_split = "test.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["hammer.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/hm3d.py b/unik3d/datasets/hm3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f77897da89f7412de9efd6c78b31fdf9f133b531 --- /dev/null +++ b/unik3d/datasets/hm3d.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class HM3D(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "full.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"HM3D.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/hoi4d.py b/unik3d/datasets/hoi4d.py new file mode 100644 index 0000000000000000000000000000000000000000..c71408d3b34dc91c080173222a625266027befa1 --- /dev/null +++ b/unik3d/datasets/hoi4d.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class HOI4D(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + default_fps = 5 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["HOI4D.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/hypersim.py b/unik3d/datasets/hypersim.py new file mode 100644 index 0000000000000000000000000000000000000000..542147980b543e3247b46f505617cd2a1a842341 --- /dev/null +++ b/unik3d/datasets/hypersim.py @@ -0,0 +1,97 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class HyperSim(ImageDataset): + min_depth = 0.01 + max_depth = 50.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = [f"hypersim/hypersim_{i}.hdf5" for i in range(8)] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii").strip("\n") + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + + # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: + # f.write(txt_string) + # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: + # json.dump(intrinsics, f) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, chunk_idx = line.strip().split(" ") + intrinsics_val = torch.tensor( + intrinsics[os.path.join(*image_filename.split("/")[:2])] + ).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.1) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + "chunk_idx": 3, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["synthetic"] = [True] * self.num_copies + results["quality"] = [0] * self.num_copies + return results diff --git a/unik3d/datasets/ibims.py b/unik3d/datasets/ibims.py new file mode 100644 index 0000000000000000000000000000000000000000..11af2b1bbe67c196a991d016dc1d4f9876847b18 --- /dev/null +++ b/unik3d/datasets/ibims.py @@ -0,0 +1,125 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.sequence_dataset import SequenceDataset +from unik3d.datasets.utils import DatasetFromList + + +class IBims(ImageDataset): + min_depth = 0.005 + max_depth = 25.0 + depth_scale = 1000.0 + train_split = "ibims_val.txt" + test_split = "ibims_val.txt" + intrisics_file = "ibims_intrinsics.json" + hdf5_paths = ["ibims.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["quality"] = [1] * self.num_copies + return results + + +class IBims_F(SequenceDataset): + min_depth = 0.01 + max_depth = 25.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["IBims-F.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, float], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=( + decode_fields if not test_mode else [*decode_fields, "points"] + ), + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/image_dataset.py b/unik3d/datasets/image_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..78fca306d0f9e884eb0c3ade305a8386030ff9ef --- /dev/null +++ b/unik3d/datasets/image_dataset.py @@ -0,0 +1,194 @@ +import io +import os +from time import time +from typing import Any, Dict, List, Tuple + +import numpy as np +import tables +import torch +import torchvision +import torchvision.transforms.v2.functional as TF +from PIL import Image + +from unik3d.datasets.base_dataset import BaseDataset +from unik3d.utils import is_main_process +from unik3d.utils.camera import BatchCamera, Pinhole + +""" +Awful class for legacy reasons, we assume only pinhole cameras +And we "fake" sequences by setting sequence_fields to [(0, 0)] and cam2w as eye(4) +""" + + +class ImageDataset(BaseDataset): + def __init__( + self, + image_shape: Tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: Dict[str, Any], + shape_constraints: Dict[str, Any], + resize_method: str, + mini: float, + benchmark: bool = False, + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + shape_constraints=shape_constraints, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.mapper = self.get_mapper() + + def get_single_item(self, idx, sample=None, mapper=None): + sample = self.dataset[idx] if sample is None else sample + mapper = self.mapper if mapper is None else mapper + + results = { + (0, 0): dict( + gt_fields=set(), + image_fields=set(), + mask_fields=set(), + camera_fields=set(), + ) + } + results = self.pre_pipeline(results) + results["sequence_fields"] = [(0, 0)] + + chunk_idx = ( + int(sample[self.mapper["chunk_idx"]]) if "chunk_idx" in self.mapper else 0 + ) + h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) + with tables.File( + h5_path, + mode="r", + libver="latest", + swmr=True, + ) as h5file_chunk: + for key_mapper, idx_mapper in mapper.items(): + if "image" not in key_mapper and "depth" not in key_mapper: + continue + value = sample[idx_mapper] + results[(0, 0)][key_mapper] = value + name = key_mapper.replace("_filename", "") + value_root = "/" + value + + if "image" in key_mapper: + results[(0, 0)]["filename"] = value + file = h5file_chunk.get_node(value_root).read() + image = ( + torchvision.io.decode_image(torch.from_numpy(file)) + .to(torch.uint8) + .squeeze() + ) + results[(0, 0)]["image_fields"].add(name) + results[(0, 0)][f"image_ori_shape"] = image.shape[-2:] + results[(0, 0)][name] = image[None, ...] + + # collect camera information for the given image + name = name.replace("image_", "") + results[(0, 0)]["camera_fields"].update({"camera", "cam2w"}) + K = self.get_intrinsics(idx, value) + if K is None: + K = torch.eye(3) + K[0, 0] = K[1, 1] = 0.7 * self.image_shape[1] + K[0, 2] = 0.5 * self.image_shape[1] + K[1, 2] = 0.5 * self.image_shape[0] + + camera = Pinhole(K=K[None, ...].clone()) + results[(0, 0)]["camera"] = BatchCamera.from_camera(camera) + results[(0, 0)]["cam2w"] = self.get_extrinsics(idx, value)[ + None, ... + ] + + elif "depth" in key_mapper: + # start = time() + file = h5file_chunk.get_node(value_root).read() + depth = Image.open(io.BytesIO(file)) + depth = TF.pil_to_tensor(depth).squeeze().to(torch.float32) + if depth.ndim == 3: + depth = depth[2] + depth[1] * 255 + depth[0] * 255 * 255 + + results[(0, 0)]["gt_fields"].add(name) + results[(0, 0)][f"depth_ori_shape"] = depth.shape + + depth = ( + depth.view(1, 1, *depth.shape).contiguous() / self.depth_scale + ) + results[(0, 0)][name] = depth + + results = self.preprocess(results) + if not self.test_mode: + results = self.augment(results) + results = self.postprocess(results) + return results + + def preprocess(self, results): + results = self.replicate(results) + for i, seq in enumerate(results["sequence_fields"]): + self.resizer.ctx = None + results[seq] = self.resizer(results[seq]) + num_pts = torch.count_nonzero(results[seq]["depth"] > 0) + if num_pts < 50: + raise IndexError(f"Too few points in depth map ({num_pts})") + + for key in results[seq].get("image_fields", ["image"]): + results[seq][key] = results[seq][key].to(torch.float32) / 255 + + # update fields common in sequence + for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: + if key in results[(0, 0)]: + results[key] = results[(0, 0)][key] + results = self.pack_batch(results) + return results + + def postprocess(self, results): + # normalize after because color aug requires [0,255]? + for key in results.get("image_fields", ["image"]): + results[key] = TF.normalize(results[key], **self.normalization_stats) + results = self.filler(results) + results = self.unpack_batch(results) + results = self.masker(results) + results = self.collecter(results) + return results + + def __getitem__(self, idx): + try: + if isinstance(idx, (list, tuple)): + results = [self.get_single_item(i) for i in idx] + else: + results = self.get_single_item(idx) + except Exception as e: + print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") + idx = np.random.randint(0, len(self.dataset)) + results = self[idx] + return results + + def get_intrinsics(self, idx, image_name): + idx_sample = self.mapper.get("K", 1000) + sample = self.dataset[idx] + if idx_sample >= len(sample): + return None + return sample[idx_sample] + + def get_extrinsics(self, idx, image_name): + idx_sample = self.mapper.get("cam2w", 1000) + sample = self.dataset[idx] + if idx_sample >= len(sample): + return torch.eye(4) + return sample[idx_sample] + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + } diff --git a/unik3d/datasets/ken_burns.py b/unik3d/datasets/ken_burns.py new file mode 100644 index 0000000000000000000000000000000000000000..ab443b4ef1298b472a63ecb88d7432f6e1ad95ba --- /dev/null +++ b/unik3d/datasets/ken_burns.py @@ -0,0 +1,95 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class KenBurns(ImageDataset): + min_depth = 0.05 + max_depth = 50.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = [f"3dkenburns/3DKenBurns_{i}.hdf5" for i in range(8)] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii").strip("\n") + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + + # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: + # f.write(txt_string) + # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: + # json.dump(intrinsics, f) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, chunk_idx = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.25) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + "chunk_idx": 3, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["synthetic"] = [True] * self.num_copies + results["quality"] = [0] * self.num_copies + return results diff --git a/unik3d/datasets/kitti.py b/unik3d/datasets/kitti.py new file mode 100644 index 0000000000000000000000000000000000000000..77b7cef52109279edcc758c0f20a846d4fcbd01b --- /dev/null +++ b/unik3d/datasets/kitti.py @@ -0,0 +1,317 @@ +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.pipelines import AnnotationMask, Compose, KittiCrop +from unik3d.datasets.sequence_dataset import SequenceDataset +from unik3d.datasets.utils import DatasetFromList +from unik3d.utils import identity + + +class KITTI(ImageDataset): + CAM_INTRINSIC = { + "2011_09_26": torch.tensor( + [ + [7.215377e02, 0.000000e00, 6.095593e02, 4.485728e01], + [0.000000e00, 7.215377e02, 1.728540e02, 2.163791e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 2.745884e-03], + ] + ), + "2011_09_28": torch.tensor( + [ + [7.070493e02, 0.000000e00, 6.040814e02, 4.575831e01], + [0.000000e00, 7.070493e02, 1.805066e02, -3.454157e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 4.981016e-03], + ] + ), + "2011_09_29": torch.tensor( + [ + [7.183351e02, 0.000000e00, 6.003891e02, 4.450382e01], + [0.000000e00, 7.183351e02, 1.815122e02, -5.951107e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 2.616315e-03], + ] + ), + "2011_09_30": torch.tensor( + [ + [7.070912e02, 0.000000e00, 6.018873e02, 4.688783e01], + [0.000000e00, 7.070912e02, 1.831104e02, 1.178601e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 6.203223e-03], + ] + ), + "2011_10_03": torch.tensor( + [ + [7.188560e02, 0.000000e00, 6.071928e02, 4.538225e01], + [0.000000e00, 7.188560e02, 1.852157e02, -1.130887e-01], + [0.000000e00, 0.000000e00, 1.000000e00, 3.779761e-03], + ] + ), + } + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + log_mean = 2.5462 + log_std = 0.5871 + test_split = "kitti_eigen_test.txt" + train_split = "kitti_eigen_train.txt" + test_split_benchmark = "kitti_test.txt" + hdf5_paths = ["kitti.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.masker = AnnotationMask( + min_value=0.0, + max_value=self.max_depth if test_mode else None, + custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, + ) + self.test_mode = test_mode + self.crop = crop + self.cropper_base = KittiCrop(crop_size=(352, 1216)) + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename = line.strip().split(" ")[0] + depth_filename = line.strip().split(" ")[1] + if depth_filename == "None": + continue + sample = [ + image_filename, + depth_filename, + ] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_intrinsics(self, idx, image_name): + return self.CAM_INTRINSIC[image_name.split("/")[0]][:, :3].clone() + + def preprocess(self, results): + results = self.replicate(results) + for i, seq in enumerate(results["sequence_fields"]): + self.resizer.ctx = None + results[seq] = self.cropper_base(results[seq]) + results[seq] = self.resizer(results[seq]) + num_pts = torch.count_nonzero(results[seq]["depth"] > 0) + if num_pts < 50: + raise IndexError(f"Too few points in depth map ({num_pts})") + + for key in results[seq].get("image_fields", ["image"]): + results[seq][key] = results[seq][key].to(torch.float32) / 255 + + # update fields common in sequence + for key in ["image_fields", "gt_fields", "mask_fields", "camera_fields"]: + if key in results[(0, 0)]: + results[key] = results[(0, 0)][key] + results = self.pack_batch(results) + return results + + def eval_mask(self, valid_mask, info={}): + """Do grag_crop or eigen_crop for testing""" + mask_height, mask_width = valid_mask.shape[-2:] + eval_mask = torch.zeros_like(valid_mask) + if "garg" in self.crop: + eval_mask[ + ..., + int(0.40810811 * mask_height) : int(0.99189189 * mask_height), + int(0.03594771 * mask_width) : int(0.96405229 * mask_width), + ] = 1 + elif "eigen" in self.crop: + eval_mask[ + ..., + int(0.3324324 * mask_height) : int(0.91351351 * mask_height), + int(0.03594771 * mask_width) : int(0.96405229 * mask_width), + ] = 1 + return torch.logical_and(valid_mask, eval_mask) + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_copies + results["quality"] = [1] * self.num_copies + return results + + +import json + + +class KITTIBenchmark(ImageDataset): + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + test_split = "test_split.txt" + train_split = "val_split.txt" + intrinsics_file = "intrinsics.json" + hdf5_paths = ["kitti_benchmark.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=True, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + + self.masker = AnnotationMask( + min_value=self.min_depth, + max_value=self.max_depth if test_mode else None, + custom_fn=lambda x, *args, **kwargs: x, + ) + self.collecter = Collect(keys=["image_fields", "mask_fields", "gt_fields"]) + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_path), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(self.h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrinsics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics = torch.tensor( + intrinsics[os.path.join(*image_filename.split("/")[:2])] + ).squeeze()[:, :3] + sample = { + "image_filename": image_filename, + "depth_filename": depth_filename, + "K": intrinsics, + } + dataset.append(sample) + + self.dataset = DatasetFromList(dataset) + + self.log_load_dataset() + + +class KITTIRMVD(SequenceDataset): + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + default_fps = 10 + test_split = "test.txt" + train_split = "test.txt" + sequences_file = "sequences.json" + hdf5_paths = ["kitti_rmvd.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + self.crop = crop + self.resizer = Compose([KittiCrop(crop_size=(352, 1216)), self.resizer]) + + def eval_mask(self, valid_mask, info={}): + """Do grag_crop or eigen_crop for testing""" + mask_height, mask_width = valid_mask.shape[-2:] + eval_mask = torch.zeros_like(valid_mask) + if "garg" in self.crop: + eval_mask[ + ..., + int(0.40810811 * mask_height) : int(0.99189189 * mask_height), + int(0.03594771 * mask_width) : int(0.96405229 * mask_width), + ] = 1 + elif "eigen" in self.crop: + eval_mask[ + ..., + int(0.3324324 * mask_height) : int(0.91351351 * mask_height), + int(0.03594771 * mask_width) : int(0.96405229 * mask_width), + ] = 1 + else: + return valid_mask + return torch.logical_and(valid_mask, eval_mask) diff --git a/unik3d/datasets/kitti360.py b/unik3d/datasets/kitti360.py new file mode 100644 index 0000000000000000000000000000000000000000..403afdd2b28d4052181b7b65c986af847be3c99c --- /dev/null +++ b/unik3d/datasets/kitti360.py @@ -0,0 +1,65 @@ +from typing import Any + +import torch + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class KITTI360(SequenceDataset): + min_depth = 0.01 + max_depth = 80.0 + depth_scale = 256.0 + train_split = "train.txt" + test_split = "val_split.txt" + sequences_file = "sequences_split.json" + hdf5_paths = [f"KITTI360.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=( + decode_fields if not test_mode else [*decode_fields, "points"] + ), + inplace_fields=inplace_fields, + **kwargs, + ) + + def preprocess(self, results): + self.resizer.ctx = None + for i, seq in enumerate(results["sequence_fields"]): + # Create a mask where the distance from the center is less than H/2 + H, W = results[seq]["image"].shape[-2:] + x = torch.linspace(-W / 2, W / 2, W) + y = torch.linspace(-H / 2, H / 2, H) + xv, yv = torch.meshgrid(x, y, indexing="xy") + distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) + results[seq]["validity_mask"] = distance_from_center < (H / 2) + return super().preprocess(results) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/lyft.py b/unik3d/datasets/lyft.py new file mode 100644 index 0000000000000000000000000000000000000000..bafeef68bc07a8222202c315e25ae24e8acf8597 --- /dev/null +++ b/unik3d/datasets/lyft.py @@ -0,0 +1,84 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class Lyft(ImageDataset): + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + test_split = "test.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["Lyft2.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: + # f.write(txt_string) + # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: + # json.dump(intrinsics, f) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [ + image_filename, + depth_filename, + intrinsics_val, + ] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] + return results diff --git a/unik3d/datasets/mapillary.py b/unik3d/datasets/mapillary.py new file mode 100644 index 0000000000000000000000000000000000000000..d58054add6dcd1cd5243de3bbd9eb496862d3488 --- /dev/null +++ b/unik3d/datasets/mapillary.py @@ -0,0 +1,84 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class Mapillary(ImageDataset): + min_depth = 0.01 + max_depth = 70.0 + depth_scale = 256.0 + test_split = "mapillary_val.txt" + train_split = "mapillary_train_clean.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["Mapillary.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.05) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["si"] = [True] * self.num_copies + results["valid_camera"] = [False] * self.num_copies + results["dense"] = [False] * self.num_copies + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/matrix_city.py b/unik3d/datasets/matrix_city.py new file mode 100644 index 0000000000000000000000000000000000000000..32a7f9f93451def80a1561d18a93777bc357f445 --- /dev/null +++ b/unik3d/datasets/matrix_city.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class MatrixCity(SequenceDataset): + min_depth = 0.01 + max_depth = 200.0 + depth_scale = 1000.0 + test_split = "test.txt" + train_split = "train_full.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"MatrixCity.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/matterport3d.py b/unik3d/datasets/matterport3d.py new file mode 100644 index 0000000000000000000000000000000000000000..7c1602725974ae7959987ee1127024696dc0a3f2 --- /dev/null +++ b/unik3d/datasets/matterport3d.py @@ -0,0 +1,56 @@ +from typing import Any + +import torch + +from unik3d.datasets.pipelines import Compose, PanoCrop, PanoRoll +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Matterport3D(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"Matterport3D.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["cam2w", "camera_params"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + self.resizer = Compose( + [PanoCrop(), PanoRoll(test_mode=test_mode), self.resizer] + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/megadepth.py b/unik3d/datasets/megadepth.py new file mode 100644 index 0000000000000000000000000000000000000000..1d65d8d1604595b14e516b1258ad49d23fb33acd --- /dev/null +++ b/unik3d/datasets/megadepth.py @@ -0,0 +1,83 @@ +import os + +import h5py +import numpy as np + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class MegaDepth(ImageDataset): + min_depth = 0.01 + max_depth = 1000.0 + depth_scale = 50.0 + test_split = "test.txt" + train_split = "train.txt" + hdf5_paths = ["MegaDepth.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: + # f.write(txt_string) + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + sample = [ + image_filename, + depth_filename, + ] + dataset.append(sample) + h5file.close() + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + else: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.5) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["ssi"] = [True] + results["valid_camera"] = [False] + results["dense"] = [False] + return results + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } diff --git a/unik3d/datasets/megadepth_s.py b/unik3d/datasets/megadepth_s.py new file mode 100644 index 0000000000000000000000000000000000000000..ad6189460f90aa527574c4e1554131e0c17b1fd1 --- /dev/null +++ b/unik3d/datasets/megadepth_s.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class MegaDepthS(SequenceDataset): + min_depth = 0.001 + max_depth = 10000.0 + depth_scale = 512.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences_filter_clean.json" + hdf5_paths = ["MegaDepthS.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["intrinsics", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["si"] = [True] * self.num_frames * self.num_copies + results["dense"] = [False] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/midair.py b/unik3d/datasets/midair.py new file mode 100644 index 0000000000000000000000000000000000000000..184a78305fd89e9ffc8dd6682ccc9ab5a645d0fc --- /dev/null +++ b/unik3d/datasets/midair.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class MidAir(SequenceDataset): + min_depth = 0.1 + max_depth = 1000.0 + depth_scale = 1000.0 + default_fps = 6 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["MidAir.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/mip.py b/unik3d/datasets/mip.py new file mode 100644 index 0000000000000000000000000000000000000000..ba18a000e60f5ffc694088de38f69bdf30dc2fe0 --- /dev/null +++ b/unik3d/datasets/mip.py @@ -0,0 +1,52 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class MIP(SequenceDataset): + min_depth = 0.01 + max_depth = 100.0 + depth_scale = 1000.0 + default_fps = 10 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["MIP.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["si"] = [True] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/ms2.py b/unik3d/datasets/ms2.py new file mode 100644 index 0000000000000000000000000000000000000000..bd9f13dd1fb3ff96a55cf4992fe35127fc8a2626 --- /dev/null +++ b/unik3d/datasets/ms2.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class MS2(SequenceDataset): + min_depth = 0.01 + max_depth = 100.0 + depth_scale = 256.0 + default_fps = 5 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["MS2.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/mvimgnet.py b/unik3d/datasets/mvimgnet.py new file mode 100644 index 0000000000000000000000000000000000000000..53976cb31f8f0abd006d910478716fdbe83d8bf2 --- /dev/null +++ b/unik3d/datasets/mvimgnet.py @@ -0,0 +1,137 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + +INVALID_SEQUENCES = [ + "1/000121f2-0", + "15/1600ae56-0", + "26/000000f3-0", + "33/1d00e677-0", + "43/22008925-0", + "49/000147db-0", + "51/23002a43-0", + "51/23000916-0", + "108/000133ae-0", + "129/000037f2-0", + "141/17012545-0", + "141/1700f3de-0", + "152/1b00e061-0", + "154/1d00decb-0", + "154/1d017c1c-0", + "154/1d0019a5-0", + "154/1d00334d-0", + "154/1d012ed6-0", + "154/1d016b8a-0", + "154/1d016cc1-0", + "154/1d008d5f-0", + "159/000157f9-0", + "159/00000b96-0", + "159/000075c0-0", + "159/0000445c-0", + "159/000056a0-0", + "159/00010c68-0", + "159/0000573b-0", + "159/00002698-0", + "159/00008fca-0", + "159/00009ef8-0", + "159/00015f05-0", + "159/0000c6df-0", + "159/0000ee59-0", + "163/290159d2-0", + "163/29016c7c-0", + "163/2900239c-0", + "163/29002f7b-0", + "163/29014b05-0", + "163/29000196-0", + "163/2901750f-0", + "164/1b0145cf-0", + "164/1b00eb1d-0", + "164/1b00c28b-0", + "164/1b0110d0-0", + "164/1b00dd20-0", + "165/2600e15a-0", + "165/26008444-0", + "165/260145c5-0", + "165/26003a0c-0", + "165/260106ba-0", + "165/26001548-0", + "167/2a0092b0-0", + "167/2a014dbe-0", + "167/2a003ce6-0", + "169/1800c645-0", + "171/2500014d-0", + "176/1d0021c2-0", + "176/1d014abf-0", + "176/1d00e714-0", + "176/1d0159cb-0", + "176/1e016629-0", + "178/000102b8-0", + "191/23008fdb-0", + "191/2300187f-0", + "191/2300ae68-0", + "191/230076dd-0", + "191/24007d7e-0", + "192/000107b5-0", + "195/1f012359-0", + "195/1f00f751-0", + "195/1f011331-0", + "195/1e00d999-0", + "196/1c01304e-0", + "198/1a00e02f-0", + "198/050084ac-0", + "198/1a0075fa-0", + "199/1e001742-0", + "199/1e00116a-0", + "199/1e011d00-0", + "199/1e018040-0", + "199/1e001107-0", +] + + +class MVImgNet(SequenceDataset): + min_depth = 0.005 + max_depth = 10.0 + # weird scale issue, should be 1000, but avg depth is ~10meters... + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["MVImgNet.hdf5"] + invalid_sequences = INVALID_SEQUENCES + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["intrinsics", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["si"] = [True] * self.num_frames * self.num_copies + results["dense"] = [False] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/mvsynth.py b/unik3d/datasets/mvsynth.py new file mode 100644 index 0000000000000000000000000000000000000000..1d34f6c0c57ebe32baab09d80459f932cd9aff43 --- /dev/null +++ b/unik3d/datasets/mvsynth.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class MVSynth(SequenceDataset): + min_depth = 0.1 + max_depth = 1000.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"MVSynth.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["si"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/nerds360.py b/unik3d/datasets/nerds360.py new file mode 100644 index 0000000000000000000000000000000000000000..c669a0ab46a3d7b09cdfefa5c688d7aecadb7f32 --- /dev/null +++ b/unik3d/datasets/nerds360.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class NeRDS360(SequenceDataset): + min_depth = 0.01 + max_depth = 1000.0 + depth_scale = 1000.0 + test_split = "val.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["NeRDS360.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/niantic_mapfree.py b/unik3d/datasets/niantic_mapfree.py new file mode 100644 index 0000000000000000000000000000000000000000..768cb3f0c4f11dd821bef00dd51dbeffb0d9b330 --- /dev/null +++ b/unik3d/datasets/niantic_mapfree.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class NianticMapFree(SequenceDataset): + min_depth = 0.1 + max_depth = 250.0 + depth_scale = 512.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"NianticMapFree.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["si"] = [True] * self.num_frames * self.num_copies + results["dense"] = [False] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/nuscenes.py b/unik3d/datasets/nuscenes.py new file mode 100644 index 0000000000000000000000000000000000000000..1992efba64c81927f0452f363f3018180d4c0873 --- /dev/null +++ b/unik3d/datasets/nuscenes.py @@ -0,0 +1,89 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class Nuscenes(ImageDataset): + min_depth = 0.05 + max_depth = 80.0 + depth_scale = 256.0 + test_split = "val.txt" + train_split = "train.txt" + intrisics_file = "intrinsics.json" + # hdf5_paths = ["Nuscenes2.hdf5"] + hdf5_paths = [f"nuscenes/nuscenes_{i}.hdf5" for i in range(8)] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii").strip("\n") + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, chunk_idx = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=6, pct=self.mini) + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=6, pct=0.1) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + "chunk_idx": 3, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_copies + results["quality"] = [1] * self.num_copies + return results diff --git a/unik3d/datasets/nyuv2.py b/unik3d/datasets/nyuv2.py new file mode 100644 index 0000000000000000000000000000000000000000..c270c83af17f4527adacc3fe0122022501c4b14b --- /dev/null +++ b/unik3d/datasets/nyuv2.py @@ -0,0 +1,112 @@ +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.pipelines import AnnotationMask +from unik3d.datasets.utils import DatasetFromList +from unik3d.utils import identity + + +class NYUv2Depth(ImageDataset): + CAM_INTRINSIC = { + "ALL": torch.tensor( + [ + [5.1885790117450188e02, 0, 3.2558244941119034e02], + [0, 5.1946961112127485e02, 2.5373616633400465e02], + [0, 0, 1], + ] + ) + } + min_depth = 0.005 + max_depth = 10.0 + depth_scale = 1000.0 + log_mean = 0.9140 + log_std = 0.4825 + test_split = "nyu_test.txt" + train_split = "nyu_train.txt" + hdf5_paths = ["nyuv2.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.masker = AnnotationMask( + min_value=0.0, + max_value=self.max_depth if test_mode else None, + custom_fn=self.eval_mask if test_mode else lambda x, *args, **kwargs: x, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, _ = line.strip().split(" ") + sample = [ + image_filename, + depth_filename, + ] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + return results + + def get_intrinsics(self, idx, image_name): + return self.CAM_INTRINSIC["ALL"].clone() + + def eval_mask(self, valid_mask, info={}): + border_mask = torch.zeros_like(valid_mask) + border_mask[..., 45:-9, 41:-39] = 1 + return torch.logical_and(valid_mask, border_mask) + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/pipelines/__init__.py b/unik3d/datasets/pipelines/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07aa134db64b5134bf32e797608360ec70e450df --- /dev/null +++ b/unik3d/datasets/pipelines/__init__.py @@ -0,0 +1,46 @@ +from .formating import AnnotationMask, Collect +from .transforms import (Compose, ContextCrop, Crop, Dilation, DownsamplerGT, + DummyCrop, GaussianBlur, KittiCrop, MotionBlur, + PanoCrop, PanoRoll, RandomAutoContrast, + RandomBrightness, RandomColor, RandomColorJitter, + RandomContrast, RandomCut, RandomEqualize, + RandomFiller, RandomFlip, RandomGamma, + RandomGrayscale, RandomInvert, RandomMasking, + RandomPosterize, RandomSaturation, RandomSharpness, + RandomShear, RandomSolarize, RandomTranslate, Resize, + Rotate) + +__all__ = [ + "Resize", + "Rotate", + "RandomFlip", + "RandomBrightness", + "Crop", + "Dilation", + "RandomColor", + "RandomContrast", + "RandomEqualize", + "RandomSaturation", + "Collect", + "AnnotationMask", + "RandomSolarize", + "RandomPosterize", + "RandomSharpness", + "RandomShear", + "RandomTranslate", + "RandomAutoContrast", + "RandomInvert", + "KittiCrop", + "DownsamplerGT", + "RandomCut", + "RandomGamma", + "RandomMasking", + "RandomColorJitter", + "GaussianBlur", + "RandomGrayscale", + "ContextCrop", + "RandomFiller", + "MotionBlur", + "Compose", + "DummyCrop", +] diff --git a/unik3d/datasets/pipelines/formating.py b/unik3d/datasets/pipelines/formating.py new file mode 100644 index 0000000000000000000000000000000000000000..f746d1727e5152bd1047f5972a5e175d8ed03b2d --- /dev/null +++ b/unik3d/datasets/pipelines/formating.py @@ -0,0 +1,113 @@ +from collections.abc import Sequence + +import numpy as np +import torch + + +class Collect(object): + def __init__( + self, + keys, + meta_keys=( + "filename", + "keyframe_idx", + "sequence_name", + "image_filename", + "depth_filename", + "image_ori_shape", + "camera", + "original_camera", + "sfm", + "image_shape", + "resized_shape", + "scale_factor", + "rotation", + "resize_factor", + "flip", + "flip_direction", + "dataset_name", + "paddings", + "max_value", + "log_mean", + "log_std", + "image_rescale", + "focal_rescale", + "depth_rescale", + ), + ): + self.keys = keys + self.meta_keys = meta_keys + + def __call__(self, results): + data_keys = [key for field in self.keys for key in results.get(field, [])] + data = { + key: { + sequence_key: results[key][sequence_key] + for sequence_key in results["sequence_fields"] + } + for key in data_keys + } + data["img_metas"] = { + key: value for key, value in results.items() if key not in data_keys + } + return data + + def __repr__(self): + return ( + self.__class__.__name__ + f"(keys={self.keys}, meta_keys={self.meta_keys})" + ) + + +class AnnotationMask(object): + def __init__(self, min_value, max_value, custom_fn=lambda x: x): + self.min_value = min_value + self.max_value = max_value + self.custom_fn = custom_fn + + def __call__(self, results): + for key in results.get("gt_fields", []): + if key + "_mask" in results["mask_fields"]: + if "flow" in key: + for sequence_idx in results.get("sequence_fields", []): + boundaries = (results[key][sequence_idx] >= -1) & ( + results[key][sequence_idx] <= 1 + ) + boundaries = boundaries[:, :1] & boundaries[:, 1:] + results[key + "_mask"][sequence_idx] = ( + results[key + "_mask"][sequence_idx].bool() & boundaries + ) + continue + for sequence_idx in results.get("sequence_fields", []): + # take care of xyz or flow, dim=1 is the channel dim + if results[key][sequence_idx].shape[1] == 1: + mask = results[key][sequence_idx] > self.min_value + else: + mask = ( + results[key][sequence_idx].norm(dim=1, keepdim=True) + > self.min_value + ) + if self.max_value is not None: + if results[key][sequence_idx].shape[1] == 1: + mask = mask & (results[key][sequence_idx] < self.max_value) + else: + mask = mask & ( + results[key][sequence_idx].norm(dim=1, keepdim=True) + < self.max_value + ) + mask = self.custom_fn(mask, info=results) + if key + "_mask" not in results: + results[key + "_mask"] = {} + if sequence_idx not in results[key + "_mask"]: + results[key + "_mask"][sequence_idx] = mask.bool() + else: + results[key + "_mask"][sequence_idx] = ( + results[key + "_mask"][sequence_idx].bool() & mask.bool() + ) + results["mask_fields"].add(key + "_mask") + return results + + def __repr__(self): + return ( + self.__class__.__name__ + + f"(min_value={self.min_value}, max_value={ self.max_value})" + ) diff --git a/unik3d/datasets/pipelines/transforms.py b/unik3d/datasets/pipelines/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..d1f035f1b1ce224494c4019ea6bd97b5621f45cc --- /dev/null +++ b/unik3d/datasets/pipelines/transforms.py @@ -0,0 +1,2176 @@ +import os +import random +from copy import deepcopy +from math import ceil, exp, log, log2, log10, tanh +from typing import Dict, List, Tuple + +import numpy as np +import torch +import torch.nn.functional as F +import torchvision.transforms.v2.functional as TF + +from unik3d.utils.geometric import downsample + + +def euler_to_rotation_matrix(angles): + """ + Convert Euler angles to a 3x3 rotation matrix. + + Args: + angles (torch.Tensor): Euler angles [roll, pitch, yaw]. + + Returns: + torch.Tensor: 3x3 rotation matrix. + """ + phi, theta, psi = angles + cos_phi, sin_phi = torch.cos(phi), torch.sin(phi) + cos_theta, sin_theta = torch.cos(theta), torch.sin(theta) + cos_psi, sin_psi = torch.cos(psi), torch.sin(psi) + + # Rotation matrices + Rx = torch.tensor([[1, 0, 0], [0, cos_phi, -sin_phi], [0, sin_phi, cos_phi]]) + + Ry = torch.tensor( + [[cos_theta, 0, sin_theta], [0, 1, 0], [-sin_theta, 0, cos_theta]] + ) + + Rz = torch.tensor([[cos_psi, -sin_psi, 0], [sin_psi, cos_psi, 0], [0, 0, 1]]) + + return Rz @ Ry @ Rx + + +def compute_grid(H, W): + meshgrid = torch.meshgrid(torch.arange(W), torch.arange(H), indexing="xy") + id_coords = torch.stack(meshgrid, axis=0).to(torch.float32) + id_coords = id_coords.reshape(2, -1) + id_coords = torch.cat( + [id_coords, torch.ones(1, id_coords.shape[-1])], dim=0 + ) # 3 HW + id_coords = id_coords.unsqueeze(0) + return id_coords + + +def lexsort(keys): + sorted_indices = torch.arange(keys[0].size(0)) + for key in reversed(keys): + _, sorted_indices = key[sorted_indices].sort() + return sorted_indices + + +def masked_bilinear_interpolation(input, mask, target_size): + B, C, H, W = input.shape + target_H, target_W = target_size + mask = mask.float() + + # Generate a grid of coordinates in the target space + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, target_H), torch.linspace(0, W - 1, target_W) + ) + grid_y = grid_y.to(input.device) + grid_x = grid_x.to(input.device) + + # Calculate the floor and ceil of the grid coordinates to get the bounding box + x0 = torch.floor(grid_x).long().clamp(0, W - 1) + x1 = (x0 + 1).clamp(0, W - 1) + y0 = torch.floor(grid_y).long().clamp(0, H - 1) + y1 = (y0 + 1).clamp(0, H - 1) + + # Gather depth values at the four corners + Ia = input[..., y0, x0] + Ib = input[..., y1, x0] + Ic = input[..., y0, x1] + Id = input[..., y1, x1] + + # Gather corresponding mask values + ma = mask[..., y0, x0] + mb = mask[..., y1, x0] + mc = mask[..., y0, x1] + md = mask[..., y1, x1] + + # Calculate the areas (weights) for bilinear interpolation + wa = (x1.float() - grid_x) * (y1.float() - grid_y) + wb = (x1.float() - grid_x) * (grid_y - y0.float()) + wc = (grid_x - x0.float()) * (y1.float() - grid_y) + wd = (grid_x - x0.float()) * (grid_y - y0.float()) + + wa = wa.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) + wb = wb.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) + wc = wc.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) + wd = wd.reshape(1, 1, target_H, target_W).repeat(B, C, 1, 1) + + # Only consider valid points for interpolation + weights_sum = (wa * ma) + (wb * mb) + (wc * mc) + (wd * md) + weights_sum = torch.clamp(weights_sum, min=1e-5) + + # Perform the interpolation + interpolated_depth = ( + wa * Ia * ma + wb * Ib * mb + wc * Ic * mc + wd * Id * md + ) / weights_sum + + return interpolated_depth, (ma + mb + mc + md) > 0 + + +def masked_nearest_interpolation(input, mask, target_size): + B, C, H, W = input.shape + target_H, target_W = target_size + mask = mask.float() + + # Generate a grid of coordinates in the target space + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, target_H), + torch.linspace(0, W - 1, target_W), + indexing="ij", + ) + grid_y = grid_y.to(input.device) + grid_x = grid_x.to(input.device) + + # Calculate the floor and ceil of the grid coordinates to get the bounding box + x0 = torch.floor(grid_x).long().clamp(0, W - 1) + x1 = (x0 + 1).clamp(0, W - 1) + y0 = torch.floor(grid_y).long().clamp(0, H - 1) + y1 = (y0 + 1).clamp(0, H - 1) + + # Gather depth values at the four corners + Ia = input[..., y0, x0] + Ib = input[..., y1, x0] + Ic = input[..., y0, x1] + Id = input[..., y1, x1] + + # Gather corresponding mask values + ma = mask[..., y0, x0] + mb = mask[..., y1, x0] + mc = mask[..., y0, x1] + md = mask[..., y1, x1] + + # Calculate distances to each neighbor + # The distances are calculated from the center (grid_x, grid_y) to each corner + dist_a = (grid_x - x0.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-left + dist_b = (grid_x - x0.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-left + dist_c = (grid_x - x1.float()) ** 2 + (grid_y - y0.float()) ** 2 # Top-right + dist_d = (grid_x - x1.float()) ** 2 + (grid_y - y1.float()) ** 2 # Bottom-right + + # Stack the neighbors, their masks, and distances + stacked_values = torch.stack( + [Ia, Ib, Ic, Id], dim=-1 + ) # Shape: (B, C, target_H, target_W, 4) + stacked_masks = torch.stack( + [ma, mb, mc, md], dim=-1 + ) # Shape: (B, 1, target_H, target_W, 4) + stacked_distances = torch.stack( + [dist_a, dist_b, dist_c, dist_d], dim=-1 + ) # Shape: (target_H, target_W, 4) + stacked_distances = ( + stacked_distances.unsqueeze(0).unsqueeze(1).repeat(B, 1, 1, 1, 1) + ) # Shape: (B, 1, target_H, target_W, 4) + + # Set distances to infinity for invalid neighbors (so that invalid neighbors are never chosen) + stacked_distances[stacked_masks == 0] = float("inf") + + # Find the index of the nearest valid neighbor (the one with the smallest distance) + nearest_indices = stacked_distances.argmin(dim=-1, keepdim=True)[ + ..., :1 + ] # Shape: (B, 1, target_H, target_W, 1) + + # Select the corresponding depth value using the nearest valid neighbor index + interpolated_depth = torch.gather( + stacked_values, dim=-1, index=nearest_indices.repeat(1, C, 1, 1, 1) + ).squeeze(-1) + + # Set depth to zero where no valid neighbors were found + interpolated_depth = interpolated_depth * stacked_masks.sum(dim=-1).clip( + min=0.0, max=1.0 + ) + + return interpolated_depth + + +def masked_nxn_interpolation(input, mask, target_size, N=2): + B, C, H, W = input.shape + target_H, target_W = target_size + + # Generate a grid of coordinates in the target space + grid_y, grid_x = torch.meshgrid( + torch.linspace(0, H - 1, target_H), + torch.linspace(0, W - 1, target_W), + indexing="ij", + ) + grid_y = grid_y.to(input.device) + grid_x = grid_x.to(input.device) + + # Calculate the top-left corner of the NxN grid + half_N = (N - 1) // 2 + y0 = torch.floor(grid_y - half_N).long().clamp(0, H - 1) + x0 = torch.floor(grid_x - half_N).long().clamp(0, W - 1) + + # Prepare to gather NxN neighborhoods + input_patches = [] + mask_patches = [] + weights = [] + + for i in range(N): + for j in range(N): + yi = (y0 + i).clamp(0, H - 1) + xi = (x0 + j).clamp(0, W - 1) + + # Gather depth and mask values + input_patches.append(input[..., yi, xi]) + mask_patches.append(mask[..., yi, xi]) + + # Compute bilinear weights + weight_y = 1 - torch.abs(grid_y - yi.float()) / N + weight_x = 1 - torch.abs(grid_x - xi.float()) / N + weight = ( + (weight_y * weight_x) + .reshape(1, 1, target_H, target_W) + .repeat(B, C, 1, 1) + ) + weights.append(weight) + + input_patches = torch.stack(input_patches) + mask_patches = torch.stack(mask_patches) + weights = torch.stack(weights) + + # Calculate weighted sum and normalize by the sum of weights + weighted_sum = (input_patches * mask_patches * weights).sum(dim=0) + weights_sum = (mask_patches * weights).sum(dim=0) + interpolated_tensor = weighted_sum / torch.clamp(weights_sum, min=1e-8) + + if N != 2: + interpolated_tensor_2x2, mask_sum_2x2 = masked_bilinear_interpolation( + input, mask, target_size + ) + interpolated_tensor = torch.where( + mask_sum_2x2, interpolated_tensor_2x2, interpolated_tensor + ) + + return interpolated_tensor + + +class PanoCrop: + def __init__(self, crop_v=0.15): + self.crop_v = crop_v + + def _crop_data(self, results, crop_size): + offset_w, offset_h = crop_size + left, top, right, bottom = offset_w[0], offset_h[0], offset_w[1], offset_h[1] + H, W = results["image"].shape[-2:] + for key in results.get("image_fields", ["image"]): + img = results[key][..., top : H - bottom, left : W - right] + results[key] = img + results["image_shape"] = tuple(img.shape) + + for key in results.get("gt_fields", []): + results[key] = results[key][..., top : H - bottom, left : W - right] + + for key in results.get("mask_fields", []): + results[key] = results[key][..., top : H - bottom, left : W - right] + + results["camera"] = results["camera"].crop(left, top, right, bottom) + return results + + def __call__(self, results): + H, W = results["image"].shape[-2:] + crop_w = (0, 0) + crop_h = (int(H * self.crop_v), int(H * self.crop_v)) + results = self._crop_data(results, (crop_w, crop_h)) + return results + + +class PanoRoll: + def __init__(self, test_mode, roll=[-0.5, 0.5]): + self.roll = roll + self.test_mode = test_mode + + def __call__(self, results): + if self.test_mode: + return results + W = results["image"].shape[-1] + roll = random.randint(int(W * self.roll[0]), int(W * self.roll[1])) + for key in results.get("image_fields", ["image"]): + img = results[key] + img = torch.roll(img, roll, dims=-1) + results[key] = img + for key in results.get("gt_fields", []): + results[key] = torch.roll(results[key], roll, dims=-1) + for key in results.get("mask_fields", []): + results[key] = torch.roll(results[key], roll, dims=-1) + return results + + +class RandomFlip: + def __init__(self, direction="horizontal", prob=0.5, consistent=False, **kwargs): + self.flip_ratio = prob + valid_directions = ["horizontal", "vertical", "diagonal"] + if isinstance(direction, str): + assert direction in valid_directions + elif isinstance(direction, list): + assert set(direction).issubset(set(valid_directions)) + else: + raise ValueError("direction must be either str or list of str") + self.direction = direction + self.consistent = consistent + + def __call__(self, results): + if "flip" not in results: + # None means non-flip + if isinstance(self.direction, list): + direction_list = self.direction + [None] + else: + direction_list = [self.direction, None] + + if isinstance(self.flip_ratio, list): + non_flip_ratio = 1 - sum(self.flip_ratio) + flip_ratio_list = self.flip_ratio + [non_flip_ratio] + else: + non_flip_ratio = 1 - self.flip_ratio + # exclude non-flip + single_ratio = self.flip_ratio / (len(direction_list) - 1) + flip_ratio_list = [single_ratio] * (len(direction_list) - 1) + [ + non_flip_ratio + ] + + cur_dir = np.random.choice(direction_list, p=flip_ratio_list) + + results["flip"] = cur_dir is not None + + if "flip_direction" not in results: + results["flip_direction"] = cur_dir + + if results["flip"]: + # flip image + if results["flip_direction"] != "vertical": + for key in results.get("image_fields", ["image"]): + results[key] = TF.hflip(results[key]) + for key in results.get("mask_fields", []): + results[key] = TF.hflip(results[key]) + for key in results.get("gt_fields", []): + results[key] = TF.hflip(results[key]) + if "flow" in key: # flip u direction + results[key][:, 0] = -results[key][:, 0] + + H, W = results["image"].shape[-2:] + results["camera"] = results["camera"].flip( + H=H, W=W, direction="horizontal" + ) + flip_transform = torch.tensor( + [[-1, 0, 0, 0], [0, 1, 0, 0], [0, 0, 1, 0], [0, 0, 0, 1]], + dtype=torch.float32, + ).unsqueeze(0) + repeats = (results["cam2w"].shape[0],) + (1,) * ( + results["cam2w"].ndim - 1 + ) + results["cam2w"] = flip_transform.repeat(*repeats) @ results["cam2w"] + + if results["flip_direction"] != "horizontal": + for key in results.get("image_fields", ["image"]): + results[key] = TF.vflip(results[key]) + for key in results.get("mask_fields", []): + results[key] = TF.vflip(results[key]) + for key in results.get("gt_fields", []): + results[key] = TF.vflip(results[key]) + results["K"][..., 1, 2] = ( + results["image"].shape[-2] - results["K"][..., 1, 2] + ) + results["flip"] = [results["flip"]] * len(results["image"]) + return results + + +class Crop: + def __init__( + self, + crop_size, + crop_type="absolute", + crop_offset=(0, 0), + ): + if crop_type not in [ + "relative_range", + "relative", + "absolute", + "absolute_range", + ]: + raise ValueError(f"Invalid crop_type {crop_type}.") + if crop_type in ["absolute", "absolute_range"]: + assert crop_size[0] > 0 and crop_size[1] > 0 + assert isinstance(crop_size[0], int) and isinstance(crop_size[1], int) + else: + assert 0 < crop_size[0] <= 1 and 0 < crop_size[1] <= 1 + self.crop_size = crop_size + self.crop_type = crop_type + self.offset_h, self.offset_w = ( + crop_offset[: len(crop_offset) // 2], + crop_offset[len(crop_offset) // 2 :], + ) + + def _get_crop_size(self, image_shape): + h, w = image_shape + if self.crop_type == "absolute": + return (min(self.crop_size[0], h), min(self.crop_size[1], w)) + elif self.crop_type == "absolute_range": + assert self.crop_size[0] <= self.crop_size[1] + crop_h = np.random.randint( + min(h, self.crop_size[0]), min(h, self.crop_size[1]) + 1 + ) + crop_w = np.random.randint( + min(w, self.crop_size[0]), min(w, self.crop_size[1]) + 1 + ) + return crop_h, crop_w + elif self.crop_type == "relative": + crop_h, crop_w = self.crop_size + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + elif self.crop_type == "relative_range": + crop_size = np.asarray(self.crop_size, dtype=np.float32) + crop_h, crop_w = crop_size + np.random.rand(2) * (1 - crop_size) + return int(h * crop_h + 0.5), int(w * crop_w + 0.5) + + def _crop_data(self, results, crop_size): + assert crop_size[0] > 0 and crop_size[1] > 0 + for key in results.get("image_fields", ["image"]): + img = results[key] + img = TF.crop( + img, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] + ) + results[key] = img + results["image_shape"] = tuple(img.shape) + + for key in results.get("gt_fields", []): + gt = results[key] + results[key] = TF.crop( + gt, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] + ) + + # crop semantic seg + for key in results.get("mask_fields", []): + mask = results[key] + results[key] = TF.crop( + mask, self.offset_h[0], self.offset_w[0], crop_size[0], crop_size[1] + ) + + results["K"][..., 0, 2] = results["K"][..., 0, 2] - self.offset_w[0] + results["K"][..., 1, 2] = results["K"][..., 1, 2] - self.offset_h[0] + return results + + def __call__(self, results): + image_shape = results["image"].shape[-2:] + crop_size = self._get_crop_size(image_shape) + results = self._crop_data(results, crop_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(crop_size={self.crop_size}, " + repr_str += f"crop_type={self.crop_type}, " + return repr_str + + +class KittiCrop: + def __init__(self, crop_size): + self.crop_size = crop_size + + def _crop_data(self, results, crop_size): + """Function to randomly crop images, bounding boxes, masks, semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + crop_size (tuple): Expected absolute size after cropping, (h, w). + allow_negative_crop (bool): Whether to allow a crop that does not + contain any bbox area. Default to False. + + Returns: + dict: Randomly cropped results, 'image_shape' key in result dict is + updated according to crop size. + """ + assert crop_size[0] > 0 and crop_size[1] > 0 + for key in results.get("image_fields", ["image"]): + img = results[key] + h, w = img.shape[-2:] + offset_h, offset_w = int(h - self.crop_size[0]), int( + (w - self.crop_size[1]) / 2 + ) + + # crop the image + img = TF.crop(img, offset_h, offset_w, crop_size[0], crop_size[1]) + results[key] = img + results["image_shape"] = tuple(img.shape) + + for key in results.get("gt_fields", []): + gt = results[key] + results[key] = TF.crop(gt, offset_h, offset_w, crop_size[0], crop_size[1]) + + # crop semantic seg + for key in results.get("mask_fields", []): + mask = results[key] + results[key] = TF.crop(mask, offset_h, offset_w, crop_size[0], crop_size[1]) + + results["camera"] = results["camera"].crop(offset_w, offset_h) + return results + + def __call__(self, results): + """Call function to randomly crop images, bounding boxes, masks, + semantic segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Randomly cropped results, 'image_shape' key in result dict is + updated according to crop size. + """ + results = self._crop_data(results, self.crop_size) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(crop_size={self.crop_size}, " + return repr_str + + +class RandomMasking: + def __init__( + self, + mask_ratio, + mask_patch=16, + prob=0.5, + warmup_steps=50000, + sampling="random", + curriculum=False, + ): + self.mask_patch = mask_patch + self.prob = prob + self.mask_ratio = mask_ratio + self.warmup_steps = max(1, warmup_steps) + self.hard_bound = 1 + self.idx = 0 + self.curriculum = curriculum + self.sampling = sampling + self.low_bound = 0.0 + self.up_bound = 0.0 + + def __call__(self, results): + B, _, H, W = results["image"].shape + device = results["image"].device + down_size = H // self.mask_patch, W // self.mask_patch + if np.random.random() > self.prob: # fill with dummy + return self._nop(results, down_size, device) + + validity_mask = results["validity_mask"].float().reshape(B, -1, H, W) + validity_mask = F.interpolate(validity_mask, size=down_size).bool() + validity_mask = validity_mask.reshape(B, 1, *down_size) + is_random = self.is_warmup or results.get("guidance") is None + + if not is_random: + guidance = F.interpolate(results["guidance"], size=(H, W), mode="bilinear") + results["guidance"] = -F.max_pool2d( + -guidance, kernel_size=self.mask_patch, stride=self.mask_patch + ) + + if is_random and self.sampling == "inverse": + sampling = self.inverse_sampling + elif is_random and self.sampling == "random": + sampling = self.random_sampling + else: + sampling = self.guided_sampling + mask_ratio = np.random.uniform(self.low_bound, self.up_bound) + for key in results.get("image_fields", ["image"]): + mask = sampling(results, mask_ratio, down_size, validity_mask, device) + results[key + "_mask"] = mask + return results + + def _nop(self, results, down_size, device): + B = results["image"].shape[0] + for key in results.get("image_fields", ["image"]): + mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) + results[key + "_mask"] = mask_blocks + return results + + def random_sampling(self, results, mask_ratio, down_size, validity_mask, device): + B = results["image"].shape[0] + prob_blocks = torch.rand(size=(B, 1, *down_size), device=device) + mask_blocks = torch.logical_and(prob_blocks < mask_ratio, validity_mask) + return mask_blocks + + def inverse_sampling(self, results, mask_ratio, down_size, validity_mask, device): + # from PIL import Image + # from unik3d.utils import colorize + def area_sample(depth, fx, fy): + dtype = depth.dtype + B = depth.shape[0] + H, W = down_size + depth = downsample(depth, depth.shape[-2] // H) + depth[depth > 200] = 50 # set sky as if depth 50 meters + pixel_area3d = depth / torch.sqrt(fx * fy) + + # Set invalid as -1 (no div problem) -> then clip to 0.0 + pixel_area3d[depth == 0.0] = -1 + prob_density = (1 / pixel_area3d).clamp(min=0.0).square() + prob_density = prob_density / prob_density.sum( + dim=(-1, -2), keepdim=True + ).clamp(min=1e-5) + # Image.fromarray((prob_density[0] * 255 * 100).clamp(min=0.0, max=255.0).squeeze().cpu().byte().numpy()).save("prob_density.png") + + # Sample locations based on prob_density + prob_density_flat = prob_density.view(B, -1) + + # Get the avgerage valid locations, of those we mask self.mask_ratio + valid_locations = (prob_density_flat > 0).to(dtype).sum(dim=1) + + masks = [] + for i in range(B): + num_samples = int(valid_locations[i] * mask_ratio) + mask = torch.zeros_like(prob_density_flat[i]) + # Sample indices + if num_samples > 0: + sampled_indices_flat = torch.multinomial( + prob_density_flat[i], num_samples, replacement=False + ) + mask.scatter_(0, sampled_indices_flat, 1) + masks.append(mask) + return torch.stack(masks).bool().view(B, 1, H, W) + + def random_sample(validity_mask): + prob_blocks = torch.rand( + size=(validity_mask.shape[0], 1, *down_size), device=device + ) + mask = torch.logical_and(prob_blocks < mask_ratio, validity_mask) + return mask + + fx = results["K"][..., 0, 0].view(-1, 1, 1, 1) / self.mask_patch + fy = results["K"][..., 1, 1].view(-1, 1, 1, 1) / self.mask_patch + + valid = ~results["ssi"] & ~results["si"] & results["valid_camera"] + mask_blocks = torch.zeros_like(validity_mask) + if valid.any(): + out = area_sample(results["depth"][valid], fx[valid], fy[valid]) + mask_blocks[valid] = out + if (~valid).any(): + mask_blocks[~valid] = random_sample(validity_mask[~valid]) + + # mask_blocks_ = (mask_blocks.float() * 255).squeeze(1).byte().cpu().numpy() + # Image.fromarray(mask_blocks_[0]).save("mask1.png") + # Image.fromarray(mask_blocks_[-1]).save("mask2.png") + # dd = results["depth"] + # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1_p.png") + # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2_p.png") + # dd = downsample(dd, dd.shape[-2] // down_size[0]) + # Image.fromarray(colorize(dd[0].squeeze().cpu().numpy())).save("depth1.png") + # Image.fromarray(colorize(dd[-1].squeeze().cpu().numpy())).save("depth2.png") + # raise ValueError + + return mask_blocks + + def guided_sampling(self, results, mask_ratio, down_size, validity_mask, device): + # get the lowest (based on guidance) "mask_ratio" quantile of the patches that are in validity mask + B = results["image"].shape[0] + guidance = results["guidance"] + mask_blocks = torch.zeros(size=(B, 1, *down_size), device=device) + for b in range(B): + low_bound = torch.quantile( + guidance[b][validity_mask[b]], max(0.0, self.hard_bound - mask_ratio) + ) + up_bound = torch.quantile( + guidance[b][validity_mask[b]], min(1.0, self.hard_bound) + ) + mask_blocks[b] = torch.logical_and( + guidance[b] < up_bound, guidance[b] > low_bound + ) + mask_blocks = torch.logical_and(mask_blocks, validity_mask) + return mask_blocks + + def step(self): + self.idx += 1 + # schedule hard from 1.0 to self.mask_ratio + if self.curriculum: + step = max(0, self.idx / self.warmup_steps / 2 - 0.5) + self.hard_bound = 1 - (1 - self.mask_ratio) * tanh(step) + self.up_bound = self.mask_ratio * tanh(step) + self.low_bound = 0.1 * tanh(step) + + @property + def is_warmup(self): + return self.idx < self.warmup_steps + + +class Resize: + def __init__(self, image_scale=None, image_shape=None, keep_original=False): + assert (image_scale is None) ^ (image_shape is None) + if isinstance(image_scale, (float, int)): + image_scale = (image_scale, image_scale) + if isinstance(image_shape, (float, int)): + image_shape = (int(image_shape), int(image_shape)) + self.image_scale = image_scale + self.image_shape = image_shape + self.keep_original = keep_original + + def _resize_img(self, results): + for key in results.get("image_fields", ["image"]): + img = TF.resize( + results[key], + results["resized_shape"], + interpolation=TF.InterpolationMode.BILINEAR, + antialias=True, + ) + results[key] = img + + def _resize_masks(self, results): + for key in results.get("mask_fields", []): + mask = TF.resize( + results[key], + results["resized_shape"], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + antialias=True, + ) + results[key] = mask + + def _resize_gt(self, results): + for key in results.get("gt_fields", []): + gt = TF.resize( + results[key], + results["resized_shape"], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + antialias=True, + ) + results[key] = gt + + def __call__(self, results): + h, w = results["image"].shape[-2:] + results["K_original"] = results["K"].clone() + if self.image_scale: + image_shape = ( + int(h * self.image_scale[0] + 0.5), + int(w * self.image_scale[1] + 0.5), + ) + image_scale = self.image_scale + elif self.image_shape: + image_scale = (self.image_shape[0] / h, self.image_shape[1] / w) + image_shape = self.image_shape + else: + raise ValueError( + f"In {self.__class__.__name__}: image_scale of image_shape must be set" + ) + + results["resized_shape"] = tuple(image_shape) + results["resize_factor"] = tuple(image_scale) + results["K"][..., 0, 2] = (results["K"][..., 0, 2] - 0.5) * image_scale[1] + 0.5 + results["K"][..., 1, 2] = (results["K"][..., 1, 2] - 0.5) * image_scale[0] + 0.5 + results["K"][..., 0, 0] = results["K"][..., 0, 0] * image_scale[1] + results["K"][..., 1, 1] = results["K"][..., 1, 1] * image_scale[0] + + self._resize_img(results) + if not self.keep_original: + self._resize_masks(results) + self._resize_gt(results) + + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + return repr_str + + +class Rotate: + def __init__( + self, angle, center=None, img_fill_val=(123.68, 116.28, 103.53), prob=0.5 + ): + if isinstance(img_fill_val, (float, int)): + img_fill_val = tuple([float(img_fill_val)] * 3) + elif isinstance(img_fill_val, tuple): + assert len(img_fill_val) == 3, ( + "image_fill_val as tuple must " + f"have 3 elements. got {len(img_fill_val)}." + ) + img_fill_val = tuple([float(val) for val in img_fill_val]) + else: + raise ValueError("image_fill_val must be float or tuple with 3 elements.") + assert np.all( + [0 <= val <= 255 for val in img_fill_val] + ), f"all elements of img_fill_val should between range [0,255] got {img_fill_val}." + assert 0 <= prob <= 1.0, f"The probability should be in range [0,1]bgot {prob}." + self.center = center + self.img_fill_val = img_fill_val + self.prob = prob + self.random = not isinstance(angle, (float, int)) + self.angle = angle + + def _rotate(self, results, angle, center=None, fill_val=0.0): + for key in results.get("image_fields", ["image"]): + img = results[key] + img_rotated = TF.rotate( + img, + angle, + center=center, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=self.img_fill_val, + ) + results[key] = img_rotated.to(img.dtype) + results["image_shape"] = results[key].shape + + for key in results.get("mask_fields", []): + results[key] = TF.rotate( + results[key], + angle, + center=center, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=fill_val, + ) + + for key in results.get("gt_fields", []): + results[key] = TF.rotate( + results[key], + angle, + center=center, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=fill_val, + ) + + def __call__(self, results): + """Call function to rotate images, bounding boxes, masks and semantic + segmentation maps. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Rotated results. + """ + if np.random.random() > self.prob: + return results + + angle = ( + (self.angle[1] - self.angle[0]) * np.random.rand() + self.angle[0] + if self.random + else np.random.choice([-1, 1], size=1) * self.angle + ) + self._rotate(results, angle, None, fill_val=0.0) + results["rotation"] = angle + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(angle={self.angle}, " + repr_str += f"center={self.center}, " + repr_str += f"image_fill_val={self.img_fill_val}, " + repr_str += f"prob={self.prob}, " + return repr_str + + +class RandomColor: + """Apply Color transformation to image. The bboxes, masks, and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Color transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_color_img(self, results, factor=1.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + results[key] = TF.adjust_hue(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + """Call function for Color transformation. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Colored results. + """ + if np.random.random() > self.prob: + return results + factor = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else self.level + ) + self._adjust_color_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomSaturation: + """Apply Color transformation to image. The bboxes, masks, and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Color transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_saturation_img(self, results, factor=1.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + # NOTE defaultly the image should be BGR format + results[key] = TF.adjust_saturation(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + """Call function for Color transformation. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Colored results. + """ + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_saturation_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomSharpness: + """Apply Color transformation to image. The bboxes, masks, and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Color transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_sharpeness_img(self, results, factor=1.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + # NOTE defaultly the image should be BGR format + results[key] = TF.adjust_sharpness(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + """Call function for Color transformation. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Colored results. + """ + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_sharpeness_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomSolarize: + """Apply Color transformation to image. The bboxes, masks, and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Color transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_solarize_img(self, results, factor=255.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + results[key] = TF.solarize(results[key], factor) # .to(img.dtype) + + def __call__(self, results): + """Call function for Color transformation. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Colored results. + """ + if np.random.random() > self.prob: + return results + factor = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else self.level + ) + self._adjust_solarize_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomPosterize: + """Apply Color transformation to image. The bboxes, masks, and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Color transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _posterize_img(self, results, factor=1.0): + """Apply Color transformation to image.""" + for key in results.get("image_fields", ["image"]): + results[key] = TF.posterize(results[key], int(factor)) # .to(img.dtype) + + def __call__(self, results): + """Call function for Color transformation. + + Args: + results (dict): Result dict from loading pipeline. + + Returns: + dict: Colored results. + """ + if np.random.random() > self.prob: + return results + factor = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else self.level + ) + self._posterize_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomEqualize: + """Apply Equalize transformation to image. The bboxes, masks and + segmentations are not modified. + + Args: + prob (float): The probability for performing Equalize transformation. + """ + + def __init__(self, prob=0.5): + assert 0 <= prob <= 1.0, "The probability should be in range [0,1]." + self.prob = prob + + def _imequalize(self, results): + """Equalizes the histogram of one image.""" + for key in results.get("image_fields", ["image"]): + results[key] = TF.equalize(results[key]) # .to(img.dtype) + + def __call__(self, results): + """Call function for Equalize transformation. + + Args: + results (dict): Results dict from loading pipeline. + + Returns: + dict: Results after the transformation. + """ + if np.random.random() > self.prob: + return results + self._imequalize(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(prob={self.prob})" + + +class RandomBrightness: + """Apply Brightness transformation to image. The bboxes, masks and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Brightness transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_brightness_img(self, results, factor=1.0): + """Adjust the brightness of image.""" + for key in results.get("image_fields", ["image"]): + results[key] = TF.adjust_brightness(results[key], factor) # .to(img.dtype) + + def __call__(self, results, level=None): + """Call function for Brightness transformation. + + Args: + results (dict): Results dict from loading pipeline. + + Returns: + dict: Results after the transformation. + """ + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_brightness_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomContrast: + """Apply Contrast transformation to image. The bboxes, masks and + segmentations are not modified. + + Args: + level (int | float): Should be in range [0,_MAX_LEVEL]. + prob (float): The probability for performing Contrast transformation. + """ + + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def _adjust_contrast_img(self, results, factor=1.0): + """Adjust the image contrast.""" + for key in results.get("image_fields", ["image"]): + results[key] = TF.adjust_contrast(results[key], factor) # .to(img.dtype) + + def __call__(self, results, level=None): + """Call function for Contrast transformation. + + Args: + results (dict): Results dict from loading pipeline. + + Returns: + dict: Results after the transformation. + """ + if np.random.random() > self.prob: + return results + factor = ( + 2 ** ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else 2**self.level + ) + self._adjust_contrast_img(results, factor) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class RandomGamma: + def __init__(self, level, prob=0.5): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + + def __call__(self, results, level=None): + """Call function for Contrast transformation. + + Args: + results (dict): Results dict from loading pipeline. + + Returns: + dict: Results after the transformation. + """ + if np.random.random() > self.prob: + return results + factor = (self.level[1] - self.level[0]) * np.random.rand() + self.level[0] + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = TF.adjust_gamma(results[key], 1 + factor) + return results + + +class RandomInvert: + def __init__(self, prob=0.5): + self.prob = prob + + def __call__(self, results): + if np.random.random() > self.prob: + return results + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = TF.invert(results[key]) # .to(img.dtype) + return results + + +class RandomAutoContrast: + def __init__(self, prob=0.5): + self.prob = prob + + def _autocontrast_img(self, results): + for key in results.get("image_fields", ["image"]): + img = results[key] + results[key] = TF.autocontrast(img) # .to(img.dtype) + + def __call__(self, results): + if np.random.random() > self.prob: + return results + self._autocontrast_img(results) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"prob={self.prob})" + return repr_str + + +class Dilation: + def __init__(self, origin, kernel, border_value=-1.0, iterations=1) -> None: + self.structured_element = torch.ones(size=kernel) + self.origin = origin + self.border_value = border_value + self.iterations = iterations + + def dilate(self, image): + image_pad = F.pad( + image, + [ + self.origin[0], + self.structured_element.shape[0] - self.origin[0] - 1, + self.origin[1], + self.structured_element.shape[1] - self.origin[1] - 1, + ], + mode="constant", + value=self.border_value, + ) + if image_pad.ndim < 4: + image_pad = image_pad.unsqueeze(0) + # Unfold the image to be able to perform operation on neighborhoods + image_unfold = F.unfold(image_pad, kernel_size=self.structured_element.shape) + # Flatten the structural element since its two dimensions have been flatten when unfolding + # structured_element_flatten = torch.flatten(self.structured_element).unsqueeze(0).unsqueeze(-1) + # Perform the greyscale operation; sum would be replaced by rest if you want erosion + # sums = image_unfold + structured_element_flatten + # Take maximum over the neighborhood + # since we use depth, we need to take the cloest point (perspectivity) + # thus the min. But min is for "unknown" (0), so put it to a large number + # than take min + + mask = image_unfold < 1e-3 # if == 0, some pixels are not involved, why? + + # Replace the zero elements with a large value, so they don't affect the minimum operation + image_unfold = image_unfold.masked_fill(mask, 1000.0) + + # Calculate the minimum along the neighborhood axis + dilate_image = torch.min(image_unfold, dim=1).values + + # Fill the masked values with 0 to propagate zero if all pixels are zero + dilate_image[mask.all(dim=1)] = 0 + return torch.reshape(dilate_image, image.shape) + + def __call__(self, results): + for key in results.get("gt_fields", []): + gt = results[key] + for _ in range(self.iterations): + gt[gt < 1e-4] = self.dilate(gt)[gt < 1e-4] + results[key] = gt + + return results + + +class RandomShear(object): + def __init__( + self, + level, + prob=0.5, + direction="horizontal", + ): + self.random = not isinstance(level, (float, int)) + self.level = level + self.prob = prob + self.direction = direction + + def _shear_img(self, results, magnitude): + for key in results.get("image_fields", ["image"]): + img_sheared = TF.affine( + results[key], + angle=0.0, + translate=[0.0, 0.0], + scale=1.0, + shear=magnitude, + interpolation=TF.InterpolationMode.BILINEAR, + fill=0.0, + ) + results[key] = img_sheared + + def _shear_masks(self, results, magnitude): + for key in results.get("mask_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=[0.0, 0.0], + scale=1.0, + shear=magnitude, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def _shear_gt( + self, + results, + magnitude, + ): + for key in results.get("gt_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=[0.0, 0.0], + scale=1.0, + shear=magnitude, + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def __call__(self, results): + if np.random.random() > self.prob: + return results + magnitude = ( + ((self.level[1] - self.level[0]) * np.random.rand() + self.level[0]) + if self.random + else np.random.choice([-1, 1], size=1) * self.level + ) + if self.direction == "horizontal": + magnitude = [magnitude, 0.0] + else: + magnitude = [0.0, magnitude] + self._shear_img(results, magnitude) + self._shear_masks(results, magnitude) + self._shear_gt(results, magnitude) + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(level={self.level}, " + repr_str += f"img_fill_val={self.img_fill_val}, " + repr_str += f"seg_ignore_label={self.seg_ignore_label}, " + repr_str += f"prob={self.prob}, " + repr_str += f"direction={self.direction}, " + repr_str += f"max_shear_magnitude={self.max_shear_magnitude}, " + repr_str += f"random_negative_prob={self.random_negative_prob}, " + repr_str += f"interpolation={self.interpolation})" + return repr_str + + +class RandomTranslate(object): + def __init__( + self, + range, + prob=0.5, + direction="horizontal", + ): + self.range = range + self.prob = prob + self.direction = direction + + def _translate_img(self, results, magnitude): + for key in results.get("image_fields", ["image"]): + img_sheared = TF.affine( + results[key], + angle=0.0, + translate=magnitude, + scale=1.0, + shear=[0.0, 0.0], + interpolation=TF.InterpolationMode.BILINEAR, + fill=(123.68, 116.28, 103.53), + ) + results[key] = img_sheared + + def _translate_mask(self, results, magnitude): + for key in results.get("mask_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=magnitude, + scale=1.0, + shear=[0.0, 0.0], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def _translate_gt( + self, + results, + magnitude, + ): + for key in results.get("gt_fields", []): + mask_sheared = TF.affine( + results[key], + angle=0.0, + translate=magnitude, + scale=1.0, + shear=[0.0, 0.0], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + fill=0.0, + ) + results[key] = mask_sheared + + def __call__(self, results): + if np.random.random() > self.prob: + return results + magnitude = (self.range[1] - self.range[0]) * np.random.rand() + self.range[0] + if self.direction == "horizontal": + magnitude = [magnitude * results["image"].shape[1], 0] + else: + magnitude = [0, magnitude * results["image"].shape[0]] + self._translate_img(results, magnitude) + self._translate_mask(results, magnitude) + self._translate_gt(results, magnitude) + results["K"][..., 0, 2] = results["K"][..., 0, 2] + magnitude[0] + results["K"][..., 1, 2] = results["K"][..., 1, 2] + magnitude[1] + return results + + def __repr__(self): + repr_str = self.__class__.__name__ + repr_str += f"(range={self.range}, " + repr_str += f"prob={self.prob}, " + repr_str += f"direction={self.direction}, " + return repr_str + + +class RandomCut(object): + def __init__(self, prob=0.5, direction="all"): + self.direction = direction + self.prob = prob + + def _cut_img(self, results, coord, dim): + for key in results.get("image_fields", ["image"]): + img_sheared = torch.roll( + results[key], int(coord * results[key].shape[dim]), dims=dim + ) + results[key] = img_sheared + + def _cut_mask(self, results, coord, dim): + for key in results.get("mask_fields", []): + mask_sheared = torch.roll( + results[key], int(coord * results[key].shape[dim]), dims=dim + ) + results[key] = mask_sheared + + def _cut_gt(self, results, coord, dim): + for key in results.get("gt_fields", []): + gt_sheared = torch.roll( + results[key], int(coord * results[key].shape[dim]), dims=dim + ) + results[key] = gt_sheared + + def __call__(self, results): + if np.random.random() > self.prob: + return results + coord = 0.8 * random.random() + 0.1 + if self.direction == "horizontal": + dim = -1 + elif self.direction == "vertical": + dim = -2 + else: + dim = -1 if random.random() < 0.5 else -2 + + self._cut_img(results, coord, dim) + self._cut_mask(results, coord, dim) + self._cut_gt(results, coord, dim) + return results + + +class DownsamplerGT(object): + def __init__(self, downsample_factor: int, min_depth: float = 0.01): + assert downsample_factor == round( + downsample_factor, 0 + ), f"Downsample factor needs to be an integer, got {downsample_factor}" + self.downsample_factor = downsample_factor + self.min_depth = min_depth + + def _downsample_gt(self, results): + for key in deepcopy(results.get("gt_fields", [])): + gt = results[key] + N, H, W = gt.shape + gt = gt.view( + N, + H // self.downsample_factor, + self.downsample_factor, + W // self.downsample_factor, + self.downsample_factor, + 1, + ) + gt = gt.permute(0, 1, 3, 5, 2, 4) + gt = gt.view(-1, self.downsample_factor * self.downsample_factor) + gt_tmp = torch.where(gt == 0.0, 1e5 * torch.ones_like(gt), gt) + gt = torch.min(gt_tmp, dim=-1).values + gt = gt.view(N, H // self.downsample_factor, W // self.downsample_factor) + gt = torch.where(gt > 1000, torch.zeros_like(gt), gt) + results[f"{key}_downsample"] = gt + results["gt_fields"].append(f"{key}_downsample") + results["downsampled"] = True + return results + + def __call__(self, results): + results = self._downsample_gt(results) + return results + + +class RandomColorJitter: + def __init__(self, level, prob=0.9): + self.level = level + self.prob = prob + self.list_transform = [ + self._adjust_brightness_img, + # self._adjust_sharpness_img, + self._adjust_contrast_img, + self._adjust_saturation_img, + self._adjust_color_img, + ] + + def _adjust_contrast_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_contrast(img, factor) + + def _adjust_sharpness_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_sharpness(img, factor) + + def _adjust_brightness_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_brightness(img, factor) + + def _adjust_saturation_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_saturation(img, factor / 2.0) + + def _adjust_color_img(self, results, factor=1.0): + for key in results.get("image_fields", ["image"]): + if "original" not in key: + img = results[key] + results[key] = TF.adjust_hue(img, (factor - 1.0) / 4.0) + + def __call__(self, results): + random.shuffle(self.list_transform) + for op in self.list_transform: + if np.random.random() < self.prob: + factor = 1.0 + ( + (self.level[1] - self.level[0]) * np.random.random() + self.level[0] + ) + op(results, factor) + return results + + +class RandomGrayscale: + def __init__(self, prob=0.1, num_output_channels=3): + super().__init__() + self.prob = prob + self.num_output_channels = num_output_channels + + def __call__(self, results): + if np.random.random() > self.prob: + return results + + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = TF.rgb_to_grayscale( + results[key], num_output_channels=self.num_output_channels + ) + return results + + +class ContextCrop(Resize): + def __init__( + self, + image_shape, + keep_original=False, + test_min_ctx=1.0, + train_ctx_range=[0.5, 1.5], + shape_constraints={}, + ): + super().__init__(image_shape=image_shape, keep_original=keep_original) + self.test_min_ctx = test_min_ctx + self.train_ctx_range = train_ctx_range + + self.shape_mult = shape_constraints["shape_mult"] + self.sample = shape_constraints["sample"] + self.ratio_bounds = shape_constraints["ratio_bounds"] + pixels_min = shape_constraints["pixels_min"] / ( + self.shape_mult * self.shape_mult + ) + pixels_max = shape_constraints["pixels_max"] / ( + self.shape_mult * self.shape_mult + ) + self.pixels_bounds = (pixels_min, pixels_max) + self.keepGT = int(os.environ.get("keepGT", 0)) + self.ctx = None + + def _transform_img(self, results, shapes): + for key in results.get("image_fields", ["image"]): + img = self.crop(results[key], **shapes) + img = TF.resize( + img, + results["resized_shape"], + interpolation=TF.InterpolationMode.BICUBIC, + antialias=True, + ) + results[key] = img + + def _transform_masks(self, results, shapes): + for key in results.get("mask_fields", []): + mask = self.crop(results[key].float(), **shapes).byte() + if "flow" in key: # take pad/crop into flow resize + mask = TF.resize( + mask, + results["resized_shape"], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + antialias=False, + ) + else: + mask = masked_nearest_interpolation( + mask, mask > 0, results["resized_shape"] + ) + results[key] = mask + + def _transform_gt(self, results, shapes): + for key in results.get("gt_fields", []): + gt = self.crop(results[key], **shapes) + if not self.keepGT: + if "flow" in key: # take pad/crop into flow resize + gt = self._rescale_flow(gt, results) + gt = TF.resize( + gt, + results["resized_shape"], + interpolation=TF.InterpolationMode.NEAREST_EXACT, + antialias=False, + ) + else: + gt = masked_nearest_interpolation( + gt, gt > 0, results["resized_shape"] + ) + + results[key] = gt + + def _rescale_flow(self, gt, results): + h_new, w_new = gt.shape[-2:] + h_old, w_old = results["image_ori_shape"] + gt[:, 0] = gt[:, 0] * (w_old - 1) / (w_new - 1) + gt[:, 1] = gt[:, 1] * (h_old - 1) / (h_new - 1) + return gt + + @staticmethod + def crop(img, height, width, top, left) -> torch.Tensor: + h, w = img.shape[-2:] + right = left + width + bottom = top + height + padding_ltrb = [ + max(-left + min(0, right), 0), + max(-top + min(0, bottom), 0), + max(right - max(w, left), 0), + max(bottom - max(h, top), 0), + ] + image_cropped = img[..., max(top, 0) : bottom, max(left, 0) : right] + return TF.pad(image_cropped, padding_ltrb) + + def test_closest_shape(self, image_shape): + h, w = image_shape + input_ratio = w / h + if self.sample: + input_pixels = int(ceil(h / self.shape_mult * w / self.shape_mult)) + pixels = max( + min(input_pixels, self.pixels_bounds[1]), self.pixels_bounds[0] + ) + ratio = min(max(input_ratio, self.ratio_bounds[0]), self.ratio_bounds[1]) + h = round((pixels / ratio) ** 0.5) + w = h * ratio + self.image_shape[0] = int(h) * self.shape_mult + self.image_shape[1] = int(w) * self.shape_mult + + def _get_crop_shapes(self, image_shape, ctx=None): + h, w = image_shape + input_ratio = w / h + if self.keep_original: + self.test_closest_shape(image_shape) + ctx = 1.0 + elif ctx is None: + ctx = float( + torch.empty(1) + .uniform_(self.train_ctx_range[0], self.train_ctx_range[1]) + .item() + ) + output_ratio = self.image_shape[1] / self.image_shape[0] + + if output_ratio <= input_ratio: # out like 4:3 in like kitti + if ( + ctx >= 1 + ): # fully in -> use just max_length with sqrt(ctx), here max is width + new_w = w * ctx**0.5 + # sporge un po in una sola dim + # we know that in_width will stick out before in_height, partial overshoot (sporge) + # new_h > old_h via area -> new_h ** 2 * ratio_new = old_h ** 2 * ratio_old * ctx + elif output_ratio / input_ratio * ctx > 1: + new_w = w * ctx + else: # fully contained -> use area + new_w = w * (ctx * output_ratio / input_ratio) ** 0.5 + new_h = new_w / output_ratio + else: + if ctx >= 1: + new_h = h * ctx**0.5 + elif input_ratio / output_ratio * ctx > 1: + new_h = h * ctx + else: + new_h = h * (ctx * input_ratio / output_ratio) ** 0.5 + new_w = new_h * output_ratio + return (int(ceil(new_h - 0.5)), int(ceil(new_w - 0.5))), ctx + + # def sample_view(self, results): + # original_K = results["K"] + # original_image = results["image"] + # original_depth = results["depth"] + # original_validity_mask = results["validity_mask"].float() + # # sample angles and translation + # # sample translation: + # # 10 max of z + + # x = np.random.normal(0, 0.05 / 2) * original_depth.max() + # y = np.random.normal(0, 0.05) + # z = np.random.normal(0, 0.05) * original_depth.max() + + # fov = 2 * np.arctan(original_image.shape[-2] / 2 / results["K"][0, 0, 0]) + # phi = np.random.normal(0, fov / 10) + # theta = np.random.normal(0, fov / 10) + # psi = np.random.normal(0, np.pi / 60) + # translation = torch.tensor([x, y, z]).unsqueeze(0) + # angles = torch.tensor([phi, theta, psi]) + # angles = euler_to_rotation_matrix(angles) + # translation = translation @ angles # translation before rotation + + # cam2w = torch.eye(4).unsqueeze(0) + # cam2w[..., :3, :3] = angles + # cam2w[..., :3, 3] = translation + # cam2cam = torch.inverse(cam2w) + # image_warped, depth_warped = forward_warping(original_image, original_depth, original_K, original_K, cam2cam=cam2cam) + # depth_warped[depth_warped > 0] = depth_warped[depth_warped > 0] - z + # validity_mask_warped = image_warped.sum(dim=1, keepdim=True) > 0.0 + + # results["K"] = results["K"].repeat(2, 1, 1) + # results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), cam2w]) + # results["image"] = torch.cat([original_image, image_warped]) + # results["depth"] = torch.cat([original_depth, depth_warped]) + # results["validity_mask"] = torch.cat([original_validity_mask, validity_mask_warped], dim=0) + + # # results["cam2w"] = torch.cat([torch.eye(4).unsqueeze(0), torch.eye(4).unsqueeze(0)]) + # # results["image"] = torch.cat([original_image, original_image]) + # # results["depth"] = torch.cat([original_depth, original_depth]) + # # results["validity_mask"] = torch.cat([original_validity_mask, original_validity_mask], dim=0) + # return results + + def __call__(self, results): + h, w = results["image"].shape[-2:] + results["image_ori_shape"] = (h, w) + results["camera_fields"].add("camera_original") + results["camera_original"] = results["camera"].clone() + + results.get("mask_fields", set()).add("validity_mask") + if "validity_mask" not in results: + results["validity_mask"] = torch.ones( + (results["image"].shape[0], 1, h, w), + dtype=torch.uint8, + device=results["image"].device, + ) + + n_iter = 1 if self.keep_original or not self.sample else 100 + + min_valid_area = 0.5 + max_hfov, max_vfov = results["camera"].max_fov[0] # it is a 1-dim list + ctx = None + for ii in range(n_iter): + + (height, width), ctx = self._get_crop_shapes((h, w), ctx=self.ctx or ctx) + margin_h = h - height + margin_w = w - width + + # keep it centered in y direction + top = margin_h // 2 + left = margin_w // 2 + if not self.keep_original: + left = left + np.random.randint( + -self.shape_mult // 2, self.shape_mult // 2 + 1 + ) + top = top + np.random.randint( + -self.shape_mult // 2, self.shape_mult // 2 + 1 + ) + + right = left + width + bottom = top + height + x_zoom = self.image_shape[0] / height + paddings = [ + max(-left + min(0, right), 0), + max(bottom - max(h, top), 0), + max(right - max(w, left), 0), + max(-top + min(0, bottom), 0), + ] + + valid_area = ( + h + * w + / (h + paddings[1] + paddings[3]) + / (w + paddings[0] + paddings[2]) + ) + new_hfov, new_vfov = results["camera_original"].get_new_fov( + new_shape=(height, width), original_shape=(h, w) + )[0] + # if valid_area >= min_valid_area or getattr(self, "ctx", None) is not None: + # break + if ( + valid_area >= min_valid_area + and new_hfov < max_hfov + and new_vfov < max_vfov + ): + break + ctx = ( + ctx * 0.96 + ) # if not enough valid area, try again with less ctx (more zoom) + + # save ctx for next iteration of sequences? + self.ctx = ctx + + results["resized_shape"] = self.image_shape + results["paddings"] = paddings # left ,top ,right, bottom + results["image_rescale"] = x_zoom + results["scale_factor"] = results.get("scale_factor", 1.0) * x_zoom + results["camera"] = results["camera"].crop( + left, top, right=w - right, bottom=h - bottom + ) + results["camera"] = results["camera"].resize(x_zoom) + + # print("XAM", results["camera"].params.squeeze(), results["camera"][0].params.squeeze(), results["camera_original"].params.squeeze(), results["camera_original"][0].params.squeeze()) + + shapes = dict(height=height, width=width, top=top, left=left) + self._transform_img(results, shapes) + if not self.keep_original: + self._transform_gt(results, shapes) + self._transform_masks(results, shapes) + else: + # only validity_mask (rgb's masks follows rgb transform) #FIXME + mask = results["validity_mask"].float() + mask = self.crop(mask, **shapes).byte() + mask = TF.resize( + mask, + results["resized_shape"], + interpolation=TF.InterpolationMode.NEAREST, + ) + results["validity_mask"] = mask + + # # print(ii, ctx, results["camera"].hfov[0] * 180 / np.pi, original_hfov * 180 / np.pi, results["camera"].vfov[0] * 180 / np.pi, original_vfov * 180 / np.pi, valid_area) + # from PIL import Image + # from unik3d.utils.visualization import colorize + # img1 = results["image"][0].permute(1,2,0).clip(0, 255.0).cpu().numpy() + # # img2 = results["image"][1].permute(1,2,0).clip(0, 255.0).cpu().numpy() + # Image.fromarray(img1.astype(np.uint8)).save("test_col1.png") + # # Image.fromarray(img2.astype(np.uint8)).save("test_col2.png") + # Image.fromarray(colorize(results["depth"][0].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep1.png") + # # Image.fromarray(colorize(results["depth"][1].cpu().numpy().squeeze(), 0.0, 10.0)).save("test_dep2.png") + # raise ValueError + + # keep original images before photo-augment + results["image_original"] = results["image"].clone() + results["image_fields"].add( + *[ + field.replace("image", "image_original") + for field in results["image_fields"] + ] + ) + + # repeat for batch resized shape and paddings + results["paddings"] = [results["paddings"]] * results["image"].shape[0] + results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[ + 0 + ] + return results + + +class RandomFiller: + def __init__(self, test_mode, *args, **kwargs): + super().__init__() + self.test_mode = test_mode + + def _transform(self, results): + def fill_noise(size, device): + return torch.normal(0, 2.0, size=size, device=device) + + def fill_black(size, device): + return -4 * torch.ones(size, device=device, dtype=torch.float32) + + def fill_white(size, device): + return 4 * torch.ones(size, device=device, dtype=torch.float32) + + def fill_zero(size, device): + return torch.zeros(size, device=device, dtype=torch.float32) + + B, C = results["image"].shape[:2] + mismatch = B // results["validity_mask"].shape[0] + if mismatch: + results["validity_mask"] = results["validity_mask"].repeat( + mismatch, 1, 1, 1 + ) + validity_mask = results["validity_mask"].repeat(1, C, 1, 1).bool() + filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero]) + if self.test_mode: + filler_fn = fill_zero + for key in results.get("image_fields", ["image"]): + results[key][~validity_mask] = filler_fn( + size=results[key][~validity_mask].shape, device=results[key].device + ) + + def __call__(self, results): + # generate mask for filler + if "validity_mask" not in results: + paddings = results.get("padding_size", [0] * 4) + height, width = results["image"].shape[-2:] + results.get("mask_fields", set()).add("validity_mask") + results["validity_mask"] = torch.zeros_like(results["image"][:, :1]) + results["validity_mask"][ + ..., + paddings[1] : height - paddings[3], + paddings[0] : width - paddings[2], + ] = 1.0 + self._transform(results) + return results + + +class GaussianBlur: + def __init__(self, kernel_size, sigma=(0.1, 2.0), prob=0.9): + super().__init__() + self.kernel_size = kernel_size + self.sigma = sigma + self.prob = prob + self.padding = kernel_size // 2 + + def apply(self, x, kernel): + # Pad the input tensor + x = F.pad( + x, (self.padding, self.padding, self.padding, self.padding), mode="reflect" + ) + # Apply the convolution with the Gaussian kernel + return F.conv2d(x, kernel, stride=1, padding=0, groups=x.size(1)) + + def _create_kernel(self, sigma): + # Create a 1D Gaussian kernel + kernel_1d = torch.exp( + -torch.arange(-self.padding, self.padding + 1) ** 2 / (2 * sigma**2) + ) + kernel_1d = kernel_1d / kernel_1d.sum() + + # Expand the kernel to 2D and match size of the input + kernel_2d = kernel_1d.unsqueeze(0) * kernel_1d.unsqueeze(1) + kernel_2d = kernel_2d.view(1, 1, self.kernel_size, self.kernel_size).expand( + 3, 1, -1, -1 + ) + return kernel_2d + + def __call__(self, results): + if np.random.random() > self.prob: + return results + sigma = (self.sigma[1] - self.sigma[0]) * np.random.rand() + self.sigma[0] + kernel = self._create_kernel(sigma) + for key in results.get("image_fields", ["image"]): + if "original" not in key: + results[key] = self.apply(results[key], kernel) + return results + + +class MotionBlur: + def __init__(self, kernel_size=(9, 9), angles=(-180, 180), prob=0.1): + super().__init__() + self.kernel_size = kernel_size + self.angles = angles + self.prob = prob + self.padding = kernel_size // 2 + + def _create_kernel(self, angle): + # Generate a 2D grid of coordinates + grid = torch.meshgrid( + torch.arange(self.kernel_size), torch.arange(self.kernel_size) + ) + grid = torch.stack(grid).float() # Shape: (2, kernel_size, kernel_size) + + # Calculate relative coordinates from the center + center = (self.kernel_size - 1) / 2.0 + x_offset = grid[1] - center + y_offset = grid[0] - center + + # Compute motion blur kernel + cos_theta = torch.cos(angle * torch.pi / 180.0) + sin_theta = torch.sin(angle * torch.pi / 180.0) + kernel = (1.0 / self.kernel_size) * ( + 1.0 - torch.abs(x_offset * cos_theta + y_offset * sin_theta) + ) + + # Expand kernel dimensions to match input image channels + kernel = kernel.unsqueeze(0).unsqueeze(0).expand(3, 1, -1, -1) + return kernel + + def apply(self, image, kernel): + x = F.pad( + x, (self.padding, self.padding, self.padding, self.padding), mode="reflect" + ) + # Apply convolution with the motion blur kernel + blurred_image = F.conv2d(image, kernel, stride=1, padding=0, groups=x.size(1)) + return blurred_image + + def __call__(self, results): + if np.random.random() > self.prob: + return results + + angle = np.random.uniform(self.angles[0], self.angles[1]) + kernel = self._create_kernel(angle) + for key in results.get("image_fields", ["image"]): + if "original" in key: + continue + results[key] = self.apply(results[key], kernel) + + return results + + +class JPEGCompression: + def __init__(self, level=(10, 70), prob=0.1): + super().__init__() + self.level = level + self.prob = prob + + def __call__(self, results): + if np.random.random() > self.prob: + return results + + level = np.random.uniform(self.level[0], self.level[1]) + for key in results.get("image_fields", ["image"]): + if "original" in key: + continue + results[key] = TF.jpeg(results[key], level) + + return results + + +class Compose: + def __init__(self, transforms): + self.transforms = deepcopy(transforms) + + def __call__(self, results): + for t in self.transforms: + results = t(results) + return results + + def __setattr__(self, name: str, value) -> None: + super().__setattr__(name, value) + for t in self.transforms: + setattr(t, name, value) + + def __repr__(self): + format_string = self.__class__.__name__ + "(" + for t in self.transforms: + format_string += f"\n {t}" + format_string += "\n)" + return format_string + + +class DummyCrop(Resize): + def __init__( + self, + *args, + **kwargs, + ): + # dummy image shape, not really used + super().__init__(image_shape=(512, 512)) + + def __call__(self, results): + h, w = results["image"].shape[-2:] + results["image_ori_shape"] = (h, w) + results["camera_fields"].add("camera_original") + results["camera_original"] = results["camera"].clone() + results.get("mask_fields", set()).add("validity_mask") + if "validity_mask" not in results: + results["validity_mask"] = torch.ones( + (results["image"].shape[0], 1, h, w), + dtype=torch.uint8, + device=results["image"].device, + ) + + self.ctx = 1.0 + + results["resized_shape"] = self.image_shape + results["paddings"] = [0, 0, 0, 0] + results["image_rescale"] = 1.0 + results["scale_factor"] = results.get("scale_factor", 1.0) * 1.0 + results["camera"] = results["camera"].crop(0, 0, right=w, bottom=h) + results["camera"] = results["camera"].resize(1) + + # keep original images before photo-augment + results["image_original"] = results["image"].clone() + results["image_fields"].add( + *[ + field.replace("image", "image_original") + for field in results["image_fields"] + ] + ) + + # repeat for batch resized shape and paddings + results["paddings"] = [results["paddings"]] * results["image"].shape[0] + results["resized_shape"] = [results["resized_shape"]] * results["image"].shape[ + 0 + ] + return results diff --git a/unik3d/datasets/point_odyssey.py b/unik3d/datasets/point_odyssey.py new file mode 100644 index 0000000000000000000000000000000000000000..e3cc3de5419518fa1a397ca757abd2175bc33248 --- /dev/null +++ b/unik3d/datasets/point_odyssey.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class PointOdyssey(SequenceDataset): + min_depth = 0.01 + max_depth = 250.0 + depth_scale = 1000.0 + test_split = "test.txt" + train_split = "train.txt" + sequences_file = "sequences_clean.json" + hdf5_paths = [f"PointOdyssey.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/proteus.py b/unik3d/datasets/proteus.py new file mode 100644 index 0000000000000000000000000000000000000000..726fb3e9a5f764fc2c1f605252f923e98eae941a --- /dev/null +++ b/unik3d/datasets/proteus.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Proteus(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + default_fps = 5 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["Proteus.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/samplers.py b/unik3d/datasets/samplers.py new file mode 100644 index 0000000000000000000000000000000000000000..a0a33a032c0fc346f256fcd81959a431727cb51c --- /dev/null +++ b/unik3d/datasets/samplers.py @@ -0,0 +1,242 @@ +import itertools +import warnings +from operator import itemgetter +from typing import Any, Optional + +import numpy as np +import torch +from torch.utils.data import Sampler + +from unik3d.utils import get_dist_info + + +def _get_numpy_dtype(size: int) -> Any: + return np.int32 if size <= 2**31 else np.int64 + + +def _get_torch_dtype(size: int) -> Any: + return torch.int32 if size <= 2**31 else torch.int64 + + +def _generate_randperm_indices(*, size: int, generator: torch.Generator): + """Generate the indices of a random permutation.""" + dtype = _get_torch_dtype(size) + # This is actually matching PyTorch's CPU implementation, see: https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/native/TensorFactories.cpp#L900-L921 + perm = torch.arange(size, dtype=dtype) + for i in range(size): + j = torch.randint(i, size, size=(1,), generator=generator).item() + + # Always swap even if no-op + value = perm[j].item() + perm[j] = perm[i].item() + perm[i] = value + yield value + + +# The following function is somewhat equivalent to _new_shuffle_tensor_slice below, +# but avoids a full in-place random permutation generation. +def _shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + + dtype = _get_numpy_dtype(stop) + result = np.empty(count, dtype=dtype) + + for i in range(count): + j = ( + torch.randint(0, i + 1, size=(1,), generator=generator).item() + if i > 0 + else 0 + ) + + result[i] = result[j] + result[j] = tensor[start + i * step].item() + + return result + + +def _new_shuffle_tensor_slice( + *, tensor: torch.Tensor, start: int = 0, step: int = 1, generator: torch.Generator +) -> np.ndarray: + stop = len(tensor) + count = stop // step + dtype = torch.int64 # Needed for using randperm result as indices + count = stop // step + drop_count = stop - step * count + if drop_count: + warnings.warn(f"# of dropped samples: {drop_count}") + indices = torch.randperm(count, dtype=dtype, generator=generator) + return tensor[start::step][indices].numpy() + + +def _make_seed(seed: int, start: int, iter_count: int) -> int: + # NOTE: Tried a few variants (including iter_count << 32), this one worked best. + return seed + start + (iter_count << 24) + + +class ShardedInfiniteSampler(Sampler): + def __init__( + self, + *, + sample_count: int, + shuffle: bool = False, + seed: int = 0, + start: Optional[int] = None, + step: Optional[int] = None, + advance: int = 0, + use_new_shuffle_tensor_slice: bool = False, + ): + self._sample_count = sample_count + self._seed = seed + self._shuffle = shuffle + rank, world_size = get_dist_info() + self._start = rank if start is None else start + self._step = world_size if step is None else step + self._advance = advance + self._iter_count = 0 + self._shuffle_tensor_slice_fn = ( + _new_shuffle_tensor_slice + if use_new_shuffle_tensor_slice + else _shuffle_tensor_slice + ) + + def __iter__(self): + iter_count = self._advance // self._sample_count + if iter_count > 0: + self._advance -= iter_count * self._sample_count + self._iter_count += iter_count + + if self._shuffle: + iterator = self._shuffled_iterator() + else: + iterator = self._iterator() + + yield from itertools.islice(iterator, self._advance, None) + + def _iterator(self): + assert not self._shuffle + + while True: + iterable = range(self._sample_count) + yield from itertools.islice(iterable, self._start, None, self._step) + + def _shuffled_iterator(self): + assert self._shuffle + + # Instantiate a generator here (rather than in the ctor) to be keep the class + # picklable (requirement of mp.spawn) + generator = torch.Generator() + + # Always shuffle everything first + generator.manual_seed(self._seed) + dtype = _get_torch_dtype(self._sample_count) + perm = torch.randperm(self._sample_count, dtype=dtype, generator=generator) + + while True: + # Re-seed on each iteration to allow skipping whole permutations + seed = _make_seed(self._seed, self._start, self._iter_count) + generator.manual_seed(seed) + + iterable = self._shuffle_tensor_slice_fn( + tensor=perm, start=self._start, step=self._step, generator=generator + ) + yield from iterable + self._iter_count += 1 + + +class DistributedSamplerNoDuplicate(torch.utils.data.DistributedSampler): + """A distributed sampler that doesn't add duplicates. Arguments are the same as DistributedSampler""" + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + if not self.drop_last and len(self.dataset) % self.num_replicas != 0: + # some ranks may have less samples, that's fine + if self.rank >= len(self.dataset) % self.num_replicas: + self.num_samples -= 1 + self.total_size = len(self.dataset) + + +class DatasetFromSampler(torch.utils.data.Dataset): + """Dataset to create indexes from `Sampler`. + + Args: + sampler: PyTorch sampler + """ + + def __init__(self, sampler: Sampler): + """Initialisation for DatasetFromSampler.""" + self.sampler = sampler + self.sampler_list = None + + def __getitem__(self, index: int): + """Gets element of the dataset. + + Args: + index: index of the element in the dataset + + Returns: + Single element by index + """ + if self.sampler_list is None: + self.sampler_list = list(self.sampler) + return self.sampler_list[index] + + def __len__(self) -> int: + """ + Returns: + int: length of the dataset + """ + return len(self.sampler) + + +class DistributedSamplerWrapper(torch.utils.data.DistributedSampler): + """ + Wrapper over `Sampler` for distributed training + Allows you to use any sampler in distributed mode. + + It is especially useful in conjunction with + `torch.nn.parallel.DistributedDataParallel`. In such case, each + process can pass a DistributedSamplerWrapper instance as a DataLoader + sampler, and load a subset of subsampled data of the original dataset + that is exclusive to it. + + .. note:: + Sampler is assumed to be of constant size. + """ + + def __init__( + self, + sampler, + num_replicas: Optional[int] = None, + rank: Optional[int] = None, + shuffle: bool = True, + ): + """ + + Args: + sampler: Sampler used for subsampling + num_replicas (int, optional): Number of processes participating in + distributed training + rank (int, optional): Rank of the current process + within ``num_replicas`` + shuffle (bool, optional): If true (default), + sampler will shuffle the indices + """ + super(DistributedSamplerWrapper, self).__init__( + DatasetFromSampler(sampler), + num_replicas=num_replicas, + rank=rank, + shuffle=shuffle, + ) + self.sampler = sampler + + def __iter__(self): + self.dataset = DatasetFromSampler(self.sampler) + indexes_of_indexes = super().__iter__() + subsampler_indexes = self.dataset + return iter(itemgetter(*indexes_of_indexes)(subsampler_indexes)) diff --git a/unik3d/datasets/scannet.py b/unik3d/datasets/scannet.py new file mode 100644 index 0000000000000000000000000000000000000000..178c601df320a4b7b9ca9136ee0b00f650d4812a --- /dev/null +++ b/unik3d/datasets/scannet.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class ScanNet(SequenceDataset): + min_depth = 0.005 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "test.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["ScanNetS.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/scannetpp.py b/unik3d/datasets/scannetpp.py new file mode 100644 index 0000000000000000000000000000000000000000..ebd1f5b77699b486cec9e61137341b63cf03278c --- /dev/null +++ b/unik3d/datasets/scannetpp.py @@ -0,0 +1,98 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class ScanNetpp(SequenceDataset): + min_depth = 0.001 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "val_iphone.txt" + train_split = "train_iphone.txt" + sequences_file = "sequences_iphone_clean.json" + hdf5_paths = [f"ScanNetpp_viz.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results + + +class ScanNetpp_F(SequenceDataset): + min_depth = 0.001 + max_depth = 10.0 + depth_scale = 1000.0 + train_split = "train.txt" + test_split = "val_split.txt" + + sequences_file = "sequences_split.json" + hdf5_paths = [f"ScanNetpp_F.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=( + decode_fields if not test_mode else [*decode_fields, "points"] + ), + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/sequence_dataset.py b/unik3d/datasets/sequence_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..1c0f383f43a88e3c08e36e3cc9827a20cdeaffcc --- /dev/null +++ b/unik3d/datasets/sequence_dataset.py @@ -0,0 +1,301 @@ +import json +import os +from functools import partial +from typing import Any, Dict, Tuple + +import h5py +import numpy as np +import tables +import torch +import torchvision.transforms.v2.functional as TF + +from unik3d.datasets.base_dataset import BaseDataset +from unik3d.datasets.utils import DatasetFromList +from unik3d.datasets.utils_decode import (decode_camera, decode_depth, + decode_flow, decode_K, decode_mask, + decode_numpy, decode_rgb, + decode_tensor) +from unik3d.utils.distributed import is_main_process + + +class SequenceDataset(BaseDataset): + DECODE_FNS = { + "image": partial(decode_rgb, name="image"), + "points": partial(decode_numpy, name="points"), + "K": partial(decode_K, name="camera"), + "camera_params": partial(decode_camera, name="camera"), + "cam2w": partial(decode_tensor, name="cam2w"), + "depth": partial(decode_depth, name="depth"), + "flow_fwd": partial(decode_flow, name="flow_fwd"), + "flow_bwd": partial(decode_flow, name="flow_bwd"), + "flow_fwd_mask": partial(decode_mask, name="flow_fwd_mask"), + "flow_bwd_mask": partial(decode_mask, name="flow_bwd_mask"), + } + default_fps = 5 + + def __init__( + self, + image_shape: Tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: Dict[str, Any], + shape_constraints: Dict[str, Any], + resize_method: str, + mini: float, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + shape_constraints=shape_constraints, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.num_frames = num_frames + self.original_num_frames = num_frames + self.decode_fields = decode_fields + self.inplace_fields = inplace_fields + self.fps = self.default_fps + self.fps_range = kwargs.get("fps_range", None) + if self.fps_range is not None: + self.fps_range[1] = min(self.default_fps, self.fps_range[1]) + + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii").strip() + sequences = np.array(h5file[self.sequences_file]).tostring().decode("ascii") + sequences = json.loads(sequences) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + if len(line.strip().split(" ")) == 1: + print(line) + continue + sequence_name, num_samples = line.strip().split(" ") + dataset.append( + { + "sequence_name": sequence_name, + "num_samples": int(num_samples), + "chunk_idx": 0, + } + ) + + # filter dataset based on attr "invalid_sequences" + invalid_sequences = getattr(self, "invalid_sequences", []) + dataset = [ + sample + for sample in dataset + if sample["sequence_name"] not in invalid_sequences + ] + + self.dataset = DatasetFromList(dataset) + self.sequences = DatasetFromList( + [sequences[sample["sequence_name"]] for sample in dataset] + ) + self.log_load_dataset() + + def get_random_idxs(self, num_samples_sequence): + if self.num_frames == 1: + return [np.random.randint(0, num_samples_sequence)], 0 + + # Check if we can satisfy the required number of frames + if self.num_frames > num_samples_sequence: + raise ValueError( + "Cannot sample more frames than available in the sequence." + ) + + # Restrict FPS range to be within default FPS + min_fps, max_fps = self.fps_range + max_fps = min(max_fps, self.default_fps) + if min_fps > self.default_fps: + sampled_fps = self.default_fps + else: + # Compute minimal viable FPS + min_required_fps = ( + self.num_frames / num_samples_sequence + ) * self.default_fps + min_fps = max(min_fps, min_required_fps) + + # Sample an FPS from the viable range + sampled_fps = np.random.uniform(min_fps, max_fps) + + # Compute the stride based on the sampled FPS + stride = self.default_fps / sampled_fps + max_start_index = num_samples_sequence - int(stride * (self.num_frames - 1)) + + # Ensure a valid starting position + if max_start_index <= 0: + raise ValueError( + "No valid start position allows sampling num_frames with the chosen FPS." + ) + + start_index = np.random.randint(0, max_start_index + 1) + + # Compute indices based on the sampled FPS + indices = [int(start_index + i * stride) for i in range(self.num_frames)] + + return indices, np.random.randint(0, len(indices)) + + def get_test_idxs(self, num_samples_sequence, keyframe_idx): + if self.num_frames == 1: + return [ + keyframe_idx if keyframe_idx is not None else num_samples_sequence // 2 + ], 0 + + if self.num_frames == -1: + cap_idxs = min(32, num_samples_sequence) # CAP 32 images + idxs = list( + range(max(0, num_samples_sequence - cap_idxs), num_samples_sequence, 1) + ) + return idxs, keyframe_idx + + # pick closest keyframe_idx st they are around it or capped by the 0 and max num_samples_sequence + keyframe_idx = ( + keyframe_idx if keyframe_idx is not None else num_samples_sequence - 1 + ) + excess_tail = 0 - min(0, keyframe_idx - self.num_frames // 2) + excess_head = ( + max(num_samples_sequence, keyframe_idx + (self.num_frames - 1) // 2) + - num_samples_sequence + ) + start = keyframe_idx - self.num_frames // 2 + excess_tail - excess_head + end = keyframe_idx + (self.num_frames - 1) // 2 + excess_head - excess_tail + idxs = list(range(start, 1 + end)) + + return idxs, idxs.index(keyframe_idx) + + def get_single_sequence(self, idx): + self.num_frames = self.original_num_frames + # sequence_name = self.dataset[idx]["sequence_name"] + sample = self.sequences[idx] + chunk_idx = int(sample.get("chunk_idx", 0)) + h5_path = os.path.join(self.data_root, self.hdf5_paths[chunk_idx]) + + num_samples_sequence = len(sample["image"]) + if self.num_frames > 0 and num_samples_sequence < self.num_frames: + raise IndexError(f"Sequence {idx} has less than {self.num_frames} frames") + keyframe_idx = None + + if not self.test_mode: + idxs, keyframe_idx = self.get_random_idxs(num_samples_sequence) + else: + idxs, keyframe_idx = self.get_test_idxs( + num_samples_sequence, sample.get("keyframe_idx", None) + ) + + self.num_frames = len(idxs) + results = {} + results = self.pre_pipeline(results) + results["sequence_fields"] = [(i, 0) for i in range(self.num_frames)] + results["keyframe_idx"] = keyframe_idx + with tables.File( + h5_path, + mode="r", + libver="latest", + swmr=True, + ) as h5file_chunk: + + for i, j in enumerate(idxs): + results[(i, 0)] = { + k: v.copy() for k, v in results.items() if "fields" in k + } + for inplace_field in self.inplace_fields: + inplace_field_ = inplace_field.replace("intrinsics", "K") + inplace_field_ = inplace_field_.replace("extrinsics", "cam2w") + if inplace_field_ == "cam2w": + # take care of missing cam2w -> assume 1 frame and assume identity + if inplace_field_ not in sample: + sample[inplace_field] = [None] * num_samples_sequence + if sample[inplace_field][j] is None: + sample[inplace_field][j] = np.eye(4, dtype=np.float32) + results = self.DECODE_FNS[inplace_field_]( + results, sample[inplace_field][j], idx=i, sample=sample, j=j + ) + + for i, j in enumerate(idxs): + for decode_field in self.decode_fields: + results = self.DECODE_FNS[decode_field]( + results, + h5file_chunk, + sample[decode_field][j], + idx=i, + depth_scale=self.depth_scale, + ) + + results["filename"] = sample["image"][j] + + results = self.preprocess(results) + if not self.test_mode: + results = self.augment(results) + results = self.postprocess(results) + return results + + def preprocess(self, results): + results = self.replicate(results) + for i, seq in enumerate(results["sequence_fields"]): + results[seq] = self.resizer(results[seq]) + self.resizer.ctx = None if self.num_copies > 1 else self.resizer.ctx + num_pts = torch.count_nonzero(results[seq]["depth"] > 0) + if num_pts < 50: + raise IndexError(f"Too few points in depth map ({num_pts})") + + for key in results[seq].get("image_fields", ["image"]): + results[seq][key] = results[seq][key].to(torch.float32) / 255 + + # update fields common in sequence + for key in [ + "image_fields", + "gt_fields", + "mask_fields", + "camera_fields", + ]: + if key in results[(0, 0)]: + results[key] = results[(0, 0)][key] + + results = self.pack_batch(results) + return results + + def postprocess(self, results): + # # normalize after because color aug requires [0,255]? + for key in results.get("image_fields", ["image"]): + results[key] = TF.normalize(results[key], **self.normalization_stats) + results = self.filler(results) + results = self.unpack_batch(results) + results = self.masker(results) + results = self.collecter(results) + return results + + def __getitem__(self, idx): + try: + if isinstance(idx, (list, tuple)): + results = [self.get_single_sequence(i) for i in idx] + else: + results = self.get_single_sequence(idx) + except Exception as e: + print(f"Error loading sequence {idx} for {self.__class__.__name__}: {e}") + idx = np.random.randint(0, len(self.dataset)) + results = self[idx] + return results + + def log_load_dataset(self): + if is_main_process(): + info = f"Loaded {self.__class__.__name__} with {sum([len(x['image']) for x in self.sequences])} images in {len(self)} sequences." + print(info) diff --git a/unik3d/datasets/sintel.py b/unik3d/datasets/sintel.py new file mode 100644 index 0000000000000000000000000000000000000000..8b2d0eb34986c81609f42b1e14fa7dd2bd29d7df --- /dev/null +++ b/unik3d/datasets/sintel.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Sintel(SequenceDataset): + min_depth = 0.001 + max_depth = 1000.0 + depth_scale = 1000.0 + test_split = "training.txt" + train_split = "training.txt" + sequences_file = "sequences.json" + hdf5_paths = ["Sintel.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/sunrgbd.py b/unik3d/datasets/sunrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..43cc98a4bb25511501010195dae195a0396a6d91 --- /dev/null +++ b/unik3d/datasets/sunrgbd.py @@ -0,0 +1,73 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class SUNRGBD(ImageDataset): + min_depth = 0.005 + max_depth = 8.0 + depth_scale = 1000.0 + test_split = "alltest.txt" + train_split = "alltrain.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["SUNRGB.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() diff --git a/unik3d/datasets/synscapes.py b/unik3d/datasets/synscapes.py new file mode 100644 index 0000000000000000000000000000000000000000..dff078984285ca1172622d516becbee8248bf0e5 --- /dev/null +++ b/unik3d/datasets/synscapes.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Synscapes(SequenceDataset): + min_depth = 0.1 + max_depth = 1000.0 + depth_scale = 256.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"Synscapes.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/tartanair.py b/unik3d/datasets/tartanair.py new file mode 100644 index 0000000000000000000000000000000000000000..dcd24d4a200f4800ebb51f8e9f742215bc541a46 --- /dev/null +++ b/unik3d/datasets/tartanair.py @@ -0,0 +1,51 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class TartanAir(SequenceDataset): + min_depth = 0.01 + max_depth = 512.0 + depth_scale = 1000.0 + default_fps = 15 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["TartanAir.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/taskonomy.py b/unik3d/datasets/taskonomy.py new file mode 100644 index 0000000000000000000000000000000000000000..69880eacce0289496d9522b233f061feba84f0b3 --- /dev/null +++ b/unik3d/datasets/taskonomy.py @@ -0,0 +1,91 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class Taskonomy(ImageDataset): + min_depth = 0.005 + max_depth = 15.0 + depth_scale = 512.0 + test_split = "val.txt" + train_split = "train_clean.txt" + intrisics_file = "intrinsics.json" + hdf5_paths = ["Taskonomy.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii") # [:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + # with open(os.path.join(os.environ["TMPDIR"], self.split_file), "w") as f: + # f.write(txt_string) + # with open(os.path.join(os.environ["TMPDIR"], self.intrisics_file), "w") as f: + # json.dump(intrinsics, f) + + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename, chunk_idx = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val, chunk_idx] + dataset.append(sample) + h5file.close() + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + if self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=0.01) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def get_mapper(self): + return { + "image_filename": 0, + "depth_filename": 1, + "K": 2, + } + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/tat_rmvd.py b/unik3d/datasets/tat_rmvd.py new file mode 100644 index 0000000000000000000000000000000000000000..08072cb2d4873afebd7f6efc42b40f6edadbf6d4 --- /dev/null +++ b/unik3d/datasets/tat_rmvd.py @@ -0,0 +1,63 @@ +import json +import os +from copy import deepcopy +from typing import Any + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.pipelines import AnnotationMask, KittiCrop +from unik3d.datasets.sequence_dataset import SequenceDataset +from unik3d.datasets.utils import DatasetFromList +from unik3d.utils import identity + + +class TATRMVD(SequenceDataset): + min_depth = 0.001 + max_depth = 50.0 + depth_scale = 1000.0 + default_fps = 6 + test_split = "test.txt" + train_split = "test.txt" + sequences_file = "sequences.json" + hdf5_paths = ["tanks_and_temples_rmvd.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_frames * self.num_copies + results["si"] = [True] * self.num_frames * self.num_copies + results["quality"] = [2] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/theo.py b/unik3d/datasets/theo.py new file mode 100644 index 0000000000000000000000000000000000000000..9be1b2f3f245a98e582a4359398acb102a9d5d51 --- /dev/null +++ b/unik3d/datasets/theo.py @@ -0,0 +1,66 @@ +from typing import Any + +import torch + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Theo(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + default_fps = 5 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["THEO.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["camera_params", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def preprocess(self, results): + self.resizer.ctx = None + for i, seq in enumerate(results["sequence_fields"]): + # Create a mask where the distance from the center is less than H/2 + H, W = results[seq]["image"].shape[-2:] + x = torch.linspace(-(W - 1) / 2, (W - 1) / 2, W) + y = torch.linspace(-(H - 1) / 2, (H - 1) / 2, H) + xv, yv = torch.meshgrid(x, y, indexing="xy") + distance_from_center = torch.sqrt(xv**2 + yv**2).reshape(1, 1, H, W) + results[seq]["validity_mask"] = distance_from_center < (H - 1) / 2 + + return super().preprocess(results) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/unrealstereo4k.py b/unik3d/datasets/unrealstereo4k.py new file mode 100644 index 0000000000000000000000000000000000000000..64520dff8ae473c2027340663cd759eed4acb75a --- /dev/null +++ b/unik3d/datasets/unrealstereo4k.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class UnrealStereo4K(SequenceDataset): + min_depth = 0.01 + max_depth = 200.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"UnrealStereo4K.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/urbansyn.py b/unik3d/datasets/urbansyn.py new file mode 100644 index 0000000000000000000000000000000000000000..b3f72c77cf7fc1ca1981cd7cb09aa97f7701cb2a --- /dev/null +++ b/unik3d/datasets/urbansyn.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class UrbanSyn(SequenceDataset): + min_depth = 0.1 + max_depth = 1000.0 + depth_scale = 256.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"UrbanSyn.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/utils.py b/unik3d/datasets/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..2994594d53f48665c72096d870be56e3796fbdb8 --- /dev/null +++ b/unik3d/datasets/utils.py @@ -0,0 +1,229 @@ +import copy +import multiprocessing as mp +import pickle +import random +from collections import defaultdict +from typing import Any, Dict, List + +import numpy as np +import torch +import torch.utils.data + +from unik3d.utils.distributed import (all_gather, get_local_rank, + get_local_size, get_rank, get_world_size) + + +class ConcatDataset(torch.utils.data.ConcatDataset): + def __init__(self, datasets, shape_constraints: dict[str, list[int]] = {}): + super().__init__(datasets) + + self.sample = shape_constraints["sample"] + self.shape_mult = shape_constraints["shape_mult"] + self.ratio_bounds = shape_constraints["ratio_bounds"] + self.pixels_max = float(shape_constraints["pixels_max"]) + self.pixels_min = float(shape_constraints["pixels_min"]) + + self.height_min = shape_constraints["height_min"] + self.width_min = shape_constraints["width_min"] + + def sample_shape(self): + if not self.sample: + return + # 1: sample image ratio + ratio = np.random.uniform(*self.ratio_bounds) + pixels_min = self.pixels_min // (self.shape_mult * self.shape_mult) + pixels_max = self.pixels_max // (self.shape_mult * self.shape_mult) + # 2: sample image height or width, if ratio > 1 or < 1 + if ratio > 1: + height_min = max(self.height_min, np.sqrt(pixels_min / ratio)) + height = np.random.uniform(height_min, np.sqrt(pixels_max / ratio)) + width = height * ratio + else: + width_min = max(self.width_min, np.sqrt(pixels_min * ratio)) + width = np.random.uniform(width_min, np.sqrt(pixels_max * ratio)) + height = width / ratio + # 3: get final shape based on the shape_mult + shape = [int(height) * self.shape_mult, int(width) * self.shape_mult] + for dataset in self.datasets: + setattr(dataset, "image_shape", shape) + setattr(dataset.resizer, "image_shape", shape) + + def __getitem__(self, idxs): + self.sample_shape() + samples = [super(ConcatDataset, self).__getitem__(idx) for idx in idxs] + return samples + + +def _paddings(image_shape, network_shape): + cur_h, cur_w = image_shape + h, w = network_shape + pad_top, pad_bottom = (h - cur_h) // 2, h - cur_h - (h - cur_h) // 2 + pad_left, pad_right = (w - cur_w) // 2, w - cur_w - (w - cur_w) // 2 + return pad_left, pad_right, pad_top, pad_bottom + + +def collate_fn(in_data: List[List[Dict[str, Any]]], is_batched: bool = True): + out_data = defaultdict(list) + img_metas = [] + in_data = in_data[0] if is_batched else in_data + + # get max_shape and paddings + shapes = [tensor.shape[-2:] for x in in_data for tensor in x["depth"].values()] + max_shape_tuple = tuple(max(elements) for elements in zip(*shapes)) + paddings = [ + [ + _paddings(tensor.shape[-2:], max_shape_tuple) + for tensor in x["depth"].values() + ] + for x in in_data + ] + + for x in in_data: # here iter over batches + padding = paddings.pop(0) + for k, v in x.items(): + if "img_metas" not in k: + values = list(v.values()) + v = torch.cat(values) + out_data[k].append(v) + else: + v["depth_paddings"] = padding + img_metas.append(v) + + # calculate all valid_samples in batch as depth_mask.sum().long() and then append to out_data as "valid_samples" + num_valid_samples = [ + x.flatten(1).sum(dim=1).long() for i, x in enumerate(out_data["depth_mask"]) + ] + out_data["num_valid_depth"] = [ + torch.stack([x, sum(num_valid_samples) * torch.ones_like(x)], dim=-1) + for x in num_valid_samples + ] + num_valid_samples = [ + x.flatten(1).sum(dim=1).long() for i, x in enumerate(out_data["validity_mask"]) + ] + out_data["num_valid_pix"] = [ + torch.stack([x, sum(num_valid_samples) * torch.ones_like(x)], dim=-1) + for x in num_valid_samples + ] + output_dict = { + "data": {k: torch.stack(v) for k, v in out_data.items()}, + "img_metas": img_metas, + } + if "camera" in output_dict["data"]: + output_dict["data"]["camera"] = output_dict["data"]["camera"].reshape( + *output_dict["data"]["image"].shape[:2] + ) + return output_dict + + +def local_scatter(array: list[Any]): + if get_world_size() == 1: + return array[0] + if get_local_rank() == 0: + assert len(array) == get_local_size() + all_gather(array) + else: + all_data = all_gather(None) + array = all_data[get_rank() - get_local_rank()] + return array[get_local_rank()] + + +class DatasetFromList(torch.utils.data.Dataset): # type: ignore + """Wrap a list to a torch Dataset. + + We serialize and wrap big python objects in a torch.Dataset due to a + memory leak when dealing with large python objects using multiple workers. + See: https://github.com/pytorch/pytorch/issues/13246 + """ + + def __init__(self, lst: List[Any], deepcopy: bool = False, serialize: bool = True): + self._copy = deepcopy + self._serialize = serialize + + def _serialize(data: Any): + buffer = pickle.dumps(data, protocol=pickle.HIGHEST_PROTOCOL) + return torch.frombuffer(buffer, dtype=torch.uint8) + + if self._serialize: + # load only on 0th rank + if get_local_rank() == 0: + _lst = [_serialize(x) for x in lst] + self._addr = torch.cumsum( + torch.tensor([len(x) for x in _lst], dtype=torch.int64), dim=0 + ) + self._lst = torch.concatenate(_lst) + # Move data to shared memory, obtain a handle to send to each local worker. + handles = [None] + [ + bytes(mp.reduction.ForkingPickler.dumps((self._addr, self._lst))) + for _ in range(get_local_size() - 1) + ] + else: + handles = None + + # Each worker receives the handle from local leader (rank 0) + # then materialize the tensor from shared memory + handle = local_scatter(handles) + if get_local_rank() > 0: + self._addr, self._lst = mp.reduction.ForkingPickler.loads(handle) + + else: + self._lst = lst + + def __len__(self) -> int: + if self._serialize: + return len(self._addr) + return len(self._lst) + + def __getitem__(self, idx: int) -> Any: + if self._serialize: + start_addr = 0 if idx == 0 else self._addr[idx - 1] + end_addr = self._addr[idx] + bytes_ = memoryview(self._lst[start_addr:end_addr].numpy()) + return pickle.loads(bytes_) + if self._copy: + return copy.deepcopy(self._lst[idx]) + + return self._lst[idx] + + +def get_weights( + train_datasets: dict[str, torch.utils.data.Dataset], sampling: dict[str, float] +) -> torch.Tensor: + from .image_dataset import ImageDataset + from .sequence_dataset import SequenceDataset + + weights = [] + num_samples = 0 + info_weights = {} + for dataset_name, dataset in train_datasets.items(): + assert ( + dataset_name in sampling + ), f"Dataset {dataset_name} not found in {sampling.keys()}" + + if isinstance(dataset, ImageDataset): + # sum of all samples has weight as in sampling s.t. sampling dataset in general + # is as in sampling inside is uniform + weight = sampling[dataset_name] / len(dataset) + weights.append(torch.full((len(dataset),), weight).double()) + num_samples += len(dataset) + + elif isinstance(dataset, SequenceDataset): + # local weight is num_samples, but global must be as in sampling + # hence is num_samples / (sum num_samples / sampling[dataset_name]) + # s.t. sampling anything from the dataset is + # sum(num_samples / (sum num_samples / sampling[dataset_name])) + # -> sampling[dataset_name] + numerator = [int(data["num_samples"]) for data in dataset.dataset] + weights.append( + sampling[dataset_name] + * torch.tensor(numerator).double() + / sum(numerator) + ) + num_samples += sum(numerator) + + else: + weight = sampling[dataset_name] / len(dataset) + weights.append(torch.full((len(dataset),), weight).double()) + + info_weights[dataset_name] = weights[-1][-1] + + return torch.cat(weights), num_samples diff --git a/unik3d/datasets/utils_decode.py b/unik3d/datasets/utils_decode.py new file mode 100644 index 0000000000000000000000000000000000000000..9a628a708c9cfa2b1ac18614546dcdeb23f3678a --- /dev/null +++ b/unik3d/datasets/utils_decode.py @@ -0,0 +1,122 @@ +import io + +import cv2 +import numpy as np +import torch +import torchvision +import torchvision.transforms.v2.functional as TF +from PIL import Image + +from unik3d.utils.camera import (EUCM, MEI, OPENCV, BatchCamera, Fisheye624, + Pinhole, Spherical) + + +def decode_depth(results, h5file, value, idx, depth_scale, name="depth", **kwargs): + file = h5file.get_node("/" + value).read() + decoded_data = Image.open(io.BytesIO(file)) + decoded_data = TF.pil_to_tensor(decoded_data).squeeze() + + if decoded_data.ndim == 3: # 24 channel loading + decoded_channels = [ + (decoded_data[0] & 0xFF).to(torch.int32), + (decoded_data[1] & 0xFF).to(torch.int32), + (decoded_data[2] & 0xFF).to(torch.int32), + ] + # Reshape and extract the original depth map + decoded_data = ( + decoded_channels[0] + | (decoded_channels[1] << 8) + | (decoded_channels[2] << 16) + ) + + decoded_data = decoded_data.to(torch.float32) + results.get("gt_fields", set()).add(name) + results[(idx, 0)].get("gt_fields", set()).add(name) + results[f"{name}_ori_shape"] = decoded_data.shape + results[(idx, 0)][name] = ( + decoded_data.view(1, 1, *decoded_data.shape).contiguous() / depth_scale + ) + return results + + +def decode_numpy(results, h5file, value, idx, name="points", **kwargs): + file = h5file.get_node("/" + value).read() + decoded_data = np.load(io.BytesIO(file), allow_pickle=False) + decoded_data = torch.from_numpy(decoded_data).to(torch.float32) + if decoded_data.ndim > 2: + decoded_data = decoded_data.permute(2, 0, 1) + results.get("gt_fields", set()).add(name) + results[(idx, 0)].get("gt_fields", set()).add(name) + results[(idx, 0)][name] = decoded_data.unsqueeze(0) + return results + + +def decode_tensor(results, value, idx, name, **kwargs): + results.get("camera_fields", set()).add(name) + results[(idx, 0)].get("camera_fields", set()).add(name) + results[(idx, 0)][name] = torch.tensor(value).unsqueeze(0) + return results + + +def decode_camera(results, value, idx, name, sample, j, **kwargs): + results.get("camera_fields", set()).add(name) + results[(idx, 0)].get("camera_fields", set()).add(name) + camera = eval(sample["camera_model"][j])(params=torch.tensor(value).unsqueeze(0)) + results[(idx, 0)][name] = BatchCamera.from_camera(camera) + return results + + +def decode_K(results, value, idx, name, **kwargs): + results.get("camera_fields", set()).add(name) + results[(idx, 0)].get("camera_fields", set()).add(name) + camera = Pinhole(K=torch.tensor(value).unsqueeze(0)) + results[(idx, 0)][name] = BatchCamera.from_camera(camera) + return results + + +def decode_mask(results, h5file, value, idx, name, **kwargs): + file = h5file.get_node("/" + value).read() + mask = torchvision.io.decode_image(torch.from_numpy(file)).bool().squeeze() + results.get("mask_fields", set()).add(name) + results[(idx, 0)].get("mask_fields", set()).add(name) + results[f"{name}_ori_shape"] = mask.shape[-2:] + results[(idx, 0)][name] = mask.view(1, 1, *mask.shape).contiguous() + return results + + +def decode_rgb(results, h5file, value, idx, name="image", **kwargs): + file = h5file.get_node("/" + value).read() + image = ( + torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze() + ) + results.get("image_fields", set()).add(name) + results[(idx, 0)].get("image_fields", set()).add(name) + results[f"{name}_ori_shape"] = image.shape[-2:] + if image.ndim == 2: + image = image.unsqueeze(0).repeat(3, 1, 1) + results[(idx, 0)][name] = image.unsqueeze(0) + return results + + +def decode_flow(results, h5file, value, idx, name, **kwargs): + file = h5file.get_node("/" + value).read() + image = ( + torchvision.io.decode_image(torch.from_numpy(file)).to(torch.uint8).squeeze() + ) + decoded_channels = [ + (image[0] & 0xFF).to(torch.int16), + (image[1] & 0xFF).to(torch.int16), + (image[2] & 0xFF).to(torch.int16), + ] + + # Reshape and extract the original 2-channel flow map + flow = torch.zeros((2, image.shape[1], image.shape[2]), dtype=torch.int16) + flow[0] = (decoded_channels[0] | decoded_channels[1] << 8) & 0xFFF + flow[1] = (decoded_channels[1] >> 4 | decoded_channels[2] << 4) & 0xFFF + + results.get("gt_fields", set()).add(name) + results[(idx, 0)].get("gt_fields", set()).add(name) + results[f"{name}_ori_shape"] = flow.shape[-2:] + flow = flow.unsqueeze(0).contiguous().float() + results[(idx, 0)][name] = (0.5 + flow) / 4095.0 * 2 - 1 + return results diff --git a/unik3d/datasets/vkitti.py b/unik3d/datasets/vkitti.py new file mode 100644 index 0000000000000000000000000000000000000000..5864b068808d0049837cab16b158684065a824c2 --- /dev/null +++ b/unik3d/datasets/vkitti.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class VKITTI(SequenceDataset): + min_depth = 0.01 + max_depth = 255.0 + depth_scale = 256.0 + test_split = "training.txt" + train_split = "training.txt" + sequences_file = "sequences.json" + hdf5_paths = ["VKITTI2.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth", "flow_fwd", "flow_fwd_mask"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["synthetic"] = [True] * self.num_frames * self.num_copies + results["quality"] = [0] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/void.py b/unik3d/datasets/void.py new file mode 100644 index 0000000000000000000000000000000000000000..f77f7450979fa2aa9ef577e235ad3a854662eaec --- /dev/null +++ b/unik3d/datasets/void.py @@ -0,0 +1,79 @@ +import json +import os + +import h5py +import numpy as np +import torch + +from unik3d.datasets.image_dataset import ImageDataset +from unik3d.datasets.utils import DatasetFromList + + +class VOID(ImageDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 256.0 + test_split = "void_val.txt" + train_split = "void_train.txt" + intrisics_file = "void_intrinsics.json" + hdf5_paths = ["void.hdf5"] + + def __init__( + self, + image_shape, + split_file, + test_mode, + crop=None, + benchmark=False, + augmentations_db={}, + normalize=True, + resize_method="hard", + mini=1.0, + **kwargs, + ): + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + **kwargs, + ) + self.test_mode = test_mode + + self.crop = crop + self.load_dataset() + + def load_dataset(self): + h5file = h5py.File( + os.path.join(self.data_root, self.hdf5_paths[0]), + "r", + libver="latest", + swmr=True, + ) + txt_file = np.array(h5file[self.split_file]) + txt_string = txt_file.tostring().decode("ascii")[:-1] # correct the -1 + intrinsics = np.array(h5file[self.intrisics_file]).tostring().decode("ascii") + intrinsics = json.loads(intrinsics) + h5file.close() + dataset = [] + for line in txt_string.split("\n"): + image_filename, depth_filename = line.strip().split(" ") + intrinsics_val = torch.tensor(intrinsics[image_filename]).squeeze()[:, :3] + sample = [image_filename, depth_filename, intrinsics_val] + dataset.append(sample) + + if not self.test_mode: + dataset = self.chunk(dataset, chunk_dim=1, pct=self.mini) + + self.dataset = DatasetFromList(dataset) + self.log_load_dataset() + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_copies + results["quality"] = [2] * self.num_copies + return results diff --git a/unik3d/datasets/waymo.py b/unik3d/datasets/waymo.py new file mode 100644 index 0000000000000000000000000000000000000000..63e34ed7c741df4942f4b6a8ac6f2ac3e4fe36a6 --- /dev/null +++ b/unik3d/datasets/waymo.py @@ -0,0 +1,50 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class Waymo(SequenceDataset): + min_depth = 0.05 + max_depth = 70.0 + depth_scale = 256.0 + test_split = "validation.txt" + train_split = "training.txt" + sequences_file = "sequences.json" + hdf5_paths = [f"Waymo_viz.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [False] * self.num_frames * self.num_copies + results["synthetic"] = [False] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/datasets/wildrgbd.py b/unik3d/datasets/wildrgbd.py new file mode 100644 index 0000000000000000000000000000000000000000..db863daa395a3b33c0c8617db467ac6b43e0b166 --- /dev/null +++ b/unik3d/datasets/wildrgbd.py @@ -0,0 +1,49 @@ +from typing import Any + +from unik3d.datasets.sequence_dataset import SequenceDataset + + +class WildRGBD(SequenceDataset): + min_depth = 0.01 + max_depth = 10.0 + depth_scale = 1000.0 + test_split = "train.txt" + train_split = "train.txt" + sequences_file = "sequences.json" + hdf5_paths = ["WildRGBD.hdf5"] + + def __init__( + self, + image_shape: tuple[int, int], + split_file: str, + test_mode: bool, + normalize: bool, + augmentations_db: dict[str, Any], + resize_method: str, + mini: float = 1.0, + num_frames: int = 1, + benchmark: bool = False, + decode_fields: list[str] = ["image", "depth"], + inplace_fields: list[str] = ["K", "cam2w"], + **kwargs, + ) -> None: + super().__init__( + image_shape=image_shape, + split_file=split_file, + test_mode=test_mode, + benchmark=benchmark, + normalize=normalize, + augmentations_db=augmentations_db, + resize_method=resize_method, + mini=mini, + num_frames=num_frames, + decode_fields=decode_fields, + inplace_fields=inplace_fields, + **kwargs, + ) + + def pre_pipeline(self, results): + results = super().pre_pipeline(results) + results["dense"] = [True] * self.num_frames * self.num_copies + results["quality"] = [1] * self.num_frames * self.num_copies + return results diff --git a/unik3d/layers/__init__.py b/unik3d/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..96d3c6ddce2dda1eb482e876766c65ad8216bbf2 --- /dev/null +++ b/unik3d/layers/__init__.py @@ -0,0 +1,20 @@ +from .activation import GEGLU, SwiGLU +from .attention import AttentionBlock, AttentionDecoderBlock, AttentionLayer +from .grad_choker import GradChoker +from .mlp import MLP +from .positional_encoding import PositionEmbeddingSine +from .upsample import ResUpsample, ResUpsampleBil, ResUpsampleSH + +__all__ = [ + "SwiGLU", + "GEGLU", + "AttentionBlock", + "AttentionLayer", + "PositionEmbeddingSine", + "MLP", + "AttentionDecoderBlock", + "ResUpsample", + "ResUpsampleSH", + "ResUpsampleBil", + "GradChoker", +] diff --git a/unik3d/layers/activation.py b/unik3d/layers/activation.py new file mode 100644 index 0000000000000000000000000000000000000000..f5787a340013ba59e2956b6b829f724d9cfb7fcc --- /dev/null +++ b/unik3d/layers/activation.py @@ -0,0 +1,15 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class SwiGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.silu(gates) + + +class GEGLU(nn.Module): + def forward(self, x: torch.Tensor) -> torch.Tensor: + x, gates = x.chunk(2, dim=-1) + return x * F.gelu(gates) diff --git a/unik3d/layers/attention.py b/unik3d/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..0e4cec63e6202c9d515c3308ef550f6223776966 --- /dev/null +++ b/unik3d/layers/attention.py @@ -0,0 +1,378 @@ +from functools import partial + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange + +from .layer_scale import LayerScale +from .mlp import MLP + + +class SimpleAttention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + dropout: float = 0.0, + cosine: bool = False, + context_dim: int | None = None, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = context_dim or dim + + self.kv = nn.Linear(context_dim, dim * 2, bias=False) + self.q = nn.Linear(dim, dim, bias=False) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim, bias=False) + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + +class AttentionBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + detach_query: bool = False, + residual_ls: bool = False, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + context_dim = dim if context_dim is None else context_dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv = nn.Linear(context_dim, dim * 2, bias=False) + self.q = nn.Linear(dim, dim, bias=False) + self.norm_attnx = nn.LayerNorm(dim) + self.norm_attnctx = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out = nn.Linear(dim, dim, bias=False) + self.ls1_1 = ( + LayerScale(dim, layer_scale) + if layer_scale > 0.0 and not residual_ls + else nn.Identity() + ) + self.ls1_2 = ( + LayerScale(dim, layer_scale) + if layer_scale > 0.0 and residual_ls + else nn.Identity() + ) + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.detach_query = detach_query + + def attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + if self.detach_query: + x = x.detach() + x = self.norm_attnx(x) + context = self.norm_attnctx(context) + k, v = rearrange( + self.kv(context), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3) + k = rope(k.permute(0, 2, 1, 3), input_pos=rope_pos).permute(0, 2, 1, 3) + else: + if pos_embed is not None: + pos_embed = rearrange( + pos_embed, "b n (h d) -> b h n d", h=self.num_heads + ) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=self.num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out(x) + return x + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + attn_bias: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = self.ls1_1( + self.attn( + x, + rope=rope, + rope_pos=rope_pos, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + self.ls1_2(x) + x = self.ls2(self.mlp(x)) + x + return x + + +class AttentionLayer(nn.Module): + def __init__( + self, + num_blocks: int, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + detach_query: bool = False, + residual_ls: bool = False, + ): + super().__init__() + self.layers = nn.ModuleList( + [ + AttentionBlock( + dim=dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + cosine=cosine, + gated=gated, + layer_scale=layer_scale, + context_dim=context_dim, + detach_query=detach_query, + residual_ls=residual_ls, + ) + for _ in range(num_blocks) + ] + ) + + def forward( + self, + x: torch.Tensor, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + attn_bias: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + for layer in self.layers: + x = layer( + x, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + attn_bias=attn_bias, + rope=rope, + rope_pos=rope_pos, + ) + return x + + +class AttentionDecoderBlock(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 4, + expansion: int = 4, + dropout: float = 0.0, + cosine: bool = False, + gated: bool = False, + layer_scale: float = 1.0, + context_dim: int | None = None, + single_head_ca: bool = True, + ): + super().__init__() + self.dropout = dropout + self.num_heads = num_heads + self.hidden_dim = dim + self.single_head_ca = single_head_ca + context_dim = context_dim or dim + self.mlp = MLP(dim, expansion=expansion, dropout=dropout, gated=gated) + self.kv_ca = nn.Linear(context_dim, dim * 2, bias=False) + self.q_ca = nn.Linear(dim, dim, bias=False) + self.kv_sa = nn.Linear(dim, dim * 2, bias=False) + self.q_sa = nn.Linear(dim, dim, bias=False) + self.norm_x_sa = nn.LayerNorm(dim) + self.norm_x_ca = nn.LayerNorm(dim) + self.norm_ctx_ca = nn.LayerNorm(context_dim) + self.cosine = cosine + self.out_ca = nn.Linear(dim, dim, bias=False) + self.out_sa = nn.Linear(dim, dim, bias=False) + self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls2 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + self.ls3 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + + def cross_attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + num_heads = 1 if self.single_head_ca else self.num_heads + x = self.norm_x_ca(x) + context = self.norm_ctx_ca(context) + k, v = rearrange( + self.kv_ca(context), "b n (kv h d) -> b h n d kv", h=num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q_ca(x), "b n (h d) -> b h n d", h=num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + else: + if pos_embed is not None: + pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=num_heads) + q = q + pos_embed + if pos_embed_context is not None: + pos_embed_context = rearrange( + pos_embed_context, "b n (h d) -> b h n d", h=num_heads + ) + k = k + pos_embed_context + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out_ca(x) + return x + + def self_attn( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + x = self.norm_x_sa(x) + k, v = rearrange( + self.kv_sa(x), "b n (kv h d) -> b h n d kv", h=self.num_heads, kv=2 + ).unbind(dim=-1) + q = rearrange(self.q_sa(x), "b n (h d) -> b h n d", h=self.num_heads) + + if rope is not None: + q = rope(q) + k = rope(k) + elif pos_embed is not None: + pos_embed = rearrange(pos_embed, "b n (h d) -> b h n d", h=self.num_heads) + q = q + pos_embed + + if self.cosine: + q, k = map(partial(F.normalize, p=2, dim=-1), (q, k)) # cosine sim + x = F.scaled_dot_product_attention( + q, k, v, dropout_p=self.dropout, attn_mask=attn_bias + ) + x = rearrange(x, "b h n d -> b n (h d)") + x = self.out_sa(x) + return x + + def forward( + self, + x: torch.Tensor, + attn_bias: torch.Tensor | None = None, + context: torch.Tensor | None = None, + pos_embed: torch.Tensor | None = None, + pos_embed_context: torch.Tensor | None = None, + rope: nn.Module | None = None, + rope_pos: torch.Tensor | None = None, + ) -> torch.Tensor: + context = x if context is None else context + x = ( + self.ls1( + self.cross_attn( + x, + rope=rope, + attn_bias=attn_bias, + context=context, + pos_embed=pos_embed, + pos_embed_context=pos_embed_context, + ) + ) + + x + ) + x = ( + self.ls2( + self.self_attn(x, rope=rope, attn_bias=attn_bias, pos_embed=pos_embed) + ) + + x + ) + x = self.ls3(self.mlp(x)) + x + return x diff --git a/unik3d/layers/convnext.py b/unik3d/layers/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..d186885fb8499b98f2fe09e5dd5150a2f9f14b04 --- /dev/null +++ b/unik3d/layers/convnext.py @@ -0,0 +1,80 @@ +import torch +import torch.nn as nn + + +class CvnxtBlock(nn.Module): + def __init__( + self, + dim, + kernel_size=7, + layer_scale=1.0, + expansion=4, + dilation=1, + padding_mode: str = "zeros", + ): + super().__init__() + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + groups=dim, + dilation=dilation, + padding_mode=padding_mode, + ) # depthwise conv + self.norm = nn.LayerNorm(dim) + self.pwconv1 = nn.Linear(dim, expansion * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(expansion * dim, dim) + self.gamma = ( + nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1)) + if layer_scale > 0.0 + else 1.0 + ) + self.skip_add = nn.quantized.FloatFunctional() + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + return self.skip_add.add(self.gamma * x.permute(0, 3, 1, 2), input) + + +class SimpleCvnxtBlock(nn.Module): + def __init__( + self, + dim, + output_dim=None, + kernel_size=7, + expansion=4, + dilation=1, + padding_mode: str = "zeros", + ): + super().__init__() + output_dim = output_dim if output_dim is not None else dim + self.dwconv = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + groups=dim, + dilation=dilation, + padding_mode=padding_mode, + ) # depthwise conv + self.norm = nn.LayerNorm(dim) + self.pwconv1 = nn.Linear(dim, expansion * dim) + self.act = nn.GELU() + self.pwconv2 = nn.Linear(expansion * dim, output_dim) + + def forward(self, x): + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.pwconv2(x) + return x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) diff --git a/unik3d/layers/drop_path.py b/unik3d/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..781ff566500c923b1f199542b0c7dfb862a077ca --- /dev/null +++ b/unik3d/layers/drop_path.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn + + +def drop_path(x: torch.Tensor, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/unik3d/layers/grad_choker.py b/unik3d/layers/grad_choker.py new file mode 100644 index 0000000000000000000000000000000000000000..12b70174f1922a3d7b962c0cba560d892ac80b60 --- /dev/null +++ b/unik3d/layers/grad_choker.py @@ -0,0 +1,25 @@ +import torch +import torch.nn as nn +from torch.autograd import Function + + +class ChockerFunction(Function): + @staticmethod + def forward(ctx, x, alpha): + ctx.alpha = alpha + return x + + @staticmethod + def backward(ctx, grad_output): + grad_input = grad_output * ctx.alpha + return grad_input, None + + +class GradChoker(nn.Module): + def __init__(self, alpha): + super().__init__() + self.alpha = alpha + + def forward(self, x): + alpha = torch.tensor(self.alpha, requires_grad=False, device=x.device) + return ChockerFunction.apply(x, alpha) diff --git a/unik3d/layers/layer_scale.py b/unik3d/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..01b6662490d7296725f103d1abf8790cac84d0f8 --- /dev/null +++ b/unik3d/layers/layer_scale.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: float | torch.Tensor = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/unik3d/layers/misc.py b/unik3d/layers/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..f5d32030e066eaa8107e8ec060658318c675e8a3 --- /dev/null +++ b/unik3d/layers/misc.py @@ -0,0 +1,30 @@ +import torch +import torch.nn as nn + +from .layer_scale import LayerScale + + +class Addition(nn.Module): + def __init__( + self, + dim: int, + layer_scale: float | torch.Tensor = 1e-5, + ) -> None: + super().__init__() + self.ls1 = LayerScale(dim, layer_scale) if layer_scale > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return x + self.ls1(y) + + +class Concat(nn.Module): + def __init__( + self, + dim: int, + layer_scale: float | torch.Tensor = 1e-5, + ) -> None: + super().__init__() + self.project = nn.Linear(2 * dim, dim) + + def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: + return self.project(torch.cat([x, y], dim=-1)) diff --git a/unik3d/layers/mlp.py b/unik3d/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..5f1b9d2e6d020c039a677f566f4eb925aab696d4 --- /dev/null +++ b/unik3d/layers/mlp.py @@ -0,0 +1,35 @@ +import torch +import torch.nn as nn + +from unik3d.utils.misc import default + +from .activation import SwiGLU + + +class MLP(nn.Module): + def __init__( + self, + input_dim: int, + expansion: int = 4, + dropout: float = 0.0, + gated: bool = False, + output_dim: int | None = None, + ): + super().__init__() + if gated: + expansion = int(expansion * 2 / 3) + hidden_dim = int(input_dim * expansion) + output_dim = default(output_dim, input_dim) + self.norm = nn.LayerNorm(input_dim) + self.proj1 = nn.Linear(input_dim, hidden_dim) + self.proj2 = nn.Linear(hidden_dim, output_dim) + self.act = nn.GELU() if not gated else SwiGLU() + self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity() + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.norm(x) + x = self.proj1(x) + x = self.act(x) + x = self.proj2(x) + x = self.dropout(x) + return x diff --git a/unik3d/layers/positional_encoding.py b/unik3d/layers/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..a76e93499eab94c4dd6fcb43c4baab1493f7c2cb --- /dev/null +++ b/unik3d/layers/positional_encoding.py @@ -0,0 +1,303 @@ +from math import log, pi +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +class PositionEmbeddingSine(nn.Module): + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class LearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def generate_fourier_features(x, max_freq=64, num_bands=16): + x = x.unsqueeze(-1) + device, dtype, orig_x = x.device, x.dtype, x + + scales = torch.linspace( + -max_freq / 2, max_freq / 2, num_bands, device=device, dtype=dtype + ) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat([x.sin(), x.cos()], dim=-1) + x = torch.cat((x, orig_x), dim=-1) + return x.flatten(-2) + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print("======== shape of rope freq", self.freqs_cos.shape, "========") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin + + +class RotaryPositionalEmbeddings(nn.Module): + def __init__( + self, + dim: int, + max_seq_len: int = 30, + base: int = 10_000, + ) -> None: + super().__init__() + self.dim = dim + self.base = base + self.max_seq_len = max_seq_len + self._rope_init() + + # We need to explicitly define reset_parameters for FSDP initialization, see + # https://github.com/pytorch/pytorch/blob/797d4fbdf423dd9320ebe383fb57ffb1135c4a99/torch/distributed/fsdp/_init_utils.py#L885 + def reset_parameters(self): + self._rope_init() + + def _rope_init(self): + theta = 1.0 / ( + self.base + ** (torch.arange(0, self.dim, 2)[: (self.dim // 2)].float() / self.dim) + ) + self.register_buffer("theta", theta, persistent=False) + self.build_rope_cache(self.max_seq_len) + + def build_rope_cache(self, max_seq_len: int = 4096) -> None: + # Create position indexes `[0, 1, ..., max_seq_len - 1]` + seq_idx = torch.arange( + max_seq_len, dtype=self.theta.dtype, device=self.theta.device + ) + + # Outer product of theta and position index; output tensor has + # a shape of [max_seq_len, dim // 2] + idx_theta = torch.einsum("i, j -> ij", seq_idx, self.theta).float() + + cache = torch.stack([torch.cos(idx_theta), torch.sin(idx_theta)], dim=-1) + self.register_buffer("cache", cache, persistent=False) + + def forward(self, x: torch.Tensor, input_pos: torch.Tensor) -> torch.Tensor: + """ + Args: + x (Tensor): input tensor with shape + [bsz, seq_len, num_heads, head_dim] + input_pos (Optional[Tensor]): contains the position of the current toke + + Returns: + Tensor: output tensor with RoPE applied + + Notation used for tensor shapes: + - b: batch size + - s: sequence length + - n_h: num heads + - h_d: head dim + """ + rope_cache = self.cache[input_pos] + + # reshape input; the last dimension is used for computing the output. + # Cast to float to match the reference implementation + # tensor has shape [b, s, n_h, n_d // 2, 2] + xshaped = x.reshape(*x.shape[:-1], -1, 2) + + # reshape the cache for broadcasting + # tensor has shape [b, s, 1, n_d // 2, 2] + rope_cache = rope_cache.unsqueeze(2) + + # tensor has shape [b, s, n_h, n_d // 2, 2] + x_out = torch.stack( + [ + xshaped[..., 0] * rope_cache[..., 0] + - xshaped[..., 1] * rope_cache[..., 1], + xshaped[..., 1] * rope_cache[..., 0] + + xshaped[..., 0] * rope_cache[..., 1], + ], + -1, + ) + + # tensor has shape [b, s, n_h, n_d] + return x_out.flatten(3) diff --git a/unik3d/layers/upsample.py b/unik3d/layers/upsample.py new file mode 100644 index 0000000000000000000000000000000000000000..9a4b9141ed1e0ce747e7e43769c94b709591f115 --- /dev/null +++ b/unik3d/layers/upsample.py @@ -0,0 +1,169 @@ +import torch +import torch.nn as nn +from einops import rearrange + +from unik3d.utils.constants import VERBOSE +from unik3d.utils.misc import profile_method + + +class ResidualConvUnit(nn.Module): + def __init__( + self, + dim, + kernel_size: int = 3, + padding_mode: str = "zeros", + dilation: int = 1, + layer_scale: float = 1.0, + use_norm: bool = False, + ): + super().__init__() + self.conv1 = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation, + padding_mode=padding_mode, + ) + self.conv2 = nn.Conv2d( + dim, + dim, + kernel_size=kernel_size, + padding=dilation * (kernel_size - 1) // 2, + dilation=dilation, + padding_mode=padding_mode, + ) + self.activation = nn.LeakyReLU() + self.skip_add = nn.quantized.FloatFunctional() + self.gamma = ( + nn.Parameter(layer_scale * torch.ones(1, dim, 1, 1)) + if layer_scale > 0.0 + else 1.0 + ) + self.norm1 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() + self.norm2 = nn.GroupNorm(dim // 16, dim) if use_norm else nn.Identity() + + def forward(self, x): + out = self.activation(x) + out = self.conv1(out) + out = self.norm1(out) + out = self.activation(out) + out = self.conv2(out) + out = self.norm2(out) + return self.skip_add.add(self.gamma * out, x) + + +class ResUpsampleBil(nn.Module): + def __init__( + self, + hidden_dim, + output_dim: int = None, + num_layers: int = 2, + kernel_size: int = 3, + layer_scale: float = 1.0, + padding_mode: str = "zeros", + use_norm: bool = False, + **kwargs, + ): + super().__init__() + output_dim = output_dim if output_dim is not None else hidden_dim // 2 + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + ResidualConvUnit( + hidden_dim, + kernel_size=kernel_size, + layer_scale=layer_scale, + padding_mode=padding_mode, + use_norm=use_norm, + ) + ) + self.up = nn.Sequential( + nn.Conv2d( + hidden_dim, + output_dim, + kernel_size=1, + padding=0, + padding_mode=padding_mode, + ), + nn.Upsample(scale_factor=2, mode="bilinear", align_corners=False), + ) + + @profile_method(verbose=VERBOSE) + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + return x + + +class ResUpsample(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + kernel_size: int = 3, + layer_scale: float = 1.0, + padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + ResidualConvUnit( + hidden_dim, + kernel_size=kernel_size, + layer_scale=layer_scale, + padding_mode=padding_mode, + ) + ) + self.up = nn.ConvTranspose2d( + hidden_dim, hidden_dim // 2, kernel_size=2, stride=2, padding=0 + ) + + @profile_method(verbose=VERBOSE) + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + return x + + +class ResUpsampleSH(nn.Module): + def __init__( + self, + hidden_dim, + num_layers: int = 2, + kernel_size: int = 3, + layer_scale: float = 1.0, + padding_mode: str = "zeros", + **kwargs, + ): + super().__init__() + self.convs = nn.ModuleList([]) + for _ in range(num_layers): + self.convs.append( + ResidualConvUnit( + hidden_dim, + kernel_size=kernel_size, + layer_scale=layer_scale, + padding_mode=padding_mode, + ) + ) + self.up = nn.Sequential( + nn.PixelShuffle(2), + nn.Conv2d( + hidden_dim // 4, + hidden_dim // 2, + kernel_size=3, + padding=1, + padding_mode=padding_mode, + ), + ) + + def forward(self, x: torch.Tensor): + for conv in self.convs: + x = conv(x) + x = self.up(x) + return x diff --git a/unik3d/models/__init__.py b/unik3d/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..62d77529507b5eb1ae2dc281cd8c152e84212125 --- /dev/null +++ b/unik3d/models/__init__.py @@ -0,0 +1,3 @@ +from unik3d.models.unik3d import UniK3D + +__all__ = ["UniK3D"] diff --git a/unik3d/models/backbones/__init__.py b/unik3d/models/backbones/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..af81441399818e6a32fd91b23aae7945b8fd4e1c --- /dev/null +++ b/unik3d/models/backbones/__init__.py @@ -0,0 +1,13 @@ +from .convnext import ConvNeXt +from .convnext2 import ConvNeXtV2 +from .dinov2 import _make_dinov2_model +from .swinv2 import SwinTransformerV2 + +# from .svd import StableVideoDiffusion + +__all__ = [ + "SwinTransformerV2", + "ConvNeXtV2", + "_make_dinov2_model", + "ConvNeXt", +] diff --git a/unik3d/models/backbones/convnext.py b/unik3d/models/backbones/convnext.py new file mode 100644 index 0000000000000000000000000000000000000000..c7f9d820f3230e349e6263cc0be4ce1139b2df3f --- /dev/null +++ b/unik3d/models/backbones/convnext.py @@ -0,0 +1,577 @@ +from collections import OrderedDict +from functools import partial +from typing import Callable, Optional, Sequence, Tuple, Union + +import torch +import torch.nn as nn +from timm.layers import (AvgPool2dSame, DropPath, GlobalResponseNormMlp, + LayerNorm, LayerNorm2d, Mlp, create_conv2d, + get_act_layer, make_divisible, to_ntuple, + trunc_normal_) +from torch.utils.checkpoint import checkpoint + + +def get_num_layer_for_convnext(var_name): + """ + Divide [3, 3, 27, 3] layers into 12 groups; each group is three + consecutive blocks, including possible neighboring downsample layers; + adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py + """ + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + elif stage_id == 3: + layer_id = 12 + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[3]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = 12 + + elif var_name.startswith("stem"): + return 0 + else: + layer_id = 12 + return layer_id + 1 + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=None): + parameter_group_names = {} + parameter_group_vars = {} + skip = set() + if skip_list is not None: + skip = skip_list + if hasattr(model, "no_weight_decay"): + skip.update(model.no_weight_decay()) + num_layers = 12 + layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if len(param.shape) == 1 or name.endswith(".bias") or name in skip: + group_name = "no_decay" + this_wd = 0.0 + else: + group_name = "decay" + this_wd = wd + + layer_id = get_num_layer_for_convnext(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_wd, + "weight_decay_init": this_wd, + "weight_decay_base": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_wd, + "weight_decay_init": this_wd, + "weight_decay_base": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + if this_wd == 0.0: + parameter_group_names[group_name]["weight_decay_final"] = 0.0 + parameter_group_vars[group_name]["weight_decay_final"] = 0.0 + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class Downsample(nn.Module): + def __init__(self, in_chs, out_chs, stride=1, dilation=1): + super().__init__() + avg_stride = stride if dilation == 1 else 1 + if stride > 1 or dilation > 1: + avg_pool_fn = ( + AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d + ) + self.pool = avg_pool_fn( + 2, avg_stride, ceil_mode=True, count_include_pad=False + ) + else: + self.pool = nn.Identity() + + if in_chs != out_chs: + self.conv = create_conv2d(in_chs, out_chs, 1, stride=1) + else: + self.conv = nn.Identity() + + def forward(self, x): + x = self.pool(x) + x = self.conv(x) + return x + + +class ConvNeXtBlock(nn.Module): + """ConvNeXt Block + There are two equivalent implementations: + (1) DwConv -> LayerNorm (channels_first) -> 1x1 Conv -> GELU -> 1x1 Conv; all in (N, C, H, W) + (2) DwConv -> Permute to (N, H, W, C); LayerNorm (channels_last) -> Linear -> GELU -> Linear; Permute back + + Unlike the official impl, this one allows choice of 1 or 2, 1x1 conv can be faster with appropriate + choice of LayerNorm impl, however as model size increases the tradeoffs appear to change and nn.Linear + is a better choice. This was observed with PyTorch 1.10 on 3090 GPU, it could change over time & w/ different HW. + """ + + def __init__( + self, + in_chs: int, + out_chs: Optional[int] = None, + kernel_size: int = 7, + stride: int = 1, + dilation: Union[int, Tuple[int, int]] = (1, 1), + mlp_ratio: float = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + ls_init_value: Optional[float] = 1e-6, + act_layer: Union[str, Callable] = "gelu", + norm_layer: Optional[Callable] = None, + drop_path: float = 0.0, + ): + """ + + Args: + in_chs: Block input channels. + out_chs: Block output channels (same as in_chs if None). + kernel_size: Depthwise convolution kernel size. + stride: Stride of depthwise convolution. + dilation: Tuple specifying input and output dilation of block. + mlp_ratio: MLP expansion ratio. + conv_mlp: Use 1x1 convolutions for MLP and a NCHW compatible norm layer if True. + conv_bias: Apply bias for all convolution (linear) layers. + use_grn: Use GlobalResponseNorm in MLP (from ConvNeXt-V2) + ls_init_value: Layer-scale init values, layer-scale applied if not None. + act_layer: Activation layer. + norm_layer: Normalization layer (defaults to LN if not specified). + drop_path: Stochastic depth probability. + """ + super().__init__() + out_chs = out_chs or in_chs + dilation = to_ntuple(2)(dilation) + act_layer = get_act_layer(act_layer) + if not norm_layer: + norm_layer = LayerNorm2d if conv_mlp else LayerNorm + mlp_layer = partial( + GlobalResponseNormMlp if use_grn else Mlp, use_conv=conv_mlp + ) + self.use_conv_mlp = conv_mlp + self.conv_dw = create_conv2d( + in_chs, + out_chs, + kernel_size=kernel_size, + stride=stride, + dilation=dilation[0], + depthwise=True, + bias=conv_bias, + ) + self.norm = norm_layer(out_chs) + self.mlp = mlp_layer(out_chs, int(mlp_ratio * out_chs), act_layer=act_layer) + self.gamma = ( + nn.Parameter(ls_init_value * torch.ones(out_chs)) + if ls_init_value is not None + else None + ) + if in_chs != out_chs or stride != 1 or dilation[0] != dilation[1]: + self.shortcut = Downsample( + in_chs, out_chs, stride=stride, dilation=dilation[0] + ) + else: + self.shortcut = nn.Identity() + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + def forward(self, x): + shortcut = x + x = self.conv_dw(x.contiguous()) + if self.use_conv_mlp: + x = self.norm(x) + x = self.mlp(x) + else: + x = x.permute(0, 2, 3, 1).contiguous() + x = self.norm(x) + x = self.mlp(x) + x = x.permute(0, 3, 1, 2).contiguous() + if self.gamma is not None: + x = x.mul(self.gamma.reshape(1, -1, 1, 1)) + + x = self.drop_path(x) + self.shortcut(shortcut) + return x.contiguous() + + +class ConvNeXtStage(nn.Module): + def __init__( + self, + in_chs, + out_chs, + kernel_size=7, + stride=2, + depth=2, + dilation=(1, 1), + drop_path_rates=None, + ls_init_value=1.0, + conv_mlp=False, + conv_bias=True, + use_grn=False, + act_layer="gelu", + norm_layer=None, + norm_layer_cl=None, + ): + super().__init__() + self.grad_checkpointing = False + + if in_chs != out_chs or stride > 1 or dilation[0] != dilation[1]: + ds_ks = 2 if stride > 1 or dilation[0] != dilation[1] else 1 + pad = ( + "same" if dilation[1] > 1 else 0 + ) # same padding needed if dilation used + self.downsample = nn.Sequential( + norm_layer(in_chs), + create_conv2d( + in_chs, + out_chs, + kernel_size=ds_ks, + stride=stride, + dilation=dilation[0], + padding=pad, + bias=conv_bias, + ), + ) + in_chs = out_chs + else: + self.downsample = nn.Identity() + + drop_path_rates = drop_path_rates or [0.0] * depth + stage_blocks = [] + for i in range(depth): + stage_blocks.append( + ConvNeXtBlock( + in_chs=in_chs, + out_chs=out_chs, + kernel_size=kernel_size, + dilation=dilation[1], + drop_path=drop_path_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer if conv_mlp else norm_layer_cl, + ) + ) + in_chs = out_chs + self.blocks = nn.ModuleList(stage_blocks) + + def forward(self, x): + xs = [] + x = self.downsample(x) + for block in self.blocks: + if self.grad_checkpointing: + x = checkpoint(block, x) + else: + x = block(x) + xs.append(x) + return xs + + +class ConvNeXt(nn.Module): + def __init__( + self, + in_chans: int = 3, + output_stride: int = 32, + depths: Tuple[int, ...] = (3, 3, 9, 3), + dims: Tuple[int, ...] = (96, 192, 384, 768), + kernel_sizes: Union[int, Tuple[int, ...]] = 7, + ls_init_value: Optional[float] = 1e-6, + stem_type: str = "patch", + patch_size: int = 4, + conv_mlp: bool = False, + conv_bias: bool = True, + use_grn: bool = False, + act_layer: Union[str, Callable] = "gelu", + norm_layer: Optional[Union[str, Callable]] = None, + norm_eps: Optional[float] = None, + drop_path_rate: float = 0.0, + output_idx=[], + use_checkpoint=False, + ): + """ + Args: + in_chans: Number of input image channels. + num_classes: Number of classes for classification head. + global_pool: Global pooling type. + output_stride: Output stride of network, one of (8, 16, 32). + depths: Number of blocks at each stage. + dims: Feature dimension at each stage. + kernel_sizes: Depthwise convolution kernel-sizes for each stage. + ls_init_value: Init value for Layer Scale, disabled if None. + stem_type: Type of stem. + patch_size: Stem patch size for patch stem. + head_init_scale: Init scaling value for classifier weights and biases. + head_norm_first: Apply normalization before global pool + head. + head_hidden_size: Size of MLP hidden layer in head if not None and head_norm_first == False. + conv_mlp: Use 1x1 conv in MLP, improves speed for small networks w/ chan last. + conv_bias: Use bias layers w/ all convolutions. + use_grn: Use Global Response Norm (ConvNeXt-V2) in MLP. + act_layer: Activation layer type. + norm_layer: Normalization layer type. + drop_rate: Head pre-classifier dropout rate. + drop_path_rate: Stochastic depth drop rate. + """ + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + self.embed_dim = dims[0] + + assert output_stride in (8, 16, 32) + kernel_sizes = to_ntuple(4)(kernel_sizes) + if norm_layer is None: + norm_layer = LayerNorm2d + norm_layer_cl = norm_layer if conv_mlp else LayerNorm + if norm_eps is not None: + norm_layer = partial(norm_layer, eps=norm_eps) + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + else: + assert ( + conv_mlp + ), "If a norm_layer is specified, conv MLP must be used so all norm expect rank-4, channels-first input" + norm_layer_cl = norm_layer + if norm_eps is not None: + norm_layer_cl = partial(norm_layer_cl, eps=norm_eps) + + self.feature_info = [] + + assert stem_type in ("patch", "overlap", "overlap_tiered") + if stem_type == "patch": + # NOTE: this stem is a minimal form of ViT PatchEmbed, as used in SwinTransformer w/ patch_size = 4 + self.stem = nn.Sequential( + nn.Conv2d( + in_chans, + dims[0], + kernel_size=patch_size, + stride=patch_size, + bias=conv_bias, + ), + norm_layer(dims[0]), + ) + stem_stride = patch_size + else: + mid_chs = make_divisible(dims[0] // 2) if "tiered" in stem_type else dims[0] + self.stem = nn.Sequential( + nn.Conv2d( + in_chans, + mid_chs, + kernel_size=3, + stride=2, + padding=1, + bias=conv_bias, + ), + nn.Conv2d( + mid_chs, dims[0], kernel_size=3, stride=2, padding=1, bias=conv_bias + ), + norm_layer(dims[0]), + ) + stem_stride = 4 + + self.stages = nn.Sequential() + dp_rates = [ + x.tolist() + for x in torch.linspace(0, drop_path_rate, sum(depths)).split(depths) + ] + stages = [] + prev_chs = dims[0] + curr_stride = stem_stride + dilation = 1 + # 4 feature resolution stages, each consisting of multiple residual blocks + for i in range(4): + stride = 2 if curr_stride == 2 or i > 0 else 1 + if curr_stride >= output_stride and stride > 1: + dilation *= stride + stride = 1 + curr_stride *= stride + first_dilation = 1 if dilation in (1, 2) else 2 + out_chs = dims[i] + stages.append( + ConvNeXtStage( + prev_chs, + out_chs, + kernel_size=kernel_sizes[i], + stride=stride, + dilation=(first_dilation, dilation), + depth=depths[i], + drop_path_rates=dp_rates[i], + ls_init_value=ls_init_value, + conv_mlp=conv_mlp, + conv_bias=conv_bias, + use_grn=use_grn, + act_layer=act_layer, + norm_layer=norm_layer, + norm_layer_cl=norm_layer_cl, + ) + ) + prev_chs = out_chs + # NOTE feature_info use currently assumes stage 0 == stride 1, rest are stride 2 + self.feature_info += [ + dict(num_chs=prev_chs, reduction=curr_stride, module=f"stages.{i}") + ] + self.stages = nn.ModuleList(stages) + self.mask_token = nn.Parameter(torch.zeros(1, self.embed_dim, 1, 1)) + self.num_features = prev_chs + self.apply(self._init_weights) + self.set_grad_checkpointing(use_checkpoint) + + def _init_weights(self, module): + if isinstance(module, nn.Conv2d): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + elif isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + nn.init.zeros_(module.bias) + + def forward(self, x, masks=None): + outs = [] + x = self.stem(x) + if masks is not None: + masks = torch.nn.functional.interpolate( + masks.float(), size=x.shape[-2:], mode="nearest" + ) + x = torch.where(masks.bool(), self.mask_token.to(x.dtype), x).contiguous() + for stage in self.stages: + xs = stage(x) + outs.extend([x.permute(0, 2, 3, 1).contiguous() for x in xs]) + x = xs[-1] + return outs, [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] + + @torch.jit.ignore + def group_matcher(self, coarse=False): + return dict( + stem=r"^stem", + blocks=( + r"^stages\.(\d+)" + if coarse + else [ + (r"^stages\.(\d+)\.downsample", (0,)), # blocks + (r"^stages\.(\d+)\.blocks\.(\d+)", None), + (r"^norm_pre", (99999,)), + ] + ), + ) + + @torch.jit.ignore + def set_grad_checkpointing(self, enable=True): + for s in self.stages: + s.grad_checkpointing = enable + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def no_weight_decay(self): + return {"mask_token"} + + @classmethod + def build(cls, config): + obj = globals()[config["model"]["encoder"]["name"]](config) + return obj + + +def checkpoint_filter_fn(state_dict, model): + """Remap FB checkpoints -> timm""" + if "head.norm.weight" in state_dict or "norm_pre.weight" in state_dict: + return state_dict # non-FB checkpoint + if "model" in state_dict: + state_dict = state_dict["model"] + + out_dict = {} + if "visual.trunk.stem.0.weight" in state_dict: + out_dict = { + k.replace("visual.trunk.", ""): v + for k, v in state_dict.items() + if k.startswith("visual.trunk.") + } + if "visual.head.proj.weight" in state_dict: + out_dict["head.fc.weight"] = state_dict["visual.head.proj.weight"] + out_dict["head.fc.bias"] = torch.zeros( + state_dict["visual.head.proj.weight"].shape[0] + ) + elif "visual.head.mlp.fc1.weight" in state_dict: + out_dict["head.pre_logits.fc.weight"] = state_dict[ + "visual.head.mlp.fc1.weight" + ] + out_dict["head.pre_logits.fc.bias"] = state_dict["visual.head.mlp.fc1.bias"] + out_dict["head.fc.weight"] = state_dict["visual.head.mlp.fc2.weight"] + out_dict["head.fc.bias"] = torch.zeros( + state_dict["visual.head.mlp.fc2.weight"].shape[0] + ) + return out_dict + + import re + + for k, v in state_dict.items(): + k = k.replace("downsample_layers.0.", "stem.") + k = re.sub(r"stages.([0-9]+).([0-9]+)", r"stages.\1.blocks.\2", k) + k = re.sub( + r"downsample_layers.([0-9]+).([0-9]+)", r"stages.\1.downsample.\2", k + ) + k = k.replace("dwconv", "conv_dw") + k = k.replace("pwconv", "mlp.fc") + if "grn" in k: + k = k.replace("grn.beta", "mlp.grn.bias") + k = k.replace("grn.gamma", "mlp.grn.weight") + v = v.reshape(v.shape[-1]) + k = k.replace("head.", "head.fc.") + if k.startswith("norm."): + k = k.replace("norm", "head.norm") + if v.ndim == 2 and "head" not in k: + model_shape = model.state_dict()[k].shape + v = v.reshape(model_shape) + out_dict[k] = v + + return out_dict + + +HF_URL = { + "convnext_xxlarge_pt": ( + "laion/CLIP-convnext_xxlarge-laion2B-s34B-b82K-augreg-soup", + "open_clip_pytorch_model.bin", + ), + "convnext_large_pt": ( + "laion/CLIP-convnext_large_d_320.laion2B-s29B-b131K-ft-soup", + "open_clip_pytorch_model.bin", + ), + "convnext_large": ( + "timm/convnext_large_mlp.clip_laion2b_soup_ft_in12k_in1k_384", + "pytorch_model.bin", + ), +} diff --git a/unik3d/models/backbones/convnext2.py b/unik3d/models/backbones/convnext2.py new file mode 100644 index 0000000000000000000000000000000000000000..74d5bbf12bf586552ef35e2e82d59e47b9c9cc42 --- /dev/null +++ b/unik3d/models/backbones/convnext2.py @@ -0,0 +1,288 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from timm.models.layers import DropPath, trunc_normal_ + + +def get_num_layer_for_convnext_single(var_name, depths): + """ + Each layer is assigned distinctive layer ids + """ + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + layer_id = sum(depths[:stage_id]) + 1 + return layer_id + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[2]) + layer_id = sum(depths[:stage_id]) + block_id + 1 + return layer_id + + else: + return sum(depths) + 1 + + +def get_num_layer_for_convnext(var_name): + """ + Divide [3, 3, 27, 3] layers into 12 groups; each group is three + consecutive blocks, including possible neighboring downsample layers; + adapted from https://github.com/microsoft/unilm/blob/master/beit/optim_factory.py + """ + num_max_layer = 12 + if var_name.startswith("downsample_layers"): + stage_id = int(var_name.split(".")[1]) + if stage_id == 0: + layer_id = 0 + elif stage_id == 1 or stage_id == 2: + layer_id = stage_id + 1 + elif stage_id == 3: + layer_id = 12 + return layer_id + + elif var_name.startswith("stages"): + stage_id = int(var_name.split(".")[1]) + block_id = int(var_name.split(".")[2]) + if stage_id == 0 or stage_id == 1: + layer_id = stage_id + 1 + elif stage_id == 2: + layer_id = 3 + block_id // 3 + elif stage_id == 3: + layer_id = 12 + return layer_id + else: + return num_max_layer + 1 + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): + parameter_group_names = {} + parameter_group_vars = {} + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + num_layers = 12 # sum(model.depths) + layer_scale = list(ld ** (num_layers + 1 - i) for i in range(num_layers + 2)) + for name, param in model.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if ( + len(param.shape) == 1 + or name.endswith(".bias") + or name in skip + or name.endswith(".gamma") + or name.endswith(".beta") + ): + group_name = "no_decay" + this_weight_decay = 0.0 + else: + group_name = "decay" + this_weight_decay = wd + + # layer_id = get_num_layer_for_convnext_single(name, model.depths) + layer_id = get_num_layer_for_convnext(name) + group_name = "layer_%d_%s" % (layer_id, group_name) + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_weight_decay, + "params": [], + "lr_scale": scale, + "lr": cur_lr, + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + # if is_main_process(): + # print("Param groups = %s" % json.dumps(parameter_group_names, indent=2)) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class LayerNorm(nn.Module): + """LayerNorm that supports two data formats: channels_last (default) or channels_first. + The ordering of the dimensions in the inputs. channels_last corresponds to inputs with + shape (batch_size, height, width, channels) while channels_first corresponds to inputs + with shape (batch_size, channels, height, width). + """ + + def __init__(self, normalized_shape, eps=1e-6, data_format="channels_last"): + super().__init__() + self.weight = nn.Parameter(torch.ones(normalized_shape)) + self.bias = nn.Parameter(torch.zeros(normalized_shape)) + self.eps = eps + self.data_format = data_format + if self.data_format not in ["channels_last", "channels_first"]: + raise NotImplementedError + self.normalized_shape = (normalized_shape,) + + def forward(self, x): + if self.data_format == "channels_last": + return F.layer_norm( + x, self.normalized_shape, self.weight, self.bias, self.eps + ) + elif self.data_format == "channels_first": + u = x.mean(1, keepdim=True) + s = (x - u).pow(2).mean(1, keepdim=True) + x = (x - u) / torch.sqrt(s + self.eps) + x = self.weight[:, None, None] * x + self.bias[:, None, None] + return x + + +class GRN(nn.Module): + """GRN (Global Response Normalization) layer""" + + def __init__(self, dim): + super().__init__() + self.gamma = nn.Parameter(torch.zeros(1, 1, 1, dim)) + self.beta = nn.Parameter(torch.zeros(1, 1, 1, dim)) + + def forward(self, x): + Gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True) + Nx = Gx / (Gx.mean(dim=-1, keepdim=True) + 1e-6) + return self.gamma * (x * Nx) + self.beta + x + + +class Block(nn.Module): + """ConvNeXtV2 Block. + + Args: + dim (int): Number of input channels. + drop_path (float): Stochastic depth rate. Default: 0.0 + """ + + def __init__(self, dim, drop_path=0.0, mult=4, use_checkpoint=False): + super().__init__() + self.dwconv = nn.Conv2d( + dim, dim, kernel_size=7, padding=3, groups=dim + ) # depthwise conv + self.norm = LayerNorm(dim, eps=1e-6) + self.pwconv1 = nn.Linear( + dim, mult * dim + ) # pointwise/1x1 convs, implemented with linear layers + self.act = nn.GELU() + self.grn = GRN(mult * dim) + self.pwconv2 = nn.Linear(mult * dim, dim) + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.use_checkpoint = use_checkpoint + + def forward(self, x): + input = x + x = self.dwconv(x) + x = x.permute(0, 2, 3, 1) # (N, C, H, W) -> (N, H, W, C) + x = self.norm(x) + x = self.pwconv1(x) + x = self.act(x) + x = self.grn(x) + x = self.pwconv2(x) + x = x.permute(0, 3, 1, 2) # (N, H, W, C) -> (N, C, H, W) + + x = input + self.drop_path(x) + return x + + +class ConvNeXtV2(nn.Module): + """ConvNeXt V2 + + Args: + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + depths (tuple(int)): Number of blocks at each stage. Default: [3, 3, 9, 3] + dims (int): Feature dimension at each stage. Default: [96, 192, 384, 768] + drop_path_rate (float): Stochastic depth rate. Default: 0. + head_init_scale (float): Init scaling value for classifier weights and biases. Default: 1. + """ + + def __init__( + self, + in_chans=3, + depths=[3, 3, 9, 3], + dims=96, + drop_path_rate=0.0, + output_idx=[], + use_checkpoint=False, + ): + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + self.embed_dim = dims[0] + + self.downsample_layers = ( + nn.ModuleList() + ) # stem and 3 intermediate downsampling conv layers + stem = nn.Sequential( + nn.Conv2d(in_chans, dims[0], kernel_size=4, stride=4), + LayerNorm(dims[0], eps=1e-6, data_format="channels_first"), + ) + self.downsample_layers.append(stem) + for i in range(3): + downsample_layer = nn.Sequential( + LayerNorm(dims[i], eps=1e-6, data_format="channels_first"), + nn.Conv2d(dims[i], dims[i + 1], kernel_size=2, stride=2), + ) + self.downsample_layers.append(downsample_layer) + + self.stages = ( + nn.ModuleList() + ) # 4 feature resolution stages, each consisting of multiple residual blocks + self.out_norms = nn.ModuleList() + dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))] + cur = 0 + for i in range(4): + stage = nn.ModuleList( + [ + Block( + dim=dims[i], + drop_path=dp_rates[cur + j], + use_checkpoint=use_checkpoint, + ) + for j in range(depths[i]) + ] + ) + self.stages.append(stage) + cur += depths[i] + + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, (nn.Conv2d, nn.Linear)): + trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + def forward(self, x): + outs = [] + for i in range(4): + x = self.downsample_layers[i](x) + for stage in self.stages[i]: + x = stage(x) + outs.append(x.permute(0, 2, 3, 1)) + cls_tokens = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in outs] + return outs, cls_tokens + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + @classmethod + def build(cls, config): + obj = globals()[config["model"]["encoder"]["name"]](config) + return obj diff --git a/unik3d/models/backbones/dinov2.py b/unik3d/models/backbones/dinov2.py new file mode 100644 index 0000000000000000000000000000000000000000..163fbaac32c723251bb2614f765e30337d143d11 --- /dev/null +++ b/unik3d/models/backbones/dinov2.py @@ -0,0 +1,521 @@ +import contextlib +import logging +import math +from functools import partial +from typing import Callable, Sequence + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ + +from unik3d.models.metadinov2 import (Block, MemEffAttention, Mlp, PatchEmbed, + SwiGLUFFNFused) + + +def named_apply( + fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False +) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply( + fn=fn, + module=child_module, + name=child_name, + depth_first=depth_first, + include_root=True, + ) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +def get_parameter_groups(model, lr, wd=1e-5, ld=0.9, skip_list=()): + parameter_group_names = {} + parameter_group_vars = {} + skip = {} + if skip_list is not None: + skip = skip_list + elif hasattr(model, "no_weight_decay"): + skip = model.no_weight_decay() + + num_layers = model.n_blocks + layer_scale = list(ld ** (num_layers - i) for i in range(num_layers)) + + for name, param in model.named_parameters(): + if not param.requires_grad: + continue + + if len(param.shape) == 1: # norm + group_name = "no_decay" + this_wd = 0.0 + # layer scale, bias beta? + elif ( + name in skip + or name.endswith(".gamma") + or name.endswith(".beta") + or name.endswith(".bias") + ): + group_name = "no_decay" + this_wd = 0.0 + elif "cls_token" in name or "pos_embed" in name or "mask_token" in name: + group_name = "no_decay" + this_wd = 0.0 + else: + group_name = "decay" + this_wd = wd + + if name.startswith("blocks"): + layer_id = int(name.split(".")[1]) + elif name.startswith("patch_embed"): + layer_id = 0 + else: + layer_id = 0 + + group_name = f"layer_{layer_id}_{group_name}" + + if group_name not in parameter_group_names: + scale = layer_scale[layer_id] + cur_lr = lr * scale + + parameter_group_names[group_name] = { + "weight_decay": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name] = { + "weight_decay": this_wd, + "params": [], + "lr_init": cur_lr, + "lr_base": lr, + "lr": cur_lr, + } + parameter_group_vars[group_name]["params"].append(param) + parameter_group_names[group_name]["params"].append(name) + + # for group_name in parameter_group_names.keys(): + # for k, v in zip(parameter_group_names[group_name]["params"], parameter_group_vars[group_name]["params"]): + # print(group_name,k) + return list(parameter_group_vars.values()), [ + v["lr"] for k, v in parameter_group_vars.items() + ] + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DummyModule(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + for i in range(100): + setattr(self, f"layer{i}", nn.Linear(2048, 2048)) + + def forward(self, x): + return self.layer(x) + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + output_idx=[5, 12, 18, 24], + checkpoint: bool = False, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.0, + use_norm=True, + frozen_stages=0, + freeze_norm=True, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + """ + super().__init__() + # norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = embed_dim # num_features for consistency with other models + self.frozen_stages = frozen_stages + self.embed_dims = [embed_dim] * output_idx[-1] + self.embed_dim = embed_dim + self.num_tokens = 1 + self.freeze_norm = freeze_norm + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.depths = output_idx + self.checkpoint = checkpoint + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + ) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter( + torch.zeros(1, num_patches + self.num_tokens, embed_dim) + ) + assert num_register_tokens >= 0 + self.register_tokens = nn.Parameter( + torch.zeros(1, max(1, num_register_tokens), embed_dim) + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, depth) + ] # stochastic depth decay rule + + if ffn_layer == "mlp": + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + # nn.Identity() + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=nn.LayerNorm, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = nn.LayerNorm(embed_dim) + self.use_norm = use_norm + self.head = nn.Identity() + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.num_register_tokens: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to( + previous_dtype + ) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + with torch.no_grad() if self.frozen_stages > -1 else contextlib.nullcontext(): + x = self.patch_embed(x) + if masks is not None: + masks = masks.bool().view(B, -1, 1) + x = torch.where(masks, self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.num_register_tokens: + x = torch.cat( + (x[:, :1], self.register_tokens.expand(x.shape[0], -1, -1), x[:, 1:]), + dim=1, + ) + + return x + + def forward_features(self, x, masks=None): + shapes = [val // self.patch_size for val in x.shape[-2:]] + batch_size = x.shape[0] + outputs = [] + + x = self.prepare_tokens_with_masks(x, masks) + + for i, blk in enumerate(self.blocks): + with ( + torch.no_grad() if i < self.frozen_stages else contextlib.nullcontext() + ): + x = blk(x) + outputs.append(x) + + if self.use_norm: + with ( + torch.no_grad() + if self.frozen_stages >= len(self.blocks) + else contextlib.nullcontext() + ): + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, :1] for out in outputs] + outputs = [out[:, self.num_register_tokens + 1 :] for out in outputs] + outputs = [out.reshape(batch_size, *shapes, -1) for out in outputs] + + return (outputs, class_tokens) + + def get_params(self, lr, wd, ld, *args, **kwargs): + encoder_p, encoder_lr = get_parameter_groups(self, lr, wd, ld) + return encoder_p, encoder_lr + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def train(self, mode=True): + super().train(mode) + + if self.freeze_norm: + for module in self.modules(): + if isinstance(module, nn.LayerNorm): + for param in module.parameters(): + param.requires_grad = False + module.eval() + + if self.frozen_stages > -1: + for p in self.patch_embed.parameters(): + p.requires_grad = False + + for i, blk in enumerate(self.blocks): + if i < self.frozen_stages: + blk.eval() + for p in blk.parameters(): + p.requires_grad = False + + for p in self.norm.parameters(): + p.requires_grad = self.frozen_stages <= len(self.blocks) + + self.cls_token.requires_grad = self.frozen_stages < 1 + self.pos_embed.requires_grad = self.frozen_stages < 1 + self.mask_token.requires_grad = False + self.register_tokens.requires_grad = False + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + return ret + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + num_register_tokens=num_register_tokens, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + **kwargs, + ) + return model + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + return f"dinov2_{compact_arch_name}{patch_size}" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + pretrained: str = "", + output_idx: Sequence[int] = [], + num_register_tokens: int = 0, + drop_path_rate: float = 0.0, + use_norm: bool = False, + interpolate_offset: float = 0.0, + frozen_stages: int = 0, + freeze_norm: bool = True, + **kwargs, +): + model_name = _make_dinov2_model_name(arch_name, patch_size) + + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + output_idx=output_idx, + drop_path_rate=drop_path_rate, + num_register_tokens=num_register_tokens, + use_norm=use_norm, + interpolate_offset=interpolate_offset, + frozen_stages=frozen_stages, + freeze_norm=freeze_norm, + ) + vit_kwargs.update(**kwargs) + model = eval(arch_name)(**vit_kwargs) + + if pretrained == "": + url = _DINOV2_BASE_URL + f"/{model_name}/{model_name}" + if num_register_tokens > 0: + url += "_reg4" + url += "_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + ) + info = model.load_state_dict(state_dict, strict=False) + del state_dict + elif pretrained is not None: + state_dict = torch.load(pretrained, map_location="cpu", weights_only=False) + info = model.load_state_dict(state_dict, strict=False) + del state_dict + else: + info = {} + + print(f"DINOv2 loaded from {pretrained} with info:", info) + + return model diff --git a/unik3d/models/backbones/swinv2.py b/unik3d/models/backbones/swinv2.py new file mode 100644 index 0000000000000000000000000000000000000000..9f352aa3b784214637c9ced580a2a3e44063008b --- /dev/null +++ b/unik3d/models/backbones/swinv2.py @@ -0,0 +1,942 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint as checkpoint +from timm.models.layers import DropPath, to_2tuple, trunc_normal_ + +from unik3d.utils.misc import get_params, load_checkpoint_swin + + +class Mlp(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x + + +def window_partition(x, window_size): + """ + Args: + x: (B, H, W, C) + window_size (int): window size + + Returns: + windows: (num_windows*B, window_size, window_size, C) + """ + B, H, W, C = x.shape + x = x.view( + B, H // window_size[0], window_size[0], W // window_size[1], window_size[1], C + ) + windows = ( + x.permute(0, 1, 3, 2, 4, 5) + .contiguous() + .view(-1, window_size[0], window_size[1], C) + ) + return windows + + +def window_reverse(windows, window_size, H, W): + """ + Args: + windows: (num_windows*B, window_size, window_size, C) + window_size (int): Window size + H (int): Height of image + W (int): Width of image + + Returns: + x: (B, H, W, C) + """ + B = int(windows.shape[0] / (H * W / window_size[0] / window_size[1])) + x = windows.view( + B, H // window_size[0], W // window_size[1], window_size[0], window_size[1], -1 + ) + x = x.permute(0, 1, 3, 2, 4, 5).contiguous().view(B, H, W, -1) + return x + + +class WindowAttention(nn.Module): + r"""Window based multi-head self attention (W-MSA) module with relative position bias. + It supports both of shifted and non-shifted window. + + Args: + dim (int): Number of input channels. + window_size (tuple[int]): The height and width of the window. + num_heads (int): Number of attention heads. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + attn_drop (float, optional): Dropout ratio of attention weight. Default: 0.0 + proj_drop (float, optional): Dropout ratio of output. Default: 0.0 + pretrained_window_size (tuple[int]): The height and width of the window in pre-training. + """ + + def __init__( + self, + dim, + window_size, + num_heads, + qkv_bias=True, + attn_drop=0.0, + proj_drop=0.0, + pretrained_window_size=[0, 0], + ): + + super().__init__() + self.dim = dim + self.window_size = window_size # Wh, Ww + self.pretrained_window_size = pretrained_window_size + self.num_heads = num_heads + + self.logit_scale = nn.Parameter( + torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True + ) + + # mlp to generate continuous relative position bias + self.rpe_mlp = nn.Sequential( + nn.Linear(2, 512, bias=True), + nn.ReLU(inplace=True), + nn.Linear(512, num_heads, bias=False), + ) + + # get relative_coords_table + relative_coords_h = torch.arange( + -(self.window_size[0] - 1), self.window_size[0], dtype=torch.float32 + ) + relative_coords_w = torch.arange( + -(self.window_size[1] - 1), self.window_size[1], dtype=torch.float32 + ) + relative_coords_table = ( + torch.stack(torch.meshgrid([relative_coords_h, relative_coords_w])) + .permute(1, 2, 0) + .contiguous() + .unsqueeze(0) + ) # 1, 2*Wh-1, 2*Ww-1, 2 + if pretrained_window_size[0] > 0: + relative_coords_table[:, :, :, 0] /= pretrained_window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= pretrained_window_size[1] - 1 + else: + relative_coords_table[:, :, :, 0] /= self.window_size[0] - 1 + relative_coords_table[:, :, :, 1] /= self.window_size[1] - 1 + relative_coords_table *= 8 # normalize to -8, 8 + relative_coords_table = ( + torch.sign(relative_coords_table) + * torch.log2(torch.abs(relative_coords_table) + 1.0) + / np.log2(8) + ) + + self.register_buffer("relative_coords_table", relative_coords_table) + + # get pair-wise relative position index for each token inside the window + coords_h = torch.arange(self.window_size[0]) + coords_w = torch.arange(self.window_size[1]) + coords = torch.stack(torch.meshgrid([coords_h, coords_w])) # 2, Wh, Ww + coords_flatten = torch.flatten(coords, 1) # 2, Wh*Ww + relative_coords = ( + coords_flatten[:, :, None] - coords_flatten[:, None, :] + ) # 2, Wh*Ww, Wh*Ww + relative_coords = relative_coords.permute( + 1, 2, 0 + ).contiguous() # Wh*Ww, Wh*Ww, 2 + relative_coords[:, :, 0] += self.window_size[0] - 1 # shift to start from 0 + relative_coords[:, :, 1] += self.window_size[1] - 1 + relative_coords[:, :, 0] *= 2 * self.window_size[1] - 1 + relative_position_index = relative_coords.sum(-1) # Wh*Ww, Wh*Ww + self.register_buffer("relative_position_index", relative_position_index) + + self.qkv = nn.Linear(dim, dim * 3, bias=False) + if qkv_bias: + self.q_bias = nn.Parameter(torch.zeros(dim)) + self.v_bias = nn.Parameter(torch.zeros(dim)) + else: + self.q_bias = None + self.v_bias = None + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + self.softmax = nn.Softmax(dim=-1) + + def forward(self, x, mask=None): + """ + Args: + x: input features with shape of (num_windows*B, N, C) + mask: (0/-inf) mask with shape of (num_windows, Wh*Ww, Wh*Ww) or None + """ + B_, N, C = x.shape + qkv_bias = None + if self.q_bias is not None: + qkv_bias = torch.cat( + ( + self.q_bias, + torch.zeros_like(self.v_bias, requires_grad=False), + self.v_bias, + ) + ) + qkv = F.linear(input=x, weight=self.qkv.weight, bias=qkv_bias) + qkv = qkv.reshape(B_, N, 3, self.num_heads, -1).permute(2, 0, 3, 1, 4) + q, k, v = ( + qkv[0], + qkv[1], + qkv[2], + ) # make torchscript happy (cannot use tensor as tuple) + + # cosine attention + attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1) + logit_scale = torch.clamp( + self.logit_scale, + max=torch.log(torch.tensor(1.0 / 0.01, device=self.logit_scale.device)), + ).exp() + attn = attn * logit_scale + + relative_position_bias_table = self.rpe_mlp(self.relative_coords_table).view( + -1, self.num_heads + ) + relative_position_bias = relative_position_bias_table[ + self.relative_position_index.view(-1) + ].view( + self.window_size[0] * self.window_size[1], + self.window_size[0] * self.window_size[1], + -1, + ) # Wh*Ww,Wh*Ww,nH + relative_position_bias = relative_position_bias.permute( + 2, 0, 1 + ).contiguous() # nH, Wh*Ww, Wh*Ww + relative_position_bias = 16 * torch.sigmoid(relative_position_bias) + attn = attn + relative_position_bias.unsqueeze(0) + + if mask is not None: + nW = mask.shape[0] + attn = attn.view(B_ // nW, nW, self.num_heads, N, N) + mask.unsqueeze( + 1 + ).unsqueeze(0) + attn = attn.view(-1, self.num_heads, N, N) + + attn = self.softmax(attn) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B_, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, window_size={self.window_size}, " + f"pretrained_window_size={self.pretrained_window_size}, num_heads={self.num_heads}" + ) + + def flops(self, N): + # calculate flops for 1 window with token length of N + flops = 0 + # qkv = self.qkv(x) + flops += N * self.dim * 3 * self.dim + # attn = (q @ k.transpose(-2, -1)) + flops += self.num_heads * N * (self.dim // self.num_heads) * N + # x = (attn @ v) + flops += self.num_heads * N * N * (self.dim // self.num_heads) + # x = self.proj(x) + flops += N * self.dim * self.dim + return flops + + +class SwinTransformerBlock(nn.Module): + r"""Swin Transformer Block. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resulotion. + num_heads (int): Number of attention heads. + window_size (int): Window size. + shift_size (int): Shift size for SW-MSA. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float, optional): Stochastic depth rate. Default: 0.0 + act_layer (nn.Module, optional): Activation layer. Default: nn.GELU + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + pretrained_window_size (int): Window size in pre-training. + """ + + def __init__( + self, + dim, + input_resolution, + num_heads, + window_size=7, + shift_size=0, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + act_layer=nn.GELU, + norm_layer=nn.LayerNorm, + pretrained_window_size=0, + ): + super().__init__() + self.dim = dim + self.num_heads = num_heads + self.window_size = window_size + self.shift_size = shift_size + self.mlp_ratio = mlp_ratio + + if input_resolution[0] <= self.window_size[0]: + self.shift_size[0] = 0 + self.window_size[0] = input_resolution[0] + if input_resolution[1] <= self.window_size[1]: + self.shift_size[1] = 0 + self.window_size[1] = input_resolution[1] + + assert ( + 0 <= self.shift_size[1] < self.window_size[1] + ), "shift_size must in 0-window_size" + assert ( + 0 <= self.shift_size[0] < self.window_size[0] + ), "shift_size must in 0-window_size" + self.norm1 = norm_layer(dim) + self.attn = WindowAttention( + dim, + window_size=self.window_size, + num_heads=num_heads, + qkv_bias=qkv_bias, + attn_drop=attn_drop, + proj_drop=drop, + pretrained_window_size=pretrained_window_size, + ) + + self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = Mlp( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + ) + + # if self.shift_size > 0: + # # calculate attention mask for SW-MSA + # H, W = self.input_resolution + # img_mask = torch.zeros((1, H, W, 1)) # 1 H W 1 + # h_slices = (slice(0, -self.window_size), + # slice(-self.window_size, -self.shift_size), + # slice(-self.shift_size, None)) + # w_slices = (slice(0, -self.window_size), + # slice(-self.window_size, -self.shift_size), + # slice(-self.shift_size, None)) + # cnt = 0 + # for h in h_slices: + # for w in w_slices: + # img_mask[:, h, w, :] = cnt + # cnt += 1 + + # mask_windows = window_partition(img_mask, self.window_size) # nW, window_size, window_size, 1 + # mask_windows = mask_windows.view(-1, self.window_size * self.window_size) + # attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + # attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill(attn_mask == 0, float(0.0)) + # else: + # attn_mask = None + + # self.register_buffer("attn_mask", attn_mask) + + def forward(self, x, mask_matrix): + H, W = self.H, self.W + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + + shortcut = x + x = x.view(B, H, W, C) + + # pad feature maps to multiples of window size + pad_l = pad_t = 0 + pad_r = (self.window_size[1] - W % self.window_size[1]) % self.window_size[1] + pad_b = (self.window_size[0] - H % self.window_size[0]) % self.window_size[0] + if pad_r > 0 or pad_b > 0: + x = F.pad(x, (0, 0, pad_l, pad_r, pad_t, pad_b)) + _, Hp, Wp, _ = x.shape + + # cyclic shift + if self.shift_size[0] > 0 or self.shift_size[1] > 0: + shifted_x = torch.roll( + x, shifts=(-self.shift_size[0], -self.shift_size[1]), dims=(1, 2) + ) + attn_mask = mask_matrix + else: + shifted_x = x + attn_mask = None + + # partition windows + x_windows = window_partition( + shifted_x, self.window_size + ) # nW*B, window_size, window_size, C + x_windows = x_windows.view( + -1, self.window_size[0] * self.window_size[1], C + ) # nW*B, window_size*window_size, C + + # W-MSA/SW-MSA + attn_windows = self.attn( + x_windows, mask=attn_mask + ) # nW*B, window_size*window_size, C + + # merge windows + attn_windows = attn_windows.view( + -1, self.window_size[0], self.window_size[1], C + ) + shifted_x = window_reverse(attn_windows, self.window_size, Hp, Wp) # B H' W' C + + # reverse cyclic shift + if self.shift_size[0] > 0 or self.shift_size[1] > 0: + x = torch.roll( + shifted_x, shifts=(self.shift_size[0], self.shift_size[1]), dims=(1, 2) + ) + else: + x = shifted_x + + if pad_r > 0 or pad_b > 0: + x = x[:, :H, :W, :].contiguous() + + x = x.view(B, H * W, C) + + x = shortcut + self.drop_path(self.norm1(x)) + + # FFN + x = x + self.drop_path(self.norm2(self.mlp(x))) + + return x + + def extra_repr(self) -> str: + return ( + f"dim={self.dim}, input_resolution={self.input_resolution}, num_heads={self.num_heads}, " + f"window_size={self.window_size}, shift_size={self.shift_size}, mlp_ratio={self.mlp_ratio}" + ) + + def flops(self): + flops = 0 + H, W = self.input_resolution + # norm1 + flops += self.dim * H * W + # W-MSA/SW-MSA + nW = H * W / self.window_size / self.window_size + flops += nW * self.attn.flops(self.window_size * self.window_size) + # mlp + flops += 2 * H * W * self.dim * self.dim * self.mlp_ratio + # norm2 + flops += self.dim * H * W + return flops + + +class PatchMerging(nn.Module): + r"""Patch Merging Layer. + + Args: + input_resolution (tuple[int]): Resolution of input feature. + dim (int): Number of input channels. + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + """ + + def __init__(self, dim, norm_layer=nn.LayerNorm): + super().__init__() + self.dim = dim + self.reduction = nn.Linear(4 * dim, 2 * dim, bias=False) + self.norm = norm_layer(2 * dim) + + def forward(self, x, H, W): + """ + x: B, H*W, C + """ + B, L, C = x.shape + assert L == H * W, "input feature has wrong size" + assert H % 2 == 0 and W % 2 == 0, f"x size ({H}*{W}) are not even." + + x = x.view(B, H, W, C) + + x0 = x[:, 0::2, 0::2, :] # B H/2 W/2 C + x1 = x[:, 1::2, 0::2, :] # B H/2 W/2 C + x2 = x[:, 0::2, 1::2, :] # B H/2 W/2 C + x3 = x[:, 1::2, 1::2, :] # B H/2 W/2 C + x = torch.cat([x0, x1, x2, x3], -1) # B H/2 W/2 4*C + x = x.view(B, -1, 4 * C) # B H/2*W/2 4*C + + x = self.reduction(x) + x = self.norm(x) + + return x + + def extra_repr(self) -> str: + return f"input_resolution={self.input_resolution}, dim={self.dim}" + + def flops(self): + H, W = self.input_resolution + flops = (H // 2) * (W // 2) * 4 * self.dim * 2 * self.dim + flops += H * W * self.dim // 2 + return flops + + +class BasicLayer(nn.Module): + """A basic Swin Transformer layer for one stage. + + Args: + dim (int): Number of input channels. + input_resolution (tuple[int]): Input resolution. + depth (int): Number of blocks. + num_heads (int): Number of attention heads. + window_size (int): Local window size. + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. + qkv_bias (bool, optional): If True, add a learnable bias to query, key, value. Default: True + drop (float, optional): Dropout rate. Default: 0.0 + attn_drop (float, optional): Attention dropout rate. Default: 0.0 + drop_path (float | tuple[float], optional): Stochastic depth rate. Default: 0.0 + norm_layer (nn.Module, optional): Normalization layer. Default: nn.LayerNorm + downsample (nn.Module | None, optional): Downsample layer at the end of the layer. Default: None + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False. + pretrained_window_size (int): Local window size in pre-training. + """ + + def __init__( + self, + dim, + input_resolution, + depth, + num_heads, + window_size, + use_shift=True, + mlp_ratio=4.0, + qkv_bias=True, + drop=0.0, + attn_drop=0.0, + drop_path=0.0, + norm_layer=nn.LayerNorm, + downsample=None, + use_checkpoint=False, + pretrained_window_size=0, + ): + + super().__init__() + self.dim = dim + self.depth = depth + self.use_checkpoint = use_checkpoint + self.window_size = list(to_2tuple(window_size)) + pretrained_window_size = list(to_2tuple(pretrained_window_size)) + self.shift_size = ( + [x // 2 for x in window_size] + if isinstance(window_size, (tuple, list)) + else window_size // 2 + ) + self.shift_size = list(to_2tuple(self.shift_size)) + + # build blocks + self.blocks = nn.ModuleList( + [ + SwinTransformerBlock( + dim=dim, + input_resolution=input_resolution, + num_heads=num_heads, + window_size=self.window_size, + shift_size=self.shift_size if (i % 2 and use_shift) else [0, 0], + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + drop=drop, + attn_drop=attn_drop, + drop_path=( + drop_path[i] if isinstance(drop_path, list) else drop_path + ), + norm_layer=norm_layer, + pretrained_window_size=pretrained_window_size, + ) + for i in range(depth) + ] + ) + + # patch merging layer + if downsample is not None: + self.downsample = downsample(dim=dim, norm_layer=norm_layer) + else: + self.downsample = None + + def forward(self, x, H, W): + # calculate attention mask for SW-MSA + Hp = int(np.ceil(H / self.window_size[0])) * self.window_size[0] + Wp = int(np.ceil(W / self.window_size[1])) * self.window_size[1] + img_mask = torch.zeros((1, Hp, Wp, 1), device=x.device) # 1 Hp Wp 1 + h_slices = ( + slice(0, -self.window_size[0]), + slice(-self.window_size[0], -self.shift_size[0]), + slice(-self.shift_size[0], None), + ) + w_slices = ( + slice(0, -self.window_size[1]), + slice(-self.window_size[1], -self.shift_size[1]), + slice(-self.shift_size[1], None), + ) + cnt = 0 + for h in h_slices: + for w in w_slices: + img_mask[:, h, w, :] = cnt + cnt += 1 + + mask_windows = window_partition( + img_mask, self.window_size + ) # nW, window_size, window_size, 1 + mask_windows = mask_windows.view(-1, self.window_size[0] * self.window_size[1]) + attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2) + attn_mask = attn_mask.masked_fill(attn_mask != 0, float(-100.0)).masked_fill( + attn_mask == 0, float(0.0) + ) + + x_outs, cls_tokens = [], [] + for blk in self.blocks: + blk.H, blk.W = H, W + if self.use_checkpoint: + x = checkpoint.checkpoint(blk, x, attn_mask) + else: + x = blk(x, attn_mask) + x_outs.append(x) + if self.downsample is not None: + x_down = self.downsample(x, H, W) + Wh, Ww = (H + 1) // 2, (W + 1) // 2 + return x_outs, H, W, x_down, Wh, Ww + else: + return x_outs, H, W, x, H, W + + def extra_repr(self) -> str: + return f"dim={self.dim}, input_resolution={self.input_resolution}, depth={self.depth}" + + def flops(self): + flops = 0 + for blk in self.blocks: + flops += blk.flops() + if self.downsample is not None: + flops += self.downsample.flops() + return flops + + def _init_respostnorm(self): + for blk in self.blocks: + nn.init.constant_(blk.norm1.bias, 0) + nn.init.constant_(blk.norm1.weight, 0) + nn.init.constant_(blk.norm2.bias, 0) + nn.init.constant_(blk.norm2.weight, 0) + + +class PatchEmbed(nn.Module): + r"""Image to Patch Embedding + + Args: + img_size (int): Image size. Default: 224. + patch_size (int): Patch token size. Default: 4. + in_chans (int): Number of input image channels. Default: 3. + embed_dim (int): Number of linear projection output channels. Default: 96. + norm_layer (nn.Module, optional): Normalization layer. Default: None + """ + + def __init__( + self, img_size=224, patch_size=4, in_chans=3, embed_dim=96, norm_layer=None + ): + super().__init__() + img_size = to_2tuple(img_size) + patch_size = to_2tuple(patch_size) + patches_resolution = [ + img_size[0] // patch_size[0], + img_size[1] // patch_size[1], + ] + self.img_size = img_size + self.patch_size = patch_size + self.patches_resolution = patches_resolution + self.num_patches = patches_resolution[0] * patches_resolution[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_size, stride=patch_size + ) + if norm_layer is not None: + self.norm = norm_layer(embed_dim) + else: + self.norm = None + + def forward(self, x): + """Forward function.""" + # padding + _, _, H, W = x.size() + if W % self.patch_size[1] != 0: + x = F.pad(x, (0, self.patch_size[1] - W % self.patch_size[1])) + if H % self.patch_size[0] != 0: + x = F.pad(x, (0, 0, 0, self.patch_size[0] - H % self.patch_size[0])) + + x = self.proj(x) # B C Wh Ww + if self.norm is not None: + Wh, Ww = x.size(2), x.size(3) + x = self.norm(x.flatten(2).transpose(1, 2)) + x = x.transpose(1, 2).view(-1, self.embed_dim, Wh, Ww) + + return x + + def flops(self): + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops + + +class SwinTransformerV2(nn.Module): + r"""Swin Transformer + A PyTorch impl of : `Swin Transformer: Hierarchical Vision Transformer using Shifted Windows` - + https://arxiv.org/pdf/2103.14030 + + Args: + img_size (int | tuple(int)): Input image size. Default 224 + patch_size (int | tuple(int)): Patch size. Default: 4 + in_chans (int): Number of input image channels. Default: 3 + num_classes (int): Number of classes for classification head. Default: 1000 + embed_dim (int): Patch embedding dimension. Default: 96 + depths (tuple(int)): Depth of each Swin Transformer layer. + num_heads (tuple(int)): Number of attention heads in different layers. + window_size (int): Window size. Default: 7 + mlp_ratio (float): Ratio of mlp hidden dim to embedding dim. Default: 4 + qkv_bias (bool): If True, add a learnable bias to query, key, value. Default: True + drop_rate (float): Dropout rate. Default: 0 + attn_drop_rate (float): Attention dropout rate. Default: 0 + drop_path_rate (float): Stochastic depth rate. Default: 0.1 + norm_layer (nn.Module): Normalization layer. Default: nn.LayerNorm. + ape (bool): If True, add absolute position embedding to the patch embedding. Default: False + patch_norm (bool): If True, add normalization after patch embedding. Default: True + use_checkpoint (bool): Whether to use checkpointing to save memory. Default: False + pretrained_window_sizes (tuple(int)): Pretrained window sizes of each layer. + """ + + def __init__( + self, + img_size=224, + patch_size=4, + in_chans=3, + embed_dim=96, + depths=[2, 2, 6, 2], + num_heads=[3, 6, 12, 24], + window_size=7, + mlp_ratio=4.0, + qkv_bias=True, + drop_rate=0.0, + attn_drop_rate=0.0, + drop_path_rate=0.1, + norm_layer=nn.LayerNorm, + ape=False, + patch_norm=True, + use_checkpoint=False, + use_shift=True, + pretrained_window_sizes=[0, 0, 0, 0], + pretrained=None, + frozen_stages=-1, + output_idx=[2, 4, 22, 24], + **kwargs, + ): + super().__init__() + self.num_layers = len(depths) + self.depths = output_idx + self.embed_dim = embed_dim + dims = [embed_dim * 2**i for i in range(len(depths))] + self.embed_dims = [ + int(dim) for i, dim in enumerate(dims) for _ in range(depths[i]) + ] + + self.ape = ape + self.patch_norm = patch_norm + self.num_features = [int(embed_dim * 2**i) for i in range(self.num_layers)] + self.mlp_ratio = mlp_ratio + self.frozen_stages = frozen_stages + if isinstance(window_size, int): + window_size = [window_size] * self.num_layers + if isinstance(use_shift, bool): + use_shift = [use_shift] * self.num_layers + + # self.mask_token = nn.Parameter(torch.zeros(1, 1, self.embed_dim)) + # trunc_normal_(self.mask_token, mean=0., std=.02) + # split image into non-overlapping patches + self.patch_embed = PatchEmbed( + img_size=img_size, + patch_size=patch_size, + in_chans=in_chans, + embed_dim=embed_dim, + norm_layer=norm_layer if self.patch_norm else None, + ) + num_patches = self.patch_embed.num_patches + patches_resolution = self.patch_embed.patches_resolution + self.patches_resolution = patches_resolution + + # absolute position embedding + if self.ape: + self.absolute_pos_embed = nn.Parameter( + torch.zeros(1, num_patches, embed_dim) + ) + trunc_normal_(self.absolute_pos_embed, std=0.02) + + self.pos_drop = nn.Dropout(p=drop_rate) + + # stochastic depth + dpr = [ + x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) + ] # stochastic depth decay rule + + # build layers + self.layers = nn.ModuleList() + for i_layer in range(self.num_layers): + layer = BasicLayer( + dim=int(embed_dim * 2**i_layer), + input_resolution=[ + img_size[0] // (2 ** (2 + i_layer)), + img_size[1] // (2 ** (2 + i_layer)), + ], + depth=depths[i_layer], + num_heads=num_heads[i_layer], + window_size=window_size[i_layer], + use_shift=use_shift[i_layer], + mlp_ratio=self.mlp_ratio, + qkv_bias=qkv_bias, + drop=drop_rate, + attn_drop=attn_drop_rate, + drop_path=dpr[sum(depths[:i_layer]) : sum(depths[: i_layer + 1])], + norm_layer=norm_layer, + downsample=PatchMerging if (i_layer < self.num_layers - 1) else None, + use_checkpoint=use_checkpoint, + pretrained_window_size=pretrained_window_sizes[i_layer], + ) + self.layers.append(layer) + + self.apply(self._init_weights) + for bly in self.layers: + bly._init_respostnorm() + + if pretrained is not None: + pretrained_state = torch.load(pretrained, map_location="cpu")["model"] + pretrained_state_filtered = load_checkpoint_swin(self, pretrained_state) + msg = self.load_state_dict(pretrained_state_filtered, strict=False) + + self._freeze_stages() + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + nn.init.constant_(m.bias, 0) + nn.init.constant_(m.weight, 1.0) + + @torch.jit.ignore + def no_weight_decay(self): + return {"absolute_pos_embed"} + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"rpe_mlp", "logit_scale", "relative_position_bias_table", "mask_token"} + + def forward(self, x, mask=None): + """Forward function.""" + # Add requires_grad_() to all input to support freezing with gradient checkpointing! + x = self.patch_embed(x.requires_grad_()) + B, Wh, Ww = x.size(0), x.size(2), x.size(3) + if self.ape: + # interpolate the position embedding to the corresponding size + absolute_pos_embed = F.interpolate( + self.absolute_pos_embed, + size=(Wh, Ww), + mode="bicubic", + align_corners=True, + ) + x = (x + absolute_pos_embed).flatten(2).transpose(1, 2) # B Wh*Ww C + else: + x = x.flatten(2).transpose(1, 2) + x = self.pos_drop(x) + + # B, L, _ = x.shape + # if mask is not None: + # mask_tokens = self.mask_token.expand(B, L, -1) + # mask = mask.flatten(1).unsqueeze(-1).type_as(mask_tokens) + # else: + # mask = torch.zeros_like(x) + # mask_tokens = torch.zeros_like(self.mask_token).expand(B, L, -1) + # x = x * (1. - mask) + mask_tokens * mask + + outs, cls_tokens = [], [] + for i in range(self.num_layers): + layer = self.layers[i] + x_outs, H, W, x, Wh, Ww = layer(x.requires_grad_(), Wh, Ww) + out = [ + x_out.view(-1, H, W, self.num_features[i]).contiguous() + for x_out in x_outs + ] + outs.extend(out) + cls_token_ = [x.mean(dim=(1, 2)).unsqueeze(1).contiguous() for x in out] + cls_tokens.extend(cls_token_) + return outs, cls_tokens + + def train(self, mode=True): + super().train(mode) + self._freeze_stages() + + def freeze(self) -> None: + for module in self.modules(): + module.eval() + for parameters in self.parameters(): + parameters.requires_grad = False + + def _freeze_stages(self): + if self.frozen_stages >= 0: + self.patch_embed.eval() + for param in self.patch_embed.parameters(): + param.requires_grad = False + if self.ape: + self.absolute_pos_embed.requires_grad = False + self.pos_drop.eval() + + for i in range(1, self.frozen_stages + 1): + m = self.layers[i - 1] + m.eval() + for param in m.parameters(): + param.requires_grad = False + + def flops(self): + flops = 0 + flops += self.patch_embed.flops() + for i, layer in enumerate(self.layers): + flops += layer.flops() + flops += ( + self.num_features + * self.patches_resolution[0] + * self.patches_resolution[1] + // (2**self.num_layers) + ) + return flops + + def get_params(self, lr, wd, *args, **kwargs): + encoder_p, encoder_lr = get_params(self, lr, wd) + return encoder_p, encoder_lr + + @classmethod + def build(cls, config): + obj = globals()[config["name"]](config) + return obj diff --git a/unik3d/models/camera_augmenter.py b/unik3d/models/camera_augmenter.py new file mode 100644 index 0000000000000000000000000000000000000000..fd088daaac1ea7810da3aa388e481b4be02e1e06 --- /dev/null +++ b/unik3d/models/camera_augmenter.py @@ -0,0 +1,145 @@ +import numpy as np +import torch +from einops import rearrange + +from unik3d.utils.camera import CameraSampler +from unik3d.utils.coordinate import coords_grid +from unik3d.utils.geometric import iou + +try: + from splatting import splatting_function +except Exception as e: + splatting_function = None + print( + f"Splatting not available, please install it from github.com/hperrot/splatting" + ) + + +def fill(self, rgb, mask): + def fill_noise(size, device): + return torch.normal(0, 1.0, size=size, device=device) + + def fill_black(size, device): + return -2 * torch.ones(size, device=device, dtype=torch.float32) + + def fill_white(size, device): + return 2 * torch.ones(size, device=device, dtype=torch.float32) + + def fill_zero(size, device): + return torch.zeros(size, device=device, dtype=torch.float32) + + B, C = rgb.shape[:2] + validity_mask = mask.repeat(1, C, 1, 1).bool() + for i in range(B): + filler_fn = np.random.choice([fill_noise, fill_black, fill_white, fill_zero]) + rgb[i][~validity_mask[i]] = filler_fn( + size=rgb[i][~validity_mask[i]].shape, device=rgb.device + ) + return rgb + + +@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float32) +def augment_camera(self, inputs, camera_sampler): + rgb = inputs["image"] + gt = inputs["depth"].clone() + guidance = inputs[ + "depth_guidance" + ] # from GT if dense/synthetic or from a model's metric output + validity_mask = inputs["validity_mask"].bool() + dtype, device = gt.dtype, gt.device + B, C, H, W = rgb.shape + augmentable_indices = inputs["valid_camera"] & ( + inputs["depth_mask"].reshape(B, -1).float().mean(dim=1) > 0.0 + ) + + augment_indices = torch.rand(B, 1, 1, device=device, dtype=dtype) > 0.9 + augment_indices[~augmentable_indices] = False + id_coords = coords_grid(B, H, W, device=device) + # get rescaled depth + augment_indices = augment_indices.reshape(-1) + for i, is_augment in enumerate(augment_indices): + if not is_augment: + continue + + pinhole_camera = inputs["camera"][i] + fov = max(pinhole_camera.hfov[0], pinhole_camera.vfov[0]) * 180 / np.pi + ratio = min(70.0 / fov, 1.0) # decrease effect for larger fov + if fov < 40.0: # skips ~5% + augment_indices[i] = False + continue + + rgb_i = rgb[i : i + 1] + id_coords_i = id_coords[i : i + 1] + + validity_mask_i = validity_mask[i : i + 1] + depth = guidance[i : i + 1] + + if (depth < 0.0).any(): + augment_indices[i] = False + continue + + depth = depth.sqrt() # why sqrt?? + depth[~validity_mask_i] = depth.max() * 2.0 + + fx, fy, cx, cy = pinhole_camera.params[:, :4].unbind(dim=-1) + new_camera = camera_sampler(fx, fy, cx, cy, mult=1.0, ratio=ratio, H=H) + unprojected = pinhole_camera.reconstruct(depth) + projected = new_camera.project(unprojected) + projection_mask = new_camera.projection_mask + overlap_mask = ( + new_camera.overlap_mask + if new_camera.overlap_mask is not None + else torch.ones_like(projection_mask) + ) + mask = validity_mask_i & overlap_mask + + # if it is actually going out, we need to remember the regions + # remember when the tengetial distortion was keeping the validaty_mask border after re-warpingi + # need a better way to define overlap class, in case of vortex style if will mask wrong parts... + # also is_collapse does not take into consideration when we have vortex effect, + # how can we avoid vortex in the first place???? + is_collapse = (projected[0, 1, 0, :] >= 0.0).all() + if is_collapse: + projected[~mask.repeat(1, 2, 1, 1)] = id_coords_i[~mask.repeat(1, 2, 1, 1)] + flow = projected - id_coords_i + depth[~mask] = depth.max() * 2.0 + + if flow.norm(dim=1).median() / max(H, W) > 0.1: # extreme cases + augment_indices[i] = False + continue + + # warp via soft splat + depth_image = torch.cat([rgb_i, guidance[i : i + 1], mask], dim=1) + depth_image = splatting_function( + "softmax", depth_image, flow, -torch.log(1 + depth.clip(0.01)) + ) + rgb_warp = depth_image[:, :3] + validity_mask_i = depth_image[:, -1:] > 0.0 + + expanding = validity_mask_i.sum() > validity_mask[i : i + 1].sum() + threshold = 0.7 if expanding else 0.25 + _iou = iou(validity_mask_i, validity_mask[i : i + 1]) + if _iou < threshold: # too strong augmentation, lose most of the image + augment_indices[i] = False + continue + + # where it goes out + mask_unwarpable = projection_mask & overlap_mask + inputs["depth_mask"][i] = inputs["depth_mask"][i] & mask_unwarpable.squeeze(0) + + # compute new rays, and use the for supervision + rays = new_camera.get_rays(shapes=(1, H, W)) + rays = rearrange(rays, "b c h w -> b (h w) c") + inputs["rays"][i] = torch.where( + rays.isnan().any(dim=-1, keepdim=True), 0.0, rays + )[0] + + # update image, camera and validity_mask + inputs["camera"][i] = new_camera + inputs["image"][i] = self.fill(rgb_warp, validity_mask_i)[0] + inputs["validity_mask"][i] = inputs["validity_mask"][i] & mask_unwarpable[0] + + # needed to reverse the augmentation for loss-computation (i.e. un-warp the prediction) + inputs["grid_sample"][i] = projected[0] + + return inputs diff --git a/unik3d/models/decoder.py b/unik3d/models/decoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d5f8ac3e77e987914945f61e60e57b6a2028243f --- /dev/null +++ b/unik3d/models/decoder.py @@ -0,0 +1,562 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from math import tanh + +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from timm.models.layers import trunc_normal_ + +from unik3d.layers import (MLP, AttentionBlock, AttentionLayer, GradChoker, + PositionEmbeddingSine, ResUpsampleBil) +from unik3d.utils.coordinate import coords_grid +from unik3d.utils.geometric import flat_interpolate +from unik3d.utils.misc import get_params +from unik3d.utils.positional_embedding import generate_fourier_features +from unik3d.utils.sht import rsh_cart_3 + + +def orthonormal_init(num_tokens, dims): + pe = torch.randn(num_tokens, dims) + + # Apply Gram-Schmidt process to make the matrix orthonormal + # Awful loop.. + for i in range(num_tokens): + for j in range(i): + pe[i] -= torch.dot(pe[i], pe[j]) * pe[j] + pe[i] = F.normalize(pe[i], p=2, dim=0) + + return pe + + +class ListAdapter(nn.Module): + def __init__(self, input_dims: list[int], hidden_dim: int): + super().__init__() + self.input_adapters = nn.ModuleList([]) + self.num_chunks = len(input_dims) + self.checkpoint = True + for input_dim in input_dims: + self.input_adapters.append(nn.Linear(input_dim, hidden_dim)) + + def forward(self, xs: torch.Tensor) -> torch.Tensor: + outs = [self.input_adapters[i](x) for i, x in enumerate(xs)] + return outs + + +class AngularModule(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + dropout: float = 0.0, + layer_scale: float = 1.0, + **kwargs, + ): + super().__init__() + self.pin_params = 3 + self.deg1_params = 3 + self.deg2_params = 5 + self.deg3_params = 7 + self.num_params = ( + self.pin_params + self.deg1_params + self.deg2_params + self.deg3_params + ) + + self.aggregate1 = AttentionBlock( + hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.aggregate2 = AttentionBlock( + hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.latents_pos = nn.Parameter( + torch.randn(1, self.num_params, hidden_dim), requires_grad=True + ) + self.in_features = nn.Identity() + + self.project_pin = nn.Linear( + hidden_dim, self.pin_params * hidden_dim, bias=False + ) + self.project_deg1 = nn.Linear( + hidden_dim, self.deg1_params * hidden_dim, bias=False + ) + self.project_deg2 = nn.Linear( + hidden_dim, self.deg2_params * hidden_dim, bias=False + ) + self.project_deg3 = nn.Linear( + hidden_dim, self.deg3_params * hidden_dim, bias=False + ) + + self.out_pinhole = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=1) + self.out_deg1 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) + self.out_deg2 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) + self.out_deg3 = MLP(hidden_dim, expansion=1, dropout=dropout, output_dim=3) + + def fill_intrinsics(self, x): + hfov, cx, cy = x.unbind(dim=-1) + hfov = torch.sigmoid(hfov - 1.1) # 1.1 magic number s.t hfov = pi/2 for x=0 + ratio = self.shapes[0] / self.shapes[1] + vfov = hfov * ratio + cx = torch.sigmoid(cx) + cy = torch.sigmoid(cy) + correction_tensor = torch.tensor( + [2 * torch.pi, 2 * torch.pi, self.shapes[1], self.shapes[0]], + device=x.device, + dtype=x.dtype, + ) + + intrinsics = torch.stack([hfov, vfov, cx, cy], dim=1) + intrinsics = correction_tensor.unsqueeze(0) * intrinsics + return intrinsics + + def forward(self, cls_tokens) -> torch.Tensor: + latents_pos = self.latents_pos.expand(cls_tokens.shape[0], -1, -1) + + pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens = cls_tokens.chunk(4, dim=1) + pin_tokens = rearrange( + self.project_pin(pin_tokens), "b n (h c) -> b (n h) c", h=self.pin_params + ) + deg1_tokens = rearrange( + self.project_deg1(deg1_tokens), "b n (h c) -> b (n h) c", h=self.deg1_params + ) + deg2_tokens = rearrange( + self.project_deg2(deg2_tokens), "b n (h c) -> b (n h) c", h=self.deg2_params + ) + deg3_tokens = rearrange( + self.project_deg3(deg3_tokens), "b n (h c) -> b (n h) c", h=self.deg3_params + ) + tokens = torch.cat([pin_tokens, deg1_tokens, deg2_tokens, deg3_tokens], dim=1) + + tokens = self.aggregate1(tokens, pos_embed=latents_pos) + tokens = self.aggregate2(tokens, pos_embed=latents_pos) + + tokens_pinhole, tokens_deg1, tokens_deg2, tokens_deg3 = torch.split( + tokens, + [self.pin_params, self.deg1_params, self.deg2_params, self.deg3_params], + dim=1, + ) + x = self.out_pinhole(tokens_pinhole).squeeze(-1) + d1 = self.out_deg1(tokens_deg1) + d2 = self.out_deg2(tokens_deg2) + d3 = self.out_deg3(tokens_deg3) + + camera_intrinsics = self.fill_intrinsics(x) + return camera_intrinsics, torch.cat([d1, d2, d3], dim=1) + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + +class RadialModule(nn.Module): + def __init__( + self, + hidden_dim: int, + num_heads: int = 8, + expansion: int = 4, + depths: int | list[int] = 4, + camera_dim: int = 256, + dropout: float = 0.0, + kernel_size: int = 7, + layer_scale: float = 1.0, + out_dim: int = 1, + num_prompt_blocks: int = 1, + use_norm: bool = False, + **kwargs, + ) -> None: + super().__init__() + self.camera_dim = camera_dim + self.out_dim = out_dim + self.hidden_dim = hidden_dim + + self.ups = nn.ModuleList([]) + self.depth_mlp = nn.ModuleList([]) + self.process_features = nn.ModuleList([]) + self.project_features = nn.ModuleList([]) + self.out = nn.ModuleList([]) + self.prompt_camera = nn.ModuleList([]) + mult = 2 + self.to_latents = nn.Linear(hidden_dim, hidden_dim) + + for _ in range(4): + self.prompt_camera.append( + AttentionLayer( + num_blocks=num_prompt_blocks, + dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=-1.0, + context_dim=hidden_dim, + ) + ) + + for i, depth in enumerate(depths): + current_dim = min(hidden_dim, mult * hidden_dim // int(2**i)) + next_dim = mult * hidden_dim // int(2 ** (i + 1)) + output_dim = max(next_dim, out_dim) + self.process_features.append( + nn.ConvTranspose2d( + hidden_dim, + current_dim, + kernel_size=max(1, 2 * i), + stride=max(1, 2 * i), + padding=0, + ) + ) + self.ups.append( + ResUpsampleBil( + current_dim, + output_dim=output_dim, + expansion=expansion, + layer_scale=layer_scale, + kernel_size=kernel_size, + num_layers=depth, + use_norm=use_norm, + ) + ) + depth_mlp = ( + nn.Sequential(nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim)) + if i == len(depths) - 1 + else nn.Identity() + ) + self.depth_mlp.append(depth_mlp) + + self.confidence_mlp = nn.Sequential( + nn.LayerNorm(next_dim), nn.Linear(next_dim, output_dim) + ) + + self.to_depth_lr = nn.Conv2d( + output_dim, + output_dim // 2, + kernel_size=3, + padding=1, + padding_mode="reflect", + ) + self.to_confidence_lr = nn.Conv2d( + output_dim, + output_dim // 2, + kernel_size=3, + padding=1, + padding_mode="reflect", + ) + self.to_depth_hr = nn.Sequential( + nn.Conv2d( + output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" + ), + nn.LeakyReLU(), + nn.Conv2d(32, 1, kernel_size=1), + ) + self.to_confidence_hr = nn.Sequential( + nn.Conv2d( + output_dim // 2, 32, kernel_size=3, padding=1, padding_mode="reflect" + ), + nn.LeakyReLU(), + nn.Conv2d(32, 1, kernel_size=1), + ) + + def set_original_shapes(self, shapes: tuple[int, int]): + self.original_shapes = shapes + + def set_shapes(self, shapes: tuple[int, int]): + self.shapes = shapes + + def embed_rays(self, rays): + rays_embedding = flat_interpolate( + rays, old=self.original_shapes, new=self.shapes, antialias=True + ) + rays_embedding = rays_embedding / torch.norm( + rays_embedding, dim=-1, keepdim=True + ).clip(min=1e-4) + x, y, z = rays_embedding[..., 0], rays_embedding[..., 1], rays_embedding[..., 2] + polar = torch.acos(z) + x_clipped = x.abs().clip(min=1e-3) * (2 * (x >= 0).int() - 1) + azimuth = torch.atan2(y, x_clipped) + rays_embedding = torch.stack([polar, azimuth], dim=-1) + rays_embedding = generate_fourier_features( + rays_embedding, + dim=self.hidden_dim, + max_freq=max(self.shapes) // 2, + use_log=True, + cat_orig=False, + ) + return rays_embedding + + def condition(self, feat, rays_embeddings): + conditioned_features = [ + prompter(rearrange(feature, "b h w c -> b (h w) c"), rays_embeddings) + for prompter, feature in zip(self.prompt_camera, feat) + ] + return conditioned_features + + def process(self, features_list, rays_embeddings): + conditioned_features = self.condition(features_list, rays_embeddings) + init_latents = self.to_latents(conditioned_features[0]) + init_latents = rearrange( + init_latents, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] + ).contiguous() + conditioned_features = [ + rearrange( + x, "b (h w) c -> b c h w", h=self.shapes[0], w=self.shapes[1] + ).contiguous() + for x in conditioned_features + ] + latents = init_latents + + out_features = [] + for i, up in enumerate(self.ups): + latents = latents + self.process_features[i](conditioned_features[i + 1]) + latents = up(latents) + out_features.append(latents) + + return out_features, init_latents + + def depth_proj(self, out_features): + depths = [] + h_out, w_out = out_features[-1].shape[-2:] + # aggregate output and project to depth + for i, (layer, features) in enumerate(zip(self.depth_mlp, out_features)): + out_depth_features = layer(features.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) + if i < len(self.depth_mlp) - 1: + continue + depths.append(out_depth_features) + out_depth_features = F.interpolate( + out_depth_features, size=(h_out, w_out), mode="bilinear", align_corners=True + ) + logdepth = self.to_depth_lr(out_depth_features) + logdepth = F.interpolate( + logdepth, size=self.original_shapes, mode="bilinear", align_corners=True + ) + logdepth = self.to_depth_hr(logdepth) + return logdepth + + def confidence_proj(self, out_features): + highres_features = out_features[-1].permute(0, 2, 3, 1) + confidence = self.confidence_mlp(highres_features).permute(0, 3, 1, 2) + confidence = self.to_confidence_lr(confidence) + confidence = F.interpolate( + confidence, size=self.original_shapes, mode="bilinear", align_corners=True + ) + confidence = self.to_confidence_hr(confidence) + return confidence + + def decode(self, out_features): + logdepth = self.depth_proj(out_features) + confidence = self.confidence_proj(out_features) + return logdepth, confidence + + def forward( + self, + features: list[torch.Tensor], + rays_hr: torch.Tensor, + pos_embed, + level_embed, + ) -> torch.Tensor: + rays_embeddings = self.embed_rays(rays_hr) + features, lowres_features = self.process(features, rays_embeddings) + logdepth, logconf = self.decode(features) + return logdepth, logconf, lowres_features + + +class Decoder(nn.Module): + def __init__( + self, + config, + ): + super().__init__() + self.build(config) + self.apply(self._init_weights) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.Conv2d): + trunc_normal_(m.weight, std=0.02) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + elif isinstance(m, nn.LayerNorm): + if m.bias is not None: + nn.init.constant_(m.bias, 0) + if m.weight is not None: + nn.init.constant_(m.weight, 1.0) + + def run_camera(self, cls_tokens, original_shapes, rays_gt): + H, W = original_shapes + + # camera layer + intrinsics, sh_coeffs = self.angular_module(cls_tokens=cls_tokens) + B, N = intrinsics.shape + device = intrinsics.device + dtype = intrinsics.dtype + + id_coords = coords_grid(B, H, W, device=sh_coeffs.device) + + # This is fov based + longitude = ( + (id_coords[:, 0] - intrinsics[:, 2].view(-1, 1, 1)) + / W + * intrinsics[:, 0].view(-1, 1, 1) + ) + latitude = ( + (id_coords[:, 1] - intrinsics[:, 3].view(-1, 1, 1)) + / H + * intrinsics[:, 1].view(-1, 1, 1) + ) + x = torch.cos(latitude) * torch.sin(longitude) + z = torch.cos(latitude) * torch.cos(longitude) + y = -torch.sin(latitude) + unit_sphere = torch.stack([x, y, z], dim=-1) + unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=-1, keepdim=True).clip( + min=1e-5 + ) + + harmonics = rsh_cart_3(unit_sphere)[..., 1:] # remove constant-value harmonic + rays_pred = torch.einsum("bhwc,bcd->bhwd", harmonics, sh_coeffs) + rays_pred = rays_pred / torch.norm(rays_pred, dim=-1, keepdim=True).clip( + min=1e-5 + ) + rays_pred = rays_pred.permute(0, 3, 1, 2) + + ### LEGACY CODE for training + # if self.training: + # prob = 1 - tanh(self.steps / self.num_steps) + # where_use_gt_rays = torch.rand(B, 1, 1, device=device, dtype=dtype) < prob + # where_use_gt_rays = where_use_gt_rays.int() + # rays = rays_gt * where_use_gt_rays + rays_pred * (1 - where_use_gt_rays) + + # should clean also nans + if self.training: + rays = rays_pred + else: + rays = rays_gt if rays_gt is not None else rays_pred + rays = rearrange(rays, "b c h w -> b (h w) c") + + return intrinsics, rays + + def forward(self, inputs, image_metas) -> torch.Tensor: + B, C, H, W = inputs["image"].shape + device = inputs["image"].device + + rays_gt = inputs.get("rays", None) + + # get features in b n d format + common_shape = inputs["features"][0].shape[1:3] + + # input shapes repeat shapes for each level, times the amount of the layers: + features = self.input_adapter(inputs["features"]) + + # positional embeddings, spatial and level + level_embed = self.level_embeds.repeat( + B, common_shape[0] * common_shape[1], 1, 1 + ) + level_embed = rearrange(level_embed, "b n l d -> b (n l) d") + dummy_tensor = torch.zeros( + B, 1, common_shape[0], common_shape[1], device=device, requires_grad=False + ) + pos_embed = self.pos_embed(dummy_tensor) + pos_embed = rearrange(pos_embed, "b c h w -> b (h w) c").repeat(1, 4, 1) + + # get cls tokens projections + camera_tokens = inputs["tokens"] + camera_tokens = [self.choker(x.contiguous()) for x in camera_tokens] + camera_tokens = self.camera_token_adapter(camera_tokens) + self.angular_module.set_shapes((H, W)) + + intrinsics, rays = self.run_camera( + torch.cat(camera_tokens, dim=1), + original_shapes=(H, W), + rays_gt=rays_gt, + ) + + # run bulk of the model + self.radial_module.set_shapes(common_shape) + self.radial_module.set_original_shapes((H, W)) + logradius, logconfidence, lowres_features = self.radial_module( + features=features, + rays_hr=rays, + pos_embed=pos_embed, + level_embed=level_embed, + ) + radius = torch.exp(logradius.clip(min=-8.0, max=8.0) + 2.0) + confidence = torch.exp(logconfidence.clip(min=-8.0, max=10.0)) + + outputs = { + "distance": radius, + "lowres_features": lowres_features, + "confidence": confidence, + "K": intrinsics, + "rays": rays, + } + + return outputs + + @torch.jit.ignore + def no_weight_decay_keywords(self): + return {"latents_pos", "level_embeds"} + + def get_params(self, lr, wd): + angles_p, _ = get_params(self.angular_module, lr, wd) + radius_p, _ = get_params(self.radial_module, lr, wd) + tokens_p, _ = get_params(self.camera_token_adapter, lr, wd) + input_p, _ = get_params(self.input_adapter, lr, wd) + return [*tokens_p, *angles_p, *input_p, *radius_p] + + def build(self, config): + input_dims = config["model"]["pixel_encoder"]["embed_dims"] + hidden_dim = config["model"]["pixel_decoder"]["hidden_dim"] + expansion = config["model"]["expansion"] + num_heads = config["model"]["num_heads"] + dropout = config["model"]["pixel_decoder"]["dropout"] + layer_scale = config["model"]["layer_scale"] + depth = config["model"]["pixel_decoder"]["depths"] + depths_encoder = config["model"]["pixel_encoder"]["depths"] + out_dim = config["model"]["pixel_decoder"]["out_dim"] + kernel_size = config["model"]["pixel_decoder"]["kernel_size"] + self.slices_encoder = list(zip([d - 1 for d in depths_encoder], depths_encoder)) + input_dims = [input_dims[d - 1] for d in depths_encoder] + self.steps = 0 + self.num_steps = config["model"].get("num_steps", 100000) + + camera_dims = input_dims + self.choker = GradChoker(config["model"]["pixel_decoder"]["detach"]) + self.input_adapter = ListAdapter(input_dims, hidden_dim) + self.camera_token_adapter = ListAdapter(camera_dims, hidden_dim) + self.angular_module = AngularModule( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + dropout=dropout, + layer_scale=layer_scale, + ) + self.radial_module = RadialModule( + hidden_dim=hidden_dim, + num_heads=num_heads, + expansion=expansion, + depths=depth, + dropout=dropout, + camera_dim=96, + layer_scale=layer_scale, + out_dim=out_dim, + kernel_size=kernel_size, + num_prompt_blocks=config["model"]["pixel_decoder"]["num_prompt_blocks"], + use_norm=False, + ) + self.pos_embed = PositionEmbeddingSine(hidden_dim // 2, normalize=True) + self.level_embeds = nn.Parameter( + orthonormal_init(len(input_dims), hidden_dim).reshape( + 1, 1, len(input_dims), hidden_dim + ), + requires_grad=False, + ) diff --git a/unik3d/models/encoder.py b/unik3d/models/encoder.py new file mode 100644 index 0000000000000000000000000000000000000000..d3e8c4551983ec6566797645bed92927669cce60 --- /dev/null +++ b/unik3d/models/encoder.py @@ -0,0 +1,231 @@ +from functools import partial + +import torch +import torch.nn as nn +from timm.models.vision_transformer import _cfg + +from unik3d.models.backbones import (ConvNeXt, ConvNeXtV2, SwinTransformerV2, + _make_dinov2_model) + + +def swin2_tiny( + config, + pretrained=None, + *args, + **kwargs, +): + model = SwinTransformerV2( + img_size=config["image_shape"], + patch_size=4, + window_size=config.get("window_size", 16), + embed_dim=96, + num_heads=[3, 6, 12, 24], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 6, 2], + drop_path_rate=0.2, + pretrained=pretrained, + pretrained_window_sizes=[12, 12, 12, 6], + output_idx=config.get("output_idx", [2, 4, 10, 12]), + use_shift=config.get("use_shift", True), + use_checkpoint=config.get("use_checkpoint", False), + frozen_stages=-1, + ) + model.default_cfg = _cfg() + return model + + +def swin2_base( + config, + pretrained=None, + *args, + **kwargs, +): + model = SwinTransformerV2( + img_size=config["image_shape"], + patch_size=4, + window_size=config.get("window_size", 12), + embed_dim=128, + num_heads=[4, 8, 16, 32], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 18, 2], + drop_path_rate=0.3, + pretrained=pretrained, + pretrained_window_sizes=[12, 12, 12, 6], + use_shift=config.get("use_shift", True), + use_checkpoint=config["use_checkpoint"], + frozen_stages=-1, + ) + model.default_cfg = _cfg() + return model + + +def swin2_large( + config, + pretrained=None, + *args, + **kwargs, +): + model = SwinTransformerV2( + img_size=config["image_shape"], + patch_size=4, + window_size=config.get("window_size", 12), + embed_dim=192, + num_heads=[6, 12, 24, 48], + mlp_ratios=[4, 4, 4, 4], + qkv_bias=True, + norm_layer=partial(nn.LayerNorm, eps=1e-6), + depths=[2, 2, 18, 2], + drop_path_rate=0.3, + pretrained=pretrained, + pretrained_window_sizes=[12, 12, 12, 6], + use_shift=config.get("use_shift", True), + use_checkpoint=config["use_checkpoint"], + frozen_stages=-1, + ) + model.default_cfg = _cfg() + return model + + +def convnextv2_base(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[128, 256, 512, 1024], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config["use_checkpoint"], + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_base_22k_384_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_large(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config["use_checkpoint"], + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/im22k/convnextv2_large_22k_384_ema.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnextv2_large_mae(config, **kwargs): + model = ConvNeXtV2( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config["use_checkpoint"], + **kwargs, + ) + url = "https://dl.fbaipublicfiles.com/convnext/convnextv2/pt_only/convnextv2_large_1k_224_fcmae.pt" + state_dict = torch.hub.load_state_dict_from_url( + url, map_location="cpu", progress=False + )["model"] + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def convnext_large(config, **kwargs): + model = ConvNeXt( + depths=[3, 3, 27, 3], + dims=[192, 384, 768, 1536], + output_idx=config.get("output_idx", [3, 6, 33, 36]), + use_checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + **kwargs, + ) + from huggingface_hub import hf_hub_download + from huggingface_hub.utils import disable_progress_bars + + from unik3d.models.backbones.convnext import HF_URL, checkpoint_filter_fn + + disable_progress_bars() + repo_id, filename = HF_URL["convnext_large"] + state_dict = torch.load(hf_hub_download(repo_id=repo_id, filename=filename)) + state_dict = checkpoint_filter_fn(state_dict, model) + info = model.load_state_dict(state_dict, strict=False) + print(info) + return model + + +def dinov2_vits14(config, pretrained: bool = True, **kwargs): + vit = _make_dinov2_model( + arch_name="vit_small", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [3, 6, 9, 12]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + frozen_stages=config.get("frozen_stages", 0), + freeze_norm=config.get("freeze_norm", False), + **kwargs, + ) + return vit + + +def dinov2_vitb14(config, pretrained: bool = True, **kwargs): + vit = _make_dinov2_model( + arch_name="vit_base", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [3, 6, 9, 12]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + frozen_stages=config.get("frozen_stages", 0), + freeze_norm=config.get("freeze_norm", False), + **kwargs, + ) + return vit + + +def dinov2_vitl14(config, pretrained: str = "", **kwargs): + vit = _make_dinov2_model( + arch_name="vit_large", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [5, 12, 18, 24]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + frozen_stages=config.get("frozen_stages", 0), + freeze_norm=config.get("freeze_norm", False), + **kwargs, + ) + return vit + + +def dinov2_vitg14(config, pretrained: str = "", **kwargs): + vit = _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + pretrained=config["pretrained"], + output_idx=config.get("output_idx", [10, 20, 30, 40]), + checkpoint=config.get("use_checkpoint", False), + drop_path_rate=config.get("drop_path", 0.0), + num_register_tokens=config.get("num_register_tokens", 0), + use_norm=config.get("use_norm", False), + interpolate_offset=config.get("interpolate_offset", 0.0), + **kwargs, + ) + return vit diff --git a/unik3d/models/export.py b/unik3d/models/export.py new file mode 100644 index 0000000000000000000000000000000000000000..65a8704816f30850307342f10442b532726c949a --- /dev/null +++ b/unik3d/models/export.py @@ -0,0 +1,165 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import argparse +import json +import os +from math import ceil + +import huggingface_hub +import torch.nn.functional as F +import torch.onnx + +from unik3d.models.unik3d import UniK3D + + +class UniK3DONNX(UniK3D): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super().__init__(config, eps) + + def forward(self, rgbs): + B, _, H, W = rgbs.shape + features, tokens = self.pixel_encoder(rgbs) + + inputs = {} + inputs["image"] = rgbs + inputs["features"] = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + inputs["tokens"] = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + outputs = self.pixel_decoder(inputs, []) + outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) + pts_3d = outputs["rays"] * outputs["radius"] + + return pts_3d, outputs["confidence"] + + +class UniK3DONNXcam(UniK3D): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super().__init__(config, eps) + + def forward(self, rgbs, rays): + B, _, H, W = rgbs.shape + features, tokens = self.pixel_encoder(rgbs) + + inputs = {} + inputs["image"] = rgbs + inputs["rays"] = rays + inputs["features"] = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + inputs["tokens"] = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + outputs = self.pixel_decoder(inputs, []) + outputs["rays"] = outputs["rays"].permute(0, 2, 1).reshape(B, 3, H, W) + pts_3d = outputs["rays"] * outputs["radius"] + + return pts_3d, outputs["confidence"] + + +def export(model, path, shape=(462, 630), with_camera=False): + model.eval() + image = torch.rand(1, 3, *shape) + dynamic_axes_in = {"rgbs": {0: "batch"}} + inputs = [image] + if with_camera: + rays = torch.rand(1, 3, *shape) + inputs.append(rays) + dynamic_axes_in["rays"] = {0: "batch"} + + dynamic_axes_out = { + "pts_3d": {0: "batch"}, + "confidence": {0: "batch"}, + } + torch.onnx.export( + model, + tuple(inputs), + path, + input_names=list(dynamic_axes_in.keys()), + output_names=list(dynamic_axes_out.keys()), + opset_version=14, + dynamic_axes={**dynamic_axes_in, **dynamic_axes_out}, + ) + print(f"Model exported to {path}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Export UniK3D model to ONNX") + parser.add_argument( + "--backbone", + type=str, + default="vitl", + choices=["vits", "vitb", "vitl"], + help="Backbone model", + ) + parser.add_argument( + "--shape", + type=int, + nargs=2, + default=(462, 630), + help="Input shape. No dyamic shape supported!", + ) + parser.add_argument( + "--output-path", type=str, default="unik3d.onnx", help="Output ONNX file" + ) + parser.add_argument( + "--with-camera", + action="store_true", + help="Export model that expects GT camera as unprojected rays at inference", + ) + args = parser.parse_args() + + backbone = args.backbone + shape = args.shape + output_path = args.output_path + with_camera = args.with_camera + + # force shape to be multiple of 14 + shape_rounded = [14 * ceil(x // 14 - 0.5) for x in shape] + if list(shape) != list(shape_rounded): + print(f"Shape {shape} is not multiple of 14. Rounding to {shape_rounded}") + shape = shape_rounded + + # assumes command is from root of repo + with open(os.path.join("configs", f"config_{backbone}.json")) as f: + config = json.load(f) + + # tell DINO not to use efficient attention: not exportable + config["training"]["export"] = True + + model = UniK3DONNX(config) if not with_camera else UniK3DONNXcam(config) + path = huggingface_hub.hf_hub_download( + repo_id=f"lpiccinelli/unik3d-{backbone}", + filename=f"pytorch_model.bin", + repo_type="model", + ) + info = model.load_state_dict(torch.load(path), strict=False) + print(f"UUniK3D_{backbone} is loaded with:") + print(f"\t missing keys: {info.missing_keys}") + print(f"\t additional keys: {info.unexpected_keys}") + + export( + model=model, + path=os.path.join(os.environ.get("TMPDIR", "."), output_path), + shape=shape, + with_camera=with_camera, + ) diff --git a/unik3d/models/metadinov2/__init__.py b/unik3d/models/metadinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f0534f6bab5034b20875641adba68f8aa1f857ac --- /dev/null +++ b/unik3d/models/metadinov2/__init__.py @@ -0,0 +1,12 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from .attention import Attention, MemEffAttention +from .block import Block, NestedTensorBlock +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused diff --git a/unik3d/models/metadinov2/attention.py b/unik3d/models/metadinov2/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..ea22c9f4e8e1f6d78e1bbd0f4899d56e69424c4e --- /dev/null +++ b/unik3d/models/metadinov2/attention.py @@ -0,0 +1,84 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging + +import torch.nn as nn +from torch import Tensor + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha, memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) if attn_drop > 0.0 else nn.Identity() + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) if proj_drop > 0.0 else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + B, N, C = x.shape + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, C // self.num_heads) + .permute(2, 0, 3, 1, 4) + ) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE or x.device.type == "cpu": + assert attn_bias is None, "xFormers is required for nested tensors usage" + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/unik3d/models/metadinov2/block.py b/unik3d/models/metadinov2/block.py new file mode 100644 index 0000000000000000000000000000000000000000..7e6cfc5232905b3e0ddfe1c126bd8c5798f974f4 --- /dev/null +++ b/unik3d/models/metadinov2/block.py @@ -0,0 +1,281 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +from typing import Any, Callable, Dict, List, Tuple + +import torch +import torch.nn as nn + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + +logger = logging.getLogger("dinov2") + + +try: + from xformers.ops import fmha, index_select_cat, scaled_index_add + + XFORMERS_AVAILABLE = True +except ImportError: + logger.warning("xFormers not available") + XFORMERS_AVAILABLE = False + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = ( + LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + ) + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: torch.Tensor) -> torch.Tensor: + def attn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: torch.Tensor) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: torch.Tensor, + residual_func, #: Callable[[torch.Tensor], torch.Tensor], + sample_drop_ratio: float = 0.0, +) -> torch.Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add( + x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor + ) + else: + x_plus_residual = scaled_index_add( + x, + brange, + residual.to(dtype=x.dtype), + scaling=scaling_vector, + alpha=residual_scale_factor, + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = ( + [b.shape[0] for b in branges] + if branges is not None + else [x.shape[0] for x in x_list] + ) + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view( + 1, -1, x_list[0].shape[-1] + ) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[torch.Tensor], + residual_func, #: Callable[[torch.Tensor, Any], torch.Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> torch.Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [ + get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list + ] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip( + x_list, branges, residual_list, residual_scale_factors + ): + outputs.append( + add_residual( + x, brange, residual, residual_scale_factor, scaling_vector + ).view_as(x) + ) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls1.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=( + self.ls2.gamma if isinstance(self.ls1, LayerScale) else None + ), + ) + return x_list + else: + + def attn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: torch.Tensor, attn_bias=None) -> torch.Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, torch.Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + assert ( + XFORMERS_AVAILABLE + ), "Please install xFormers for nested tensors usage" + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/unik3d/models/metadinov2/dino_head.py b/unik3d/models/metadinov2/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..1147dd3a3c046aee8d427b42b1055f38a218275b --- /dev/null +++ b/unik3d/models/metadinov2/dino_head.py @@ -0,0 +1,68 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp( + nlayers, + in_dim, + bottleneck_dim, + hidden_dim=hidden_dim, + use_bn=use_bn, + bias=mlp_bias, + ) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp( + nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True +): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/unik3d/models/metadinov2/drop_path.py b/unik3d/models/metadinov2/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..35b1a620d06ba862ea05297d271d8c2c625b5f93 --- /dev/null +++ b/unik3d/models/metadinov2/drop_path.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +import torch.nn as nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * ( + x.ndim - 1 + ) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/unik3d/models/metadinov2/layer_scale.py b/unik3d/models/metadinov2/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..40d18b5427183534d5516652b076f9883a609fc6 --- /dev/null +++ b/unik3d/models/metadinov2/layer_scale.py @@ -0,0 +1,28 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +import torch.nn as nn +from torch import Tensor + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/unik3d/models/metadinov2/mlp.py b/unik3d/models/metadinov2/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..af598999855d897948142cc986fce82abc9e3b53 --- /dev/null +++ b/unik3d/models/metadinov2/mlp.py @@ -0,0 +1,41 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) if drop > 0.0 else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/unik3d/models/metadinov2/patch_embed.py b/unik3d/models/metadinov2/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..e5a56c02609e67922eb8f859588ef274e5298b55 --- /dev/null +++ b/unik3d/models/metadinov2/patch_embed.py @@ -0,0 +1,101 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +import torch.nn as nn +from torch import Tensor + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d( + in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW + ) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert ( + H % patch_H == 0 + ), f"Input image height {H} is not a multiple of patch height {patch_H}" + assert ( + W % patch_W == 0 + ), f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = ( + Ho + * Wo + * self.embed_dim + * self.in_chans + * (self.patch_size[0] * self.patch_size[1]) + ) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/unik3d/models/metadinov2/swiglu_ffn.py b/unik3d/models/metadinov2/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..e82999e9b09b41cd6aba9edbc4c05d51ab663a1e --- /dev/null +++ b/unik3d/models/metadinov2/swiglu_ffn.py @@ -0,0 +1,63 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +from typing import Callable, Optional + +import torch.nn.functional as F +from torch import Tensor, nn + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +try: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/unik3d/models/unik3d.py b/unik3d/models/unik3d.py new file mode 100644 index 0000000000000000000000000000000000000000..6366d137598514a76fe5bb8df57c8b2782e461f9 --- /dev/null +++ b/unik3d/models/unik3d.py @@ -0,0 +1,475 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +import importlib +import warnings +from copy import deepcopy +from math import ceil + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torchvision.transforms.v2.functional as TF +from einops import rearrange +from huggingface_hub import PyTorchModelHubMixin + +from unik3d.models.decoder import Decoder +from unik3d.utils.camera import BatchCamera, Camera +from unik3d.utils.constants import IMAGENET_DATASET_MEAN, IMAGENET_DATASET_STD +from unik3d.utils.distributed import is_main_process +from unik3d.utils.misc import get_params, last_stack, match_gt + +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +def orthonormal_init(num_tokens, dims): + pe = torch.randn(num_tokens, dims) + # use Gram-Schmidt process to make the matrix orthonormal + for i in range(num_tokens): + for j in range(i): + pe[i] -= torch.dot(pe[i], pe[j]) * pe[j] + pe[i] = F.normalize(pe[i], p=2, dim=0) + return pe + + +def get_paddings(original_shape, aspect_ratio_range): + # Original dimensions + H_ori, W_ori = original_shape + orig_aspect_ratio = W_ori / H_ori + + # Determine the closest aspect ratio within the range + min_ratio, max_ratio = aspect_ratio_range + target_aspect_ratio = min(max_ratio, max(min_ratio, orig_aspect_ratio)) + + if orig_aspect_ratio > target_aspect_ratio: # Too wide + W_new = W_ori + H_new = int(W_ori / target_aspect_ratio) + pad_top = (H_new - H_ori) // 2 + pad_bottom = H_new - H_ori - pad_top + pad_left, pad_right = 0, 0 + else: # Too tall + H_new = H_ori + W_new = int(H_ori * target_aspect_ratio) + pad_left = (W_new - W_ori) // 2 + pad_right = W_new - W_ori - pad_left + pad_top, pad_bottom = 0, 0 + + return (pad_left, pad_right, pad_top, pad_bottom), (H_new, W_new) + + +def get_resize_factor(original_shape, pixels_range, shape_multiplier=14): + # Original dimensions + H_ori, W_ori = original_shape + n_pixels_ori = W_ori * H_ori + + # Determine the closest number of pixels within the range + min_pixels, max_pixels = pixels_range + target_pixels = min(max_pixels, max(min_pixels, n_pixels_ori)) + + # Calculate the resize factor + resize_factor = (target_pixels / n_pixels_ori) ** 0.5 + new_width = int(W_ori * resize_factor) + new_height = int(H_ori * resize_factor) + new_height = ceil(new_height / shape_multiplier) * shape_multiplier + new_width = ceil(new_width / shape_multiplier) * shape_multiplier + + return resize_factor, (new_height, new_width) + + +def _postprocess(tensor, shapes, paddings, interpolation_mode="bilinear"): + + # interpolate to original size + tensor = F.interpolate( + tensor, size=shapes, mode=interpolation_mode, align_corners=False + ) + + # remove paddings + pad1_l, pad1_r, pad1_t, pad1_b = paddings + tensor = tensor[..., pad1_t : shapes[0] - pad1_b, pad1_l : shapes[1] - pad1_r] + return tensor + + +class UniK3D( + nn.Module, + PyTorchModelHubMixin, + library_name="UniK3D", + repo_url="https://github.com/lpiccinelli-eth/UniK3D", + tags=["monocular-metric-3D-estimation"], +): + def __init__( + self, + config, + eps: float = 1e-6, + **kwargs, + ): + super().__init__() + self.eps = eps + self.build(config) + self.build_losses(config) + + def pack_sequence( + self, + inputs: dict[str, torch.Tensor], + ): + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.reshape(-1, *value.shape[2:]) + elif isinstance(value, BatchCamera): + inputs[key] = value.reshape(-1) + return inputs + + def unpack_sequence(self, inputs: dict[str, torch.Tensor], B: int, T: int): + for key, value in inputs.items(): + if isinstance(value, torch.Tensor): + inputs[key] = value.reshape(B, T, *value.shape[1:]) + elif isinstance(value, BatchCamera): + inputs[key] = value.reshape(B, T) + return inputs + + def forward_train(self, inputs, image_metas): + losses = {"opt": {}, "stat": {}} + B, T = inputs["image"].shape[:2] + image_metas[0]["B"], image_metas[0]["T"] = B, T + inputs = self.pack_sequence(inputs) # move from B, T, ... -> B*T, ... + + inputs, outputs = self.encode_decode(inputs, image_metas) + validity_mask = inputs["validity_mask"] + + # be careful on possible NaNs in reconstruced 3D (unprojection out-of-bound) + pts_gt = inputs["camera"].reconstruct(inputs["depth"]) * validity_mask.float() + pts_gt = torch.where(pts_gt.isnan().any(dim=1, keepdim=True), 0.0, pts_gt) + mask_pts_gt_nan = ~pts_gt.isnan().any(dim=1, keepdim=True) + mask = ( + inputs["depth_mask"].bool() & validity_mask.bool() & mask_pts_gt_nan.bool() + ) + + # compute loss! + inputs["radius"] = torch.norm(pts_gt, dim=1, keepdim=True) + inputs["points"] = pts_gt + inputs["depth_mask"] = mask + losses = self.compute_losses(outputs, inputs, image_metas) + + outputs = self.unpack_sequence(outputs, B, T) + return ( + outputs, + losses, + ) + + def forward_test(self, inputs, image_metas): + B, T = inputs["image"].shape[:2] + image_metas[0]["B"], image_metas[0]["T"] = B, T + # move from B, T, ... -> B*T, ... + inputs = self.pack_sequence(inputs) + inputs, outputs = self.encode_decode(inputs, image_metas) + + # you can add a dummy tensor with the actual output shape + depth_gt = inputs["depth"] + + outs = {} + outs["points"] = match_gt( + outputs["points"], depth_gt, padding1=inputs["paddings"], padding2=None + ) + outs["confidence"] = match_gt( + outputs["confidence"], depth_gt, padding1=inputs["paddings"], padding2=None + ) + outs["distance"] = outs["points"].norm(dim=1, keepdim=True) + outs["depth"] = outs["points"][:, -1:] + outs["rays"] = outs["points"] / torch.norm( + outs["points"], dim=1, keepdim=True + ).clip(min=1e-5) + + outs = self.unpack_sequence(outs, B, T) + return outs + + def forward(self, inputs, image_metas): + if self.training: + return self.forward_train(inputs, image_metas) + else: + return self.forward_test(inputs, image_metas) + + def encode_decode(self, inputs, image_metas=[]): + B, _, H, W = inputs["image"].shape + + # shortcut eval should avoid errors + if len(image_metas) and "paddings" in image_metas[0]: + # lrtb + inputs["paddings"] = torch.tensor( + [image_meta["paddings"] for image_meta in image_metas], + device=self.device, + )[..., [0, 2, 1, 3]] + inputs["depth_paddings"] = torch.tensor( + [image_meta["depth_paddings"] for image_meta in image_metas], + device=self.device, + ) + # at inference we do not have image paddings on top of depth ones (we have not "crop" on gt in ContextCrop) + if self.training: + inputs["depth_paddings"] = inputs["depth_paddings"] + inputs["paddings"] + else: + inputs["paddings"] = inputs["paddings"].squeeze(0) + inputs["depth_paddings"] = inputs["depth_paddings"].squeeze(0) + + if inputs.get("camera", None) is not None: + inputs["rays"] = inputs["camera"].get_rays(shapes=(B, H, W)) + + features, tokens = self.pixel_encoder(inputs["image"]) + inputs["features"] = [ + self.stacking_fn(features[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + inputs["tokens"] = [ + self.stacking_fn(tokens[i:j]).contiguous() + for i, j in self.slices_encoder_range + ] + + outputs = self.pixel_decoder(inputs, image_metas) + outputs["rays"] = rearrange(outputs["rays"], "b (h w) c -> b c h w", h=H, w=W) + pts_3d = outputs["rays"] * outputs["distance"] + outputs.update({"points": pts_3d, "depth": pts_3d[:, -1:]}) + + return inputs, outputs + + def compute_losses(self, outputs, inputs, image_metas): + B, _, H, W = inputs["image"].shape + losses = {"opt": {}, "stat": {}} + losses_to_be_computed = list(self.losses.keys()) + + # depth loss + si = torch.tensor( + [x.get("si", False) for x in image_metas], device=self.device + ).reshape(B) + loss = self.losses["depth"] + depth_losses = loss( + outputs["depth"], + target=inputs["depth"], + mask=inputs["depth_mask"].clone(), + si=si, + ) + losses["opt"][loss.name] = loss.weight * depth_losses.mean() + losses_to_be_computed.remove("depth") + + loss = self.losses["camera"] + camera_losses = loss( + outputs["rays"], target=inputs["rays"], mask=inputs["validity_mask"].bool() + ) + losses["opt"][loss.name] = loss.weight * camera_losses.mean() + losses_to_be_computed.remove("camera") + + # remaining losses, we expect no more losses to be computed + loss = self.losses["confidence"] + conf_losses = loss( + outputs["confidence"], + target_gt=inputs["depth"], + target_pred=outputs["depth"], + mask=inputs["depth_mask"].clone(), + ) + losses["opt"][loss.name + "_conf"] = loss.weight * conf_losses.mean() + losses_to_be_computed.remove("confidence") + + assert ( + not losses_to_be_computed + ), f"Losses {losses_to_be_computed} not computed, revise `compute_loss` method" + + return losses + + @torch.no_grad() + @torch.autocast(device_type=DEVICE, enabled=True, dtype=torch.float16) + def infer( + self, + rgb: torch.Tensor, + camera: torch.Tensor | Camera | None = None, + rays=None, + normalize=True, + ): + ratio_bounds = self.shape_constraints["ratio_bounds"] + pixels_bounds = [ + self.shape_constraints["pixels_min"], + self.shape_constraints["pixels_max"], + ] + if hasattr(self, "resolution_level"): + assert ( + self.resolution_level >= 0 and self.resolution_level < 10 + ), "resolution_level should be in [0, 10)" + pixels_range = pixels_bounds[1] - pixels_bounds[0] + interval = pixels_range / 10 + new_lowbound = self.resolution_level * interval + pixels_bounds[0] + new_upbound = (self.resolution_level + 1) * interval + pixels_bounds[0] + pixels_bounds = (new_lowbound, new_upbound) + else: + warnings.warn("!! self.resolution_level not set, using default bounds !!") + + # houskeeping on cpu/cuda and batchify + if rgb.ndim == 3: + rgb = rgb.unsqueeze(0) + if camera is not None: + camera = BatchCamera.from_camera(camera) + camera = camera.to(self.device) + + B, _, H, W = rgb.shape + rgb = rgb.to(self.device) + + # preprocess + paddings, (padded_H, padded_W) = get_paddings((H, W), ratio_bounds) + (pad_left, pad_right, pad_top, pad_bottom) = paddings + resize_factor, (new_H, new_W) = get_resize_factor( + (padded_H, padded_W), pixels_bounds + ) + # -> rgb preprocess (input std-ized and resized) + if normalize: + rgb = TF.normalize( + rgb.float() / 255.0, + mean=IMAGENET_DATASET_MEAN, + std=IMAGENET_DATASET_STD, + ) + rgb = F.pad(rgb, (pad_left, pad_right, pad_top, pad_bottom), value=0.0) + rgb = F.interpolate( + rgb, size=(new_H, new_W), mode="bilinear", align_corners=False + ) + # -> camera preprocess + if camera is not None: + camera = camera.crop( + left=-pad_left, top=-pad_top, right=-pad_right, bottom=-pad_bottom + ) + camera = camera.resize(resize_factor) + + # prepare inputs + inputs = {"image": rgb} + if camera is not None: + inputs["camera"] = camera + rays = camera.get_rays(shapes=(B, new_H, new_W), noisy=False).reshape( + B, 3, new_H, new_W + ) + inputs["rays"] = rays + + if rays is not None: + rays = rays.to(self.device) + if rays.ndim == 3: + rays = rays.unsqueeze(0) + rays = F.pad( + rays, + ( + max(0, pad_left), + max(0, pad_right), + max(0, pad_top), + max(0, pad_bottom), + ), + value=0.0, + ) + rays = F.interpolate( + rays, size=(new_H, new_W), mode="bilinear", align_corners=False + ) + inputs["rays"] = rays + + # run model + _, model_outputs = self.encode_decode(inputs, image_metas={}) + + # collect outputs + out = {} + out["confidence"] = _postprocess( + model_outputs["confidence"], + (padded_H, padded_W), + paddings=paddings, + interpolation_mode=self.interpolation_mode, + ) + points = _postprocess( + model_outputs["points"], + (padded_H, padded_W), + paddings=paddings, + interpolation_mode=self.interpolation_mode, + ) + rays = _postprocess( + model_outputs["rays"], + (padded_H, padded_W), + paddings=paddings, + interpolation_mode=self.interpolation_mode, + ) + + out["distance"] = points.norm(dim=1, keepdim=True) + out["depth"] = points[:, -1:] + out["points"] = points + out["rays"] = rays / torch.norm(rays, dim=1, keepdim=True).clip(min=1e-5) + out["lowres_features"] = model_outputs["lowres_features"] + return out + + def load_pretrained(self, model_file): + dict_model = torch.load(model_file, map_location="cpu", weights_only=False) + if "model" in dict_model: + dict_model = dict_model["model"] + info = self.load_state_dict(dict_model, strict=False) + if is_main_process(): + print( + f"Loaded from {model_file} for {self.__class__.__name__} results in:", + info, + ) + + def build(self, config): + mod = importlib.import_module("unik3d.models.encoder") + pixel_encoder_factory = getattr(mod, config["model"]["pixel_encoder"]["name"]) + pixel_encoder_config = { + **config["training"], + **config["model"]["pixel_encoder"], + **config["data"], + } + pixel_encoder = pixel_encoder_factory(pixel_encoder_config) + pixel_encoder_embed_dims = ( + pixel_encoder.embed_dims + if hasattr(pixel_encoder, "embed_dims") + else [getattr(pixel_encoder, "embed_dim") * 2**i for i in range(4)] + ) + config["model"]["pixel_encoder"]["embed_dim"] = getattr( + pixel_encoder, "embed_dim" + ) + config["model"]["pixel_encoder"]["embed_dims"] = pixel_encoder_embed_dims + config["model"]["pixel_encoder"]["depths"] = pixel_encoder.depths + config["model"]["pixel_encoder"]["cls_token_embed_dims"] = getattr( + pixel_encoder, "cls_token_embed_dims", pixel_encoder_embed_dims + ) + + pixel_decoder = Decoder(config) + + self.pixel_encoder = pixel_encoder + self.pixel_decoder = pixel_decoder + self.slices_encoder_range = list( + zip([0, *self.pixel_encoder.depths[:-1]], self.pixel_encoder.depths) + ) + self.stacking_fn = last_stack + self.shape_constraints = config["data"]["shape_constraints"] + self.interpolation_mode = "bilinear" + + def build_losses(self, config): + self.losses = {} + for loss_name, loss_config in config["training"]["losses"].items(): + mod = importlib.import_module("unik3d.ops.losses") + loss_factory = getattr(mod, loss_config["name"]) + self.losses[loss_name] = loss_factory.build(loss_config) + + def get_params(self, config): + if hasattr(self.pixel_encoder, "get_params"): + encoder_p, _ = self.pixel_encoder.get_params( + config["model"]["pixel_encoder"]["lr"], + config["training"]["wd"], + config["training"]["ld"], + ) + else: + encoder_p, _ = get_params( + self.pixel_encoder, + config["model"]["pixel_encoder"]["lr"], + config["training"]["wd"], + ) + decoder_p = self.pixel_decoder.get_params( + config["training"]["lr"], config["training"]["wd"] + ) + return [*encoder_p, *decoder_p] + + def step(self): + self.pixel_decoder.steps += 1 + + def parameters_grad(self): + for p in self.parameters(): + if p.requires_grad: + yield p + + @property + def device(self): + return next(self.parameters()).device diff --git a/unik3d/ops/__init__.py b/unik3d/ops/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..eb598ed6fbf6ffae6052dce0f11645adb0d3fda3 --- /dev/null +++ b/unik3d/ops/__init__.py @@ -0,0 +1,18 @@ +from .losses import (Confidence, Dummy, LocalNormal, LocalSSI, PolarRegression, + Regression, RobustLoss, Scale, SILog, SpatialGradient) +from .scheduler import CosineScheduler, PlainCosineScheduler + +__all__ = [ + "Dummy", + "SpatialGradient", + "LocalSSI", + "Regression", + "LocalNormal", + "RobustLoss", + "SILog", + "CosineScheduler", + "PlainCosineScheduler", + "PolarRegression", + "Scale", + "Confidence", +] diff --git a/unik3d/ops/knn/__init__.py b/unik3d/ops/knn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ba469c8fcb582177fd183702a468d6e72d262a7f --- /dev/null +++ b/unik3d/ops/knn/__init__.py @@ -0,0 +1,6 @@ +from .functions.knn import knn_gather, knn_points + +__all__ = [ + "knn_points", + "knn_gather", +] diff --git a/unik3d/ops/knn/compile.sh b/unik3d/ops/knn/compile.sh new file mode 100755 index 0000000000000000000000000000000000000000..ccf4338cd44a1291d75c260fa180a24790ed98b4 --- /dev/null +++ b/unik3d/ops/knn/compile.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +export TORCH_CUDA_ARCH_LIST="8.0 8.6+PTX" +# export FORCE_CUDA=1 #if you do not actually have cuda, workaround +python setup.py build install \ No newline at end of file diff --git a/unik3d/ops/knn/functions/__init__.py b/unik3d/ops/knn/functions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6a54211c5e208129ba33507b0011733cb4a31491 --- /dev/null +++ b/unik3d/ops/knn/functions/__init__.py @@ -0,0 +1,6 @@ +from .knn import knn_gather, knn_points + +__all__ = [ + "knn_points", + "knn_gather", +] diff --git a/unik3d/ops/knn/functions/knn.py b/unik3d/ops/knn/functions/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..42c36344d42952f123290cb0bec10e39a6f2da3b --- /dev/null +++ b/unik3d/ops/knn/functions/knn.py @@ -0,0 +1,249 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from collections import namedtuple +from typing import Union + +import torch +from KNN import knn_points_backward, knn_points_idx +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +_KNN = namedtuple("KNN", "dists idx knn") + + +class _knn_points(Function): + """ + Torch autograd Function wrapper for KNN C++/CUDA implementations. + """ + + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + ctx, + p1, + p2, + lengths1, + lengths2, + K, + version, + norm: int = 2, + return_sorted: bool = True, + ): + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2). + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. + + Returns: + p1_dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + p1_idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. + """ + if not ((norm == 1) or (norm == 2)): + raise ValueError("Support for 1 or 2 norm.") + + idx, dists = knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version) + + # sort KNN in ascending order if K > 1 + if K > 1 and return_sorted: + if lengths2.min() < K: + P1 = p1.shape[1] + mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None] + # mask has shape [N, K], true where dists irrelevant + mask = mask[:, None].expand(-1, P1, -1) + # mask has shape [N, P1, K], true where dists irrelevant + dists[mask] = float("inf") + dists, sort_idx = dists.sort(dim=2) + dists[mask] = 0 + else: + dists, sort_idx = dists.sort(dim=2) + idx = idx.gather(2, sort_idx) + + ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) + ctx.mark_non_differentiable(idx) + ctx.norm = norm + return dists, idx + + @staticmethod + @once_differentiable + def backward(ctx, grad_dists, grad_idx): + p1, p2, lengths1, lengths2, idx = ctx.saved_tensors + norm = ctx.norm + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (grad_dists.dtype == torch.float32): + grad_dists = grad_dists.float() + if not (p1.dtype == torch.float32): + p1 = p1.float() + if not (p2.dtype == torch.float32): + p2 = p2.float() + grad_p1, grad_p2 = knn_points_backward( + p1, p2, lengths1, lengths2, idx, norm, grad_dists + ) + return grad_p1, grad_p2, None, None, None, None, None, None + + +def knn_points( + p1: torch.Tensor, + p2: torch.Tensor, + lengths1: Union[torch.Tensor, None] = None, + lengths2: Union[torch.Tensor, None] = None, + norm: int = 2, + K: int = 1, + version: int = -1, + return_nn: bool = False, + return_sorted: bool = True, +) -> _KNN: + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1. + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. + + Returns: + dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 + points. + + nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for + each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor + for `p1[n, i]`. Returned if `return_nn` is True. + The nearest neighbors are collected using `knn_gather` + + .. code-block:: + + p2_nn = knn_gather(p2, p1_idx, lengths2) + + which is a helper function that allows indexing any tensor of shape (N, P2, U) with + the indices `p1_idx` returned by `knn_points`. The output is a tensor + of shape (N, P1, K, U). + + """ + if p1.shape[0] != p2.shape[0]: + raise ValueError("pts1 and pts2 must have the same batch dimension.") + if p1.shape[2] != p2.shape[2]: + raise ValueError("pts1 and pts2 must have the same point dimension.") + + p1 = p1.contiguous() + p2 = p2.contiguous() + + P1 = p1.shape[1] + P2 = p2.shape[1] + + if lengths1 is None: + lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) + + p1_dists, p1_idx = _knn_points.apply( + p1, p2, lengths1, lengths2, K, version, norm, return_sorted + ) + + p2_nn = None + if return_nn: + p2_nn = knn_gather(p2, p1_idx, lengths2) + + return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None) + + +def knn_gather( + x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None +): + """ + A helper function for knn that allows indexing a tensor x with the indices `idx` + returned by `knn_points`. + + For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)` + where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D), + then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`. + It can also be applied for any tensor x of shape (N, M, U) where U != D. + + Args: + x: Tensor of shape (N, M, U) containing U-dimensional features to + be gathered. + idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`. + lengths: LongTensor of shape (N,) of values in the range [0, M], giving the + length of each example in the batch in x. Or None to indicate that every + example has length M. + Returns: + x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x + with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`. + If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0. + """ + N, M, U = x.shape + _N, L, K = idx.shape + + if N != _N: + raise ValueError("x and idx must have same batch dimension.") + + if lengths is None: + lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device) + + idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U) + # idx_expanded has shape [N, L, K, U] + + x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded) + # p2_nn has shape [N, L, K, U] + + needs_mask = lengths.min() < K + if needs_mask: + # mask has shape [N, K], true where idx is irrelevant because + # there is less number of points in p2 than K + mask = lengths[:, None] <= torch.arange(K, device=x.device)[None] + + # expand mask to shape [N, L, K, U] + mask = mask[:, None].expand(-1, L, -1) + mask = mask[:, :, :, None].expand(-1, -1, -1, U) + x_out[mask] = 0.0 + + return x_out diff --git a/unik3d/ops/knn/setup.py b/unik3d/ops/knn/setup.py new file mode 100644 index 0000000000000000000000000000000000000000..6e7f7dd9953b7da22310cdb7d227acccf2749f36 --- /dev/null +++ b/unik3d/ops/knn/setup.py @@ -0,0 +1,61 @@ +import glob +import os + +import torch +from setuptools import find_packages, setup +from torch.utils.cpp_extension import CUDA_HOME, CppExtension, CUDAExtension + +requirements = ["torch", "torchvision"] + + +def get_extensions(): + this_dir = os.path.dirname(os.path.abspath(__file__)) + extensions_dir = os.path.join(this_dir, "src") + + main_file = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cpu = glob.glob(os.path.join(extensions_dir, "*.cpp")) + source_cuda = glob.glob(os.path.join(extensions_dir, "*.cu")) + + sources = main_file + source_cpu + extension = CppExtension + extra_compile_args = {"cxx": ["-O3"]} + define_macros = [] + + if torch.cuda.is_available() and CUDA_HOME is not None: + extension = CUDAExtension + sources += source_cuda + define_macros += [("WITH_CUDA", None)] + extra_compile_args["nvcc"] = [ + "-O3", + ] + else: + raise NotImplementedError("Cuda is not available") + + sources = list(set([os.path.join(extensions_dir, s) for s in sources])) + include_dirs = [extensions_dir] + ext_modules = [ + extension( + "KNN", + sources, + include_dirs=include_dirs, + define_macros=define_macros, + extra_compile_args=extra_compile_args, + ) + ] + + return ext_modules + + +setup( + name="KNN", + version="0.1", + author="Luigi Piccinelli", + ext_modules=get_extensions(), + packages=find_packages( + exclude=( + "configs", + "tests", + ) + ), + cmdclass={"build_ext": torch.utils.cpp_extension.BuildExtension}, +) diff --git a/unik3d/ops/knn/src/knn.cu b/unik3d/ops/knn/src/knn.cu new file mode 100644 index 0000000000000000000000000000000000000000..ba0732dca6ee00732286d5e3a15191aaaf7c5409 --- /dev/null +++ b/unik3d/ops/knn/src/knn.cu @@ -0,0 +1,587 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include +#include +#include +#include + +#include "utils/dispatch.cuh" +#include "utils/mink.cuh" + +// A chunk of work is blocksize-many points of P1. +// The number of potential chunks to do is N*(1+(P1-1)/blocksize) +// call (1+(P1-1)/blocksize) chunks_per_cloud +// These chunks are divided among the gridSize-many blocks. +// In block b, we work on chunks b, b+gridSize, b+2*gridSize etc . +// In chunk i, we work on cloud i/chunks_per_cloud on points starting from +// blocksize*(i%chunks_per_cloud). + +template +__global__ void KNearestNeighborKernelV0( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t D, + const size_t K, + const size_t norm) { + // Store both dists and indices for knn in global memory. + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + int offset = n * P1 * K + p1 * K; + int64_t length2 = lengths2[n]; + MinK mink(dists + offset, idxs + offset, K); + for (int p2 = 0; p2 < length2; ++p2) { + // Find the distance between points1[n, p1] and points[n, p2] + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + scalar_t coord1 = points1[n * P1 * D + p1 * D + d]; + scalar_t coord2 = points2[n * P2 * D + p2 * D + d]; + scalar_t diff = coord1 - coord2; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + } +} + +template +__global__ void KNearestNeighborKernelV1( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t norm) { + // Same idea as the previous version, but hoist D into a template argument + // so we can cache the current point in a thread-local array. We still store + // the current best K dists and indices in global memory, so this should work + // for very large K and fairly large D. + scalar_t cur_point[D]; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int offset = n * P1 * K + p1 * K; + int64_t length2 = lengths2[n]; + MinK mink(dists + offset, idxs + offset, K); + for (int p2 = 0; p2 < length2; ++p2) { + // Find the distance between cur_point and points[n, p2] + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + scalar_t diff = cur_point[d] - points2[n * P2 * D + p2 * D + d]; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + } +} + +// This is a shim functor to allow us to dispatch using DispatchKernel1D +template +struct KNearestNeighborV1Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t norm) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV1<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, K, norm); + } +}; + +template +__global__ void KNearestNeighborKernelV2( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const int64_t N, + const int64_t P1, + const int64_t P2, + const size_t norm) { + // Same general implementation as V2, but also hoist K into a template arg. + scalar_t cur_point[D]; + scalar_t min_dists[K]; + int min_idxs[K]; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int64_t length2 = lengths2[n]; + MinK mink(min_dists, min_idxs, K); + for (int p2 = 0; p2 < length2; ++p2) { + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + int offset = n * P2 * D + p2 * D + d; + scalar_t diff = cur_point[d] - points2[offset]; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + for (int k = 0; k < mink.size(); ++k) { + idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; + dists[n * P1 * K + p1 * K + k] = min_dists[k]; + } + } +} + +// This is a shim so we can dispatch using DispatchKernel2D +template +struct KNearestNeighborKernelV2Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const int64_t N, + const int64_t P1, + const int64_t P2, + const size_t norm) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV2<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); + } +}; + +template +__global__ void KNearestNeighborKernelV3( + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t norm) { + // Same idea as V2, but use register indexing for thread-local arrays. + // Enabling sorting for this version leads to huge slowdowns; I suspect + // that it forces min_dists into local memory rather than registers. + // As a result this version is always unsorted. + scalar_t cur_point[D]; + scalar_t min_dists[K]; + int min_idxs[K]; + const int64_t chunks_per_cloud = (1 + (P1 - 1) / blockDim.x); + const int64_t chunks_to_do = N * chunks_per_cloud; + for (int64_t chunk = blockIdx.x; chunk < chunks_to_do; chunk += gridDim.x) { + const int64_t n = chunk / chunks_per_cloud; + const int64_t start_point = blockDim.x * (chunk % chunks_per_cloud); + int64_t p1 = start_point + threadIdx.x; + if (p1 >= lengths1[n]) + continue; + for (int d = 0; d < D; ++d) { + cur_point[d] = points1[n * P1 * D + p1 * D + d]; + } + int64_t length2 = lengths2[n]; + RegisterMinK mink(min_dists, min_idxs); + for (int p2 = 0; p2 < length2; ++p2) { + scalar_t dist = 0; + for (int d = 0; d < D; ++d) { + int offset = n * P2 * D + p2 * D + d; + scalar_t diff = cur_point[d] - points2[offset]; + scalar_t norm_diff = (norm == 2) ? (diff * diff) : abs(diff); + dist += norm_diff; + } + mink.add(dist, p2); + } + for (int k = 0; k < mink.size(); ++k) { + idxs[n * P1 * K + p1 * K + k] = min_idxs[k]; + dists[n * P1 * K + p1 * K + k] = min_dists[k]; + } + } +} + +// This is a shim so we can dispatch using DispatchKernel2D +template +struct KNearestNeighborKernelV3Functor { + static void run( + size_t blocks, + size_t threads, + const scalar_t* __restrict__ points1, + const scalar_t* __restrict__ points2, + const int64_t* __restrict__ lengths1, + const int64_t* __restrict__ lengths2, + scalar_t* __restrict__ dists, + int64_t* __restrict__ idxs, + const size_t N, + const size_t P1, + const size_t P2, + const size_t norm) { + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + KNearestNeighborKernelV3<<>>( + points1, points2, lengths1, lengths2, dists, idxs, N, P1, P2, norm); + } +}; + +constexpr int V1_MIN_D = 1; +constexpr int V1_MAX_D = 32; + +constexpr int V2_MIN_D = 1; +constexpr int V2_MAX_D = 8; +constexpr int V2_MIN_K = 1; +constexpr int V2_MAX_K = 32; + +constexpr int V3_MIN_D = 1; +constexpr int V3_MAX_D = 8; +constexpr int V3_MIN_K = 1; +constexpr int V3_MAX_K = 4; + +bool InBounds(const int64_t min, const int64_t x, const int64_t max) { + return min <= x && x <= max; +} + +bool KnnCheckVersion(int version, const int64_t D, const int64_t K) { + if (version == 0) { + return true; + } else if (version == 1) { + return InBounds(V1_MIN_D, D, V1_MAX_D); + } else if (version == 2) { + return InBounds(V2_MIN_D, D, V2_MAX_D) && InBounds(V2_MIN_K, K, V2_MAX_K); + } else if (version == 3) { + return InBounds(V3_MIN_D, D, V3_MAX_D) && InBounds(V3_MIN_K, K, V3_MAX_K); + } + return false; +} + +int ChooseVersion(const int64_t D, const int64_t K) { + for (int version = 3; version >= 1; version--) { + if (KnnCheckVersion(version, D, K)) { + return version; + } + } + return 0; +} + +std::tuple KNearestNeighborIdxCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K, + int version) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}; + at::CheckedFrom c = "KNearestNeighborIdxCuda"; + at::checkAllSameGPU(c, {p1_t, p2_t, lengths1_t, lengths2_t}); + at::checkAllSameType(c, {p1_t, p2_t}); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const auto N = p1.size(0); + const auto P1 = p1.size(1); + const auto P2 = p2.size(1); + const auto D = p2.size(2); + const int64_t K_64 = K; + + TORCH_CHECK((norm == 1) || (norm == 2), "Norm must be 1 or 2."); + + TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); + auto long_dtype = lengths1.options().dtype(at::kLong); + auto idxs = at::zeros({N, P1, K}, long_dtype); + auto dists = at::zeros({N, P1, K}, p1.options()); + + if (idxs.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(idxs, dists); + } + + if (version < 0) { + version = ChooseVersion(D, K); + } else if (!KnnCheckVersion(version, D, K)) { + int new_version = ChooseVersion(D, K); + std::cout << "WARNING: Requested KNN version " << version + << " is not compatible with D = " << D << "; K = " << K + << ". Falling back to version = " << new_version << std::endl; + version = new_version; + } + + // At this point we should have a valid version no matter what data the user + // gave us. But we can check once more to be sure; however this time + // assert fail since failing at this point means we have a bug in our version + // selection or checking code. + AT_ASSERTM(KnnCheckVersion(version, D, K), "Invalid version"); + + const size_t threads = 256; + const size_t blocks = 256; + if (version == 0) { + AT_DISPATCH_FLOATING_TYPES( + p1.scalar_type(), "knn_kernel_cuda", ([&] { + KNearestNeighborKernelV0<<>>( + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + D, + K, + norm); + })); + } else if (version == 1) { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + DispatchKernel1D< + KNearestNeighborV1Functor, + scalar_t, + V1_MIN_D, + V1_MAX_D>( + D, + blocks, + threads, + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + K, + norm); + })); + } else if (version == 2) { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + DispatchKernel2D< + KNearestNeighborKernelV2Functor, + scalar_t, + V2_MIN_D, + V2_MAX_D, + V2_MIN_K, + V2_MAX_K>( + D, + K_64, + blocks, + threads, + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + norm); + })); + } else if (version == 3) { + AT_DISPATCH_FLOATING_TYPES(p1.scalar_type(), "knn_kernel_cuda", ([&] { + DispatchKernel2D< + KNearestNeighborKernelV3Functor, + scalar_t, + V3_MIN_D, + V3_MAX_D, + V3_MIN_K, + V3_MAX_K>( + D, + K_64, + blocks, + threads, + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + dists.data_ptr(), + idxs.data_ptr(), + N, + P1, + P2, + norm); + })); + } + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(idxs, dists); +} + +// ------------------------------------------------------------- // +// Backward Operators // +// ------------------------------------------------------------- // + +// TODO(gkioxari) support all data types once AtomicAdd supports doubles. +// Currently, support is for floats only. +__global__ void KNearestNeighborBackwardKernel( + const float* __restrict__ p1, // (N, P1, D) + const float* __restrict__ p2, // (N, P2, D) + const int64_t* __restrict__ lengths1, // (N,) + const int64_t* __restrict__ lengths2, // (N,) + const int64_t* __restrict__ idxs, // (N, P1, K) + const float* __restrict__ grad_dists, // (N, P1, K) + float* __restrict__ grad_p1, // (N, P1, D) + float* __restrict__ grad_p2, // (N, P2, D) + const size_t N, + const size_t P1, + const size_t P2, + const size_t K, + const size_t D, + const size_t norm) { + const size_t tid = blockIdx.x * blockDim.x + threadIdx.x; + const size_t stride = gridDim.x * blockDim.x; + + for (size_t i = tid; i < N * P1 * K * D; i += stride) { + const size_t n = i / (P1 * K * D); // batch index + size_t rem = i % (P1 * K * D); + const size_t p1_idx = rem / (K * D); // index of point in p1 + rem = rem % (K * D); + const size_t k = rem / D; // k-th nearest neighbor + const size_t d = rem % D; // d-th dimension in the feature vector + + const size_t num1 = lengths1[n]; // number of valid points in p1 in batch + const size_t num2 = lengths2[n]; // number of valid points in p2 in batch + if ((p1_idx < num1) && (k < num2)) { + const float grad_dist = grad_dists[n * P1 * K + p1_idx * K + k]; + // index of point in p2 corresponding to the k-th nearest neighbor + const int64_t p2_idx = idxs[n * P1 * K + p1_idx * K + k]; + // If the index is the pad value of -1 then ignore it + if (p2_idx == -1) { + continue; + } + float diff = 0.0; + if (norm == 1) { + float sign = + (p1[n * P1 * D + p1_idx * D + d] > p2[n * P2 * D + p2_idx * D + d]) + ? 1.0 + : -1.0; + diff = grad_dist * sign; + } else { // norm is 2 + diff = 2.0 * grad_dist * + (p1[n * P1 * D + p1_idx * D + d] - p2[n * P2 * D + p2_idx * D + d]); + } + atomicAdd(grad_p1 + n * P1 * D + p1_idx * D + d, diff); + atomicAdd(grad_p2 + n * P2 * D + p2_idx * D + d, -1.0f * diff); + } + } +} + +std::tuple KNearestNeighborBackwardCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + int norm, + const at::Tensor& grad_dists) { + // Check inputs are on the same device + at::TensorArg p1_t{p1, "p1", 1}, p2_t{p2, "p2", 2}, + lengths1_t{lengths1, "lengths1", 3}, lengths2_t{lengths2, "lengths2", 4}, + idxs_t{idxs, "idxs", 5}, grad_dists_t{grad_dists, "grad_dists", 6}; + at::CheckedFrom c = "KNearestNeighborBackwardCuda"; + at::checkAllSameGPU( + c, {p1_t, p2_t, lengths1_t, lengths2_t, idxs_t, grad_dists_t}); + at::checkAllSameType(c, {p1_t, p2_t, grad_dists_t}); + + // This is nondeterministic because atomicAdd + at::globalContext().alertNotDeterministic("KNearestNeighborBackwardCuda"); + + // Set the device for the kernel launch based on the device of the input + at::cuda::CUDAGuard device_guard(p1.device()); + cudaStream_t stream = at::cuda::getCurrentCUDAStream(); + + const auto N = p1.size(0); + const auto P1 = p1.size(1); + const auto P2 = p2.size(1); + const auto D = p2.size(2); + const auto K = idxs.size(2); + + TORCH_CHECK(p1.size(2) == D, "Point sets must have the same last dimension"); + TORCH_CHECK(idxs.size(0) == N, "KNN idxs must have the same batch dimension"); + TORCH_CHECK( + idxs.size(1) == P1, "KNN idxs must have the same point dimension as p1"); + TORCH_CHECK(grad_dists.size(0) == N); + TORCH_CHECK(grad_dists.size(1) == P1); + TORCH_CHECK(grad_dists.size(2) == K); + + auto grad_p1 = at::zeros({N, P1, D}, p1.options()); + auto grad_p2 = at::zeros({N, P2, D}, p2.options()); + + if (grad_p1.numel() == 0 || grad_p2.numel() == 0) { + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_p1, grad_p2); + } + + const int blocks = 64; + const int threads = 512; + + KNearestNeighborBackwardKernel<<>>( + p1.contiguous().data_ptr(), + p2.contiguous().data_ptr(), + lengths1.contiguous().data_ptr(), + lengths2.contiguous().data_ptr(), + idxs.contiguous().data_ptr(), + grad_dists.contiguous().data_ptr(), + grad_p1.data_ptr(), + grad_p2.data_ptr(), + N, + P1, + P2, + K, + D, + norm); + + AT_CUDA_CHECK(cudaGetLastError()); + return std::make_tuple(grad_p1, grad_p2); +} \ No newline at end of file diff --git a/unik3d/ops/knn/src/knn.h b/unik3d/ops/knn/src/knn.h new file mode 100644 index 0000000000000000000000000000000000000000..e43a0231119bf170502e1078d77110a3cd4510aa --- /dev/null +++ b/unik3d/ops/knn/src/knn.h @@ -0,0 +1,157 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include +#include +#include "utils/pytorch3d_cutils.h" + +// Compute indices of K nearest neighbors in pointcloud p2 to points +// in pointcloud p1. +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// norm: int specifying the norm for the distance (1 for L1, 2 for L2) +// K: int giving the number of nearest points to return. +// version: Integer telling which implementation to use. +// +// Returns: +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// It is padded with zeros so that it can be used easily in a later +// gather() operation. +// +// p1_neighbor_dists: FloatTensor of shape (N, P1, K) containing the squared +// distance from each point p1[n, p, :] to its K neighbors +// p2[n, p1_neighbor_idx[n, p, k], :]. + +// CPU implementation. +std::tuple KNearestNeighborIdxCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K); + +// CUDA implementation +std::tuple KNearestNeighborIdxCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K, + const int version); + +// Implementation which is exposed. +std::tuple KNearestNeighborIdx( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K, + const int version) { + if (p1.is_cuda() || p2.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(p1); + CHECK_CUDA(p2); + return KNearestNeighborIdxCuda( + p1, p2, lengths1, lengths2, norm, K, version); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return KNearestNeighborIdxCpu(p1, p2, lengths1, lengths2, norm, K); +} + +// Compute gradients with respect to p1 and p2 +// +// Args: +// p1: FloatTensor of shape (N, P1, D) giving a batch of pointclouds each +// containing P1 points of dimension D. +// p2: FloatTensor of shape (N, P2, D) giving a batch of pointclouds each +// containing P2 points of dimension D. +// lengths1: LongTensor, shape (N,), giving actual length of each P1 cloud. +// lengths2: LongTensor, shape (N,), giving actual length of each P2 cloud. +// p1_neighbor_idx: LongTensor of shape (N, P1, K), where +// p1_neighbor_idx[n, i, k] = j means that the kth nearest +// neighbor to p1[n, i] in the cloud p2[n] is p2[n, j]. +// It is padded with zeros so that it can be used easily in a later +// gather() operation. This is computed from the forward pass. +// norm: int specifying the norm for the distance (1 for L1, 2 for L2) +// grad_dists: FLoatTensor of shape (N, P1, K) which contains the input +// gradients. +// +// Returns: +// grad_p1: FloatTensor of shape (N, P1, D) containing the output gradients +// wrt p1. +// grad_p2: FloatTensor of shape (N, P2, D) containing the output gradients +// wrt p2. + +// CPU implementation. +std::tuple KNearestNeighborBackwardCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists); + +// CUDA implementation +std::tuple KNearestNeighborBackwardCuda( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists); + +// Implementation which is exposed. +std::tuple KNearestNeighborBackward( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists) { + if (p1.is_cuda() || p2.is_cuda()) { +#ifdef WITH_CUDA + CHECK_CUDA(p1); + CHECK_CUDA(p2); + return KNearestNeighborBackwardCuda( + p1, p2, lengths1, lengths2, idxs, norm, grad_dists); +#else + AT_ERROR("Not compiled with GPU support."); +#endif + } + return KNearestNeighborBackwardCpu( + p1, p2, lengths1, lengths2, idxs, norm, grad_dists); +} + +// Utility to check whether a KNN version can be used. +// +// Args: +// version: Integer in the range 0 <= version <= 3 indicating one of our +// KNN implementations. +// D: Number of dimensions for the input and query point clouds +// K: Number of neighbors to be found +// +// Returns: +// Whether the indicated KNN version can be used. +bool KnnCheckVersion(int version, const int64_t D, const int64_t K); \ No newline at end of file diff --git a/unik3d/ops/knn/src/knn_cpu.cpp b/unik3d/ops/knn/src/knn_cpu.cpp new file mode 100644 index 0000000000000000000000000000000000000000..694ab11ea62c2eb285e1a03f6e4c0862498c24ae --- /dev/null +++ b/unik3d/ops/knn/src/knn_cpu.cpp @@ -0,0 +1,128 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#include +#include +#include + +std::tuple KNearestNeighborIdxCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const int norm, + const int K) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + + auto long_opts = lengths1.options().dtype(torch::kInt64); + torch::Tensor idxs = torch::full({N, P1, K}, 0, long_opts); + torch::Tensor dists = torch::full({N, P1, K}, 0, p1.options()); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); + auto idxs_a = idxs.accessor(); + auto dists_a = dists.accessor(); + + for (int n = 0; n < N; ++n) { + const int64_t length1 = lengths1_a[n]; + const int64_t length2 = lengths2_a[n]; + for (int64_t i1 = 0; i1 < length1; ++i1) { + // Use a priority queue to store (distance, index) tuples. + std::priority_queue> q; + for (int64_t i2 = 0; i2 < length2; ++i2) { + float dist = 0; + for (int d = 0; d < D; ++d) { + float diff = p1_a[n][i1][d] - p2_a[n][i2][d]; + if (norm == 1) { + dist += abs(diff); + } else { // norm is 2 (default) + dist += diff * diff; + } + } + int size = static_cast(q.size()); + if (size < K || dist < std::get<0>(q.top())) { + q.emplace(dist, i2); + if (size >= K) { + q.pop(); + } + } + } + while (!q.empty()) { + auto t = q.top(); + q.pop(); + const int k = q.size(); + dists_a[n][i1][k] = std::get<0>(t); + idxs_a[n][i1][k] = std::get<1>(t); + } + } + } + return std::make_tuple(idxs, dists); +} + +// ------------------------------------------------------------- // +// Backward Operators // +// ------------------------------------------------------------- // + +std::tuple KNearestNeighborBackwardCpu( + const at::Tensor& p1, + const at::Tensor& p2, + const at::Tensor& lengths1, + const at::Tensor& lengths2, + const at::Tensor& idxs, + const int norm, + const at::Tensor& grad_dists) { + const int N = p1.size(0); + const int P1 = p1.size(1); + const int D = p1.size(2); + const int P2 = p2.size(1); + const int K = idxs.size(2); + + torch::Tensor grad_p1 = torch::full({N, P1, D}, 0, p1.options()); + torch::Tensor grad_p2 = torch::full({N, P2, D}, 0, p2.options()); + + auto p1_a = p1.accessor(); + auto p2_a = p2.accessor(); + auto lengths1_a = lengths1.accessor(); + auto lengths2_a = lengths2.accessor(); + auto idxs_a = idxs.accessor(); + auto grad_dists_a = grad_dists.accessor(); + auto grad_p1_a = grad_p1.accessor(); + auto grad_p2_a = grad_p2.accessor(); + + for (int n = 0; n < N; ++n) { + const int64_t length1 = lengths1_a[n]; + int64_t length2 = lengths2_a[n]; + length2 = (length2 < K) ? length2 : K; + for (int64_t i1 = 0; i1 < length1; ++i1) { + for (int64_t k = 0; k < length2; ++k) { + const int64_t i2 = idxs_a[n][i1][k]; + // If the index is the pad value of -1 then ignore it + if (i2 == -1) { + continue; + } + for (int64_t d = 0; d < D; ++d) { + float diff = 0.0; + if (norm == 1) { + float sign = (p1_a[n][i1][d] > p2_a[n][i2][d]) ? 1.0 : -1.0; + diff = grad_dists_a[n][i1][k] * sign; + } else { // norm is 2 (default) + diff = 2.0f * grad_dists_a[n][i1][k] * + (p1_a[n][i1][d] - p2_a[n][i2][d]); + } + grad_p1_a[n][i1][d] += diff; + grad_p2_a[n][i2][d] += -1.0f * diff; + } + } + } + } + return std::make_tuple(grad_p1, grad_p2); +} \ No newline at end of file diff --git a/unik3d/ops/knn/src/knn_ext.cpp b/unik3d/ops/knn/src/knn_ext.cpp new file mode 100644 index 0000000000000000000000000000000000000000..f2fc9b4374ea5a2196d8ddee4e3cd5c0a907ddaa --- /dev/null +++ b/unik3d/ops/knn/src/knn_ext.cpp @@ -0,0 +1,10 @@ +#include +#include "knn.h" + +PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { +#ifdef WITH_CUDA + m.def("knn_check_version", &KnnCheckVersion); +#endif + m.def("knn_points_idx", &KNearestNeighborIdx); + m.def("knn_points_backward", &KNearestNeighborBackward); +} \ No newline at end of file diff --git a/unik3d/ops/knn/src/utils/dispatch.cuh b/unik3d/ops/knn/src/utils/dispatch.cuh new file mode 100644 index 0000000000000000000000000000000000000000..af197b4392106391236f9f00d999a02e6ac2defa --- /dev/null +++ b/unik3d/ops/knn/src/utils/dispatch.cuh @@ -0,0 +1,357 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This file provides utilities for dispatching to specialized versions of +// functions. This is especially useful for CUDA kernels, since specializing +// them to particular input sizes can often allow the compiler to unroll loops +// and place arrays into registers, which can give huge performance speedups. +// +// As an example, suppose we have the following function which is specialized +// based on a compile-time int64_t value: +// +// template +// struct SquareOffset { +// static void run(T y) { +// T val = x * x + y; +// std::cout << val << std::endl; +// } +// } +// +// This function takes one compile-time argument x, and one run-time argument y. +// We might want to compile specialized versions of this for x=0, x=1, etc and +// then dispatch to the correct one based on the runtime value of x. +// One simple way to achieve this is with a lookup table: +// +// template +// void DispatchSquareOffset(const int64_t x, T y) { +// if (x == 0) { +// SquareOffset::run(y); +// } else if (x == 1) { +// SquareOffset::run(y); +// } else if (x == 2) { +// SquareOffset::run(y); +// } +// } +// +// This function takes both x and y as run-time arguments, and dispatches to +// different specialized versions of SquareOffset based on the run-time value +// of x. This works, but it's tedious and error-prone. If we want to change the +// set of x values for which we provide compile-time specializations, then we +// will need to do a lot of tedius editing of the dispatch function. Also, if we +// want to provide compile-time specializations for another function other than +// SquareOffset, we will need to duplicate the entire lookup table. +// +// To solve these problems, we can use the DispatchKernel1D function provided by +// this file instead: +// +// template +// void DispatchSquareOffset(const int64_t x, T y) { +// constexpr int64_t xmin = 0; +// constexpr int64_t xmax = 2; +// DispatchKernel1D(x, y); +// } +// +// DispatchKernel1D uses template metaprogramming to compile specialized +// versions of SquareOffset for all values of x with xmin <= x <= xmax, and +// then dispatches to the correct one based on the run-time value of x. If we +// want to change the range of x values for which SquareOffset is specialized +// at compile-time, then all we have to do is change the values of the +// compile-time constants xmin and xmax. +// +// This file also allows us to similarly dispatch functions that depend on two +// compile-time int64_t values, using the DispatchKernel2D function like this: +// +// template +// struct Sum { +// static void run(T z, T w) { +// T val = x + y + z + w; +// std::cout << val << std::endl; +// } +// } +// +// template +// void DispatchSum(const int64_t x, const int64_t y, int z, int w) { +// constexpr int64_t xmin = 1; +// constexpr int64_t xmax = 3; +// constexpr int64_t ymin = 2; +// constexpr int64_t ymax = 5; +// DispatchKernel2D(x, y, z, w); +// } +// +// Like its 1D counterpart, DispatchKernel2D uses template metaprogramming to +// compile specialized versions of sum for all values of (x, y) with +// xmin <= x <= xmax and ymin <= y <= ymax, then dispatches to the correct +// specialized version based on the runtime values of x and y. + +// Define some helper structs in an anonymous namespace. +namespace { + +// 1D dispatch: general case. +// Kernel is the function we want to dispatch to; it should take a typename and +// an int64_t as template args, and it should define a static void function +// run which takes any number of arguments of any type. +// In order to dispatch, we will take an additional template argument curN, +// and increment it via template recursion until it is equal to the run-time +// argument N. +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + typename... Args> +struct DispatchKernelHelper1D { + static void run(const int64_t N, Args... args) { + if (curN == N) { + // The compile-time value curN is equal to the run-time value N, so we + // can dispatch to the run method of the Kernel. + Kernel::run(args...); + } else if (curN < N) { + // Increment curN via template recursion + DispatchKernelHelper1D::run( + N, args...); + } + // We shouldn't get here -- throw an error? + } +}; + +// 1D dispatch: Specialization when curN == maxN +// We need this base case to avoid infinite template recursion. +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args> +struct DispatchKernelHelper1D { + static void run(const int64_t N, Args... args) { + if (N == maxN) { + Kernel::run(args...); + } + // We shouldn't get here -- throw an error? + } +}; + +// 2D dispatch, general case. +// This is similar to the 1D case: we take additional template args curN and +// curM, and increment them via template recursion until they are equal to +// the run-time values of N and M, at which point we dispatch to the run +// method of the kernel. +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + int64_t minM, + int64_t maxM, + int64_t curM, + typename... Args> +struct DispatchKernelHelper2D { + static void run(const int64_t N, const int64_t M, Args... args) { + if (curN == N && curM == M) { + Kernel::run(args...); + } else if (curN < N && curM < M) { + // Increment both curN and curM. This isn't strictly necessary; we could + // just increment one or the other at each step. But this helps to cut + // on the number of recursive calls we make. + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); + } else if (curN < N) { + // Increment curN only + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + curM, + Args...>::run(N, M, args...); + } else if (curM < M) { + // Increment curM only + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); + } + } +}; + +// 2D dispatch, specialization for curN == maxN +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + int64_t curM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + curM, + Args...> { + static void run(const int64_t N, const int64_t M, Args... args) { + if (maxN == N && curM == M) { + Kernel::run(args...); + } else if (curM < maxM) { + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + curM + 1, + Args...>::run(N, M, args...); + } + // We should not get here -- throw an error? + } +}; + +// 2D dispatch, specialization for curM == maxM +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t curN, + int64_t minM, + int64_t maxM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN, + minM, + maxM, + maxM, + Args...> { + static void run(const int64_t N, const int64_t M, Args... args) { + if (curN == N && maxM == M) { + Kernel::run(args...); + } else if (curN < maxN) { + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + curN + 1, + minM, + maxM, + maxM, + Args...>::run(N, M, args...); + } + // We should not get here -- throw an error? + } +}; + +// 2D dispatch, specialization for curN == maxN, curM == maxM +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + typename... Args> +struct DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + maxN, + minM, + maxM, + maxM, + Args...> { + static void run(const int64_t N, const int64_t M, Args... args) { + if (maxN == N && maxM == M) { + Kernel::run(args...); + } + // We should not get here -- throw an error? + } +}; + +} // namespace + +// This is the function we expect users to call to dispatch to 1D functions +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + typename... Args> +void DispatchKernel1D(const int64_t N, Args... args) { + if (minN <= N && N <= maxN) { + // Kick off the template recursion by calling the Helper with curN = minN + DispatchKernelHelper1D::run( + N, args...); + } + // Maybe throw an error if we tried to dispatch outside the allowed range? +} + +// This is the function we expect users to call to dispatch to 2D functions +template < + template + class Kernel, + typename T, + int64_t minN, + int64_t maxN, + int64_t minM, + int64_t maxM, + typename... Args> +void DispatchKernel2D(const int64_t N, const int64_t M, Args... args) { + if (minN <= N && N <= maxN && minM <= M && M <= maxM) { + // Kick off the template recursion by calling the Helper with curN = minN + // and curM = minM + DispatchKernelHelper2D< + Kernel, + T, + minN, + maxN, + minN, + minM, + maxM, + minM, + Args...>::run(N, M, args...); + } + // Maybe throw an error if we tried to dispatch outside the specified range? +} \ No newline at end of file diff --git a/unik3d/ops/knn/src/utils/index_utils.cuh b/unik3d/ops/knn/src/utils/index_utils.cuh new file mode 100644 index 0000000000000000000000000000000000000000..d3f7f7afa0d9fdc8c2c23187cd1c12d2ccd670f5 --- /dev/null +++ b/unik3d/ops/knn/src/utils/index_utils.cuh @@ -0,0 +1,224 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +// This converts dynamic array lookups into static array lookups, for small +// arrays up to size 32. +// +// Suppose we have a small thread-local array: +// +// float vals[10]; +// +// Ideally we should only index this array using static indices: +// +// for (int i = 0; i < 10; ++i) vals[i] = i * i; +// +// If we do so, then the CUDA compiler may be able to place the array into +// registers, which can have a big performance improvement. However if we +// access the array dynamically, the the compiler may force the array into +// local memory, which has the same latency as global memory. +// +// These functions convert dynamic array access into static array access +// using a brute-force lookup table. It can be used like this: +// +// float vals[10]; +// int idx = 3; +// float val = 3.14f; +// RegisterIndexUtils::set(vals, idx, val); +// float val2 = RegisterIndexUtils::get(vals, idx); +// +// The implementation is based on fbcuda/RegisterUtils.cuh: +// https://github.com/facebook/fbcuda/blob/master/RegisterUtils.cuh +// To avoid depending on the entire library, we just reimplement these two +// functions. The fbcuda implementation is a bit more sophisticated, and uses +// the preprocessor to generate switch statements that go up to N for each +// value of N. We are lazy and just have a giant explicit switch statement. +// +// We might be able to use a template metaprogramming approach similar to +// DispatchKernel1D for this. However DispatchKernel1D is intended to be used +// for dispatching to the correct CUDA kernel on the host, while this is +// is intended to run on the device. I was concerned that a metaprogramming +// approach for this might lead to extra function calls at runtime if the +// compiler fails to optimize them away, which could be very slow on device. +// However I didn't actually benchmark or test this. +template +struct RegisterIndexUtils { + __device__ __forceinline__ static T get(const T arr[N], int idx) { + if (idx < 0 || idx >= N) + return T(); + switch (idx) { + case 0: + return arr[0]; + case 1: + return arr[1]; + case 2: + return arr[2]; + case 3: + return arr[3]; + case 4: + return arr[4]; + case 5: + return arr[5]; + case 6: + return arr[6]; + case 7: + return arr[7]; + case 8: + return arr[8]; + case 9: + return arr[9]; + case 10: + return arr[10]; + case 11: + return arr[11]; + case 12: + return arr[12]; + case 13: + return arr[13]; + case 14: + return arr[14]; + case 15: + return arr[15]; + case 16: + return arr[16]; + case 17: + return arr[17]; + case 18: + return arr[18]; + case 19: + return arr[19]; + case 20: + return arr[20]; + case 21: + return arr[21]; + case 22: + return arr[22]; + case 23: + return arr[23]; + case 24: + return arr[24]; + case 25: + return arr[25]; + case 26: + return arr[26]; + case 27: + return arr[27]; + case 28: + return arr[28]; + case 29: + return arr[29]; + case 30: + return arr[30]; + case 31: + return arr[31]; + }; + return T(); + } + + __device__ __forceinline__ static void set(T arr[N], int idx, T val) { + if (idx < 0 || idx >= N) + return; + switch (idx) { + case 0: + arr[0] = val; + break; + case 1: + arr[1] = val; + break; + case 2: + arr[2] = val; + break; + case 3: + arr[3] = val; + break; + case 4: + arr[4] = val; + break; + case 5: + arr[5] = val; + break; + case 6: + arr[6] = val; + break; + case 7: + arr[7] = val; + break; + case 8: + arr[8] = val; + break; + case 9: + arr[9] = val; + break; + case 10: + arr[10] = val; + break; + case 11: + arr[11] = val; + break; + case 12: + arr[12] = val; + break; + case 13: + arr[13] = val; + break; + case 14: + arr[14] = val; + break; + case 15: + arr[15] = val; + break; + case 16: + arr[16] = val; + break; + case 17: + arr[17] = val; + break; + case 18: + arr[18] = val; + break; + case 19: + arr[19] = val; + break; + case 20: + arr[20] = val; + break; + case 21: + arr[21] = val; + break; + case 22: + arr[22] = val; + break; + case 23: + arr[23] = val; + break; + case 24: + arr[24] = val; + break; + case 25: + arr[25] = val; + break; + case 26: + arr[26] = val; + break; + case 27: + arr[27] = val; + break; + case 28: + arr[28] = val; + break; + case 29: + arr[29] = val; + break; + case 30: + arr[30] = val; + break; + case 31: + arr[31] = val; + break; + } + } +}; \ No newline at end of file diff --git a/unik3d/ops/knn/src/utils/mink.cuh b/unik3d/ops/knn/src/utils/mink.cuh new file mode 100644 index 0000000000000000000000000000000000000000..7512aabc2ca2a536eba8b4f14562ef47718cd064 --- /dev/null +++ b/unik3d/ops/knn/src/utils/mink.cuh @@ -0,0 +1,165 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#define MINK_H + +#include "index_utils.cuh" + +// A data structure to keep track of the smallest K keys seen so far as well +// as their associated values, intended to be used in device code. +// This data structure doesn't allocate any memory; keys and values are stored +// in arrays passed to the constructor. +// +// The implementation is generic; it can be used for any key type that supports +// the < operator, and can be used with any value type. +// +// Example usage: +// +// float keys[K]; +// int values[K]; +// MinK mink(keys, values, K); +// for (...) { +// // Produce some key and value from somewhere +// mink.add(key, value); +// } +// mink.sort(); +// +// Now keys and values store the smallest K keys seen so far and the values +// associated to these keys: +// +// for (int k = 0; k < K; ++k) { +// float key_k = keys[k]; +// int value_k = values[k]; +// } +template +class MinK { + public: + // Constructor. + // + // Arguments: + // keys: Array in which to store keys + // values: Array in which to store values + // K: How many values to keep track of + __device__ MinK(key_t* keys, value_t* vals, int K) + : keys(keys), vals(vals), K(K), _size(0) {} + + // Try to add a new key and associated value to the data structure. If the key + // is one of the smallest K seen so far then it will be kept; otherwise it + // it will not be kept. + // + // This takes O(1) operations if the new key is not kept, or if the structure + // currently contains fewer than K elements. Otherwise this takes O(K) time. + // + // Arguments: + // key: The key to add + // val: The value associated to the key + __device__ __forceinline__ void add(const key_t& key, const value_t& val) { + if (_size < K) { + keys[_size] = key; + vals[_size] = val; + if (_size == 0 || key > max_key) { + max_key = key; + max_idx = _size; + } + _size++; + } else if (key < max_key) { + keys[max_idx] = key; + vals[max_idx] = val; + max_key = key; + for (int k = 0; k < K; ++k) { + key_t cur_key = keys[k]; + if (cur_key > max_key) { + max_key = cur_key; + max_idx = k; + } + } + } + } + + // Get the number of items currently stored in the structure. + // This takes O(1) time. + __device__ __forceinline__ int size() { + return _size; + } + + // Sort the items stored in the structure using bubble sort. + // This takes O(K^2) time. + __device__ __forceinline__ void sort() { + for (int i = 0; i < _size - 1; ++i) { + for (int j = 0; j < _size - i - 1; ++j) { + if (keys[j + 1] < keys[j]) { + key_t key = keys[j]; + value_t val = vals[j]; + keys[j] = keys[j + 1]; + vals[j] = vals[j + 1]; + keys[j + 1] = key; + vals[j + 1] = val; + } + } + } + } + + private: + key_t* keys; + value_t* vals; + int K; + int _size; + key_t max_key; + int max_idx; +}; + +// This is a version of MinK that only touches the arrays using static indexing +// via RegisterIndexUtils. If the keys and values are stored in thread-local +// arrays, then this may allow the compiler to place them in registers for +// fast access. +// +// This has the same API as RegisterMinK, but doesn't support sorting. +// We found that sorting via RegisterIndexUtils gave very poor performance, +// and suspect it may have prevented the compiler from placing the arrays +// into registers. +template +class RegisterMinK { + public: + __device__ RegisterMinK(key_t* keys, value_t* vals) + : keys(keys), vals(vals), _size(0) {} + + __device__ __forceinline__ void add(const key_t& key, const value_t& val) { + if (_size < K) { + RegisterIndexUtils::set(keys, _size, key); + RegisterIndexUtils::set(vals, _size, val); + if (_size == 0 || key > max_key) { + max_key = key; + max_idx = _size; + } + _size++; + } else if (key < max_key) { + RegisterIndexUtils::set(keys, max_idx, key); + RegisterIndexUtils::set(vals, max_idx, val); + max_key = key; + for (int k = 0; k < K; ++k) { + key_t cur_key = RegisterIndexUtils::get(keys, k); + if (cur_key > max_key) { + max_key = cur_key; + max_idx = k; + } + } + } + } + + __device__ __forceinline__ int size() { + return _size; + } + + private: + key_t* keys; + value_t* vals; + int _size; + key_t max_key; + int max_idx; +}; \ No newline at end of file diff --git a/unik3d/ops/knn/src/utils/pytorch3d_cutils.h b/unik3d/ops/knn/src/utils/pytorch3d_cutils.h new file mode 100644 index 0000000000000000000000000000000000000000..c46b5aea5682bbef6960c7eb2cc1bf052fb6e32a --- /dev/null +++ b/unik3d/ops/knn/src/utils/pytorch3d_cutils.h @@ -0,0 +1,17 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * All rights reserved. + * + * This source code is licensed under the BSD-style license found in the + * LICENSE file in the root directory of this source tree. + */ + +#pragma once +#include + +#define CHECK_CUDA(x) TORCH_CHECK(x.is_cuda(), #x " must be a CUDA tensor.") +#define CHECK_CONTIGUOUS(x) \ + TORCH_CHECK(x.is_contiguous(), #x " must be contiguous.") +#define CHECK_CONTIGUOUS_CUDA(x) \ + CHECK_CUDA(x); \ + CHECK_CONTIGUOUS(x) \ No newline at end of file diff --git a/unik3d/ops/losses/__init__.py b/unik3d/ops/losses/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..07032e80bee5dd6da40eafb943919a5034296c1d --- /dev/null +++ b/unik3d/ops/losses/__init__.py @@ -0,0 +1,22 @@ +from .confidence import Confidence +from .dummy import Dummy +from .edge import SpatialGradient +from .local_ssi import LocalSSI +from .normals import LocalNormal +from .regression import PolarRegression, Regression +from .robust_loss import RobustLoss +from .scale import Scale +from .silog import SILog + +__all__ = [ + "Confidence", + "Dummy", + "SpatialGradient", + "LocalSSI", + "Regression", + "LocalNormal", + "RobustLoss", + "SILog", + "Scale", + "PolarRegression", +] diff --git a/unik3d/ops/losses/confidence.py b/unik3d/ops/losses/confidence.py new file mode 100644 index 0000000000000000000000000000000000000000..dca6ea4f63fadc10b3916e82ef2b82f076137219 --- /dev/null +++ b/unik3d/ops/losses/confidence.py @@ -0,0 +1,63 @@ +import torch +import torch.nn as nn + +from .utils import FNS, masked_mean + + +class Confidence(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + input_fn: str = "linear", + rescale: bool = True, + eps: float = 1e-5, + ): + super(Confidence, self).__init__() + self.name: str = self.__class__.__name__ + self.weight = weight + self.rescale = rescale + self.eps = eps + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target_pred: torch.Tensor, + target_gt: torch.Tensor, + mask: torch.Tensor, + ): + B, C = target_gt.shape[:2] + mask = mask.bool() + target_gt = target_gt.float().reshape(B, C, -1) + target_pred = target_pred.float().reshape(B, C, -1) + input = input.float().reshape(B, -1) + mask = mask.reshape(B, -1) + + if self.rescale: + target_pred = torch.stack( + [ + p * torch.median(gt[:, m]) / torch.median(p[:, m]) + for p, gt, m in zip(target_pred, target_gt, mask) + ] + ) + + error = torch.abs( + (self.input_fn(target_pred) - self.input_fn(target_gt)).norm(dim=1) - input + ) + losses = masked_mean(error, dim=[-1], mask=mask).squeeze(dim=-1) + losses = self.output_fn(losses) + + return losses + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + rescale=config.get("rescale", True), + ) + return obj diff --git a/unik3d/ops/losses/dummy.py b/unik3d/ops/losses/dummy.py new file mode 100644 index 0000000000000000000000000000000000000000..77a4c99dc888bf2a5d80c7d324257db863622d28 --- /dev/null +++ b/unik3d/ops/losses/dummy.py @@ -0,0 +1,17 @@ +import torch +import torch.nn as nn + + +class Dummy(nn.Module): + def __init__(self, *args, **kwargs): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight = 1.0 + + def forward(self, dummy: torch.Tensor, *args, **kwargs) -> torch.Tensor: + return torch.tensor([0.0] * dummy.shape[0], device=dummy.device) + + @classmethod + def build(cls, config): + obj = cls() + return obj diff --git a/unik3d/ops/losses/edge.py b/unik3d/ops/losses/edge.py new file mode 100644 index 0000000000000000000000000000000000000000..20e5265532a2eb462cf94b2858483c48df6d8922 --- /dev/null +++ b/unik3d/ops/losses/edge.py @@ -0,0 +1,191 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unik3d.utils.constants import VERBOSE +from unik3d.utils.geometric import dilate, erode +from unik3d.utils.misc import profile_method + +from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var, + masked_quantile) + + +class SpatialGradient(torch.nn.Module): + def __init__( + self, + weight: float, + input_fn: str, + output_fn: str, + fn: str, + scales: int = 1, + gamma: float = 1.0, + quantile: float = 0.0, + laplacian: bool = False, + canny_edge: bool = False, + **kwargs, + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight = weight + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.fn = REGRESSION_DICT[fn] + self.gamma = gamma + sobel_kernel_x = ( + torch.tensor([[-1, 0, 1], [-2, 0, 2], [-1, 0, 1]], dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(0) + ) + sobel_kernel_y = ( + torch.tensor([[-1, -2, -1], [0, 0, 0], [1, 2, 1]], dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(0) + ) + laplacian_kernel = ( + torch.tensor([[0, 1, 0], [1, -4, 1], [0, 1, 0]], dtype=torch.float32) + .unsqueeze(0) + .unsqueeze(0) + ) + ones = torch.ones(1, 1, 3, 3, dtype=torch.float32) + self.sobel_kernel_x = nn.Parameter(sobel_kernel_x, requires_grad=False) + self.sobel_kernel_y = nn.Parameter(sobel_kernel_y, requires_grad=False) + self.ones = nn.Parameter(ones, requires_grad=False) + self.laplacian_kernel = nn.Parameter(laplacian_kernel, requires_grad=False) + + self.quantile = quantile + self.scales = scales + self.laplacian = laplacian + self.canny_edge = canny_edge + + @profile_method(verbose=VERBOSE) + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input, + target, + mask, + quality=None, + ): + B = input.shape[0] + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + + # normalize to avoid scale issue, shift is not important as we are computing gradients + input_mean, input_var = masked_mean_var(input.detach(), mask, dim=(-3, -2, -1)) + target_mean, target_var = masked_mean_var(target, mask, dim=(-3, -2, -1)) + input = (input - input_mean) / (input_var + 1e-6) ** 0.5 + target = (target - target_mean) / (target_var + 1e-6) ** 0.5 + + loss = 0.0 + norm_factor = sum([(i + 1) ** 2 for i in range(self.scales)]) + for scale in range(self.scales): + if scale > 0: + input = F.interpolate( + input, + size=(input.shape[-2] // 2, input.shape[-1] // 2), + mode="bilinear", + align_corners=False, + antialias=True, + ) + target = F.interpolate( + target, + size=(target.shape[-2] // 2, target.shape[-1] // 2), + mode="nearest", + ) + mask = ( + F.interpolate( + mask.float(), + size=(mask.shape[-2] // 2, mask.shape[-1] // 2), + mode="nearest", + ) + > 0.9 + ) + grad_loss = self.loss(input, target, mask, quality) + # keep per pixel same weight + loss = loss + grad_loss * (self.scales - scale) ** 2 / norm_factor + + loss = self.output_fn(loss) + return loss + + def loss(self, input, target, mask, quality): + device, dtype = input.device, input.dtype + B, C, H, W = input.shape + + # sobel + input_edge_x = ( + F.conv2d(input, self.sobel_kernel_x.repeat(C, 1, 1, 1), groups=C) / 8 + ) + target_edge_x = ( + F.conv2d(target, self.sobel_kernel_x.repeat(C, 1, 1, 1), groups=C) / 8 + ) + input_edge_y = ( + F.conv2d(input, self.sobel_kernel_y.repeat(C, 1, 1, 1), groups=C) / 8 + ) + target_edge_y = ( + F.conv2d(target, self.sobel_kernel_y.repeat(C, 1, 1, 1), groups=C) / 8 + ) + input_edge = torch.stack([input_edge_x, input_edge_y], dim=-1) + target_edge = torch.stack([target_edge_x, target_edge_y], dim=-1) + + mask = F.conv2d(mask.clone().to(input.dtype), self.ones) == 9 + mask = mask.squeeze(1) + + error = input_edge - target_edge + error = error.norm(dim=-1).norm( + dim=1 + ) # take RMSE over xy-dir (isotropic) and over channel-dir (isotropic) + + if quality is not None: + for quality_level in [1, 2]: + current_quality = quality == quality_level + if current_quality.sum() > 0: + error_qtl = error[current_quality].detach() + mask_qtl = error_qtl < masked_quantile( + error_qtl, + mask[current_quality], + dims=[1, 2], + q=1 - self.quantile * quality_level, + ).view(-1, 1, 1) + mask[current_quality] = mask[current_quality] & mask_qtl + else: + error_qtl = error.detach() + mask = mask & ( + error_qtl + < masked_quantile( + error_qtl, mask, dims=[1, 2], q=1 - self.quantile + ).view(-1, 1, 1) + ) + + loss = masked_mean(error, mask, dim=(-2, -1)).squeeze(dim=(-2, -1)) + + if self.laplacian: + input_laplacian = ( + F.conv2d(input, self.laplacian_kernel.repeat(C, 1, 1, 1), groups=C) / 8 + ) + target_laplacian = ( + F.conv2d(target, self.laplacian_kernel.repeat(C, 1, 1, 1), groups=C) / 8 + ) + error_laplacian = self.fn( + input_laplacian - target_laplacian, gamma=self.gamma + ) + error_laplacian = (torch.mean(error_laplacian**2, dim=1) + 1e-6) ** 0.5 + loss_laplacian = masked_mean(error_laplacian, mask, dim=(-2, -1)).squeeze( + dim=(-2, -1) + ) + loss = loss + 0.1 * loss_laplacian + + return loss + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + input_fn=config["input_fn"], + output_fn=config["output_fn"], + fn=config["fn"], + gamma=config["gamma"], + quantile=config["quantile"], + scales=config["scales"], + laplacian=config["laplacian"], + ) + return obj diff --git a/unik3d/ops/losses/local_ssi.py b/unik3d/ops/losses/local_ssi.py new file mode 100644 index 0000000000000000000000000000000000000000..933506c5a581d3cc7f4fb2b1021b314c08c60dfd --- /dev/null +++ b/unik3d/ops/losses/local_ssi.py @@ -0,0 +1,315 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unik3d.utils.constants import VERBOSE +from unik3d.utils.geometric import downsample, erode +from unik3d.utils.misc import profile_method + +from .utils import (FNS, REGRESSION_DICT, ind2sub, masked_mean, + masked_quantile, ssi, ssi_nd) + + +def sample_strong_edges(edges_img, quantile=0.95, reshape=8): + # flat + edges_img = F.interpolate( + edges_img, scale_factor=1 / reshape, mode="bilinear", align_corners=False + ) + edges_img_flat = edges_img.flatten(1) + + # Find strong edges + edges_mask = edges_img_flat > torch.quantile( + edges_img_flat, quantile, dim=-1, keepdim=True + ) + num_samples = edges_mask.sum(dim=-1) + if (num_samples < 10).any(): + # sample random edges where num_samples < 2 + random = torch.rand_like(edges_img_flat[num_samples < 10, :]) > quantile + edges_mask[num_samples < 10, :] = torch.logical_or( + edges_mask[num_samples < 10, :], random + ) + num_samples = edges_mask.sum(dim=-1) + + min_samples = num_samples.min() + + # Compute the coordinates of the strong edges as B, N, 2 + edges_coords = torch.stack( + [torch.nonzero(x, as_tuple=False)[:min_samples].squeeze() for x in edges_mask] + ) + edges_coords = ( + torch.stack(ind2sub(edges_coords, edges_img.shape[-1]), dim=-1) * reshape + ) + return edges_coords + + +@torch.jit.script +def extract_patches(tensor, sample_coords, patch_size: tuple[int, int] = (32, 32)): + """ + Extracts patches around specified edge coordinates, with zero padding. + + Parameters: + - tensor: tenosr to be gatherd based on sampling (B, 1, H, W). + - sample_coords: Batch of edge coordinates as a PyTorch tensor of shape (B, num_coords, 2). + - patch_size: Tuple (width, height) representing the size of the patches. + + Returns: + - Patches as a PyTorch tensor of shape (B, num_coords, patch_height, patch_width). + """ + + N, _, H, W = tensor.shape + device = tensor.device + dtype = tensor.dtype + patch_width, patch_height = patch_size + pad_width = patch_width // 2 + pad_height = patch_height // 2 + + # Pad the RGB images for both sheep + tensor_padded = F.pad( + tensor, + (pad_width, pad_width, pad_height, pad_height), + mode="constant", + value=0.0, + ) + + # Adjust edge coordinates to account for padding + sample_coords_padded = sample_coords + torch.tensor( + [pad_height, pad_width], dtype=dtype, device=device + ).reshape(1, 1, 2) + + # Calculate the indices for gather operation + x_centers = sample_coords_padded[:, :, 1].int() + y_centers = sample_coords_padded[:, :, 0].int() + + all_patches = [] + for tensor_i, x_centers_i, y_centers_i in zip(tensor_padded, x_centers, y_centers): + patches = [] + for x_center, y_center in zip(x_centers_i, y_centers_i): + y_start, y_end = y_center - pad_height, y_center + pad_height + 1 + x_start, x_end = x_center - pad_width, x_center + pad_width + 1 + patches.append(tensor_i[..., y_start:y_end, x_start:x_end]) + all_patches.append(torch.stack(patches, dim=0)) + + return torch.stack(all_patches, dim=0).reshape(N, -1, patch_height * patch_width) + + +class LocalSSI(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + patch_size: tuple[int, int] = (32, 32), + min_samples: int = 4, + num_levels: int = 4, + fn: str = "l1", + rescale_fn: str = "ssi", + input_fn: str = "linear", + quantile: float = 0.1, + gamma: float = 1.0, + alpha: float = 1.0, + relative: bool = False, + eps: float = 1e-5, + ): + super(LocalSSI, self).__init__() + self.name: str = self.__class__.__name__ + self.weight = weight + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.fn = REGRESSION_DICT[fn] + self.min_samples = min_samples + self.eps = eps + patch_logrange = np.linspace( + start=np.log2(min(patch_size)), + stop=np.log2(max(patch_size)), + endpoint=True, + num=num_levels + 1, + ) + self.patch_logrange = [ + (x, y) for x, y in zip(patch_logrange[:-1], patch_logrange[1:]) + ] + self.rescale_fn = eval(rescale_fn) + self.quantile = quantile + self.gamma = gamma + self.alpha = alpha + self.relative = relative + + @profile_method(verbose=VERBOSE) + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + quality: torch.Tensor = None, + down_ratio: int = 1, + *args, + **kwargs, + ) -> torch.Tensor: + mask = mask.bool() + + if down_ratio > 1: + input = downsample(input, down_ratio) + target = downsample(target, down_ratio) + # downsample will ignore 0s in the patch "min", if there is a 1 -> set mask to 1 there + mask = downsample(mask.float(), down_ratio).bool() + + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + B, C, H, W = input.shape + total_errors = [] + + # save = random() < - 0.001 and is_main_process() + for ii, patch_logrange in enumerate(self.patch_logrange): + + log_kernel = ( + np.random.uniform(*patch_logrange) + if self.training + else np.mean(patch_logrange) + ) + kernel_size = int( + (2**log_kernel) * min(input.shape[-2:]) + ) # always smaller than min_shape + kernel_size = (kernel_size, kernel_size) + stride = (int(kernel_size[0] * 0.9), int(kernel_size[1] * 0.9)) + # stride = kernel_size + + # unfold is always exceeding right/bottom, roll image only negative + # to have them back in the unfolding window + max_roll = ( + (W - kernel_size[1]) % stride[1], + (H - kernel_size[0]) % stride[0], + ) + roll_x, roll_y = np.random.randint(-max_roll[0], 1), np.random.randint( + -max_roll[1], 1 + ) + input_fold = torch.roll(input, shifts=(roll_y, roll_x), dims=(2, 3)) + target_fold = torch.roll(target, shifts=(roll_y, roll_x), dims=(2, 3)) + mask_fold = torch.roll(mask.float(), shifts=(roll_y, roll_x), dims=(2, 3)) + + # unfold in patches + input_fold = F.unfold( + input_fold, kernel_size=kernel_size, stride=stride + ).permute( + 0, 2, 1 + ) # B N C*H_p*W_p + target_fold = F.unfold( + target_fold, kernel_size=kernel_size, stride=stride + ).permute(0, 2, 1) + mask_fold = ( + F.unfold(mask_fold, kernel_size=kernel_size, stride=stride) + .bool() + .permute(0, 2, 1) + ) + + # calculate error patchwise, then mean over patch, then over image based if sample size is significant + input_fold, target_fold, _ = self.rescale_fn( + input_fold, target_fold, mask_fold, dim=(-1,) + ) + error = self.fn( + input_fold - target_fold, alpha=self.alpha, gamma=self.gamma + ) + + # calculate elements more then 95 percentile and lower than 5percentile of error + if quality is not None: + N_patches = mask_fold.shape[1] + for quality_level in [1, 2]: + current_quality = quality == quality_level + if current_quality.sum() > 0: + error_qtl = error[current_quality].detach() + mask_qtl = error_qtl < masked_quantile( + error_qtl, + mask_fold[current_quality], + dims=[2], + q=1 - self.quantile * quality_level, + ).view(-1, N_patches, 1) + mask_fold[current_quality] = ( + mask_fold[current_quality] & mask_qtl + ) + else: + error_qtl = error.detach() + mask_fold = mask_fold & ( + error_qtl + < masked_quantile( + error_qtl, mask_fold, dims=[2], q=1 - self.quantile + ).view(B, -1, 1) + ) + + valid_patches = mask_fold.sum(dim=-1) >= self.min_samples + error_mean_patch = masked_mean(error, mask_fold, dim=(-1,)).squeeze(-1) + error_mean_image = self.output_fn(error_mean_patch.clamp(min=self.eps)) + error_mean_image = masked_mean( + error_mean_image, mask=valid_patches, dim=(-1,) + ) + total_errors.append(error_mean_image.squeeze(-1)) + + # global + input_rescale = input.reshape(B, C, -1).clone() + target_rescale = target.reshape(B, C, -1) + mask = mask.reshape(B, 1, -1).clone() + input, target, _ = self.rescale_fn( + input_rescale, + target_rescale, + mask, + dim=(-1,), + target_info=target_rescale.norm(dim=1, keepdim=True), + input_info=input_rescale.norm(dim=1, keepdim=True), + ) + error = input - target + error = error.norm(dim=1) if C > 1 else error.squeeze(1) + if self.relative: + error = error * torch.log( + 1.0 + 10.0 / target_rescale.norm(dim=1).clip(min=0.01) + ) + + error = self.fn(error, alpha=self.alpha, gamma=self.gamma).squeeze(1) + + mask = mask.squeeze(1) + valid_patches = mask.sum(dim=-1) >= 3 * self.min_samples # 30 samples per image + if quality is not None: + for quality_level in [1, 2]: + current_quality = quality == quality_level + if current_quality.sum() > 0: + error_qtl = error[current_quality].detach() + mask_qtl = error_qtl < masked_quantile( + error_qtl, + mask[current_quality], + dims=[1], + q=1 - self.quantile * quality_level, + ).view(-1, 1) + mask[current_quality] = mask[current_quality] & mask_qtl + else: + error_qtl = error.detach() + mask = mask & ( + error_qtl + < masked_quantile(error_qtl, mask, dims=[1], q=1 - self.quantile).view( + -1, 1 + ) + ) + + error_mean_image = masked_mean(error, mask, dim=(-1,)).squeeze(-1) + error_mean_image = ( + self.output_fn(error_mean_image.clamp(min=self.eps)) * valid_patches.float() + ) + + total_errors.append(error_mean_image) + + errors = torch.stack(total_errors).mean(dim=0) + return errors + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + patch_size=config["patch_size"], + output_fn=config["output_fn"], + min_samples=config["min_samples"], + num_levels=config["num_levels"], + input_fn=config["input_fn"], + quantile=config["quantile"], + gamma=config["gamma"], + alpha=config["alpha"], + rescale_fn=config["rescale_fn"], + fn=config["fn"], + relative=config["relative"], + ) + return obj diff --git a/unik3d/ops/losses/normals.py b/unik3d/ops/losses/normals.py new file mode 100644 index 0000000000000000000000000000000000000000..18f243cb70699d635f5eaae71514239a32e18f17 --- /dev/null +++ b/unik3d/ops/losses/normals.py @@ -0,0 +1,229 @@ +import itertools + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from unik3d.utils.geometric import dilate, downsample, erode + +from .utils import FNS, masked_mean, masked_quantile + + +class LocalNormal(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + min_samples: int = 4, + quantile: float = 0.2, + eps: float = 1e-5, + ): + super(LocalNormal, self).__init__() + self.name: str = self.__class__.__name__ + self.weight = weight + self.output_fn = FNS[output_fn] + self.min_samples = min_samples + self.eps = eps + self.patch_weight = torch.ones(1, 1, 3, 3, device="cuda") + self.quantile = quantile + + def bilateral_filter(self, rgb, surf, mask, patch_size=(9, 9)): + B, _, H, W = rgb.shape + sigma_surf = 0.4 + sigma_color = 0.3 + sigma_loc = 0.3 * max(H, W) + + grid_y, grid_x = torch.meshgrid(torch.arange(H), torch.arange(W)) + grid = torch.stack([grid_x, grid_y], dim=0).to(rgb.device) + grid = grid.unsqueeze(0).repeat(B, 1, 1, 1) + + paddings = [patch_size[0] // 2, patch_size[1] // 2] + rgbd = torch.cat([rgb, grid.float(), surf], dim=1) + + # format to B,H*W,C,H_p*W_p format + rgbd_neigh = F.pad(rgbd, 2 * paddings, mode="constant") + rgbd_neigh = F.unfold(rgbd_neigh, kernel_size=patch_size) + rgbd_neigh = rgbd_neigh.permute(0, 2, 1).reshape( + B, H * W, 8, -1 + ) # B N 8 H_p*W_p + mask_neigh = F.pad(mask.float(), 2 * paddings, mode="constant") + mask_neigh = F.unfold(mask_neigh, kernel_size=patch_size) + mask_neigh = mask_neigh.permute(0, 2, 1).reshape(B, H * W, -1) + rgbd = rgbd.permute(0, 2, 3, 1).reshape(B, H * W, 8, 1) # B H*W 8 1 + rgb_neigh = rgbd_neigh[:, :, :3, :] + grid_neigh = rgbd_neigh[:, :, 3:5, :] + surf_neigh = rgbd_neigh[:, :, 5:, :] + rgb = rgbd[:, :, :3, :] + grid = rgbd[:, :, 3:5, :] + surf = rgbd[:, :, 5:, :] + + # calc distance + rgb_dist = torch.norm(rgb - rgb_neigh, dim=-2, p=2) ** 2 + grid_dist = torch.norm(grid - grid_neigh, dim=-2, p=2) ** 2 + surf_dist = torch.norm(surf - surf_neigh, dim=-2, p=2) ** 2 + rgb_sim = torch.exp(-rgb_dist / 2 / sigma_color**2) + grid_sim = torch.exp(-grid_dist / 2 / sigma_loc**2) + surf_sim = torch.exp(-surf_dist / 2 / sigma_surf**2) + + weight = mask_neigh * rgb_sim * grid_sim * surf_sim # B H*W H_p*W_p + weight = weight / weight.sum(dim=-1, keepdim=True).clamp(min=1e-5) + z = (surf_neigh * weight.unsqueeze(-2)).sum(dim=-1) + return z.reshape(B, H, W, 3).permute(0, 3, 1, 2) + + def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor): + P0 = xyz + mask = mask.float() + normals, masks_valid_triangle = [], [] + combinations = list(itertools.combinations_with_replacement([-2, -1, 1, 2], 2)) + combinations += [c[::-1] for c in combinations] + # combinations = [(1, 1), (-1, -1), (1, -1), (-1, 1)] + for shift_0, shift_1 in set(combinations): + P1 = torch.roll(xyz, shifts=(0, shift_0), dims=(-1, -2)) + P2 = torch.roll(xyz, shifts=(shift_1, 0), dims=(-1, -2)) + if (shift_0 > 0) ^ (shift_1 > 0): + P1, P2 = P2, P1 + vec1, vec2 = P1 - P0, P2 - P0 + normal = torch.cross(vec1, vec2, dim=1) + vec1_norm = torch.norm(vec1, dim=1, keepdim=True).clip(min=1e-8) + vec2_norm = torch.norm(vec2, dim=1, keepdim=True).clip(min=1e-8) + normal_norm = torch.norm(normal, dim=1, keepdim=True).clip(min=1e-8) + normals.append(normal / normal_norm) + is_valid = ( + torch.roll(mask, shifts=(0, shift_0), dims=(-1, -2)) + + torch.roll(mask, shifts=(shift_1, 0), dims=(-1, -2)) + + mask + == 3 + ) + is_valid = ( + (normal_norm > 1e-6) + & (vec1_norm > 1e-6) + & (vec2_norm > 1e-6) + & is_valid + ) + masks_valid_triangle.append(is_valid) + + normals = torch.stack(normals, dim=-1) + mask_valid_triangle = torch.stack(masks_valid_triangle, dim=-1).float() + mask_valid = mask_valid_triangle.sum(dim=-1) + normals = (normals * mask_valid_triangle).sum(dim=-1) / mask_valid.clamp( + min=1.0 + ) + normals_norm = torch.norm(normals, dim=1, keepdim=True).clip(min=1e-8) + normals = normals / normals_norm + mask_valid = ( + (mask_valid > 0.001) + & (~normals.sum(dim=1, keepdim=True).isnan()) + & (normals_norm > 1e-6) + ) + return normals, mask_valid # B 3 H W, B 1 H W + + # def get_surface_normal(self, xyz: torch.Tensor, mask: torch.Tensor): + # x, y, z = torch.unbind(xyz, dim=1) # B 3 H W + # x = x.unsqueeze(1) # B 1 H W + # y = y.unsqueeze(1) + # z = z.unsqueeze(1) + + # mask_float = mask.float() + # paddings = [self.patch_weight.shape[-2] // 2, self.patch_weight.shape[-1] // 2] + # num_samples = F.conv2d(mask_float, weight=self.patch_weight, padding=paddings).clamp(min=1.0) # B 1 H W + # mask_invalid = num_samples < self.min_samples + + # xx = x * x + # yy = y * y + # zz = z * z + # xy = x * y + # xz = x * z + # yz = y * z + # xx_patch = F.conv2d(xx * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # yy_patch = F.conv2d(yy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # zz_patch = F.conv2d(zz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # xy_patch = F.conv2d(xy * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # xz_patch = F.conv2d(xz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # yz_patch = F.conv2d(yz * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + + # x_patch = F.conv2d(x * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # y_patch = F.conv2d(y * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + # z_patch = F.conv2d(z * mask_float, weight=self.patch_weight, padding=paddings) / num_samples + + # ATA = torch.stack([xx_patch, xy_patch, xz_patch, xy_patch, yy_patch, yz_patch, xz_patch, yz_patch, zz_patch], dim=-1).squeeze(1) # B H W 9 + # ATA = torch.reshape(ATA, (ATA.shape[0], ATA.shape[1], ATA.shape[2], 3, 3)) # B H W 3 3 + # eps_identity = torch.eye(3, device=ATA.device, dtype=ATA.dtype).unsqueeze(0) # 1 3 3 + # ATA = ATA + 1e-6 * eps_identity + + # AT1 = torch.stack([x_patch, y_patch, z_patch], dim=-1).squeeze(1).unsqueeze(-1) # B H W 3 1 + + # det = torch.linalg.det(ATA) + # mask_invalid_inverse = det.abs() < 1e-12 + # mask_invalid = mask_invalid.squeeze(1) | mask_invalid_inverse + # AT1[mask_invalid, :, :] = 0 + # ATA[mask_invalid, :, :] = eps_identity + + # ATA_inv = torch.linalg.inv(ATA) + # normals = (ATA_inv @ AT1).squeeze(dim=-1) # B H W 3 + # normals = normals / torch.norm(normals, dim=-1, keepdim=True).clip(min=1e-8) + # mask_invalid = mask_invalid | (torch.sum(normals, dim=-1) == 0.0) + + # # flip normals, based if a * x + b * y + c * z < 0 -> change sign of normals + # mean_patch_xyz = AT1.squeeze(-1) + # orient_mask = torch.sum(normals * mean_patch_xyz, dim=-1).unsqueeze(-1) > 0 + # normals = (2 * orient_mask.to(ATA.dtype) - 1) * normals + + # return normals.permute(0, 3, 1, 2), ~mask_invalid.unsqueeze(1) # B 3 H W, B H W + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward(self, input: torch.Tensor, target: torch.Tensor, mask, valid): + if not valid.any(): + return 0.0 * input.mean(dim=(1, 2, 3)) + + input = input.float() + target = target.float() + + mask = erode(mask, kernel_size=3) + target_normal, mask_target = self.get_surface_normal(target[valid], mask[valid]) + input_normal, mask_input = self.get_surface_normal( + input[valid], torch.ones_like(mask[valid]) + ) + + gt_similarity = F.cosine_similarity(input_normal, target_normal, dim=1) # B H W + mask_target = ( + mask_target.squeeze(1) & (gt_similarity < 0.999) & (gt_similarity > -0.999) + ) + error = F.relu((1 - gt_similarity) / 2 - 0.01) + + error_full = torch.ones_like(mask.squeeze(1).float()) + error_full[valid] = error + mask_full = torch.ones_like(mask.squeeze(1)) + mask_full[valid] = mask_target + + error_qtl = error_full.detach() + mask_full = mask_full & ( + error_qtl + < masked_quantile( + error_qtl, mask_full, dims=[1, 2], q=1 - self.quantile + ).view(-1, 1, 1) + ) + + loss = masked_mean(error_full, mask=mask_full, dim=(-2, -1)).squeeze( + dim=(-2, -1) + ) # B + loss = self.output_fn(loss) + return loss + + def von_mises(self, input, target, mask, kappa): + score = torch.cosine_similarity(input, target, dim=1).unsqueeze(1) + mask_cosine = torch.logical_and( + mask, torch.logical_and(score.detach() < 0.999, score.detach() > -0.999) + ) + nll = masked_mean( + kappa * (1 - score), mask=mask_cosine, dim=(-1, -2, -3) + ).squeeze() + return nll + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + output_fn=config["output_fn"], + quantile=config.get("quantile", 0.2), + ) + return obj diff --git a/unik3d/ops/losses/regression.py b/unik3d/ops/losses/regression.py new file mode 100644 index 0000000000000000000000000000000000000000..6f16af625325cadadcc77744fae8508d495da372 --- /dev/null +++ b/unik3d/ops/losses/regression.py @@ -0,0 +1,166 @@ +import torch +import torch.nn as nn + +from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile + + +class Regression(nn.Module): + def __init__( + self, + weight: float, + gamma: float, + fn: str, + input_fn: str, + output_fn: str, + alpha: float = 1.0, + dims: tuple[int] = (-1,), + quantile: float = 0.0, + **kwargs, + ): + super().__init__() + self.name = self.__class__.__name__ + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.fn = REGRESSION_DICT[fn] + self.weight = weight + self.gamma = gamma + self.alpha = alpha + self.dims = dims + self.quantile = quantile + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + if mask is not None: # usually it is just repeated + mask = mask[:, 0] + + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + error = self.fn(input - target, gamma=self.gamma, alpha=self.alpha).mean(dim=1) + if self.quantile > 0.0: + mask_quantile = error < masked_quantile( + error, mask, dims=self.dims, q=1 - self.quantile + ).view(-1, *((1,) * len(self.dims))) + mask = mask & mask_quantile if mask is not None else mask_quantile + mean_error = masked_mean(data=error, mask=mask, dim=self.dims).squeeze( + self.dims + ) + mean_error = self.output_fn(mean_error) + return mean_error + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + fn=config["fn"], + gamma=config["gamma"], + alpha=config.get("alpha", 1.0), + output_fn=config["output_fn"], + input_fn=config["input_fn"], + dims=config.get("dims", (-1,)), + quantile=config.get("quantile", 0.0), + ) + return obj + + +class PolarRegression(nn.Module): + def __init__( + self, + weight: float, + gamma: float, + fn: str, + input_fn: str, + output_fn: str, + alpha: float = 1.0, + dims: list[int] = [-1, -2], + polar_weight: float = 1.0, + polar_asym: float = 0.5, + **kwargs, + ): + super().__init__() + self.name = self.__class__.__name__ + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.fn = REGRESSION_DICT[fn] + self.weight = weight + self.gamma = gamma + self.alpha = alpha + self.dims = dims + self.polar_weight = polar_weight + self.polar_asym = polar_asym + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + if mask is not None: # usually it is just repeated + mask = mask.squeeze(1) + + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + input = input / torch.norm(input, dim=1, keepdim=True).clamp(min=1e-5) + target = target / torch.norm(target, dim=1, keepdim=True).clamp(min=1e-5) + + x_target, y_target, z_target = target.unbind(dim=1) + z_clipped = z_target.clip(min=-0.99999, max=0.99999) + x_clipped = x_target.abs().clip(min=1e-5) * (2 * (x_target > 0).float() - 1) + polar_target = torch.arccos(z_clipped) + azimuth_target = torch.atan2(y_target, x_clipped) + + x_input, y_input, z_input = input.unbind(dim=1) + z_clipped = z_input.clip(min=-0.99999, max=0.99999) + x_clipped = x_input.abs().clip(min=1e-5) * (2 * (x_input > 0).float() - 1) + polar_input = torch.arccos(z_clipped) + azimuth_input = torch.atan2(y_input, x_clipped) + + polar_error = self.fn( + polar_input - polar_target, gamma=self.gamma, alpha=self.alpha + ) + azimuth_error = self.fn( + azimuth_input - azimuth_target, gamma=self.gamma, alpha=self.alpha + ) + + quantile_weight = torch.ones_like(polar_input) + quantile_weight[ + (polar_target > polar_input) & (polar_target > torch.pi / 2) + ] = (2 * self.polar_asym) + quantile_weight[ + (polar_target <= polar_input) & (polar_target > torch.pi / 2) + ] = 2 * (1 - self.polar_asym) + + mean_polar_error = masked_mean( + data=polar_error * quantile_weight, mask=mask, dim=self.dims + ).squeeze(self.dims) + mean_azimuth_error = masked_mean( + data=azimuth_error, mask=mask, dim=self.dims + ).squeeze(self.dims) + mean_error = (self.polar_weight * mean_polar_error + mean_azimuth_error) / ( + 1 + self.polar_weight + ) + + mean_error = self.output_fn(mean_error) + return mean_error + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + fn=config["fn"], + gamma=config["gamma"], + alpha=config.get("alpha", 1.0), + output_fn=config["output_fn"], + input_fn=config["input_fn"], + dims=config.get("dims", (-1,)), + polar_weight=config["polar_weight"], + polar_asym=config["polar_asym"], + ) + return obj diff --git a/unik3d/ops/losses/resources/__init__.py b/unik3d/ops/losses/resources/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/unik3d/ops/losses/resources/partition_spline.npz b/unik3d/ops/losses/resources/partition_spline.npz new file mode 100644 index 0000000000000000000000000000000000000000..e2813d0d8397d603eeed56c3b75e045ad5bd226b Binary files /dev/null and b/unik3d/ops/losses/resources/partition_spline.npz differ diff --git a/unik3d/ops/losses/robust_loss.py b/unik3d/ops/losses/robust_loss.py new file mode 100644 index 0000000000000000000000000000000000000000..2ef99be1a9b117e3b537dbe5fc6a3829fea7f081 --- /dev/null +++ b/unik3d/ops/losses/robust_loss.py @@ -0,0 +1,709 @@ +import numpy as np +import torch +import torch.nn as nn + +from .utils import FNS, masked_mean + + +def log_safe(x): + """The same as torch.log(x), but clamps the input to prevent NaNs.""" + x = torch.as_tensor(x) + return torch.log(torch.min(x, torch.tensor(33e37).to(x))) + + +def log1p_safe(x): + """The same as torch.log1p(x), but clamps the input to prevent NaNs.""" + x = torch.as_tensor(x) + return torch.log1p(torch.min(x, torch.tensor(33e37).to(x))) + + +def exp_safe(x): + """The same as torch.exp(x), but clamps the input to prevent NaNs.""" + x = torch.as_tensor(x) + return torch.exp(torch.min(x, torch.tensor(87.5).to(x))) + + +def expm1_safe(x): + """The same as tf.math.expm1(x), but clamps the input to prevent NaNs.""" + x = torch.as_tensor(x) + return torch.expm1(torch.min(x, torch.tensor(87.5).to(x))) + + +def inv_softplus(y): + """The inverse of tf.nn.softplus().""" + y = torch.as_tensor(y) + return torch.where(y > 87.5, y, torch.log(torch.expm1(y))) + + +def logit(y): + """The inverse of tf.nn.sigmoid().""" + y = torch.as_tensor(y) + return -torch.log(1.0 / y - 1.0) + + +def affine_sigmoid(logits, lo=0, hi=1): + """Maps reals to (lo, hi), where 0 maps to (lo+hi)/2.""" + if not lo < hi: + raise ValueError("`lo` (%g) must be < `hi` (%g)" % (lo, hi)) + logits = torch.as_tensor(logits) + lo = torch.as_tensor(lo) + hi = torch.as_tensor(hi) + alpha = torch.sigmoid(logits) * (hi - lo) + lo + return alpha + + +def inv_affine_sigmoid(probs, lo=0, hi=1): + """The inverse of affine_sigmoid(., lo, hi).""" + if not lo < hi: + raise ValueError("`lo` (%g) must be < `hi` (%g)" % (lo, hi)) + probs = torch.as_tensor(probs) + lo = torch.as_tensor(lo) + hi = torch.as_tensor(hi) + logits = logit((probs - lo) / (hi - lo)) + return logits + + +def affine_softplus(x, lo=0, ref=1): + """Maps real numbers to (lo, infinity), where 0 maps to ref.""" + if not lo < ref: + raise ValueError("`lo` (%g) must be < `ref` (%g)" % (lo, ref)) + x = torch.as_tensor(x) + lo = torch.as_tensor(lo) + ref = torch.as_tensor(ref) + shift = inv_softplus(torch.tensor(1.0)) + y = (ref - lo) * torch.nn.Softplus()(x + shift) + lo + return y + + +def inv_affine_softplus(y, lo=0, ref=1): + """The inverse of affine_softplus(., lo, ref).""" + if not lo < ref: + raise ValueError("`lo` (%g) must be < `ref` (%g)" % (lo, ref)) + y = torch.as_tensor(y) + lo = torch.as_tensor(lo) + ref = torch.as_tensor(ref) + shift = inv_softplus(torch.tensor(1.0)) + x = inv_softplus((y - lo) / (ref - lo)) - shift + return x + + +def students_t_nll(x, df, scale): + """The NLL of a Generalized Student's T distribution (w/o including TFP).""" + x = torch.as_tensor(x) + df = torch.as_tensor(df) + scale = torch.as_tensor(scale) + log_partition = ( + torch.log(torch.abs(scale)) + + torch.lgamma(0.5 * df) + - torch.lgamma(0.5 * df + torch.tensor(0.5)) + + torch.tensor(0.5 * np.log(np.pi)) + ) + return ( + 0.5 * ((df + 1.0) * torch.log1p((x / scale) ** 2.0 / df) + torch.log(df)) + + log_partition + ) + + +def lossfun(x, alpha, scale, approximate=False, epsilon=1e-6): + r"""Implements the general form of the loss. + + This implements the rho(x, \alpha, c) function described in "A General and + Adaptive Robust Loss Function", Jonathan T. Barron, + https://arxiv.org/abs/1701.03077. + + Args: + x: The residual for which the loss is being computed. x can have any shape, + and alpha and scale will be broadcasted to match x's shape if necessary. + Must be a tensor of floats. + alpha: The shape parameter of the loss (\alpha in the paper), where more + negative values produce a loss with more robust behavior (outliers "cost" + less), and more positive values produce a loss with less robust behavior + (outliers are penalized more heavily). Alpha can be any value in + [-infinity, infinity], but the gradient of the loss with respect to alpha + is 0 at -infinity, infinity, 0, and 2. Must be a tensor of floats with the + same precision as `x`. Varying alpha allows + for smooth interpolation between a number of discrete robust losses: + alpha=-Infinity: Welsch/Leclerc Loss. + alpha=-2: Geman-McClure loss. + alpha=0: Cauchy/Lortentzian loss. + alpha=1: Charbonnier/pseudo-Huber loss. + alpha=2: L2 loss. + scale: The scale parameter of the loss. When |x| < scale, the loss is an + L2-like quadratic bowl, and when |x| > scale the loss function takes on a + different shape according to alpha. Must be a tensor of single-precision + floats. + approximate: a bool, where if True, this function returns an approximate and + faster form of the loss, as described in the appendix of the paper. This + approximation holds well everywhere except as x and alpha approach zero. + epsilon: A float that determines how inaccurate the "approximate" version of + the loss will be. Larger values are less accurate but more numerically + stable. Must be great than single-precision machine epsilon. + + Returns: + The losses for each element of x, in the same shape and precision as x. + """ + assert (scale > 0).all() + if approximate: + # Compute an approximate form of the loss which is faster, but innacurate + # when x and alpha are near zero. + b = torch.abs(alpha - 2) + epsilon + d = torch.where(alpha >= 0, alpha + epsilon, alpha - epsilon) + loss = (b / d) * (torch.pow((x / scale) ** 2 / b + 1.0, 0.5 * d) - 1.0) + else: + # This will be used repeatedly. + squared_scaled_x = (x / scale) ** 2 + + # The loss when alpha == 2. + loss_two = 0.5 * squared_scaled_x + # The loss when alpha == 0. + loss_zero = log1p_safe(0.5 * squared_scaled_x) + # The loss when alpha == -infinity. + loss_neginf = -torch.expm1(-0.5 * squared_scaled_x) + # The loss when alpha == +infinity. + loss_posinf = expm1_safe(0.5 * squared_scaled_x) + + # The loss when not in one of the above special cases. + machine_epsilon = torch.tensor(torch.finfo(alpha.dtype).eps).to(x) + # Clamp |2-alpha| to be >= machine epsilon so that it's safe to divide by. + beta_safe = torch.max(machine_epsilon, torch.abs(alpha - 2.0)) + # Clamp |alpha| to be >= machine epsilon so that it's safe to divide by. + alpha_safe = torch.where( + alpha >= 0, torch.ones_like(alpha), -torch.ones_like(alpha) + ) * torch.max(machine_epsilon, torch.abs(alpha)) + loss_otherwise = (beta_safe / alpha_safe) * ( + torch.pow(squared_scaled_x / beta_safe + 1.0, 0.5 * alpha) - 1.0 + ) + + # Select which of the cases of the loss to return. + loss = torch.where( + alpha == -float("inf"), + loss_neginf, + torch.where( + alpha == 0, + loss_zero, + torch.where( + alpha == 2, + loss_two, + torch.where(alpha == float("inf"), loss_posinf, loss_otherwise), + ), + ), + ) + + return loss + + +from pkg_resources import resource_stream + + +def interpolate1d(x, values, tangents): + r"""Perform cubic hermite spline interpolation on a 1D spline. + + The x coordinates of the spline knots are at [0 : 1 : len(values)-1]. + Queries outside of the range of the spline are computed using linear + extrapolation. See https://en.wikipedia.org/wiki/Cubic_Hermite_spline + for details, where "x" corresponds to `x`, "p" corresponds to `values`, and + "m" corresponds to `tangents`. + + Args: + x: A tensor of any size of single or double precision floats containing the + set of values to be used for interpolation into the spline. + values: A vector of single or double precision floats containing the value + of each knot of the spline being interpolated into. Must be the same + length as `tangents` and the same type as `x`. + tangents: A vector of single or double precision floats containing the + tangent (derivative) of each knot of the spline being interpolated into. + Must be the same length as `values` and the same type as `x`. + + Returns: + The result of interpolating along the spline defined by `values`, and + `tangents`, using `x` as the query values. Will be the same length and type + as `x`. + """ + assert torch.is_tensor(x) + assert torch.is_tensor(values) + assert torch.is_tensor(tangents) + float_dtype = x.dtype + assert values.dtype == float_dtype + assert tangents.dtype == float_dtype + assert len(values.shape) == 1 + assert len(tangents.shape) == 1 + assert values.shape[0] == tangents.shape[0] + + x_lo = torch.floor(torch.clamp(x, torch.as_tensor(0), values.shape[0] - 2)).type( + torch.int64 + ) + x_hi = x_lo + 1 + + # Compute the relative distance between each `x` and the knot below it. + t = x - x_lo.type(float_dtype) + + # Compute the cubic hermite expansion of `t`. + t_sq = t**2 + t_cu = t * t_sq + h01 = -2.0 * t_cu + 3.0 * t_sq + h00 = 1.0 - h01 + h11 = t_cu - t_sq + h10 = h11 - t_sq + t + + # Linearly extrapolate above and below the extents of the spline for all + # values. + value_before = tangents[0] * t + values[0] + value_after = tangents[-1] * (t - 1.0) + values[-1] + + # Cubically interpolate between the knots below and above each query point. + neighbor_values_lo = values[x_lo] + neighbor_values_hi = values[x_hi] + neighbor_tangents_lo = tangents[x_lo] + neighbor_tangents_hi = tangents[x_hi] + value_mid = ( + neighbor_values_lo * h00 + + neighbor_values_hi * h01 + + neighbor_tangents_lo * h10 + + neighbor_tangents_hi * h11 + ) + + # Return the interpolated or extrapolated values for each query point, + # depending on whether or not the query lies within the span of the spline. + return torch.where( + t < 0.0, value_before, torch.where(t > 1.0, value_after, value_mid) + ) + + +def partition_spline_curve(alpha): + """Applies a curve to alpha >= 0 to compress its range before interpolation. + + This is a weird hand-crafted function designed to take in alpha values and + curve them to occupy a short finite range that works well when using spline + interpolation to model the partition function Z(alpha). Because Z(alpha) + is only varied in [0, 4] and is especially interesting around alpha=2, this + curve is roughly linear in [0, 4] with a slope of ~1 at alpha=0 and alpha=4 + but a slope of ~10 at alpha=2. When alpha > 4 the curve becomes logarithmic. + Some (input, output) pairs for this function are: + [(0, 0), (1, ~1.2), (2, 4), (3, ~6.8), (4, 8), (8, ~8.8), (400000, ~12)] + This function is continuously differentiable. + + Args: + alpha: A numpy array or tensor (float32 or float64) with values >= 0. + + Returns: + An array/tensor of curved values >= 0 with the same type as `alpha`, to be + used as input x-coordinates for spline interpolation. + """ + alpha = torch.as_tensor(alpha) + x = torch.where( + alpha < 4, + (2.25 * alpha - 4.5) / (torch.abs(alpha - 2) + 0.25) + alpha + 2, + 5.0 / 18.0 * log_safe(4 * alpha - 15) + 8, + ) + return x + + +def inv_partition_spline_curve(x): + """The inverse of partition_spline_curve().""" + x = torch.as_tensor(x) + assert (x >= 0).all() + alpha = torch.where( + x < 8, + 0.5 * x + + torch.where( + x <= 4, + 1.25 - torch.sqrt(1.5625 - x + 0.25 * x**2), + -1.25 + torch.sqrt(9.5625 - 3 * x + 0.25 * x**2), + ), + 3.75 + 0.25 * exp_safe(x * 3.6 - 28.8), + ) + return alpha + + +class Distribution: + # This is only a class so that we can pre-load the partition function spline. + def __init__(self): + # Load the values, tangents, and x-coordinate scaling of a spline that + # approximates the partition function. This was produced by running + # the script in fit_partition_spline.py + with resource_stream(__name__, "resources/partition_spline.npz") as spline_file: + with np.load(spline_file, allow_pickle=False) as f: + self._spline_x_scale = torch.tensor(f["x_scale"]) + self._spline_values = torch.tensor(f["values"]) + self._spline_tangents = torch.tensor(f["tangents"]) + + def log_base_partition_function(self, alpha): + r"""Approximate the distribution's log-partition function with a 1D spline. + + Because the partition function (Z(\alpha) in the paper) of the distribution + is difficult to model analytically, we approximate it with a (transformed) + cubic hermite spline: Each alpha is pushed through a nonlinearity before + being used to interpolate into a spline, which allows us to use a relatively + small spline to accurately model the log partition function over the range + of all non-negative input values. + + Args: + alpha: A tensor or scalar of single or double precision floats containing + the set of alphas for which we would like an approximate log partition + function. Must be non-negative, as the partition function is undefined + when alpha < 0. + + Returns: + An approximation of log(Z(alpha)) accurate to within 1e-6 + """ + alpha = torch.as_tensor(alpha) + assert (alpha >= 0).all() + # Transform `alpha` to the form expected by the spline. + x = partition_spline_curve(alpha) + # Interpolate into the spline. + return interpolate1d( + x * self._spline_x_scale.to(x), + self._spline_values.to(x), + self._spline_tangents.to(x), + ) + + def nllfun(self, x, alpha, scale): + r"""Implements the negative log-likelihood (NLL). + + Specifically, we implement -log(p(x | 0, \alpha, c) of Equation 16 in the + paper as nllfun(x, alpha, shape). + + Args: + x: The residual for which the NLL is being computed. x can have any shape, + and alpha and scale will be broadcasted to match x's shape if necessary. + Must be a tensor or numpy array of floats. + alpha: The shape parameter of the NLL (\alpha in the paper), where more + negative values cause outliers to "cost" more and inliers to "cost" + less. Alpha can be any non-negative value, but the gradient of the NLL + with respect to alpha has singularities at 0 and 2 so you may want to + limit usage to (0, 2) during gradient descent. Must be a tensor or numpy + array of floats. Varying alpha in that range allows for smooth + interpolation between a Cauchy distribution (alpha = 0) and a Normal + distribution (alpha = 2) similar to a Student's T distribution. + scale: The scale parameter of the loss. When |x| < scale, the NLL is like + that of a (possibly unnormalized) normal distribution, and when |x| > + scale the NLL takes on a different shape according to alpha. Must be a + tensor or numpy array of floats. + + Returns: + The NLLs for each element of x, in the same shape and precision as x. + """ + # `scale` and `alpha` must have the same type as `x`. + # x = torch.as_tensor(x) + # alpha = torch.as_tensor(alpha) + # scale = torch.as_tensor(scale) + # assert (alpha >= 0).all() + # assert (scale >= 0).all() + # float_dtype = x.dtype + # assert alpha.dtype == float_dtype + # assert scale.dtype == float_dtype + + loss = lossfun(x, alpha, scale, approximate=False) + log_partition = torch.log(scale) + self.log_base_partition_function(alpha) + nll = loss + log_partition + return nll + + def draw_samples(self, alpha, scale): + r"""Draw samples from the robust distribution. + + This function implements Algorithm 1 the paper. This code is written to + allow + for sampling from a set of different distributions, each parametrized by its + own alpha and scale values, as opposed to the more standard approach of + drawing N samples from the same distribution. This is done by repeatedly + performing N instances of rejection sampling for each of the N distributions + until at least one proposal for each of the N distributions has been + accepted. + All samples are drawn with a zero mean, to use a non-zero mean just add each + mean to each sample. + + Args: + alpha: A tensor/scalar or numpy array/scalar of floats where each element + is the shape parameter of that element's distribution. + scale: A tensor/scalar or numpy array/scalar of floats where each element + is the scale parameter of that element's distribution. Must be the same + shape as `alpha`. + + Returns: + A tensor with the same shape and precision as `alpha` and `scale` where + each element is a sample drawn from the distribution specified for that + element by `alpha` and `scale`. + """ + alpha = torch.as_tensor(alpha) + scale = torch.as_tensor(scale) + assert (alpha >= 0).all() + assert (scale >= 0).all() + float_dtype = alpha.dtype + assert scale.dtype == float_dtype + + cauchy = torch.distributions.cauchy.Cauchy(0.0, np.sqrt(2.0)) + uniform = torch.distributions.uniform.Uniform(0, 1) + samples = torch.zeros_like(alpha) + accepted = torch.zeros(alpha.shape).type(torch.bool) + while not accepted.type(torch.uint8).all(): + # Draw N samples from a Cauchy, our proposal distribution. + cauchy_sample = torch.reshape( + cauchy.sample((np.prod(alpha.shape),)), alpha.shape + ) + cauchy_sample = cauchy_sample.type(alpha.dtype) + + # Compute the likelihood of each sample under its target distribution. + nll = self.nllfun( + cauchy_sample, + torch.as_tensor(alpha).to(cauchy_sample), + torch.tensor(1).to(cauchy_sample), + ) + + # Bound the NLL. We don't use the approximate loss as it may cause + # unpredictable behavior in the context of sampling. + nll_bound = lossfun( + cauchy_sample, + torch.tensor(0.0, dtype=cauchy_sample.dtype), + torch.tensor(1.0, dtype=cauchy_sample.dtype), + approximate=False, + ) + self.log_base_partition_function(alpha) + + # Draw N samples from a uniform distribution, and use each uniform sample + # to decide whether or not to accept each proposal sample. + uniform_sample = torch.reshape( + uniform.sample((np.prod(alpha.shape),)), alpha.shape + ) + uniform_sample = uniform_sample.type(alpha.dtype) + accept = uniform_sample <= torch.exp(nll_bound - nll) + + # If a sample is accepted, replace its element in `samples` with the + # proposal sample, and set its bit in `accepted` to True. + samples = torch.where(accept, cauchy_sample, samples) + accepted = accepted | accept + + # Because our distribution is a location-scale family, we sample from + # p(x | 0, \alpha, 1) and then scale each sample by `scale`. + samples *= scale + return samples + + +class AdaptiveLossFunction(nn.Module): + """The adaptive loss function on a matrix. + + This class behaves differently from lossfun() and + distribution.nllfun(), which are "stateless", allow the caller to specify the + shape and scale of the loss, and allow for arbitrary sized inputs. This + class only allows for rank-2 inputs for the residual `x`, and expects that + `x` is of the form [batch_index, dimension_index]. This class then + constructs free parameters (torch Parameters) that define the alpha and scale + parameters for each dimension of `x`, such that all alphas are in + (`alpha_lo`, `alpha_hi`) and all scales are in (`scale_lo`, Infinity). + The assumption is that `x` is, say, a matrix where x[i,j] corresponds to a + pixel at location j for image i, with the idea being that all pixels at + location j should be modeled with the same shape and scale parameters across + all images in the batch. If the user wants to fix alpha or scale to be a + constant, + this can be done by setting alpha_lo=alpha_hi or scale_lo=scale_init + respectively. + """ + + def __init__( + self, + num_dims, + alpha_lo=0.001, + alpha_hi=1.999, + alpha_init=None, + scale_lo=1e-5, + scale_init=1.0, + ): + """Sets up the loss function. + + Args: + num_dims: The number of dimensions of the input to come. + float_dtype: The floating point precision of the inputs to come. + device: The device to run on (cpu, cuda, etc). + alpha_lo: The lowest possible value for loss's alpha parameters, must be + >= 0 and a scalar. Should probably be in (0, 2). + alpha_hi: The highest possible value for loss's alpha parameters, must be + >= alpha_lo and a scalar. Should probably be in (0, 2). + alpha_init: The value that the loss's alpha parameters will be initialized + to, must be in (`alpha_lo`, `alpha_hi`), unless `alpha_lo` == `alpha_hi` + in which case this will be ignored. Defaults to (`alpha_lo` + + `alpha_hi`) / 2 + scale_lo: The lowest possible value for the loss's scale parameters. Must + be > 0 and a scalar. This value may have more of an effect than you + think, as the loss is unbounded as scale approaches zero (say, at a + delta function). + scale_init: The initial value used for the loss's scale parameters. This + also defines the zero-point of the latent representation of scales, so + SGD may cause optimization to gravitate towards producing scales near + this value. + """ + super(AdaptiveLossFunction, self).__init__() + + if not np.isscalar(alpha_lo): + raise ValueError( + "`alpha_lo` must be a scalar, but is of type {}".format(type(alpha_lo)) + ) + if not np.isscalar(alpha_hi): + raise ValueError( + "`alpha_hi` must be a scalar, but is of type {}".format(type(alpha_hi)) + ) + if alpha_init is not None and not np.isscalar(alpha_init): + raise ValueError( + "`alpha_init` must be None or a scalar, but is of type {}".format( + type(alpha_init) + ) + ) + if not alpha_lo >= 0: + raise ValueError("`alpha_lo` must be >= 0, but is {}".format(alpha_lo)) + if not alpha_hi >= alpha_lo: + raise ValueError( + "`alpha_hi` = {} must be >= `alpha_lo` = {}".format(alpha_hi, alpha_lo) + ) + if alpha_init is not None and alpha_lo != alpha_hi: + if not (alpha_init > alpha_lo and alpha_init < alpha_hi): + raise ValueError( + "`alpha_init` = {} must be in (`alpha_lo`, `alpha_hi`) = ({} {})".format( + alpha_init, alpha_lo, alpha_hi + ) + ) + if not np.isscalar(scale_lo): + raise ValueError( + "`scale_lo` must be a scalar, but is of type {}".format(type(scale_lo)) + ) + if not np.isscalar(scale_init): + raise ValueError( + "`scale_init` must be a scalar, but is of type {}".format( + type(scale_init) + ) + ) + if not scale_lo > 0: + raise ValueError("`scale_lo` must be > 0, but is {}".format(scale_lo)) + if not scale_init >= scale_lo: + raise ValueError( + "`scale_init` = {} must be >= `scale_lo` = {}".format( + scale_init, scale_lo + ) + ) + + self.num_dims = num_dims + + self.distribution = Distribution() + + if alpha_lo == alpha_hi: + # If the range of alphas is a single item, then we just fix `alpha` to be a constant. + fixed_alpha = torch.tensor([[alpha_lo]]).repeat(1, self.num_dims) + self.register_parameter( + "fixed_alpha", torch.nn.Parameter(fixed_alpha, requires_grad=False) + ) + self.alpha = lambda: self.fixed_alpha + else: + # Otherwise we construct a "latent" alpha variable and define `alpha` + # As an affine function of a sigmoid on that latent variable, initialized + # such that `alpha` starts off as `alpha_init`. + if alpha_init is None: + alpha_init = (alpha_lo + alpha_hi) / 2.0 + latent_alpha_init = inv_affine_sigmoid(alpha_init, lo=alpha_lo, hi=alpha_hi) + self.register_parameter( + "latent_alpha", + torch.nn.Parameter( + latent_alpha_init.clone() + .detach() + .view(1, 1) + .repeat(1, self.num_dims), + requires_grad=True, + ), + ) + self.alpha = lambda: affine_sigmoid( + self.latent_alpha, lo=alpha_lo, hi=alpha_hi + ) + + if scale_lo == scale_init: + # If the difference between the minimum and initial scale is zero, then + # we just fix `scale` to be a constant. + fixed_scale = torch.tensor([[scale_init]]).repeat(1, self.num_dims) + self.register_parameter( + "fixed_scale", torch.nn.Parameter(fixed_scale, requires_grad=False) + ) + + self.scale = lambda: self.fixed_scale + else: + # Otherwise we construct a "latent" scale variable and define `scale` + # As an affine function of a softplus on that latent variable. + self.register_parameter( + "latent_scale", + torch.nn.Parameter(torch.zeros((1, self.num_dims)), requires_grad=True), + ) + self.scale = lambda: affine_softplus( + self.latent_scale, lo=scale_lo, ref=scale_init + ) + + def lossfun(self, x, **kwargs): + """Computes the loss on a matrix. + + Args: + x: The residual for which the loss is being computed. Must be a rank-2 + tensor, where the innermost dimension is the batch index, and the + outermost dimension must be equal to self.num_dims. Must be a tensor or + numpy array of type self.float_dtype. + **kwargs: Arguments to be passed to the underlying distribution.nllfun(). + + Returns: + A tensor of the same type and shape as input `x`, containing the loss at + each element of `x`. These "losses" are actually negative log-likelihoods + (as produced by distribution.nllfun()) and so they are not actually + bounded from below by zero. You'll probably want to minimize their sum or + mean. + """ + assert len(x.shape) == 2 + return self.distribution.nllfun(x, self.alpha(), self.scale(), **kwargs) + + +class RobustLoss(nn.Module): + def __init__( + self, + weight: float = 1.0, + adaptive: bool = False, + scale: float = 1.0, + alpha: float = 1.0, + input_fn: str = "linear", + output_fn: str = "linear", + num_dims: int = 3, + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.weight: float = weight + if not adaptive: + self.loss = AdaptiveLossFunction( + num_dims=num_dims, + alpha_lo=alpha, + alpha_hi=alpha, + scale_lo=scale, + scale_init=scale, + ) + else: + self.loss = AdaptiveLossFunction( + num_dims=num_dims, alpha_init=alpha, scale_init=scale + ) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + error = self.input_fn(input) - self.input_fn(target) + loss_map = self.loss.lossfun(error.reshape(-1, error.shape[-1])) + + mean_error = masked_mean( + data=loss_map.view(*error.shape).mean(dim=-1), mask=mask, dim=(-1,) + ).mean(dim=-1) + mean_error = self.output_fn(mean_error) + return mean_error + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + alpha=config["alpha"], + scale=config["scale"], + adaptive=config["adaptive"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + ) + return obj diff --git a/unik3d/ops/losses/scale.py b/unik3d/ops/losses/scale.py new file mode 100644 index 0000000000000000000000000000000000000000..080fac9e162b3dede7a0084e7bdc8e4275dab129 --- /dev/null +++ b/unik3d/ops/losses/scale.py @@ -0,0 +1,87 @@ +import torch +import torch.nn as nn + +from unik3d.utils.constants import VERBOSE +from unik3d.utils.misc import profile_method + +from .utils import FNS, REGRESSION_DICT, masked_mean, masked_quantile + + +class Scale(nn.Module): + def __init__( + self, + weight: float, + output_fn: str = "sqrt", + input_fn: str = "disp", + fn: str = "l1", + quantile: float = 0.0, + gamma: float = 1.0, + alpha: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight: float = weight + self.dims = [-2, -1] + self.output_fn = FNS[output_fn] + self.input_fn = FNS[input_fn] + self.fn = REGRESSION_DICT[fn] + self.gamma = gamma + self.alpha = alpha + self.quantile = quantile + self.eps = eps + + @profile_method(verbose=VERBOSE) + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + quality: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + mask = mask.bool() + input = self.input_fn(input.float()) + target = self.input_fn(target.float()) + error = self.fn(target - input, alpha=self.alpha, gamma=self.gamma) + + if self.quantile > 0.0: + if quality is not None: + for quality_level in [1, 2]: + current_quality = quality == quality_level + if current_quality.sum() > 0: + error_qtl = error[current_quality].detach().abs() + mask_qtl = error_qtl < masked_quantile( + error_qtl, + mask[current_quality], + dims=[1, 2, 3], + q=1 - self.quantile * quality_level, + ).view(-1, 1, 1, 1) + mask[current_quality] = mask[current_quality] & mask_qtl + else: + error_qtl = error.detach().abs() + mask = mask & ( + error_qtl + < masked_quantile( + error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile + ).view(-1, 1, 1, 1) + ) + + error_image = masked_mean(data=error, mask=mask, dim=self.dims).squeeze(1, 2, 3) + + error_image = self.output_fn(error_image) + return error_image + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + input_fn=config["input_fn"], + fn=config["fn"], + output_fn=config["output_fn"], + gamma=config["gamma"], + alpha=config["alpha"], + quantile=config.get("quantile", 0.1), + ) + return obj diff --git a/unik3d/ops/losses/silog.py b/unik3d/ops/losses/silog.py new file mode 100644 index 0000000000000000000000000000000000000000..f7b4d4b46ba1db34c328a507a015abbb46b4a3e2 --- /dev/null +++ b/unik3d/ops/losses/silog.py @@ -0,0 +1,127 @@ +import torch +import torch.nn as nn + +from unik3d.utils.constants import VERBOSE +from unik3d.utils.misc import profile_method + +from .utils import (FNS, REGRESSION_DICT, masked_mean, masked_mean_var, + masked_quantile) + + +class SILog(nn.Module): + def __init__( + self, + weight: float, + input_fn: str = "linear", + output_fn: str = "sqrt", + fn: str = "l1", + integrated: bool = False, + dims: bool = (-3, -2, -1), + quantile: float = 0.0, + alpha: float = 1.0, + gamma: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.name: str = self.__class__.__name__ + self.weight: float = weight + + self.dims = dims + self.input_fn = FNS[input_fn] + self.output_fn = FNS[output_fn] + self.fn = REGRESSION_DICT[fn] + self.eps: float = eps + self.integrated = integrated + self.quantile = quantile + self.alpha = alpha + self.gamma = gamma + + @profile_method(verbose=VERBOSE) + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def forward( + self, + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor | None = None, + si: torch.Tensor | None = None, + quality: torch.Tensor | None = None, + **kwargs, + ) -> torch.Tensor: + mask = mask.bool() + + if si.any(): + rescale = torch.stack( + [x[m > 0].median() for x, m in zip(target, target)] + ) / torch.stack([x[m > 0].detach().median() for x, m in zip(input, target)]) + if rescale.isnan().any(): + print( + "NaN in rescale", rescale.isnan().squeeze(), mask.sum(dim=[1, 2, 3]) + ) + rescale = torch.nan_to_num(rescale, nan=1.0) + input = (1 - si.int()).view(-1, 1, 1, 1) * input + ( + rescale * si.int() + ).view(-1, 1, 1, 1) * input + + error = self.input_fn(input.float()) - self.input_fn(target.float()) + if quality is not None: + for quality_level in [1, 2]: + current_quality = quality == quality_level + if current_quality.sum() > 0: + error_qtl = error[current_quality].detach().abs() + mask_qtl = error_qtl < masked_quantile( + error_qtl, + mask[current_quality], + dims=[1, 2, 3], + q=1 - self.quantile * quality_level, + ).view(-1, 1, 1, 1) + mask[current_quality] = mask[current_quality] & mask_qtl + else: + error_qtl = error.detach().abs() + mask = mask & ( + error_qtl + < masked_quantile( + error_qtl, mask, dims=[1, 2, 3], q=1 - self.quantile + ).view(-1, 1, 1, 1) + ) + + mean_error, var_error = masked_mean_var( + data=error, mask=mask, dim=self.dims, keepdim=False + ) + if var_error.ndim > 1: + var_error = var_error.mean(dim=-1) + + if self.integrated > 0.0: + scale_error = masked_mean( + self.fn(error, alpha=self.alpha, gamma=self.gamma), + mask=mask, + dim=self.dims, + ).reshape(-1) + var_error = var_error + self.integrated * scale_error + + out_loss = self.output_fn(var_error) + if out_loss.isnan().any(): + print( + "NaN in SILog variance, input, target, mask, target>0, error", + var_error.isnan().squeeze(), + input[mask].isnan().any(), + target[mask].isnan().any(), + mask.any(dim=[1, 2, 3]), + (target > 0.0).any(dim=[1, 2, 3]), + error[mask].isnan().any(), + ) + return out_loss + + @classmethod + def build(cls, config): + obj = cls( + weight=config["weight"], + dims=config["dims"], + output_fn=config["output_fn"], + input_fn=config["input_fn"], + fn=config["fn"], + alpha=config["alpha"], + gamma=config["gamma"], + integrated=config.get("integrated", False), + quantile=config["quantile"], + ) + return obj diff --git a/unik3d/ops/losses/utils.py b/unik3d/ops/losses/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4a0dc5928a0a71448ba27e5e4b19a2f0f8436fd9 --- /dev/null +++ b/unik3d/ops/losses/utils.py @@ -0,0 +1,314 @@ +from math import log, pi, prod +from typing import Any, Dict, List, Optional, Tuple + +import torch + +FNS = { + "sqrto": lambda x: torch.sqrt(x + 1), + "sqrt": lambda x: torch.sqrt(x + 1e-4), + "log": lambda x: torch.log(x + 1e-4), + "log1": lambda x: torch.log(x + 1), + # transition from log(1/x) to 1/x at x=100 + # if x -> 0 : log(1/x), if x -> inf : log(1+1/x) -> 1/x + hot + "log1i": lambda x: torch.log(1 + 50 / (1e-4 + x)), + "log10": lambda x: torch.log10(1e-4 + x), + "log2": lambda x: torch.log2(1e-4 + x), + "linear": lambda x: x, + "square": torch.square, + "disp": lambda x: 1 / (x + 1e-4), + "disp1": lambda x: 1 / (1 + x), +} + + +FNS_INV = { + "sqrt": torch.square, + "log": torch.exp, + "log1": lambda x: torch.exp(x) - 1, + "linear": lambda x: x, + "square": torch.sqrt, + "disp": lambda x: 1 / x, +} + + +def masked_mean_var( + data: torch.Tensor, mask: torch.Tensor, dim: List[int], keepdim: bool = True +): + if mask is None: + return data.mean(dim=dim, keepdim=keepdim), data.var(dim=dim, keepdim=keepdim) + # if data[mask].isnan().any(): + # print("Warning: NaN in masked_mean_var, valid_pixels before and after", mask.sum(dim=dim).squeeze(), (mask & ~data.isnan()).sum(dim=dim).squeeze()) + mask = (mask & ~data.isnan().any(dim=1, keepdim=True)).float() + data = torch.nan_to_num(data, nan=0.0) + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + mask_var = torch.sum( + mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) / torch.clamp(mask_sum, min=1.0) + if not keepdim: + mask_mean, mask_var = mask_mean.squeeze(dim), mask_var.squeeze(dim) + return mask_mean, mask_var + + +def masked_mean(data: torch.Tensor, mask: torch.Tensor | None, dim: List[int]): + if mask is None: + return data.mean(dim=dim, keepdim=True) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum( + torch.nan_to_num(data, nan=0.0) * mask, dim=dim, keepdim=True + ) / mask_sum.clamp(min=1.0) + return mask_mean + + +def masked_quantile( + data: torch.Tensor, mask: torch.Tensor | None, dims: List[int], q: float +): + """ + Compute the quantile of the data only where the mask is 1 along specified dimensions. + + Args: + data (torch.Tensor): The input data tensor. + mask (torch.Tensor): The mask tensor with the same shape as data, containing 1s where data should be considered. + dims (list of int): The dimensions to compute the quantile over. + q (float): The quantile to compute, must be between 0 and 1. + + Returns: + torch.Tensor: The quantile computed over the specified dimensions, ignoring masked values. + """ + masked_data = data * mask if mask is not None else data + + # Get a list of all dimensions + all_dims = list(range(masked_data.dim())) + + # Revert negative dimensions + dims = [d % masked_data.dim() for d in dims] + + # Find the dimensions to keep (not included in the `dims` list) + keep_dims = [d for d in all_dims if d not in dims] + + # Permute dimensions to bring `dims` to the front + permute_order = dims + keep_dims + permuted_data = masked_data.permute(permute_order) + + # Reshape into 2D: (-1, remaining_dims) + collapsed_shape = ( + -1, + prod([permuted_data.size(d) for d in range(len(dims), permuted_data.dim())]), + ) + reshaped_data = permuted_data.reshape(collapsed_shape) + if mask is None: + return torch.quantile(reshaped_data, q, dim=0) + + permuted_mask = mask.permute(permute_order) + reshaped_mask = permuted_mask.reshape(collapsed_shape) + + # Calculate quantile along the first dimension where mask is true + quantiles = [] + for i in range(reshaped_data.shape[1]): + valid_data = reshaped_data[:, i][reshaped_mask[:, i]] + if valid_data.numel() == 0: + # print("Warning: No valid data found for quantile calculation.") + quantiles.append(reshaped_data[:, i].min() * 0.99) + else: + quantiles.append(torch.quantile(valid_data, q, dim=0)) + + # Stack back into a tensor with reduced dimensions + quantiles = torch.stack(quantiles) + quantiles = quantiles.reshape( + [permuted_data.size(d) for d in range(len(dims), permuted_data.dim())] + ) + + return quantiles + + +def masked_median(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + ndim = data.ndim + data = data.flatten(ndim - len(dim)) + mask = mask.flatten(ndim - len(dim)) + mask_median = torch.median(data[..., mask], dim=-1).values + return mask_median + + +def masked_median_mad(data: torch.Tensor, mask: torch.Tensor, dim: List[int]): + ndim = data.ndim + data = data.flatten(ndim - len(dim)) + mask = mask.flatten(ndim - len(dim)) + mask_median = torch.median(data[mask], dim=-1, keepdim=True).values + mask_mad = masked_mean((data - mask_median).abs(), mask, dim=(-1,)) + return mask_median, mask_mad + + +def masked_weighted_mean_var( + data: torch.Tensor, mask: torch.Tensor, weights: torch.Tensor, dim: Tuple[int, ...] +): + if mask is None: + return data.mean(dim=dim, keepdim=True), data.var(dim=dim, keepdim=True) + mask = mask.float() + mask_mean = torch.sum(data * mask * weights, dim=dim, keepdim=True) / torch.sum( + mask * weights, dim=dim, keepdim=True + ).clamp(min=1.0) + # V1**2 - V2, V1: sum w_i, V2: sum w_i**2 + denom = torch.sum(weights * mask, dim=dim, keepdim=True).square() - torch.sum( + (mask * weights).square(), dim=dim, keepdim=True + ) + # correction is V1 / (V1**2 - V2), if w_i=1 => N/(N**2 - N) => 1/(N-1) (unbiased estimator of variance, cvd) + correction_factor = torch.sum(mask * weights, dim=dim, keepdim=True) / denom.clamp( + min=1.0 + ) + mask_var = correction_factor * torch.sum( + weights * mask * (data - mask_mean) ** 2, dim=dim, keepdim=True + ) + return mask_mean, mask_var + + +def stable_masked_mean_var( + input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int] +): + # recalculate mask with points in 95% confidence interval + input_detach = input.detach() + input_mean, input_var = masked_mean_var(input_detach, mask=mask, dim=dim) + target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim) + input_std = (input_var).clip(min=1e-6).sqrt() + target_std = (target_var).clip(min=1e-6).sqrt() + stable_points_input = torch.logical_and( + input_detach > input_mean - 1.96 * input_std, + input_detach < input_mean + 1.96 * input_std, + ) + stable_points_target = torch.logical_and( + target > target_mean - 1.96 * target_std, + target < target_mean + 1.96 * target_std, + ) + stable_mask = stable_points_target & stable_points_input & mask + + input_mean, input_var = masked_mean_var(input, mask=stable_mask, dim=dim) + target_mean, target_var = masked_mean_var(target, mask=stable_mask, dim=dim) + return input_mean, input_var, target_mean, target_var, stable_mask + + +def ssi( + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + dim: list[int], + *args, + **kwargs, +) -> torch.Tensor: + # recalculate mask with points in 95% confidence interval + input_mean, input_var, target_mean, target_var, stable_mask = ( + stable_masked_mean_var(input, target, mask, dim) + ) + + # if target_var.min() < 1e-6: + # print( + # "Warning: target low", + # list(zip(target_var.squeeze().cpu().numpy(), + # target_mean.squeeze().cpu().numpy(), + # mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy(), + # stable_mask.reshape(target_var.shape[0], -1).sum(dim=-1).squeeze().cpu().numpy())) + # ) + # if input_var.min() < 1e-6: + # print("Warning: input variance is too low", input_var.squeeze(), input_mean.squeeze()) + if input_var.isnan().any(): + print("Warning: input variance is nan") + if input_var.isinf().any(): + print("Warning: input variance is isinf") + if input_mean.isnan().any(): + print("Warning: input m is nan") + if input_mean.isinf().any(): + print("Warning: input m is isinf") + target_normalized = (target - target_mean) / FNS["sqrt"](target_var) + input_normalized = (input - input_mean) / FNS["sqrt"](input_var) + return input_normalized, target_normalized, stable_mask + + +def ssi_nd( + input: torch.Tensor, + target: torch.Tensor, + mask: torch.Tensor, + dim: list[int], + input_info: torch.Tensor, + target_info: torch.Tensor, +) -> torch.Tensor: + input_mean, input_var, target_mean, target_var, stable_mask = ( + stable_masked_mean_var(input_info, target_info, mask, dim) + ) + if input_var.isnan().any(): + print("Warning: input variance is nan") + if input_var.isinf().any(): + print("Warning: input variance is isinf") + if input_mean.isnan().any(): + print("Warning: input m is nan") + if input_mean.isinf().any(): + print("Warning: input m is isinf") + target_normalized = (target - target_mean) / FNS["sqrt"](target_var) + input_normalized = (input - input_mean) / FNS["sqrt"](input_var) + return input_normalized, target_normalized, stable_mask + + +def stable_ssi( + input: torch.Tensor, target: torch.Tensor, mask: torch.Tensor, dim: list[int] +) -> torch.Tensor: + input_mean, input_var = masked_mean_var(input, mask=mask, dim=dim) + target_mean, target_var = masked_mean_var(target, mask=mask, dim=dim) + target_normalized = (target - target_mean) / torch.sqrt(target_var.clamp(min=1e-6)) + input_normalized = (input - input_mean) / torch.sqrt(input_var.clamp(min=1e-6)) + return input_normalized, target_normalized, mask + + +def ind2sub(idx, cols): + r = idx // cols + c = idx % cols + return r, c + + +def sub2ind(r, c, cols): + idx = r * cols + c + return idx + + +def l2(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: + return (input_tensor / gamma) ** 2 + + +def l1(input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs) -> torch.Tensor: + return torch.abs(input_tensor) + + +def charbonnier( + input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + return gamma * torch.sqrt(torch.square(input_tensor / gamma) + 1) - 1 + + +def cauchy( + input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + return gamma * torch.log(torch.square(input_tensor / gamma) + 1) + log(gamma * pi) + + +def geman_mcclure( + input_tensor: torch.Tensor, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + return gamma * torch.square(input_tensor) / (torch.square(input_tensor) + gamma) + + +def robust_loss( + input_tensor: torch.Tensor, alpha: float, gamma: float = 1.0, *args, **kwargs +) -> torch.Tensor: + coeff = abs(alpha - 2) / alpha + power = torch.square(input_tensor / gamma) / abs(alpha - 2) + 1 + return ( + gamma * coeff * (torch.pow(power, alpha / 2) - 1) + ) # mult gamma to keep grad magnitude invariant wrt gamma + + +REGRESSION_DICT = { + "l2": l2, + "l1": l1, + "cauchy": cauchy, + "charbonnier": charbonnier, + "geman_mcclure": geman_mcclure, + "robust_loss": robust_loss, +} diff --git a/unik3d/ops/scheduler.py b/unik3d/ops/scheduler.py new file mode 100644 index 0000000000000000000000000000000000000000..7c7776b667a4c79603b14615223e5cc12857ab06 --- /dev/null +++ b/unik3d/ops/scheduler.py @@ -0,0 +1,128 @@ +import weakref + +import numpy as np + + +class PlainCosineScheduler(object): + def __init__( + self, + klass, + key, + warmup_iters, + total_iters, + overwrite=False, + init_value=None, + base_value=None, + final_value=None, + step_init=-1, + ): + super().__init__() + self.iter = step_init + self.overwrite = overwrite + self.base_value = base_value + self.init_value = init_value if init_value is not None else base_value + self.final_value = final_value + self.total_iters = total_iters + self.warmup_iters = warmup_iters + self.key = key + self.klass = klass + self.schedulers = [self.get_scheduler()] + + def get_scheduler(self): + init_value = self.init_value + base_value = self.base_value + final_value = self.final_value + warmup_iters = self.warmup_iters + total_iters = self.total_iters + + # normalize in 0,1, then apply function (power) and denormalize + normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) + normalized_schedule = np.power(normalized_schedule, 1) + warmup_schedule = (base_value - init_value) * normalized_schedule + init_value + + # main scheduling + iters = np.arange(total_iters - warmup_iters + 1) + schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * iters / (len(iters) - 1)) + ) + return np.concatenate((warmup_schedule, schedule)) + + def step(self): + self.iter = self.iter + 1 + vals = self[self.iter] + for i, val in enumerate(vals): + setattr(self.klass, self.key, val) + + def __getitem__(self, it): + it = min(it, self.total_iters) + return [scheduler[it] for scheduler in self.schedulers] + + +class CosineScheduler(object): + def __init__( + self, + optimizer, + warmup_iters, + total_iters, + key, + overwrite=False, + init_value=None, + base_value=None, + final_value=None, + flat_iters=0, + step_init=-1, + ): + super().__init__() + self.iter = step_init + self.overwrite = overwrite + self.optimizer = optimizer + self.base_value = base_value + self.init_value = init_value + self.final_value = final_value + self.total_iters = total_iters + self.warmup_iters = warmup_iters + self.flat_iters = flat_iters + self.key = key + self.schedulers = [ + self.get_schedulers(group) for group in optimizer.param_groups + ] + + def get_schedulers(self, group): + init_value = group.get(self.key + "_init", self.init_value) + base_value = group.get(self.key + "_base", self.base_value) + final_value = group.get(self.key + "_final", self.final_value) + warmup_iters = self.warmup_iters + total_iters = self.total_iters + flat_iters = self.flat_iters + if self.overwrite: + final_value = self.final_value + + # normalize in 0,1, then apply function (power) and denormalize + normalized_schedule = np.linspace(0, 1, warmup_iters, endpoint=True) + normalized_schedule = np.power(normalized_schedule, 1) + warmup_schedule = (base_value - init_value) * normalized_schedule + init_value + + # flat scheduling] + flat_schedule = np.ones(flat_iters) * base_value + + # decay scheduling + decay_iters = np.arange(total_iters - warmup_iters - flat_iters + 1) + decay_schedule = final_value + 0.5 * (base_value - final_value) * ( + 1 + np.cos(np.pi * decay_iters / (len(decay_iters) - 1)) + ) + return np.concatenate((warmup_schedule, flat_schedule, decay_schedule)) + + def step(self): + self.iter = self.iter + 1 + vals = self[self.iter] + for group, val in zip(self.optimizer.param_groups, vals): + if isinstance(group[self.key], (tuple, list)): + val = (val, *group[self.key][1:]) + group[self.key] = val + + def __getitem__(self, it): + it = min(it, self.total_iters) + return [scheduler[it] for scheduler in self.schedulers] + + def get(self): + return [group[self.key] for group in self.optimizer.param_groups] diff --git a/unik3d/utils/__init__.py b/unik3d/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a6a7f7619dab9e9c117bc0e9348175b10f072dc7 --- /dev/null +++ b/unik3d/utils/__init__.py @@ -0,0 +1,41 @@ +from .camera import invert_pinhole, project_pinhole, unproject_pinhole +from .distributed import (barrier, get_dist_info, get_rank, get_world_size, + is_main_process, setup_multi_processes, setup_slurm, + sync_tensor_across_gpus) +from .evaluation_depth import (DICT_METRICS, DICT_METRICS_3D, eval_3d, + eval_depth) +from .geometric import spherical_zbuffer_to_euclidean, unproject_points +from .misc import (format_seconds, get_params, identity, recursive_index, + remove_padding, to_cpu) +from .validation import validate +from .visualization import colorize, image_grid, log_train_artifacts + +__all__ = [ + "eval_depth", + "eval_3d", + "DICT_METRICS", + "DICT_METRICS_3D", + "colorize", + "image_grid", + "log_train_artifacts", + "format_seconds", + "remove_padding", + "get_params", + "identity", + "is_main_process", + "setup_multi_processes", + "setup_slurm", + "sync_tensor_across_gpus", + "barrier", + "get_world_size", + "get_rank", + "unproject_points", + "spherical_zbuffer_to_euclidean", + "validate", + "get_dist_info", + "to_cpu", + "recursive_index", + "invert_pinhole", + "unproject_pinhole", + "project_pinhole", +] diff --git a/unik3d/utils/camera.py b/unik3d/utils/camera.py new file mode 100644 index 0000000000000000000000000000000000000000..561c3b3377237043c0f2e94451607fb35d694363 --- /dev/null +++ b/unik3d/utils/camera.py @@ -0,0 +1,1487 @@ +""" +Author: Luigi Piccinelli +Licensed under the CC-BY NC 4.0 license (http://creativecommons.org/licenses/by-nc/4.0/) +""" + +from copy import deepcopy + +import numpy as np +import torch +import torch.nn.functional as F + +from .coordinate import coords_grid +from .misc import recursive_apply, squeeze_list + + +def invert_pinhole(K): + fx = K[..., 0, 0] + fy = K[..., 1, 1] + cx = K[..., 0, 2] + cy = K[..., 1, 2] + K_inv = torch.zeros_like(K) + K_inv[..., 0, 0] = 1.0 / fx + K_inv[..., 1, 1] = 1.0 / fy + K_inv[..., 0, 2] = -cx / fx + K_inv[..., 1, 2] = -cy / fy + K_inv[..., 2, 2] = 1.0 + return K_inv + + +def unproject_pinhole(depth, K): + b, _, h, w = depth.shape + K_inv = invert_pinhole(K) + grid = coords_grid(b, h, w, homogeneous=True, device=depth.device) # [B, 3, H, W] + grid_flat = grid.reshape(b, -1, h * w) # [B, 3, H*W] + cam_coords = K_inv @ grid_flat + pcd = cam_coords.reshape(b, -1, h, w) * depth + return pcd + + +def project_pinhole(pcd, K): + b, _, h, w = pcd.shape + pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W] + cam_coords = K @ pcd_flat + pcd_proj = cam_coords[:, :2] / cam_coords[:, 2:].clamp(min=0.01) + pcd_proj = pcd_proj.reshape(b, 2, h, w) + return pcd_proj + + +class Camera: + def __init__(self, params=None, K=None): + if params.ndim == 1: + params = params.unsqueeze(0) + + if K is None: + K = ( + torch.eye(3, device=params.device, dtype=params.dtype) + .unsqueeze(0) + .repeat(params.shape[0], 1, 1) + ) + K[..., 0, 0] = params[..., 0] + K[..., 1, 1] = params[..., 1] + K[..., 0, 2] = params[..., 2] + K[..., 1, 2] = params[..., 3] + + self.params = params + self.K = K + self.overlap_mask = None + self.projection_mask = None + + def project(self, xyz): + raise NotImplementedError + + def unproject(self, uv): + raise NotImplementedError + + def get_projection_mask(self): + return self.projection_mask + + def get_overlap_mask(self): + return self.overlap_mask + + def reconstruct(self, depth): + id_coords = coords_grid( + 1, depth.shape[-2], depth.shape[-1], device=depth.device + ) + rays = self.unproject(id_coords) + return ( + rays / rays[:, -1:].clamp(min=1e-4) * depth.clamp(min=1e-4) + ) # assumption z>0!!! + + def resize(self, factor): + self.K[..., :2, :] *= factor + self.params[..., :4] *= factor + return self + + def to(self, device, non_blocking=False): + self.params = self.params.to(device, non_blocking=non_blocking) + self.K = self.K.to(device, non_blocking=non_blocking) + return self + + def get_rays(self, shapes, noisy=False): + b, h, w = shapes + uv = coords_grid(1, h, w, device=self.K.device, noisy=noisy) + rays = self.unproject(uv) + return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) + + def get_pinhole_rays(self, shapes, noisy=False): + b, h, w = shapes + uv = coords_grid(b, h, w, device=self.K.device, homogeneous=True, noisy=noisy) + rays = (invert_pinhole(self.K) @ uv.reshape(b, 3, -1)).reshape(b, 3, h, w) + return rays / torch.norm(rays, dim=1, keepdim=True).clamp(min=1e-4) + + def flip(self, H, W, direction="horizontal"): + new_cx = ( + W - self.params[:, 2] if direction == "horizontal" else self.params[:, 2] + ) + new_cy = H - self.params[:, 3] if direction == "vertical" else self.params[:, 3] + self.params = torch.stack( + [self.params[:, 0], self.params[:, 1], new_cx, new_cy], dim=1 + ) + self.K[..., 0, 2] = new_cx + self.K[..., 1, 2] = new_cy + return self + + def clone(self): + return deepcopy(self) + + def crop(self, left, top, right=None, bottom=None): + self.K[..., 0, 2] -= left + self.K[..., 1, 2] -= top + self.params[..., 2] -= left + self.params[..., 3] -= top + return self + + # helper function to get how fov changes based on new original size and new size + def get_new_fov(self, new_shape, original_shape): + new_hfov = 2 * torch.atan( + self.params[..., 2] / self.params[..., 0] * new_shape[1] / original_shape[1] + ) + new_vfov = 2 * torch.atan( + self.params[..., 3] / self.params[..., 1] * new_shape[0] / original_shape[0] + ) + return new_hfov, new_vfov + + def mask_overlap_projection(self, projected): + B, _, H, W = projected.shape + id_coords = coords_grid(B, H, W, device=projected.device) + + # check for mask where flow would overlap with other part of the image + # eleemtns coming from the border are then masked out + flow = projected - id_coords + gamma = 0.1 + sample_grid = gamma * flow + id_coords # sample along the flow + sample_grid[:, 0] = sample_grid[:, 0] / (W - 1) * 2 - 1 + sample_grid[:, 1] = sample_grid[:, 1] / (H - 1) * 2 - 1 + sampled_flow = F.grid_sample( + flow, + sample_grid.permute(0, 2, 3, 1), + mode="bilinear", + align_corners=False, + padding_mode="border", + ) + mask = ( + (1 - gamma) * torch.norm(flow, dim=1, keepdim=True) + < torch.norm(sampled_flow, dim=1, keepdim=True) + ) | (torch.norm(flow, dim=1, keepdim=True) < 1) + return mask + + def _pad_params(self): + # Ensure params are padded to length 16 + if self.params.shape[1] < 16: + padding = torch.zeros( + 16 - self.params.shape[1], + device=self.params.device, + dtype=self.params.dtype, + ) + padding = padding.view(*[(self.params.ndim - 1) * [1] + [-1]]) + padding = padding.repeat(self.params.shape[:-1] + (1,)) + return torch.cat([self.params, padding], dim=-1) + return self.params + + @staticmethod + def flatten_cameras(cameras): # -> list[Camera]: + # Recursively flatten BatchCamera into primitive cameras + flattened_cameras = [] + for camera in cameras: + if isinstance(camera, BatchCamera): + flattened_cameras.extend(BatchCamera.flatten_cameras(camera.cameras)) + elif isinstance(camera, list): + flattened_cameras.extend(camera) + else: + flattened_cameras.append(camera) + return flattened_cameras + + @staticmethod + def _stack_or_cat_cameras(cameras, func, **kwargs): + # Generalized method to handle stacking or concatenation + flat_cameras = BatchCamera.flatten_cameras(cameras) + K_matrices = [camera.K for camera in flat_cameras] + padded_params = [camera._pad_params() for camera in flat_cameras] + + stacked_K = func(K_matrices, **kwargs) + stacked_params = func(padded_params, **kwargs) + + # Keep track of the original classes + original_class = [x.__class__.__name__ for x in flat_cameras] + return BatchCamera(stacked_params, stacked_K, original_class, flat_cameras) + + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + if kwargs is None: + kwargs = {} + + if func is torch.cat: + return Camera._stack_or_cat_cameras(args[0], func, **kwargs) + + if func is torch.stack: + return Camera._stack_or_cat_cameras(args[0], func, **kwargs) + + if func is torch.flatten: + return Camera._stack_or_cat_cameras(args[0], torch.cat, **kwargs) + return super().__torch_function__(func, types, args, kwargs) + + @property + def device(self): + return self.K.device + + # here we assume that cx,cy are more or less H/2 and W/2 + @property + def hfov(self): + return 2 * torch.atan(self.params[..., 2] / self.params[..., 0]) + + @property + def vfov(self): + return 2 * torch.atan(self.params[..., 3] / self.params[..., 1]) + + @property + def max_fov(self): + return 150.0 / 180.0 * np.pi, 150.0 / 180.0 * np.pi + + +class Pinhole(Camera): + def __init__(self, params=None, K=None): + assert params is not None or K is not None + # params = [fx, fy, cx, cy] + if params is None: + params = torch.stack( + [K[..., 0, 0], K[..., 1, 1], K[..., 0, 2], K[..., 1, 2]], dim=-1 + ) + super().__init__(params=params, K=K) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, pcd): + b, _, h, w = pcd.shape + pcd_flat = pcd.reshape(b, 3, -1) # [B, 3, H*W] + cam_coords = self.K @ pcd_flat + pcd_proj = cam_coords[:, :2] / cam_coords[:, -1:].clamp(min=0.01) + pcd_proj = pcd_proj.reshape(b, 2, h, w) + invalid = ( + (pcd_proj[:, 0] < 0) + & (pcd_proj[:, 0] >= w) + & (pcd_proj[:, 1] < 0) + & (pcd_proj[:, 1] >= h) + ) + self.projection_mask = (~invalid).unsqueeze(1) + return pcd_proj + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv): + b, _, h, w = uv.shape + uv_flat = uv.reshape(b, 2, -1) # [B, 2, H*W] + uv_homogeneous = torch.cat( + [uv_flat, torch.ones(b, 1, h * w, device=uv.device)], dim=1 + ) # [B, 3, H*W] + K_inv = torch.inverse(self.K.float()) + xyz = K_inv @ uv_homogeneous + xyz = xyz / xyz[:, -1:].clip(min=1e-4) + xyz = xyz.reshape(b, 3, h, w) + self.unprojection_mask = xyz[:, -1:] > 1e-4 + return xyz + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def reconstruct(self, depth): + b, _, h, w = depth.shape + uv = coords_grid(b, h, w, device=depth.device) + xyz = self.unproject(uv) * depth.clip(min=0.0) + return xyz + + +class EUCM(Camera): + def __init__(self, params): + # params = [fx, fy, cx, cy, alpha, beta] + super().__init__(params=params, K=None) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + H, W = xyz.shape[-2:] + fx, fy, cx, cy, alpha, beta = self.params[:6].unbind(dim=1) + x, y, z = xyz.unbind(dim=1) + d = torch.sqrt(beta * (x**2 + y**2) + z**2) + + x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3) + y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3) + + Xnorm = fx * x + cx + Ynorm = fy * y + cy + + coords = torch.stack([Xnorm, Ynorm], dim=1) + + invalid = ( + (coords[:, 0] < 0) + | (coords[:, 0] > W) + | (coords[:, 1] < 0) + | (coords[:, 1] > H) + | (z < 0) + ) + self.projection_mask = (~invalid).unsqueeze(1) + + return coords + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv): + u, v = uv.unbind(dim=1) + fx, fy, cx, cy, alpha, beta = self.params.unbind(dim=1) + mx = (u - cx) / fx + my = (v - cy) / fy + r_square = mx**2 + my**2 + valid_mask = r_square < torch.where( + alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1)) + ) + sqrt_val = 1 - (2 * alpha - 1) * beta * r_square + mz = (1 - beta * (alpha**2) * r_square) / ( + alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha) + ) + coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5) + + x = coeff * mx + y = coeff * my + z = coeff * mz + self.unprojection_mask = valid_mask & (z > 1e-3) + + xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1) + return xnorm + + +class Spherical(Camera): + def __init__(self, params): + # Hfov and Vofv are in radians and halved! + # params: [fx, fy, cx, cy, W, H, HFoV/2, VFoV/2] + # fx,fy,cx,cy = dummy values + super().__init__(params=params, K=None) + + def resize(self, factor): + self.K[..., :2, :] *= factor + self.params[..., :6] *= factor + return self + + def crop(self, left, top, right, bottom): + self.K[..., 0, 2] -= left + self.K[..., 1, 2] -= top + self.params[..., 2] -= left + self.params[..., 3] -= top + W, H = self.params[..., 4], self.params[..., 5] + angle_ratio_W = (W - left - right) / W + angle_ratio_H = (H - top - bottom) / H + + self.params[..., 4] -= left + right + self.params[..., 5] -= top + bottom + + # rescale hfov and vfov + self.params[..., 6] *= angle_ratio_W + self.params[..., 7] *= angle_ratio_H + return self + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + width, height = self.params[..., 4], self.params[..., 5] + hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] + longitude = torch.atan2(xyz[:, 0], xyz[:, 2]) + latitude = torch.asin(xyz[:, 1] / torch.norm(xyz, dim=1).clamp(min=1e-5)) + + u = longitude / hfov * (width - 1) + (width - 1) / 2 + v = latitude / vfov * (height - 1) + (height - 1) / 2 + + return torch.stack([u, v], dim=1) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv): + u, v = uv.unbind(dim=1) + + width, height = self.params[..., 4], self.params[..., 5] + hfov, vfov = 2 * self.params[..., 6], 2 * self.params[..., 7] + longitude = (u - (width - 1) / 2) / (width - 1) * hfov + latitude = (v - (height - 1) / 2) / (height - 1) * vfov + x = torch.cos(latitude) * torch.sin(longitude) + z = torch.cos(latitude) * torch.cos(longitude) + y = torch.sin(latitude) + unit_sphere = torch.stack([x, y, z], dim=1) + unit_sphere = unit_sphere / torch.norm(unit_sphere, dim=1, keepdim=True).clip( + min=1e-5 + ) + + return unit_sphere + + def reconstruct(self, depth): + id_coords = coords_grid( + 1, depth.shape[-2], depth.shape[-1], device=depth.device + ) + return self.unproject(id_coords) * depth + + def get_new_fov(self, new_shape, original_shape): + new_hfov = 2 * self.params[..., 6] * new_shape[1] / original_shape[1] + new_vfov = 2 * self.params[..., 7] * new_shape[0] / original_shape[0] + return new_hfov, new_vfov + + @property + def hfov(self): + return 2 * self.params[..., 6] + + @property + def vfov(self): + return 2 * self.params[..., 7] + + @property + def max_fov(self): + return 2 * np.pi, 0.9 * np.pi # avoid strong distortion on tops + + +class OPENCV(Camera): + def __init__(self, params): + super().__init__(params=params, K=None) + # params: [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4] + self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 + assert ( + self.params[..., 7:10].abs().sum() == 0.0 + ), "Do not support poly division model" + self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 + self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + eps = 1e-9 + B, _, H, W = xyz.shape + N = H * W + xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) + + # Radial correction. + z = xyz[:, :, 2].reshape(B, N, 1) + z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) + ab = xyz[:, :, :2] / z + r = torch.norm(ab, dim=-1, p=2, keepdim=True) + th = r + + th_pow = torch.cat( + [torch.pow(th, 2 + i * 2) for i in range(3)], dim=-1 + ) # Create powers of th (th^3, th^5, ...) + distortion_coeffs_num = self.params[:, 4:7].reshape(B, 1, 3) + distortion_coeffs_den = self.params[:, 7:10].reshape(B, 1, 3) + th_num = 1 + torch.sum(th_pow * distortion_coeffs_num, dim=-1, keepdim=True) + th_den = 1 + torch.sum(th_pow * distortion_coeffs_den, dim=-1, keepdim=True) + + xr_yr = ab * th_num / th_den + uv_dist = xr_yr + + # Tangential correction. + p0 = self.params[..., -6].reshape(B, 1) + p1 = self.params[..., -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_tv = uv_dist[:, :, 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + uv_dist = torch.stack( + [uv_dist_tu, uv_dist_tv], dim=-1 + ) # Avoids in-place complaint. + + # Thin Prism correction. + s0 = self.params[..., -4].reshape(B, 1) + s1 = self.params[..., -3].reshape(B, 1) + s2 = self.params[..., -2].reshape(B, 1) + s3 = self.params[..., -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Finally, apply standard terms: focal length and camera centers. + if self.params.shape[-1] == 15: + fx_fy = self.params[..., 0].reshape(B, 1, 1) + cx_cy = self.params[..., 1:3].reshape(B, 1, 2) + else: + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + result = uv_dist * fx_fy + cx_cy + + result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) + invalid = ( + (result[:, 0] < 0) + | (result[:, 0] > W) + | (result[:, 1] < 0) + | (result[:, 1] > H) + ) + self.projection_mask = (~invalid).unsqueeze(1) + self.overlap_mask = self.mask_overlap_projection(result) + + return result + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv, max_iters: int = 10): + eps = 1e-3 + B, _, H, W = uv.shape + N = H * W + uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) + + if self.params.shape[-1] == 15: + fx_fy = self.params[..., 0].reshape(B, 1, 1) + cx_cy = self.params[..., 1:3].reshape(B, 1, 2) + else: + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Newton's method. + xr_yr = uv_dist.clone() # Initial guess. + max_iters_tanprism = ( + max_iters if self.use_thin_prism or self.use_tangential else 0 + ) + + for _ in range(max_iters_tanprism): + uv_dist_est = xr_yr.clone() + xr = xr_yr[..., 0].reshape(B, N) + yr = xr_yr[..., 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[..., 0].reshape(B, N) + yr_sq = xr_yr_sq[..., 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + + if self.use_tangential: + # Tangential terms. + p0 = self.params[..., -6].reshape(B, 1) + p1 = self.params[..., -5].reshape(B, 1) + uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + + if self.use_thin_prism: + # Thin Prism terms. + s0 = self.params[..., -4].reshape(B, 1) + s1 = self.params[..., -3].reshape(B, 1) + s2 = self.params[..., -2].reshape(B, 1) + s3 = self.params[..., -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) + + if self.use_tangential: + duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 + offdiag = 2.0 * (xr * p1 + yr * p0) + duv_dist_dxr_yr[..., 0, 1] = offdiag + duv_dist_dxr_yr[..., 1, 0] = offdiag + duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 + + if self.use_thin_prism: + xr_yr_sq_norm = xr_sq + yr_sq + temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) + duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) + duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) + temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) + duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) + duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) + + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=-1) + bot = torch.cat([-c, a], dim=-1) + inv = det * torch.cat([top, bot], dim=-2) + inv = inv.reshape(B, N, 2, 2) + diff = uv_dist - uv_dist_est + a = inv[..., 0, 0] + b = inv[..., 0, 1] + c = inv[..., 1, 0] + d = inv[..., 1, 1] + e = diff[..., 0] + f = diff[..., 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + max_iters_radial = max_iters if self.use_radial else 0 + c = ( + torch.tensor([2.0 * i + 3 for i in range(3)], device=self.device) + .reshape(1, 1, 3) + .repeat(B, 1, 1) + ) + radial_params_num = self.params[..., 4:7].reshape(B, 1, 3) + + # Trust region parameters + delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius + delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius + eta = 0.1 # Acceptable reduction threshold + + for i in range(max_iters_radial): + th_sq = th * th # th^2 + # Compute powers of th^2 up to th^(12) + theta_powers = torch.cat( + [th_sq ** (i + 1) for i in range(3)], dim=-1 + ) # Shape: (B, N, 6) + + # Compute th_radial: radial distortion model applied to th + th_radial = 1.0 + torch.sum( + theta_powers * radial_params_num, dim=-1, keepdim=True + ) + th_radial = th_radial * th # Multiply by th at the end + + # Compute derivative dthd_th + dthd_th = 1.0 + torch.sum( + c * radial_params_num * theta_powers, dim=-1, keepdim=True + ) + dthd_th = dthd_th # Already includes derivative terms + + # Compute residual + residual = th_radial - xr_yr_norm # Shape: (B, N, 1) + residual_norm = torch.norm(residual, dim=2, keepdim=True) # For each pixel + + # Check for convergence + if torch.max(torch.abs(residual)) < eps: + break + + # Avoid division by zero by adding a small epsilon + safe_dthd_th = dthd_th.clone() + zero_derivative_mask = dthd_th.abs() < eps + safe_dthd_th[zero_derivative_mask] = eps + + # Compute Newton's step + step = -residual / safe_dthd_th + + # Compute predicted reduction + predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) + + # Adjust step based on trust region + step_norm = torch.norm(step, dim=2, keepdim=True) + over_trust_mask = step_norm > delta + + # Scale step if it exceeds trust radius + step_scaled = step.clone() + step_scaled[over_trust_mask] = step[over_trust_mask] * ( + delta[over_trust_mask] / step_norm[over_trust_mask] + ) + + # Update theta + th_new = th + step_scaled + + # Compute new residual + th_sq_new = th_new * th_new + theta_powers_new = torch.cat( + [th_sq_new ** (j + 1) for j in range(3)], dim=-1 + ) + th_radial_new = 1.0 + torch.sum( + theta_powers_new * radial_params_num, dim=-1, keepdim=True + ) + th_radial_new = th_radial_new * th_new + residual_new = th_radial_new - xr_yr_norm + residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) + + # Compute actual reduction + actual_reduction = residual_norm - residual_new_norm + + # Compute ratio of actual to predicted reduction + # predicted_reduction[predicted_reduction.abs() < eps] = eps #* torch.sign(predicted_reduction[predicted_reduction.abs() < eps]) + rho = actual_reduction / predicted_reduction + rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 + + # Update trust radius delta + delta_update_mask = rho > 0.5 + delta[delta_update_mask] = torch.min( + 2.0 * delta[delta_update_mask], delta_max + ) + + delta_decrease_mask = rho < 0.2 + delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] + + # Accept or reject the step + accept_step_mask = rho > eta + th = torch.where(accept_step_mask, th_new, th) + + # Compute the ray direction using theta and xr_yr. + close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) + ray_dir = torch.where(close_to_zero, xr_yr, th / xr_yr_norm * xr_yr) + + ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) + ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) + + return ray + + +class Fisheye624(Camera): + def __init__(self, params): + super().__init__(params=params, K=None) + # params: [fx, fy, cx, cy, k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4] + self.use_radial = self.params[..., 4:10].abs().sum() > 1e-6 + self.use_tangential = self.params[..., 10:12].abs().sum() > 1e-6 + self.use_thin_prism = self.params[..., 12:].abs().sum() > 1e-6 + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + eps = 1e-9 + B, _, H, W = xyz.shape + N = H * W + xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) + + # Radial correction. + z = xyz[:, :, 2].reshape(B, N, 1) + z = torch.where(torch.abs(z) < eps, eps * torch.sign(z), z) + ab = xyz[:, :, :2] / z + r = torch.norm(ab, dim=-1, p=2, keepdim=True) + th = torch.atan(r) + th_divr = torch.where(r < eps, torch.ones_like(ab), ab / r) + + th_pow = torch.cat( + [torch.pow(th, 3 + i * 2) for i in range(6)], dim=-1 + ) # Create powers of th (th^3, th^5, ...) + distortion_coeffs = self.params[:, 4:10].reshape(B, 1, 6) + th_k = th + torch.sum(th_pow * distortion_coeffs, dim=-1, keepdim=True) + + xr_yr = th_k * th_divr + uv_dist = xr_yr + + # Tangential correction. + p0 = self.params[..., -6].reshape(B, 1) + p1 = self.params[..., -5].reshape(B, 1) + xr = xr_yr[:, :, 0].reshape(B, N) + yr = xr_yr[:, :, 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_tv = uv_dist[:, :, 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + uv_dist = torch.stack( + [uv_dist_tu, uv_dist_tv], dim=-1 + ) # Avoids in-place complaint. + + # Thin Prism correction. + s0 = self.params[..., -4].reshape(B, 1) + s1 = self.params[..., -3].reshape(B, 1) + s2 = self.params[..., -2].reshape(B, 1) + s3 = self.params[..., -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist[:, :, 0] = uv_dist[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist[:, :, 1] = uv_dist[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Finally, apply standard terms: focal length and camera centers. + if self.params.shape[-1] == 15: + fx_fy = self.params[..., 0].reshape(B, 1, 1) + cx_cy = self.params[..., 1:3].reshape(B, 1, 2) + else: + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + result = uv_dist * fx_fy + cx_cy + + result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) + invalid = ( + (result[:, 0] < 0) + | (result[:, 0] > W) + | (result[:, 1] < 0) + | (result[:, 1] > H) + ) + self.projection_mask = (~invalid).unsqueeze(1) + self.overlap_mask = self.mask_overlap_projection(result) + + return result + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv, max_iters: int = 10): + eps = 1e-3 + B, _, H, W = uv.shape + N = H * W + uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) + + if self.params.shape[-1] == 15: + fx_fy = self.params[..., 0].reshape(B, 1, 1) + cx_cy = self.params[..., 1:3].reshape(B, 1, 2) + else: + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Newton's method. + xr_yr = uv_dist.clone() # Initial guess. + max_iters_tanprism = ( + max_iters if self.use_thin_prism or self.use_tangential else 0 + ) + + for _ in range(max_iters_tanprism): + uv_dist_est = xr_yr.clone() + xr = xr_yr[..., 0].reshape(B, N) + yr = xr_yr[..., 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[..., 0].reshape(B, N) + yr_sq = xr_yr_sq[..., 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + + if self.use_tangential: + # Tangential terms. + p0 = self.params[..., -6].reshape(B, 1) + p1 = self.params[..., -5].reshape(B, 1) + uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + + if self.use_thin_prism: + # Thin Prism terms. + s0 = self.params[..., -4].reshape(B, 1) + s1 = self.params[..., -3].reshape(B, 1) + s2 = self.params[..., -2].reshape(B, 1) + s3 = self.params[..., -1].reshape(B, 1) + rd_4 = torch.square(rd_sq) + uv_dist_est[:, :, 0] = uv_dist_est[:, :, 0] + (s0 * rd_sq + s1 * rd_4) + uv_dist_est[:, :, 1] = uv_dist_est[:, :, 1] + (s2 * rd_sq + s3 * rd_4) + + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = uv.new_ones(B, N, 2, 2) + + if self.use_tangential: + duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 + offdiag = 2.0 * (xr * p1 + yr * p0) + duv_dist_dxr_yr[..., 0, 1] = offdiag + duv_dist_dxr_yr[..., 1, 0] = offdiag + duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 + + if self.use_thin_prism: + xr_yr_sq_norm = xr_sq + yr_sq + temp1 = 2.0 * (s0 + 2.0 * s1 * xr_yr_sq_norm) + duv_dist_dxr_yr[..., 0, 0] = duv_dist_dxr_yr[..., 0, 0] + (xr * temp1) + duv_dist_dxr_yr[..., 0, 1] = duv_dist_dxr_yr[..., 0, 1] + (yr * temp1) + temp2 = 2.0 * (s2 + 2.0 * s3 * xr_yr_sq_norm) + duv_dist_dxr_yr[..., 1, 0] = duv_dist_dxr_yr[..., 1, 0] + (xr * temp2) + duv_dist_dxr_yr[..., 1, 1] = duv_dist_dxr_yr[..., 1, 1] + (yr * temp2) + # Compute 2x2 inverse manually here since torch.inverse() is very slow. + # Because this is slow: inv = duv_dist_dxr_yr.inverse() + # About a 10x reduction in speed with above line. + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=-1) + bot = torch.cat([-c, a], dim=-1) + inv = det * torch.cat([top, bot], dim=-2) + inv = inv.reshape(B, N, 2, 2) + diff = uv_dist - uv_dist_est + a = inv[..., 0, 0] + b = inv[..., 0, 1] + c = inv[..., 1, 0] + d = inv[..., 1, 1] + e = diff[..., 0] + f = diff[..., 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + max_iters_radial = max_iters if self.use_radial else 0 + c = ( + torch.tensor([2.0 * i + 3 for i in range(6)], device=self.device) + .reshape(1, 1, 6) + .repeat(B, 1, 1) + ) + radial_params = self.params[..., 4:10].reshape(B, 1, 6) + + # Trust region parameters + delta = torch.full((B, N, 1), 0.1, device=self.device) # Initial trust radius + delta_max = torch.tensor(1.0, device=self.device) # Maximum trust radius + eta = 0.1 # Acceptable reduction threshold + + for i in range(max_iters_radial): + th_sq = th * th + # Compute powers of th^2 up to th^(12) + theta_powers = torch.cat( + [th_sq ** (i + 1) for i in range(6)], dim=-1 + ) # Shape: (B, N, 6) + + # Compute th_radial: radial distortion model applied to th + th_radial = 1.0 + torch.sum( + theta_powers * radial_params, dim=-1, keepdim=True + ) + th_radial = th_radial * th + + # Compute derivative dthd_th + dthd_th = 1.0 + torch.sum( + c * radial_params * theta_powers, dim=-1, keepdim=True + ) + dthd_th = dthd_th + + # Compute residual + residual = th_radial - xr_yr_norm # Shape: (B, N, 1) + residual_norm = torch.norm(residual, dim=2, keepdim=True) + + # Check for convergence + if torch.max(torch.abs(residual)) < eps: + break + + # Avoid division by zero by adding a small epsilon + safe_dthd_th = dthd_th.clone() + zero_derivative_mask = dthd_th.abs() < eps + safe_dthd_th[zero_derivative_mask] = eps + + # Compute Newton's step + step = -residual / safe_dthd_th + + # Compute predicted reduction + predicted_reduction = -(residual * step).sum(dim=2, keepdim=True) + + # Adjust step based on trust region + step_norm = torch.norm(step, dim=2, keepdim=True) + over_trust_mask = step_norm > delta + + # Scale step if it exceeds trust radius + step_scaled = step.clone() + step_scaled[over_trust_mask] = step[over_trust_mask] * ( + delta[over_trust_mask] / step_norm[over_trust_mask] + ) + + # Update theta + th_new = th + step_scaled + + # Compute new residual + th_sq_new = th_new * th_new + theta_powers_new = torch.cat( + [th_sq_new ** (j + 1) for j in range(6)], dim=-1 + ) + th_radial_new = 1.0 + torch.sum( + theta_powers_new * radial_params, dim=-1, keepdim=True + ) + th_radial_new = th_radial_new * th_new + residual_new = th_radial_new - xr_yr_norm + residual_new_norm = torch.norm(residual_new, dim=2, keepdim=True) + + # Compute actual reduction + actual_reduction = residual_norm - residual_new_norm + + # Compute ratio of actual to predicted reduction + rho = actual_reduction / predicted_reduction + rho[(actual_reduction == 0) & (predicted_reduction == 0)] = 1.0 + + # Update trust radius delta + delta_update_mask = rho > 0.5 + delta[delta_update_mask] = torch.min( + 2.0 * delta[delta_update_mask], delta_max + ) + + delta_decrease_mask = rho < 0.2 + delta[delta_decrease_mask] = 0.25 * delta[delta_decrease_mask] + + # Accept or reject the step + accept_step_mask = rho > eta + th = torch.where(accept_step_mask, th_new, th) + + # Compute the ray direction using theta and xr_yr. + close_to_zero = torch.logical_and(th.abs() < eps, xr_yr_norm.abs() < eps) + ray_dir = torch.where(close_to_zero, xr_yr, torch.tan(th) / xr_yr_norm * xr_yr) + + ray = torch.cat([ray_dir, uv.new_ones(B, N, 1)], dim=2) + ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) + + return ray + + +class MEI(Camera): + def __init__(self, params): + super().__init__(params=params, K=None) + # fx fy cx cy k1 k2 p1 p2 xi + self.use_radial = self.params[..., 4:6].abs().sum() > 1e-6 + self.use_tangential = self.params[..., 6:8].abs().sum() > 1e-6 + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, uv, max_iters: int = 20): + eps = 1e-6 + B, _, H, W = uv.shape + N = H * W + uv = uv.permute(0, 2, 3, 1).reshape(B, N, 2) + + k1, k2, p0, p1, xi = self.params[..., 4:9].unbind(dim=1) + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + uv_dist = (uv - cx_cy) / fx_fy + + # Compute xr_yr using Newton's method. + xr_yr = uv_dist.clone() # Initial guess. + max_iters_tangential = max_iters if self.use_tangential else 0 + for _ in range(max_iters_tangential): + uv_dist_est = xr_yr.clone() + + # Tangential terms. + xr = xr_yr[..., 0] + yr = xr_yr[..., 1] + xr_yr_sq = xr_yr**2 + xr_sq = xr_yr_sq[..., 0] + yr_sq = xr_yr_sq[..., 1] + rd_sq = xr_sq + yr_sq + uv_dist_est[..., 0] = uv_dist_est[..., 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_est[..., 1] = uv_dist_est[..., 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + + # Compute the derivative of uv_dist w.r.t. xr_yr. + duv_dist_dxr_yr = torch.ones((B, N, 2, 2), device=uv.device) + duv_dist_dxr_yr[..., 0, 0] = 1.0 + 6.0 * xr * p0 + 2.0 * yr * p1 + offdiag = 2.0 * (xr * p1 + yr * p0) + duv_dist_dxr_yr[..., 0, 1] = offdiag + duv_dist_dxr_yr[..., 1, 0] = offdiag + duv_dist_dxr_yr[..., 1, 1] = 1.0 + 6.0 * yr * p1 + 2.0 * xr * p0 + + mat = duv_dist_dxr_yr.reshape(-1, 2, 2) + a = mat[:, 0, 0].reshape(-1, 1, 1) + b = mat[:, 0, 1].reshape(-1, 1, 1) + c = mat[:, 1, 0].reshape(-1, 1, 1) + d = mat[:, 1, 1].reshape(-1, 1, 1) + det = 1.0 / ((a * d) - (b * c)) + top = torch.cat([d, -b], dim=-1) + bot = torch.cat([-c, a], dim=-1) + inv = det * torch.cat([top, bot], dim=-2) + inv = inv.reshape(B, N, 2, 2) + + diff = uv_dist - uv_dist_est + a = inv[..., 0, 0] + b = inv[..., 0, 1] + c = inv[..., 1, 0] + d = inv[..., 1, 1] + e = diff[..., 0] + f = diff[..., 1] + step = torch.stack([a * e + b * f, c * e + d * f], dim=-1) + + # Newton step. + xr_yr = xr_yr + step + + # Compute theta using Newton's method. + xr_yr_norm = xr_yr.norm(p=2, dim=2).reshape(B, N, 1) + th = xr_yr_norm.clone() + max_iters_radial = max_iters if self.use_radial else 0 + for _ in range(max_iters_radial): + th_radial = 1.0 + k1 * torch.pow(th, 2) + k2 * torch.pow(th, 4) + dthd_th = 1.0 + 3.0 * k1 * torch.pow(th, 2) + 5.0 * k2 * torch.pow(th, 4) + th_radial = th_radial * th + step = (xr_yr_norm - th_radial) / dthd_th + # handle dthd_th close to 0. + step = torch.where( + torch.abs(dthd_th) > eps, step, torch.sign(step) * eps * 10.0 + ) + th = th + step + + # Compute the ray direction using theta and xr_yr. + close_to_zero = (torch.abs(th) < eps) & (torch.abs(xr_yr_norm) < eps) + ray_dir = torch.where(close_to_zero, xr_yr, th * xr_yr / xr_yr_norm) + + # Compute the 3D projective ray + rho2_u = ( + ray_dir.norm(p=2, dim=2, keepdim=True) ** 2 + ) # B N 1 # x_c * x_c + y_c * y_c + xi = xi.reshape(B, 1, 1) + sqrt_term = torch.sqrt(1.0 + (1.0 - xi * xi) * rho2_u) + P_z = 1.0 - xi * (rho2_u + 1.0) / (xi + sqrt_term) + + # Special case when xi is 1.0 (unit sphere projection ??) + P_z = torch.where(xi == 1.0, (1.0 - rho2_u) / 2.0, P_z) + + ray = torch.cat([ray_dir, P_z], dim=-1) + ray = ray.reshape(B, H, W, 3).permute(0, 3, 1, 2) + + # remove nans + where_nan = ray.isnan().any(dim=1, keepdim=True).repeat(1, 3, 1, 1) + ray = torch.where(where_nan, torch.zeros_like(ray), ray) + + return ray + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, xyz): + is_flat = xyz.ndim == 3 + B, N = xyz.shape[:2] + + if not is_flat: + B, _, H, W = xyz.shape + N = H * W + xyz = xyz.permute(0, 2, 3, 1).reshape(B, N, 3) + + k1, k2, p0, p1, xi = self.params[..., 4:].unbind(dim=1) + fx_fy = self.params[..., 0:2].reshape(B, 1, 2) + cx_cy = self.params[..., 2:4].reshape(B, 1, 2) + + norm = xyz.norm(p=2, dim=-1, keepdim=True) + ab = xyz[..., :-1] / (xyz[..., -1:] + xi.reshape(B, 1, 1) * norm) + + # radial correction + r = ab.norm(dim=-1, p=2, keepdim=True) + k1 = self.params[..., 4].reshape(B, 1, 1) + k2 = self.params[..., 5].reshape(B, 1, 1) + # ab / r * th * (1 + k1 * (th ** 2) + k2 * (th**4)) + # but here r = th, no spherical distortion + xr_yr = ab * (1 + k1 * (r**2) + k2 * (r**4)) + + # Tangential correction. + uv_dist = xr_yr + p0 = self.params[:, -3].reshape(B, 1) + p1 = self.params[:, -2].reshape(B, 1) + xr = xr_yr[..., 0].reshape(B, N) + yr = xr_yr[..., 1].reshape(B, N) + xr_yr_sq = torch.square(xr_yr) + xr_sq = xr_yr_sq[:, :, 0].reshape(B, N) + yr_sq = xr_yr_sq[:, :, 1].reshape(B, N) + rd_sq = xr_sq + yr_sq + uv_dist_tu = uv_dist[:, :, 0] + ( + (2.0 * xr_sq + rd_sq) * p0 + 2.0 * xr * yr * p1 + ) + uv_dist_tv = uv_dist[:, :, 1] + ( + (2.0 * yr_sq + rd_sq) * p1 + 2.0 * xr * yr * p0 + ) + uv_dist = torch.stack( + [uv_dist_tu, uv_dist_tv], dim=-1 + ) # Avoids in-place complaint. + + result = uv_dist * fx_fy + cx_cy + + if not is_flat: + result = result.reshape(B, H, W, 2).permute(0, 3, 1, 2) + invalid = ( + (result[:, 0] < 0) + | (result[:, 0] > W) + | (result[:, 1] < 0) + | (result[:, 1] > H) + ) + self.projection_mask = (~invalid).unsqueeze(1) + # creates hole in the middle... ?? + # self.overlap_mask = self.mask_overlap_projection(result) + + return result + + +class BatchCamera(Camera): + def __init__(self, params, K, original_class, cameras): + super().__init__(params, K) + self.original_class = original_class + self.cameras = cameras + + # Delegate these methods to original camera + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def project(self, points_3d): + return torch.cat( + [ + camera.project(points_3d[i : i + 1]) + for i, camera in enumerate(self.cameras) + ] + ) + + @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) + def unproject(self, points_2d): + def recursive_unproject(cameras): + if isinstance(cameras, list): + return [recursive_unproject(camera) for camera in cameras] + else: + return cameras.unproject(points_2d) + + def flatten_and_cat(nested_list): + if isinstance(nested_list[0], list): + return torch.cat( + [flatten_and_cat(sublist) for sublist in nested_list], dim=0 + ) + else: + return torch.cat(nested_list, dim=0) + + unprojected = recursive_unproject(self.cameras) + return flatten_and_cat(unprojected) + + def crop(self, left, top, right=None, bottom=None): + val = torch.cat( + [ + camera.crop(left, top, right, bottom) + for i, camera in enumerate(self.cameras) + ] + ) + return val + + def resize(self, ratio): + val = torch.cat([camera.resize(ratio) for i, camera in enumerate(self.cameras)]) + return val + + def reconstruct(self, depth): + val = torch.cat( + [ + camera.reconstruct(depth[i : i + 1]) + for i, camera in enumerate(self.cameras) + ] + ) + return val + + def get_projection_mask(self): + return torch.cat( + [camera.projection_mask for i, camera in enumerate(self.cameras)] + ) + + def to(self, device, non_blocking=False): + self = super().to(device, non_blocking=non_blocking) + self.cameras = recursive_apply( + self.cameras, lambda camera: camera.to(device, non_blocking=non_blocking) + ) + return self + + def reshape(self, *shape): + # Reshape the intrinsic matrix (K) and params + # we know that the shape of K is (..., 3, 3) and params is (..., 16) + reshaped_K = self.K.reshape(*shape, 3, 3) + reshaped_params = self.params.reshape(*shape, self.params.shape[-1]) + + self.cameras = np.array(self.cameras, dtype=object).reshape(shape).tolist() + self.original_class = ( + np.array(self.original_class, dtype=object).reshape(shape).tolist() + ) + + # Create a new BatchCamera with reshaped K and params + return BatchCamera( + reshaped_params, reshaped_K, self.original_class, self.cameras + ) + + def get_new_fov(self, new_shape, original_shape): + return [ + camera.get_new_fov(new_shape, original_shape) + for i, camera in enumerate(self.cameras) + ] + + def squeeze(self, dim): + return BatchCamera( + self.params.squeeze(dim), + self.K.squeeze(dim), + squeeze_list(self.original_class, dim=dim), + squeeze_list(self.cameras, dim=dim), + ) + + def __getitem__(self, idx): + # If it's an integer index, return a single camera + if isinstance(idx, int): + return self.cameras[idx] + + # If it's a slice, return a new BatchCamera with sliced cameras + elif isinstance(idx, slice): + return BatchCamera( + self.params[idx], + self.K[idx], + self.original_class[idx], + self.cameras[idx], + ) + + raise TypeError(f"Invalid index type: {type(idx)}") + + def __setitem__(self, idx, value): + # If it's an integer index, return a single camera + if isinstance(idx, int): + self.cameras[idx] = value + self.params[idx, :] = 0.0 + self.params[idx, : value.params.shape[1]] = value.params[0] + self.K[idx] = value.K[0] + + self.original_class[idx] = getattr( + value, "original_class", value.__class__.__name__ + ) + + # If it's a slice, return a new BatchCamera with sliced cameras + elif isinstance(idx, slice): + # Update each internal attribute using the slice + self.params[idx] = value.params + self.K[idx] = value.K + self.original_class[idx] = value.original_class + self.cameras[idx] = value.cameras + + def __len__(self): + return len(self.cameras) + + @classmethod + def from_camera(cls, camera): + return cls(camera.params, camera.K, [camera.__class__.__name__], [camera]) + + @property + def is_perspective(self): + return recursive_apply(self.cameras, lambda camera: isinstance(camera, Pinhole)) + + @property + def is_spherical(self): + return recursive_apply( + self.cameras, lambda camera: isinstance(camera, Spherical) + ) + + @property + def is_eucm(self): + return recursive_apply(self.cameras, lambda camera: isinstance(camera, EUCM)) + + @property + def is_fisheye(self): + return recursive_apply( + self.cameras, lambda camera: isinstance(camera, Fisheye624) + ) + + @property + def is_pinhole(self): + return recursive_apply(self.cameras, lambda camera: isinstance(camera, Pinhole)) + + @property + def hfov(self): + return recursive_apply(self.cameras, lambda camera: camera.hfov) + + @property + def vfov(self): + return recursive_apply(self.cameras, lambda camera: camera.vfov) + + @property + def max_fov(self): + return recursive_apply(self.cameras, lambda camera: camera.max_fov) + + +import json +import random +# sampler helpers +from math import log + +import torch.nn as nn + + +def eucm(boundaries, mult, batch, device, dtype): + alpha_min, alpha_max = boundaries[0][0] * mult, boundaries[0][1] * mult + beta_mean, beta_std = boundaries[1][0] * mult, boundaries[1][1] * mult + alpha = ( + torch.rand(batch, device=device, dtype=dtype) * (alpha_max - alpha_min) + + alpha_min + ) + beta = F.softplus( + torch.randn(batch, device=device, dtype=dtype) * beta_std + beta_mean, + beta=log(2), + ) + return alpha, beta + + +def free_fisheye(boundaries, mult, batch, device, dtype): + k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult + k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult + k3_min, k3_max = boundaries[2][0] * mult, boundaries[2][1] * mult + k4_min, k4_max = boundaries[3][0] * mult, boundaries[3][1] * mult + k5_min, k5_max = boundaries[4][0] * mult, boundaries[4][1] * mult + k6_min, k6_max = boundaries[5][0] * mult, boundaries[5][1] * mult + p1_min, p1_max = boundaries[6][0] * mult, boundaries[6][1] * mult + p2_min, p2_max = boundaries[7][0] * mult, boundaries[7][1] * mult + s1_min, s1_max = boundaries[8][0] * mult, boundaries[8][1] * mult + s2_min, s2_max = boundaries[9][0] * mult, boundaries[9][1] * mult + s3_min, s3_max = boundaries[10][0] * mult, boundaries[10][1] * mult + s4_min, s4_max = boundaries[11][0] * mult, boundaries[11][1] * mult + k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min + k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min + k3 = torch.rand(batch, device=device, dtype=dtype) * (k3_max - k3_min) + k3_min + k4 = torch.rand(batch, device=device, dtype=dtype) * (k4_max - k4_min) + k4_min + k5 = torch.rand(batch, device=device, dtype=dtype) * (k5_max - k5_min) + k5_min + k6 = torch.rand(batch, device=device, dtype=dtype) * (k6_max - k6_min) + k6_min + p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min + p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min + s1 = torch.rand(batch, device=device, dtype=dtype) * (s1_max - s1_min) + s1_min + s2 = torch.rand(batch, device=device, dtype=dtype) * (s2_max - s2_min) + s2_min + s3 = torch.rand(batch, device=device, dtype=dtype) * (s3_max - s3_min) + s3_min + s4 = torch.rand(batch, device=device, dtype=dtype) * (s4_max - s4_min) + s4_min + return k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4 + + +def mei(boundaries, mult, batch, device, dtype): + k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult + k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult + p1_min, p1_max = boundaries[2][0] * mult, boundaries[2][1] * mult + p2_min, p2_max = boundaries[3][0] * mult, boundaries[3][1] * mult + xi_min, xi_max = boundaries[4][0] * mult, boundaries[4][1] * mult + k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min + k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min + p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min + p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min + xi = torch.rand(batch, device=device, dtype=dtype) * (xi_max - xi_min) + xi_min + return k1, k2, p1, p2, xi + + +def consistent_fisheye(boundaries, mult, batch, device, dtype): + sign = random.choice([-1, 1]) + return free_fisheye(boundaries, sign * mult, batch, device, dtype) + + +def invert_fisheye(boundaries, mult, batch, device, dtype): + k1_min, k1_max = boundaries[0][0] * mult, boundaries[0][1] * mult + k2_min, k2_max = boundaries[1][0] * mult, boundaries[1][1] * mult + k3_min, k3_max = boundaries[2][0] * mult, boundaries[2][1] * mult + k4_min, k4_max = boundaries[3][0] * mult, boundaries[3][1] * mult + k5_min, k5_max = boundaries[4][0] * mult, boundaries[4][1] * mult + k6_min, k6_max = boundaries[5][0] * mult, boundaries[5][1] * mult + p1_min, p1_max = boundaries[6][0] * mult, boundaries[6][1] * mult + p2_min, p2_max = boundaries[7][0] * mult, boundaries[7][1] * mult + s1_min, s1_max = boundaries[8][0] * mult, boundaries[8][1] * mult + s2_min, s2_max = boundaries[9][0] * mult, boundaries[9][1] * mult + s3_min, s3_max = boundaries[10][0] * mult, boundaries[10][1] * mult + s4_min, s4_max = boundaries[11][0] * mult, boundaries[11][1] * mult + + sign = random.choice([-1, 1]) + k1 = torch.rand(batch, device=device, dtype=dtype) * (k1_max - k1_min) + k1_min + k1 = sign * k1 + k2 = torch.rand(batch, device=device, dtype=dtype) * (k2_max - k2_min) + k2_min + k2 = -1 * sign * k2 + k3 = torch.rand(batch, device=device, dtype=dtype) * (k3_max - k3_min) + k3_min + k3 = sign * k3 + k4 = torch.rand(batch, device=device, dtype=dtype) * (k4_max - k4_min) + k4_min + k4 = -1 * sign * k4 + k5 = torch.rand(batch, device=device, dtype=dtype) * (k5_max - k5_min) + k5_min + k5 = sign * k5 + k6 = torch.rand(batch, device=device, dtype=dtype) * (k6_max - k6_min) + k6_min + k6 = -1 * sign * k6 + + p1 = torch.rand(batch, device=device, dtype=dtype) * (p1_max - p1_min) + p1_min + p2 = torch.rand(batch, device=device, dtype=dtype) * (p2_max - p2_min) + p2_min + s1 = torch.rand(batch, device=device, dtype=dtype) * (s1_max - s1_min) + s1_min + s2 = torch.rand(batch, device=device, dtype=dtype) * (s2_max - s2_min) + s2_min + s3 = torch.rand(batch, device=device, dtype=dtype) * (s3_max - s3_min) + s3_min + s4 = torch.rand(batch, device=device, dtype=dtype) * (s4_max - s4_min) + s4_min + return k1, k2, k3, k4, k5, k6, p1, p2, s1, s2, s3, s4 + + +class CameraSampler(nn.Module): + def __init__(self): + super().__init__() + with open("camera_sampler.json", "r") as f: + config = json.load(f) + self.camera_type = config["type"] + self.sampling_fn = config["fn"] + self.boundaries = nn.ParameterList( + [ + nn.Parameter(torch.tensor(x), requires_grad=False) + for x in config["boundaries"] + ] + ) + self.probs = nn.Parameter(torch.tensor(config["probs"]), requires_grad=False) + + def forward(self, fx, fy, cx, cy, mult, ratio, H): + selected_idx = torch.multinomial(self.probs, num_samples=1) + device, dtype = fx.device, fx.dtype + + selected_camera = self.camera_type[selected_idx] + selected_sampling_fn = self.sampling_fn[selected_idx] + selected_boundaries = self.boundaries[selected_idx] + if "Fisheye" in selected_camera or "OPENCV" in selected_camera: + mult = mult * ratio + + params = eval(selected_sampling_fn)( + selected_boundaries, mult, len(fx), device, dtype + ) + params = torch.stack([fx, fy, cx, cy, *params], dim=1) + camera = eval(selected_camera)(params=params) + return camera diff --git a/unik3d/utils/chamfer_distance.py b/unik3d/utils/chamfer_distance.py new file mode 100644 index 0000000000000000000000000000000000000000..5f0068f9bfef86acd62d86c8774be3464f8e5636 --- /dev/null +++ b/unik3d/utils/chamfer_distance.py @@ -0,0 +1,158 @@ +import warnings +from typing import Union + +import torch + +try: + from unik3d.ops.knn import knn_points +except ImportError as e: + warnings.warn( + "!! To run evaluation you need KNN. Please compile KNN: " + "`cd unik3d/ops/knn && bash compile.sh`." + ) + knn_points = lambda x: x + + +def _validate_chamfer_reduction_inputs( + batch_reduction: Union[str, None], point_reduction: str +): + """Check the requested reductions are valid. + + Args: + batch_reduction: Reduction operation to apply for the loss across the + batch, can be one of ["mean", "sum"] or None. + point_reduction: Reduction operation to apply for the loss across the + points, can be one of ["mean", "sum"]. + """ + if batch_reduction is not None and batch_reduction not in ["mean", "sum"]: + raise ValueError('batch_reduction must be one of ["mean", "sum"] or None') + if point_reduction not in ["mean", "sum"]: + raise ValueError('point_reduction must be one of ["mean", "sum"]') + + +def _handle_pointcloud_input( + points: torch.Tensor, + lengths: Union[torch.Tensor, None], + normals: Union[torch.Tensor, None], +): + """ + If points is an instance of Pointclouds, retrieve the padded points tensor + along with the number of points per batch and the padded normals. + Otherwise, return the input points (and normals) with the number of points per cloud + set to the size of the second dimension of `points`. + """ + if points.ndim != 3: + raise ValueError("Expected points to be of shape (N, P, D)") + X = points + if lengths is not None and (lengths.ndim != 1 or lengths.shape[0] != X.shape[0]): + raise ValueError("Expected lengths to be of shape (N,)") + if lengths is None: + lengths = torch.full( + (X.shape[0],), X.shape[1], dtype=torch.int64, device=points.device + ) + if normals is not None and normals.ndim != 3: + raise ValueError("Expected normals to be of shape (N, P, 3") + + return X, lengths, normals + + +class ChamferDistance(torch.nn.Module): + def forward( + self, + x, + y, + x_lengths=None, + y_lengths=None, + x_normals=None, + y_normals=None, + weights=None, + batch_reduction: Union[str, None] = "mean", + point_reduction: str = "mean", + ): + """ + Chamfer distance between two pointclouds x and y. + + Args: + x: FloatTensor of shape (N, P1, D) or a Pointclouds object representing + a batch of point clouds with at most P1 points in each batch element, + batch size N and feature dimension D. + y: FloatTensor of shape (N, P2, D) or a Pointclouds object representing + a batch of point clouds with at most P2 points in each batch element, + batch size N and feature dimension D. + x_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. + y_lengths: Optional LongTensor of shape (N,) giving the number of points in each + cloud in x. + x_normals: Optional FloatTensor of shape (N, P1, D). + y_normals: Optional FloatTensor of shape (N, P2, D). + weights: Optional FloatTensor of shape (N,) giving weights for + batch elements for reduction operation. + batch_reduction: Reduction operation to apply for the loss across the + batch, can be one of ["mean", "sum"] or None. + point_reduction: Reduction operation to apply for the loss across the + points, can be one of ["mean", "sum"]. + + Returns: + 2-element tuple containing + + - **loss**: Tensor giving the reduced distance between the pointclouds + in x and the pointclouds in y. + - **loss_normals**: Tensor giving the reduced cosine distance of normals + between pointclouds in x and pointclouds in y. Returns None if + x_normals and y_normals are None. + """ + _validate_chamfer_reduction_inputs(batch_reduction, point_reduction) + + x, x_lengths, x_normals = _handle_pointcloud_input(x, x_lengths, x_normals) + y, y_lengths, y_normals = _handle_pointcloud_input(y, y_lengths, y_normals) + + return_normals = x_normals is not None and y_normals is not None + + N, P1, D = x.shape + P2 = y.shape[1] + + # Check if inputs are heterogeneous and create a lengths mask. + is_x_heterogeneous = (x_lengths != P1).any() + is_y_heterogeneous = (y_lengths != P2).any() + x_mask = ( + torch.arange(P1, device=x.device)[None] >= x_lengths[:, None] + ) # shape [N, P1] + y_mask = ( + torch.arange(P2, device=y.device)[None] >= y_lengths[:, None] + ) # shape [N, P2] + + if y.shape[0] != N or y.shape[2] != D: + raise ValueError("y does not have the correct shape.") + if weights is not None: + if weights.size(0) != N: + raise ValueError("weights must be of shape (N,).") + if not (weights >= 0).all(): + raise ValueError("weights cannot be negative.") + if weights.sum() == 0.0: + weights = weights.view(N, 1) + if batch_reduction in ["mean", "sum"]: + return ( + (x.sum((1, 2)) * weights).sum() * 0.0, + (x.sum((1, 2)) * weights).sum() * 0.0, + ) + return ( + (x.sum((1, 2)) * weights) * 0.0, + (x.sum((1, 2)) * weights) * 0.0, + ) + + x_nn = knn_points(x, y, lengths1=x_lengths, lengths2=y_lengths, K=1) + y_nn = knn_points(y, x, lengths1=y_lengths, lengths2=x_lengths, K=1) + + cham_x = x_nn.dists[..., 0] # (N, P1) + cham_y = y_nn.dists[..., 0] # (N, P2) + + if is_x_heterogeneous: + cham_x[x_mask] = 0.0 + if is_y_heterogeneous: + cham_y[y_mask] = 0.0 + + if weights is not None: + cham_x *= weights.view(N, 1) + cham_y *= weights.view(N, 1) + + return cham_x, cham_y, x_nn.idx[..., -1], y_nn.idx[..., -1] diff --git a/unik3d/utils/constants.py b/unik3d/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..9e8c70b011e18eb9c53ffd4c075afab8d98b7123 --- /dev/null +++ b/unik3d/utils/constants.py @@ -0,0 +1,42 @@ +import math + +import torch + +NAME_PAGE = "submission3" +OPENAI_DATASET_MEAN = (0.48145466, 0.4578275, 0.40821073) +OPENAI_DATASET_STD = (0.26862954, 0.26130258, 0.27577711) +IMAGENET_DATASET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_DATASET_STD = (0.229, 0.224, 0.225) +DEPTH_BINS = torch.cat( + ( + torch.logspace(math.log10(0.1), math.log10(180.0), steps=512), + torch.tensor([260.0]), + ), + dim=0, +) +LOGERR_BINS = torch.linspace(-2, 2, steps=128 + 1) +LINERR_BINS = torch.linspace(-50, 50, steps=256 + 1) + +VERBOSE = False +OUTDOOR_DATASETS = __all__ = [ + "Argoverse", + "DDAD", + "DrivingStereo", + "Mapillary", + "BDD", + "A2D2", + "Nuscenes", + "Cityscape", + "KITTI", + "DENSE", + "DIML", + "NianticMapFree", + "DL3DV", + "KITTIMulti", + "Waymo", + "Argoverse2", + "BEDLAM", + "NeRDS360", + "BlendedMVG", + "MegaDepthS", +] diff --git a/unik3d/utils/coordinate.py b/unik3d/utils/coordinate.py new file mode 100644 index 0000000000000000000000000000000000000000..a09129982737d288db26a3717337af136d163c32 --- /dev/null +++ b/unik3d/utils/coordinate.py @@ -0,0 +1,27 @@ +import torch + + +def coords_grid(b, h, w, homogeneous=False, device=None, noisy=False): + pixel_coords_x = torch.linspace(0.5, w - 0.5, w, device=device) + pixel_coords_y = torch.linspace(0.5, h - 0.5, h, device=device) + if noisy: # \pm 0.5px noise + pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 + pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 + + stacks = [pixel_coords_x.repeat(h, 1), pixel_coords_y.repeat(w, 1).t()] + if homogeneous: + ones = torch.ones_like(stacks[0]) # [H, W] + stacks.append(ones) + grid = torch.stack(stacks, dim=0).float() # [2, H, W] or [3, H, W] + grid = grid[None].repeat(b, 1, 1, 1) # [B, 2, H, W] or [B, 3, H, W] + if device is not None: + grid = grid.to(device) + + return grid + + +def normalize_coords(coords, h, w): + c = torch.tensor([(w - 1) / 2.0, (h - 1) / 2.0], device=coords.device).view( + 1, 2, 1, 1 + ) + return (coords - c) / c diff --git a/unik3d/utils/distributed.py b/unik3d/utils/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..7bf45f02fb7e0a6420373bf07d926303d62b7e50 --- /dev/null +++ b/unik3d/utils/distributed.py @@ -0,0 +1,247 @@ +import os +import pickle +import platform +import subprocess +import warnings + +import cv2 +import torch +import torch.utils.data.distributed +from torch import distributed as dist +from torch import multiprocessing as mp + +_LOCAL_PROCESS_GROUP = None + + +def is_dist_avail_and_initialized(): + if not dist.is_available(): + return False + if not dist.is_initialized(): + return False + return True + + +def get_rank(): + if not is_dist_avail_and_initialized(): + return 0 + return dist.get_rank() + + +def get_local_rank() -> int: + """ + Returns: + The rank of the current process within the local (per-machine) process group. + """ + if not is_dist_avail_and_initialized(): + return 0 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_rank(group=_LOCAL_PROCESS_GROUP) + + +def get_local_size() -> int: + """ + Returns: + The size of the per-machine process group, + i.e. the number of processes per machine. + """ + if not is_dist_avail_and_initialized(): + return 1 + assert _LOCAL_PROCESS_GROUP is not None + return dist.get_world_size(group=_LOCAL_PROCESS_GROUP) + + +def get_world_size(): + if not is_dist_avail_and_initialized(): + return 1 + return dist.get_world_size() + + +def barrier(): + if not is_dist_avail_and_initialized(): + return + dist.barrier() + + +def is_main_process(): + return get_rank() == 0 + + +def is_rank_zero(args): + return args.rank == 0 + + +def get_dist_info(): + if dist.is_available() and dist.is_initialized(): + rank = dist.get_rank() + world_size = dist.get_world_size() + else: + rank = 0 + world_size = 1 + return rank, world_size + + +def setup_multi_processes(cfg): + """Setup multi-processing environment variables.""" + # set multi-process start method as `fork` to speed up the training + if platform.system() != "Windows": + mp_start_method = cfg.get("mp_start_method", "fork") + current_method = mp.get_start_method(allow_none=True) + if current_method is not None and current_method != mp_start_method: + warnings.warn( + f"Multi-processing start method `{mp_start_method}` is " + f"different from the previous setting `{current_method}`." + f"It will be force set to `{mp_start_method}`. You can change " + f"this behavior by changing `mp_start_method` in your config." + ) + mp.set_start_method(mp_start_method, force=True) + + # disable opencv multithreading to avoid system being overloaded + # opencv_num_threads = cfg.get('opencv_num_threads', 0) + # cv2.setNumThreads(opencv_num_threads) + + # setup OMP threads + # This code is referred from https://github.com/pytorch/pytorch/blob/master/torch/distributed/run.py # noqa + # workers_per_gpu = cfg.get('workers_per_gpu', 4) + + # if 'OMP_NUM_THREADS' not in os.environ and workers_per_gpu > 1: + # omp_num_threads = 1 + # warnings.warn( + # f'Setting OMP_NUM_THREADS environment variable for each process ' + # f'to be {omp_num_threads} in default, to avoid your system being ' + # f'overloaded, please further tune the variable for optimal ' + # f'performance in your application as needed.') + # os.environ['OMP_NUM_THREADS'] = str(omp_num_threads) + + # setup MKL threads + # if 'MKL_NUM_THREADS' not in os.environ and workers_per_gpu > 1: + # mkl_num_threads = os.environ.get('OMP_NUM_THREADS', 1) + # warnings.warn( + # f'Setting MKL_NUM_THREADS environment variable for each process ' + # f'to be {mkl_num_threads} in default, to avoid your system being ' + # f'overloaded, please further tune the variable for optimal ' + # f'performance in your application as needed.') + # os.environ['MKL_NUM_THREADS'] = str(mkl_num_threads) + + +def setup_slurm(backend: str, port: str) -> None: + proc_id = int(os.environ["SLURM_PROCID"]) + ntasks = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + + num_gpus = torch.cuda.device_count() + + torch.cuda.set_device(proc_id % num_gpus) + if "MASTER_ADDR" not in os.environ: + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + os.environ["MASTER_PORT"] = str(port) + os.environ["MASTER_ADDR"] = addr + else: + addr = os.environ["MASTER_ADDR"] + os.environ["WORLD_SIZE"] = str(ntasks) + os.environ["LOCAL_RANK"] = str(proc_id % num_gpus) + os.environ["RANK"] = str(proc_id) + print( + proc_id, + ntasks, + num_gpus, + proc_id % num_gpus, + node_list, + addr, + os.environ["MASTER_PORT"], + os.system("nvidia-smi -L"), + ) + dist.init_process_group(backend, rank=proc_id, world_size=ntasks) + + +def sync_tensor_across_gpus(t, dim=0, cat=True): + if t is None or not (dist.is_available() and dist.is_initialized()): + return t + t = torch.atleast_1d(t) + group = dist.group.WORLD + group_size = torch.distributed.get_world_size(group) + + local_size = torch.tensor(t.size(dim), device=t.device) + all_sizes = [torch.zeros_like(local_size) for _ in range(group_size)] + dist.all_gather(all_sizes, local_size) + max_size = max(all_sizes) + size_diff = max_size.item() - local_size.item() + if size_diff: + padding = torch.zeros(size_diff, device=t.device, dtype=t.dtype) + t = torch.cat((t, padding)) + + gather_t_tensor = [torch.zeros_like(t) for _ in range(group_size)] + dist.all_gather(gather_t_tensor, t) + all_ts = [] + for t, size in zip(gather_t_tensor, all_sizes): + all_ts.append(t[:size]) + if cat: + return torch.cat(all_ts, dim=0) + return all_ts + + +def sync_string_across_gpus(keys: list[str], device, dim=0): + keys_serialized = pickle.dumps(keys, protocol=pickle.HIGHEST_PROTOCOL) + keys_serialized_tensor = ( + torch.frombuffer(keys_serialized, dtype=torch.uint8).clone().to(device) + ) + keys_serialized_tensor = sync_tensor_across_gpus( + keys_serialized_tensor, dim=0, cat=False + ) + keys = [ + key + for keys in keys_serialized_tensor + for key in pickle.loads(bytes(keys.cpu().tolist())) + ] + return keys + + +def create_local_process_group() -> None: + num_workers_per_machine = torch.cuda.device_count() + global _LOCAL_PROCESS_GROUP + assert _LOCAL_PROCESS_GROUP is None + assert get_world_size() % num_workers_per_machine == 0 + num_machines = get_world_size() // num_workers_per_machine + machine_rank = get_rank() // num_workers_per_machine + for i in range(num_machines): + ranks_on_i = list( + range(i * num_workers_per_machine, (i + 1) * num_workers_per_machine) + ) + pg = dist.new_group(ranks_on_i) + if i == machine_rank: + _LOCAL_PROCESS_GROUP = pg + + +def _get_global_gloo_group(): + if dist.get_backend() == "nccl": + return dist.new_group(backend="gloo") + else: + return dist.group.WORLD + + +def all_gather(data, group=None): + if get_world_size() == 1: + return [data] + if group is None: + group = ( + _get_global_gloo_group() + ) # use CPU group by default, to reduce GPU RAM usage. + world_size = dist.get_world_size(group) + if world_size == 1: + return [data] + + output = [None for _ in range(world_size)] + dist.all_gather_object(output, data, group=group) + return output + + +def local_broadcast_process_authkey(): + if get_local_size() == 1: + return + local_rank = get_local_rank() + authkey = bytes(mp.current_process().authkey) + all_keys = all_gather(authkey) + local_leader_key = all_keys[get_rank() - local_rank] + if authkey != local_leader_key: + # print("Process authkey is different from the key of local leader! workers are launched independently ??") + # print("Overwriting local authkey ...") + mp.current_process().authkey = local_leader_key diff --git a/unik3d/utils/ema_torch.py b/unik3d/utils/ema_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..a50685b841a3d86bf32f2fdd73c64700233f53a8 --- /dev/null +++ b/unik3d/utils/ema_torch.py @@ -0,0 +1,340 @@ +from __future__ import division, unicode_literals + +import contextlib +import copy +import weakref +from math import tanh +from typing import Iterable, Optional + +import torch + + +class DummyExponentialMovingAverage: + def __init__(self, *args, **kwargs): + pass + + def _get_parameters(self, *args, **kwargs): + pass + + def get_current_decay(self, *args, **kwargs): + pass + + def update(self, *args, **kwargs): + pass + + def copy_to(self, *args, **kwargs): + pass + + def store(self, *args, **kwargs): + return + + def restore(self, *args, **kwargs): + return + + @contextlib.contextmanager + def average_parameters(self, *args, **kwargs): + try: + yield + finally: + pass + + def to(self, *args, **kwargs): + pass + + def state_dict(self, *args, **kwargs): + pass + + def load_state_dict(self, *args, **kwargs): + pass + + +class ExponentialMovingAverage: + """ + Maintains (exponential) moving average of a set of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter` (typically from + `model.parameters()`). + Note that EMA is computed on *all* provided parameters, + regardless of whether or not they have `requires_grad = True`; + this allows a single EMA object to be consistantly used even + if which parameters are trainable changes step to step. + + If you want to some parameters in the EMA, do not pass them + to the object in the first place. For example: + + ExponentialMovingAverage( + parameters=[p for p in model.parameters() if p.requires_grad], + decay=0.9 + ) + + will ignore parameters that do not require grad. + + decay: The exponential decay. + + use_num_updates: Whether to use number of updates when computing + averages. + """ + + def __init__( + self, + parameters: Iterable[torch.nn.Parameter], + decay: float, + use_num_updates: bool = True, + update_after_step: int = 10000, + tau: int = 20000, + switch: bool = False, + save_memory: bool = True, + ): + if decay < 0.0 or decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.decay = decay + self.switch = switch # fi keeping EMA params in model after epochs + self.num_updates = 0 if use_num_updates else None + parameters = list(parameters) + self.shadow_params = [p.clone().detach() for p in parameters] + self.collected_params = None + # By maintaining only a weakref to each parameter, + # we maintain the old GC behaviour of ExponentialMovingAverage: + # if the model goes out of scope but the ExponentialMovingAverage + # is kept, no references to the model or its parameters will be + # maintained, and the model will be cleaned up. + self._params_refs = [weakref.ref(p) for p in parameters] + self.update_after_step = update_after_step + self.tau = tau + self.save_memory = save_memory + + def _get_parameters( + self, parameters: Optional[Iterable[torch.nn.Parameter]] + ) -> Iterable[torch.nn.Parameter]: + if parameters is None: + parameters = [p() for p in self._params_refs] + if any(p is None for p in parameters): + raise ValueError( + "(One of) the parameters with which this ExponentialMovingAverage was initialized no longer exists (was garbage collected);" + " please either provide `parameters` explicitly or keep the model to which they belong from being garbage collected." + ) + return parameters + else: + parameters = list(parameters) + if len(parameters) != len(self.shadow_params): + raise ValueError( + "Number of parameters passed as argument is different " + "from number of shadow parameters maintained by this " + "ExponentialMovingAverage" + ) + return parameters + + def get_current_decay(self): + epoch = max(self.num_updates - self.update_after_step - 1, 0.0) + if epoch <= 0: + return 0.0 + value = tanh(epoch / self.tau) * self.decay + return value + + def update(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: + """ + Update currently maintained parameters. + + Call this every time the parameters are updated, such as the result of + the `optimizer.step()` call. + + Args: + parameters: Iterable of `torch.nn.Parameter`; usually the same set of + parameters used to initialize this object. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + decay = self.get_current_decay() + if self.num_updates is not None: + self.num_updates += 1 + + one_minus_decay = 1.0 - decay + with torch.no_grad(): + for s_param, param in zip(self.shadow_params, parameters): + tmp = s_param - param + # tmp will be a new tensor so we can do in-place + tmp.mul_(one_minus_decay) + s_param.sub_(tmp) + + def copy_to( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Copy current averaged parameters into given collection of parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored moving averages. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + for s_param, param in zip(self.shadow_params, parameters): + param.data.copy_(s_param.data) + + def store(self, parameters: Optional[Iterable[torch.nn.Parameter]] = None) -> None: + """ + Save the current parameters for restoring later. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + temporarily stored. If `None`, the parameters of with which this + `ExponentialMovingAverage` was initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.collected_params = [param.detach().clone() for param in parameters] + + def restore( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ) -> None: + """ + Restore the parameters stored with the `store` method. + Useful to validate the model with EMA parameters without affecting the + original optimization process. Store the parameters before the + `copy_to` method. After validation (or model saving), use this to + restore the former parameters. + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + if self.collected_params is None: + raise RuntimeError( + "This ExponentialMovingAverage has no `store()`ed weights " + "to `restore()`" + ) + parameters = self._get_parameters(parameters) + for c_param, param in zip(self.collected_params, parameters): + param.data.copy_(c_param.data) + + @contextlib.contextmanager + def average_parameters( + self, parameters: Optional[Iterable[torch.nn.Parameter]] = None + ): + r""" + Context manager for validation/inference with averaged parameters. + + Equivalent to: + + ema.store() + ema.copy_to() + try: + ... + finally: + ema.restore() + + Args: + parameters: Iterable of `torch.nn.Parameter`; the parameters to be + updated with the stored parameters. If `None`, the + parameters with which this `ExponentialMovingAverage` was + initialized will be used. + """ + parameters = self._get_parameters(parameters) + self.store(parameters) + self.copy_to(parameters) + try: + yield + finally: + if not self.switch: + self.restore(parameters) + if self.save_memory: + self.collected_params = None + + def to(self, device=None, dtype=None) -> None: + r"""Move internal buffers of the ExponentialMovingAverage to `device`. + + Args: + device: like `device` argument to `torch.Tensor.to` + """ + # .to() on the tensors handles None correctly + self.shadow_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.shadow_params + ] + if self.collected_params is not None: + self.collected_params = [ + ( + p.to(device=device, dtype=dtype) + if p.is_floating_point() + else p.to(device=device) + ) + for p in self.collected_params + ] + return + + def state_dict(self) -> dict: + r"""Returns the state of the ExponentialMovingAverage as a dict.""" + # Following PyTorch conventions, references to tensors are returned: + # "returns a reference to the state and not its copy!" - + # https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict + return { + "decay": self.decay, + "num_updates": self.num_updates, + "shadow_params": self.shadow_params, + "collected_params": self.collected_params, + } + + def load_state_dict(self, state_dict: dict) -> None: + r"""Loads the ExponentialMovingAverage state. + + Args: + state_dict (dict): EMA state. Should be an object returned + from a call to :meth:`state_dict`. + """ + # deepcopy, to be consistent with module API + state_dict = copy.deepcopy(state_dict) + self.decay = state_dict["decay"] + if self.decay < 0.0 or self.decay > 1.0: + raise ValueError("Decay must be between 0 and 1") + self.num_updates = state_dict["num_updates"] + assert self.num_updates is None or isinstance( + self.num_updates, int + ), "Invalid num_updates" + + self.shadow_params = state_dict["shadow_params"] + assert isinstance(self.shadow_params, list), "shadow_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.shadow_params + ), "shadow_params must all be Tensors" + + self.collected_params = state_dict["collected_params"] + if self.collected_params is not None: + assert isinstance( + self.collected_params, list + ), "collected_params must be a list" + assert all( + isinstance(p, torch.Tensor) for p in self.collected_params + ), "collected_params must all be Tensors" + assert len(self.collected_params) == len( + self.shadow_params + ), "collected_params and shadow_params had different lengths" + + if len(self.shadow_params) == len(self._params_refs): + # Consistant with torch.optim.Optimizer, cast things to consistant + # device and dtype with the parameters + params = [p() for p in self._params_refs] + # If parameters have been garbage collected, just load the state + # we were given without change. + if not any(p is None for p in params): + # ^ parameter references are still good + for i, p in enumerate(params): + self.shadow_params[i] = self.shadow_params[i].to( + device=p.device, dtype=p.dtype + ) + if self.collected_params is not None: + self.collected_params[i] = self.collected_params[i].to( + device=p.device, dtype=p.dtype + ) + else: + raise ValueError( + "Tried to `load_state_dict()` with the wrong number of " + "parameters in the saved state." + ) diff --git a/unik3d/utils/evaluation_depth.py b/unik3d/utils/evaluation_depth.py new file mode 100644 index 0000000000000000000000000000000000000000..3ae9feb66384b73832ecc20471513be1a897e80d --- /dev/null +++ b/unik3d/utils/evaluation_depth.py @@ -0,0 +1,337 @@ +from collections import defaultdict +from functools import partial + +import torch +import torch.nn.functional as F +import torchvision.transforms.v2.functional as TF +from PIL import Image + +from unik3d.utils.chamfer_distance import ChamferDistance +from unik3d.utils.constants import DEPTH_BINS + +chamfer_cls = ChamferDistance() + + +def kl_div(gt, pred, eps: float = 1e-6): + depth_bins = DEPTH_BINS.to(gt.device) + gt, pred = torch.bucketize( + gt, boundaries=depth_bins, out_int32=True + ), torch.bucketize(pred, boundaries=depth_bins, out_int32=True) + gt = torch.bincount(gt, minlength=len(depth_bins) + 1) + pred = torch.bincount(pred, minlength=len(depth_bins) + 1) + gt = gt / gt.sum() + pred = pred / pred.sum() + return torch.sum(gt * (torch.log(gt + eps) - torch.log(pred + eps))) + + +def chamfer_dist(tensor1, tensor2): + x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) + y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) + dist1, dist2, idx1, idx2 = chamfer_cls( + tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths + ) + return (torch.sqrt(dist1) + torch.sqrt(dist2)) / 2 + + +def auc(tensor1, tensor2, thresholds): + x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) + y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) + dist1, dist2, idx1, idx2 = chamfer_cls( + tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths + ) + # compute precision recall + precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] + recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] + auc_value = torch.trapz( + torch.tensor(precisions, device=tensor1.device), + torch.tensor(recalls, device=tensor1.device), + ) + return auc_value + + +def delta(tensor1, tensor2, exponent): + inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) + return (inlier < 1.25**exponent).to(torch.float32).mean() + + +def rho(tensor1, tensor2): + min_deg = 0.5 + tensor1_norm = tensor1 / torch.norm(tensor1, dim=-1, p=2, keepdim=True).clip( + min=1e-6 + ) + tensor2_norm = tensor2 / torch.norm(tensor2, dim=-1, p=2, keepdim=True).clip( + min=1e-6 + ) + max_polar_angle = torch.arccos(tensor1_norm[..., 2]).max() * 180.0 / torch.pi + + if max_polar_angle < 100.0: + threshold = 15.0 + elif max_polar_angle < 190.0: + threshold = 20.0 + else: + threshold = 30.0 + + acos_clip = 1 - 1e-6 + # inner prod of norm vector -> cosine + angular_error = ( + torch.arccos( + (tensor1_norm * tensor2_norm) + .sum(dim=-1) + .clip(min=-acos_clip, max=acos_clip) + ) + * 180.0 + / torch.pi + ) + thresholds = torch.linspace(min_deg, threshold, steps=100, device=tensor1.device) + y_values = [ + (angular_error.abs() <= th).to(torch.float32).mean() for th in thresholds + ] + auc_value = torch.trapz( + torch.tensor(y_values, device=tensor1.device), thresholds + ) / (threshold - min_deg) + return auc_value + + +def tau(tensor1, tensor2, perc): + inlier = torch.maximum((tensor1 / tensor2), (tensor2 / tensor1)) + return (inlier < (1.0 + perc)).to(torch.float32).mean() + + +@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) +def ssi(tensor1, tensor2, qtl=0.05): + stability_mat = 1e-9 * torch.eye(2, device=tensor1.device) + error = (tensor1 - tensor2).abs() + mask = error < torch.quantile(error, 1 - qtl) + tensor1_mask = tensor1.to(torch.float32)[mask] + tensor2_mask = tensor2.to(torch.float32)[mask] + stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) + tensor2_one = torch.stack([tensor2_mask, torch.ones_like(tensor2_mask)], dim=1) + A = torch.matmul(tensor2_one.T, tensor2_one) + stability_mat + det_A = A[0, 0] * A[1, 1] - A[0, 1] * A[1, 0] + A_inv = (1.0 / det_A) * torch.tensor( + [[A[1, 1], -A[0, 1]], [-A[1, 0], A[0, 0]]], device=tensor1.device + ) + b = tensor2_one.T @ tensor1_mask.unsqueeze(1) + scale_shift = A_inv @ b + scale, shift = scale_shift.squeeze().chunk(2, dim=0) + return tensor2 * scale + shift + + +def si(tensor1, tensor2): + return tensor2 * torch.median(tensor1) / torch.median(tensor2) + + +def arel(tensor1, tensor2): + tensor2 = tensor2 * torch.median(tensor1) / torch.median(tensor2) + return (torch.abs(tensor1 - tensor2) / tensor1).mean() + + +def d_auc(tensor1, tensor2): + exponents = torch.linspace(0.01, 5.0, steps=100, device=tensor1.device) + deltas = [delta(tensor1, tensor2, exponent) for exponent in exponents] + return torch.trapz(torch.tensor(deltas, device=tensor1.device), exponents) / 5.0 + + +def f1_score(tensor1, tensor2, thresholds): + x_lengths = torch.tensor((tensor1.shape[1],), device=tensor1.device) + y_lengths = torch.tensor((tensor2.shape[1],), device=tensor2.device) + dist1, dist2, idx1, idx2 = chamfer_cls( + tensor1, tensor2, x_lengths=x_lengths, y_lengths=y_lengths + ) + # compute precision recall + precisions = [(dist1 < threshold).sum() / dist1.numel() for threshold in thresholds] + recalls = [(dist2 < threshold).sum() / dist2.numel() for threshold in thresholds] + precisions = torch.tensor(precisions, device=tensor1.device) + recalls = torch.tensor(recalls, device=tensor1.device) + f1_thresholds = 2 * precisions * recalls / (precisions + recalls) + f1_thresholds = torch.where( + torch.isnan(f1_thresholds), torch.zeros_like(f1_thresholds), f1_thresholds + ) + f1_value = torch.trapz(f1_thresholds) / len(thresholds) + return f1_value + + +def f1_score_si(tensor1, tensor2, thresholds): + tensor2 = ( + tensor2 + * torch.median(tensor1.norm(dim=-1)) + / torch.median(tensor2.norm(dim=-1)) + ) + f1_value = f1_score(tensor1, tensor2, thresholds) + return f1_value + + +DICT_METRICS = { + "d1": partial(delta, exponent=1.0), + "d2": partial(delta, exponent=2.0), + "d3": partial(delta, exponent=3.0), + "rmse": lambda gt, pred: torch.sqrt(((gt - pred) ** 2).mean()), + "rmselog": lambda gt, pred: torch.sqrt( + ((torch.log(gt) - torch.log(pred)) ** 2).mean() + ), + "arel": lambda gt, pred: (torch.abs(gt - pred) / gt).mean(), + "sqrel": lambda gt, pred: (((gt - pred) ** 2) / gt).mean(), + "log10": lambda gt, pred: torch.abs(torch.log10(pred) - torch.log10(gt)).mean(), + "silog": lambda gt, pred: 100 * torch.std(torch.log(pred) - torch.log(gt)).mean(), + "medianlog": lambda gt, pred: 100 + * (torch.log(pred) - torch.log(gt)).median().abs(), + "d_auc": d_auc, + "tau": partial(tau, perc=0.03), +} + + +DICT_METRICS_3D = { + "MSE_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2), + "arel_3d": lambda gt, pred, thresholds: torch.norm(gt - pred, dim=0, p=2) + / torch.norm(gt, dim=0, p=2), + "tau_3d": lambda gt, pred, thresholds: ( + (torch.norm(pred, dim=0, p=2) / torch.norm(gt, dim=0, p=2)).log().abs().exp() + < 1.25 + ) + .float() + .mean(), + "chamfer": lambda gt, pred, thresholds: chamfer_dist( + gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) + ), + "F1": lambda gt, pred, thresholds: f1_score( + gt.unsqueeze(0).permute(0, 2, 1), + pred.unsqueeze(0).permute(0, 2, 1), + thresholds=thresholds, + ), + "F1_si": lambda gt, pred, thresholds: f1_score_si( + gt.unsqueeze(0).permute(0, 2, 1), + pred.unsqueeze(0).permute(0, 2, 1), + thresholds=thresholds, + ), + "rays": lambda gt, pred, thresholds: rho( + gt.unsqueeze(0).permute(0, 2, 1), pred.unsqueeze(0).permute(0, 2, 1) + ), +} + + +DICT_METRICS_FLOW = { + "epe": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)), + "epe1": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 1, + "epe3": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 3, + "epe5": lambda gt, pred: torch.sqrt(torch.square(gt - pred).sum(dim=0)) < 5, +} + + +DICT_METRICS_D = { + "a1": lambda gt, pred: (torch.maximum((gt / pred), (pred / gt)) > 1.25**1.0).to( + torch.float32 + ), + "abs_rel": lambda gt, pred: (torch.abs(gt - pred) / gt), +} + + +def eval_depth( + gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, max_depth=None +): + summary_metrics = defaultdict(list) + # preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") + for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): + if max_depth is not None: + mask = mask & (gt <= max_depth) + for name, fn in DICT_METRICS.items(): + if name in ["tau", "d1", "arel"]: + for rescale_fn in ["ssi", "si"]: + summary_metrics[f"{name}_{rescale_fn}"].append( + fn(gt[mask], eval(rescale_fn)(gt[mask], pred[mask])) + ) + summary_metrics[name].append(fn(gt[mask], pred[mask]).mean()) + return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} + + +def eval_3d( + gts: torch.Tensor, preds: torch.Tensor, masks: torch.Tensor, thresholds=None +): + summary_metrics = defaultdict(list) + MAX_PIXELS = 75_000 # 300_000 + ratio = min(1.0, (MAX_PIXELS / masks[0].sum()) ** 0.5) + h_max, w_max = int(gts.shape[-2] * ratio), int(gts.shape[-1] * ratio) + gts = F.interpolate(gts, size=(h_max, w_max), mode="nearest-exact") + preds = F.interpolate(preds, size=(h_max, w_max), mode="nearest-exact") + masks = F.interpolate( + masks.float(), size=(h_max, w_max), mode="nearest-exact" + ).bool() + + for i, (gt, pred, mask) in enumerate(zip(gts, preds, masks)): + if not torch.any(mask): + continue + for name, fn in DICT_METRICS_3D.items(): + summary_metrics[name].append( + fn(gt[:, mask.squeeze()], pred[:, mask.squeeze()], thresholds).mean() + ) + return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} + + +def compute_aucs(gt, pred, mask, uncertainties, steps=50, metrics=["abs_rel"]): + dict_ = {} + x_axis = torch.linspace(0, 1, steps=steps + 1, device=gt.device) + quantiles = torch.linspace(0, 1 - 1 / steps, steps=steps, device=gt.device) + zer = torch.tensor(0.0, device=gt.device) + # revert order (high uncertainty first) + uncertainties = -uncertainties[mask] + gt = gt[mask] + pred = pred[mask] + true_uncert = {metric: -DICT_METRICS_D[metric](gt, pred) for metric in metrics} + # get percentiles for sampling and corresponding subsets + thresholds = torch.quantile(uncertainties, quantiles) + subs = [(uncertainties >= t) for t in thresholds] + + # compute sparsification curves for each metric (add 0 for final sampling) + for metric in metrics: + opt_thresholds = torch.quantile(true_uncert[metric], quantiles) + opt_subs = [(true_uncert[metric] >= t) for t in opt_thresholds] + sparse_curve = torch.stack( + [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in subs] + [zer], dim=0 + ) + opt_curve = torch.stack( + [DICT_METRICS[metric](gt[sub], pred[sub]) for sub in opt_subs] + [zer], + dim=0, + ) + rnd_curve = DICT_METRICS[metric](gt, pred) + + dict_[f"AUSE_{metric}"] = torch.trapz(sparse_curve - opt_curve, x=x_axis) + dict_[f"AURG_{metric}"] = rnd_curve - torch.trapz(sparse_curve, x=x_axis) + + return dict_ + + +def eval_depth_uncertainties( + gts: torch.Tensor, + preds: torch.Tensor, + uncertainties: torch.Tensor, + masks: torch.Tensor, + max_depth=None, +): + summary_metrics = defaultdict(list) + preds = F.interpolate(preds, gts.shape[-2:], mode="bilinear") + for i, (gt, pred, mask, uncertainty) in enumerate( + zip(gts, preds, masks, uncertainties) + ): + if max_depth is not None: + mask = torch.logical_and(mask, gt < max_depth) + for name, fn in DICT_METRICS.items(): + summary_metrics[name].append(fn(gt[mask], pred[mask])) + for name, val in compute_aucs(gt, pred, mask, uncertainty).items(): + summary_metrics[name].append(val) + return {name: torch.stack(vals, dim=0) for name, vals in summary_metrics.items()} + + +def lazy_eval_depth( + gt_fns, pred_fns, min_depth=1e-2, max_depth=None, depth_scale=256.0 +): + summary_metrics = defaultdict(list) + for i, (gt_fn, pred_fn) in enumerate(zip(gt_fns, pred_fns)): + gt = TF.pil_to_tensor(Image.open(gt_fn)).to(torch.float32) / depth_scale + pred = TF.pil_to_tensor(Image.open(pred_fn)).to(torch.float32) / depth_scale + mask = gt > min_depth + if max_depth is not None: + mask_2 = gt < max_depth + mask = torch.logical_and(mask, mask_2) + for name, fn in DICT_METRICS.items(): + summary_metrics[name].append(fn(gt[mask], pred[mask])) + + return {name: torch.mean(vals).item() for name, vals in summary_metrics.items()} diff --git a/unik3d/utils/geometric.py b/unik3d/utils/geometric.py new file mode 100644 index 0000000000000000000000000000000000000000..5c61d8e5e7ef954b2423063057eba4dbe72f6704 --- /dev/null +++ b/unik3d/utils/geometric.py @@ -0,0 +1,479 @@ +from typing import Tuple + +import torch +from torch.nn import functional as F + + +# @torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) +def generate_rays( + camera_intrinsics: torch.Tensor, image_shape: Tuple[int, int], noisy: bool = False +): + batch_size, device, dtype = ( + camera_intrinsics.shape[0], + camera_intrinsics.device, + camera_intrinsics.dtype, + ) + # print("CAMERA DTYPE", dtype) + height, width = image_shape + # Generate grid of pixel coordinates + pixel_coords_x = torch.linspace(0, width - 1, width, device=device, dtype=dtype) + pixel_coords_y = torch.linspace(0, height - 1, height, device=device, dtype=dtype) + if noisy: + pixel_coords_x += torch.rand_like(pixel_coords_x) - 0.5 + pixel_coords_y += torch.rand_like(pixel_coords_y) - 0.5 + pixel_coords = torch.stack( + [pixel_coords_x.repeat(height, 1), pixel_coords_y.repeat(width, 1).t()], dim=2 + ) # (H, W, 2) + pixel_coords = pixel_coords + 0.5 + + # Calculate ray directions + intrinsics_inv = torch.inverse(camera_intrinsics.float()).to(dtype) + homogeneous_coords = torch.cat( + [pixel_coords, torch.ones_like(pixel_coords[:, :, :1])], dim=2 + ) # (H, W, 3) + ray_directions = torch.matmul( + intrinsics_inv, homogeneous_coords.permute(2, 0, 1).flatten(1) + ) # (3, H*W) + + # unstable normalization, need float32? + ray_directions = F.normalize(ray_directions, dim=1) # (B, 3, H*W) + ray_directions = ray_directions.permute(0, 2, 1) # (B, H*W, 3) + + theta = torch.atan2(ray_directions[..., 0], ray_directions[..., -1]) + phi = torch.acos(ray_directions[..., 1]) + # pitch = torch.asin(ray_directions[..., 1]) + # roll = torch.atan2(ray_directions[..., 0], - ray_directions[..., 1]) + angles = torch.stack([theta, phi], dim=-1) + return ray_directions, angles + + +@torch.jit.script +def spherical_zbuffer_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract zbuffer depth + + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + # => + # r = z / sin(phi) / cos(theta) + # y = z / (sin(phi) / cos(phi)) / cos(theta) + # x = z * sin(theta) / cos(theta) + x = z * torch.tan(theta) + y = z / torch.tan(phi) / torch.cos(theta) + + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def spherical_to_euclidean(spherical_tensor: torch.Tensor) -> torch.Tensor: + theta = spherical_tensor[..., 0] # Extract polar angle + phi = spherical_tensor[..., 1] # Extract azimuthal angle + r = spherical_tensor[..., 2] # Extract radius + # x = r * torch.sin(theta) * torch.sin(phi) + # y = r * torch.cos(theta) + # z = r * torch.cos(phi) * torch.sin(theta) + x = r * torch.sin(theta) * torch.cos(phi) + y = r * torch.sin(theta) * torch.sin(phi) + z = r * torch.cos(theta) + euclidean_tensor = torch.stack((x, y, z), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def euclidean_to_spherical(spherical_tensor: torch.Tensor) -> torch.Tensor: + x = spherical_tensor[..., 0] # Extract polar angle + y = spherical_tensor[..., 1] # Extract azimuthal angle + z = spherical_tensor[..., 2] # Extract radius + # y = r * cos(phi) + # x = r * sin(phi) * sin(theta) + # z = r * sin(phi) * cos(theta) + r = torch.sqrt(x**2 + y**2 + z**2) + theta = torch.atan2(x / r, z / r) + phi = torch.acos(y / r) + + euclidean_tensor = torch.stack((theta, phi, r), dim=-1) + return euclidean_tensor + + +@torch.jit.script +def euclidean_to_spherical_zbuffer(euclidean_tensor: torch.Tensor) -> torch.Tensor: + pitch = torch.asin(euclidean_tensor[..., 1]) + yaw = torch.atan2(euclidean_tensor[..., 0], euclidean_tensor[..., -1]) + z = euclidean_tensor[..., 2] # Extract zbuffer depth + euclidean_tensor = torch.stack((pitch, yaw, z), dim=-1) + return euclidean_tensor + + +@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) +def unproject_points( + depth: torch.Tensor, camera_intrinsics: torch.Tensor +) -> torch.Tensor: + """ + Unprojects a batch of depth maps to 3D point clouds using camera intrinsics. + + Args: + depth (torch.Tensor): Batch of depth maps of shape (B, 1, H, W). + camera_intrinsics (torch.Tensor): Camera intrinsic matrix of shape (B, 3, 3). + + Returns: + torch.Tensor: Batch of 3D point clouds of shape (B, 3, H, W). + """ + batch_size, _, height, width = depth.shape + device = depth.device + + # Create pixel grid + y_coords, x_coords = torch.meshgrid( + torch.arange(height, device=device), + torch.arange(width, device=device), + indexing="ij", + ) + pixel_coords = torch.stack((x_coords, y_coords), dim=-1) # (H, W, 2) + + # Get homogeneous coords (u v 1) + pixel_coords_homogeneous = torch.cat( + (pixel_coords, torch.ones((height, width, 1), device=device)), dim=-1 + ) + pixel_coords_homogeneous = pixel_coords_homogeneous.permute(2, 0, 1).flatten( + 1 + ) # (3, H*W) + # Apply K^-1 @ (u v 1): [B, 3, 3] @ [3, H*W] -> [B, 3, H*W] + camera_intrinsics_inv = camera_intrinsics.clone() + # invert camera intrinsics + camera_intrinsics_inv[:, 0, 0] = 1 / camera_intrinsics_inv[:, 0, 0] + camera_intrinsics_inv[:, 1, 1] = 1 / camera_intrinsics_inv[:, 1, 1] + + unprojected_points = camera_intrinsics_inv @ pixel_coords_homogeneous # (B, 3, H*W) + unprojected_points = unprojected_points.view( + batch_size, 3, height, width + ) # (B, 3, H, W) + unprojected_points = unprojected_points * depth # (B, 3, H, W) + return unprojected_points + + +@torch.jit.script +def project_points( + points_3d: torch.Tensor, + intrinsic_matrix: torch.Tensor, + image_shape: Tuple[int, int], +) -> torch.Tensor: + # Project 3D points onto the image plane via intrinsics (u v w) = (x y z) @ K^T + points_2d = torch.matmul(points_3d, intrinsic_matrix.transpose(1, 2)) + + # Normalize projected points: (u v w) -> (u / w, v / w, 1) + points_2d = points_2d[..., :2] / points_2d[..., 2:] + + # To pixels (rounding!!!), no int as it breaks gradient + points_2d = points_2d.round() + + # pointa need to be inside the image (can it diverge onto all points out???) + valid_mask = ( + (points_2d[..., 0] >= 0) + & (points_2d[..., 0] < image_shape[1]) + & (points_2d[..., 1] >= 0) + & (points_2d[..., 1] < image_shape[0]) + ) + + # Calculate the flat indices of the valid pixels + flat_points_2d = points_2d[..., 0] + points_2d[..., 1] * image_shape[1] + flat_indices = flat_points_2d.long() + + # Create depth maps and counts using scatter_add, (B, H, W) + depth_maps = torch.zeros( + [points_3d.shape[0], *image_shape], device=points_3d.device + ) + counts = torch.zeros([points_3d.shape[0], *image_shape], device=points_3d.device) + + # Loop over batches to apply masks and accumulate depth/count values + for i in range(points_3d.shape[0]): + valid_indices = flat_indices[i, valid_mask[i]] + depth_maps[i].view(-1).scatter_add_( + 0, valid_indices, points_3d[i, valid_mask[i], 2] + ) + counts[i].view(-1).scatter_add_( + 0, valid_indices, torch.ones_like(points_3d[i, valid_mask[i], 2]) + ) + + # Calculate mean depth for each pixel in each batch + mean_depth_maps = depth_maps / counts.clamp(min=1.0) + return mean_depth_maps.reshape(-1, 1, *image_shape) # (B, 1, H, W) + + +@torch.jit.script +def downsample(data: torch.Tensor, downsample_factor: int = 2): + N, _, H, W = data.shape + data = data.view( + N, + H // downsample_factor, + downsample_factor, + W // downsample_factor, + downsample_factor, + 1, + ) + data = data.permute(0, 1, 3, 5, 2, 4).contiguous() + data = data.view(-1, downsample_factor * downsample_factor) + data_tmp = torch.where(data == 0.0, 1e5 * torch.ones_like(data), data) + data = torch.min(data_tmp, dim=-1).values + data = data.view(N, 1, H // downsample_factor, W // downsample_factor) + data = torch.where(data > 1000, torch.zeros_like(data), data) + return data + + +@torch.jit.script +def flat_interpolate( + flat_tensor: torch.Tensor, + old: Tuple[int, int], + new: Tuple[int, int], + antialias: bool = False, + mode: str = "bilinear", +) -> torch.Tensor: + if old[0] == new[0] and old[1] == new[1]: + return flat_tensor + tensor = flat_tensor.view(flat_tensor.shape[0], old[0], old[1], -1).permute( + 0, 3, 1, 2 + ) # b c h w + tensor_interp = F.interpolate( + tensor, + size=(new[0], new[1]), + mode=mode, + align_corners=False, + antialias=antialias, + ) + flat_tensor_interp = tensor_interp.view( + flat_tensor.shape[0], -1, new[0] * new[1] + ).permute( + 0, 2, 1 + ) # b (h w) c + return flat_tensor_interp.contiguous() + + +# # @torch.jit.script +# def displacement_relative_neighbour(gt: torch.Tensor, mask: torch.Tensor = None, kernel_size: int = 7, ndim: int =4): +# pad = kernel_size // 2 +# n_neighbours = int(kernel_size**2) + +# # when torchscipt will support nested generators in listcomp or usage of range +# # in product(range_, range_), then use listcomp, so far speedup ~5% wrt std python +# if mask is None: +# mask = torch.ones_like(gt).bool() + +# lst_gts, lst_masks = [], [] +# for i in range(-kernel_size//2 + 1, kernel_size//2 + 1): +# for j in range(-kernel_size//2 + 1, kernel_size//2 + 1): +# if i != 0 or j != 0: +# lst_gts.append(torch.roll(gt, shifts=(i, j), dims=(-2, -1))) +# lst_masks.append(torch.roll(F.pad(mask, (pad,) * 4), shifts=(i, j), dims=(-2, -1))) +# gts = torch.cat(lst_gts, dim=-3) +# masks = torch.cat(lst_masks, dim=-3) + +# masks = masks[..., pad:-pad, pad:-pad] +# masks[~mask.repeat(*(1,) * (ndim - 3), n_neighbours-1, 1, 1,)] = False # No displacement known if seed is missing +# log_gts = gts.clamp(min=1e-6).log() - gt.repeat(*(1,) * (ndim - 3), n_neighbours-1, 1, 1).clamp(min=1e-6).log() +# return log_gts, masks + + +# @torch.jit.script +# def antidisplacement_relative_neighbour(preds: torch.Tensor, kernel_size: int = 7): +# lst_preds, lst_masks = [], [] +# cnt = 0 +# pad = kernel_size // 2 +# mask = F.pad(torch.ones((preds.shape[0], 1, preds.shape[-2], preds.shape[-1]), device=preds.device), (pad,) * 4) +# for i in range(-kernel_size//2 + 1, kernel_size//2 + 1): +# for j in range(-kernel_size//2 + 1, kernel_size//2 + 1): +# if i != 0 or j !=0: +# lst_preds.append(torch.roll(preds[:, cnt], shifts=(-i, -j), dims=(-2, -1))) +# lst_masks.append(torch.roll(mask, shifts=(-i, -j), dims=(-2, -1))) +# cnt += 1 +# preds_ensamble = torch.stack(lst_preds, dim=1) +# masks = torch.cat(lst_masks, dim=1) +# masks = masks[..., pad:-pad, pad:-pad] +# return preds_ensamble, masks + + +# def unproject(uv, fx, fy, cx, cy, xi=0, alpha=0): +# u, v = uv.unbind(dim=1) +# mx = (u - cx) / fx +# my = (v - cy) / fy +# r_square = mx ** 2 + my ** 2 +# root = 1 - (2 * alpha - 1) * r_square +# valid_mask = root >= 0 +# root[~valid_mask] = 0.0 +# mz = (1 - (alpha ** 2) * r_square) / (alpha * torch.sqrt(root) + (1 - alpha)) +# coeff = (mz * xi + torch.sqrt(mz ** 2 + (1 - xi ** 2) * r_square)) / (mz ** 2 + r_square) + +# x = coeff * mx +# y = coeff * my +# z = coeff * mz - xi +# # z = z.clamp(min=1e-7) + +# x_norm = x / z +# y_norm = y / z +# z_norm = z / z +# xnorm = torch.stack(( x_norm, y_norm, z_norm ), dim=1) +# # print("unproj", xnorm.shape, xnorm[:, -1].mean()) + +# return xnorm, valid_mask.unsqueeze(1).repeat(1, 3, 1, 1) + + +# def project(point3D, fx, fy, cx, cy, xi=0, alpha=0): +# B, C, H, W = point3D.shape +# x, y, z = point3D.unbind(dim=1) +# z = z.clamp(min=1e-7) +# d_1 = torch.sqrt( x ** 2 + y ** 2 + z ** 2 ) +# d_2 = torch.sqrt( x ** 2 + y ** 2 + (xi * d_1 + z) ** 2 ) + +# div = alpha * d_2 + (1 - alpha) * (xi * d_1 + z) +# Xnorm = fx * x / div + cx +# Ynorm = fy * y / div + cy + +# coords = torch.stack([Xnorm, Ynorm], dim=1) +# w1 = torch.where(alpha <= 0.5, alpha / (1 - alpha), (1 - alpha) / alpha) +# w2 = w1 + xi / ((2 * w1 * xi + xi ** 2 + 1) ** 0.5) +# valid_mask = z > - w2 * d_1 + +# # Return pixel coordinates +# return coords, valid_mask.unsqueeze(1).repeat(1, 2, 1, 1) + + +@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) +def unproject(uv, fx, fy, cx, cy, alpha=None, beta=None): + uv = uv.float() + fx = fx.float() + fy = fy.float() + cx = cx.float() + cy = cy.float() + u, v = uv.unbind(dim=1) + alpha = torch.zeros_like(fx) if alpha is None else alpha.float() + beta = torch.ones_like(fx) if beta is None else beta.float() + mx = (u - cx) / fx + my = (v - cy) / fy + r_square = mx**2 + my**2 + valid_mask = r_square < torch.where(alpha < 0.5, 1e6, 1 / (beta * (2 * alpha - 1))) + sqrt_val = 1 - (2 * alpha - 1) * beta * r_square + mz = (1 - beta * (alpha**2) * r_square) / ( + alpha * torch.sqrt(sqrt_val.clip(min=1e-5)) + (1 - alpha) + ) + coeff = 1 / torch.sqrt(mx**2 + my**2 + mz**2 + 1e-5) + + x = coeff * mx + y = coeff * my + z = coeff * mz + valid_mask = valid_mask & (z > 1e-3) + + xnorm = torch.stack((x, y, z.clamp(1e-3)), dim=1) + return xnorm, valid_mask.unsqueeze(1) + + +@torch.autocast(device_type="cuda", enabled=False, dtype=torch.float32) +def project(point3D, fx, fy, cx, cy, alpha=None, beta=None): + H, W = point3D.shape[-2:] + alpha = torch.zeros_like(fx) if alpha is None else alpha + beta = torch.ones_like(fx) if beta is None else beta + x, y, z = point3D.unbind(dim=1) + d = torch.sqrt(beta * (x**2 + y**2) + z**2) + + x = x / (alpha * d + (1 - alpha) * z).clip(min=1e-3) + y = y / (alpha * d + (1 - alpha) * z).clip(min=1e-3) + + Xnorm = fx * x + cx + Ynorm = fy * y + cy + + coords = torch.stack([Xnorm, Ynorm], dim=1) + + invalid = ( + (coords[:, 0] < 0) + | (coords[:, 0] > W) + | (coords[:, 1] < 0) + | (coords[:, 1] > H) + | (z < 0) + ) + + # Return pixel coordinates + return coords, (~invalid).unsqueeze(1) + + +def rays2angles(rays: torch.Tensor) -> torch.Tensor: + theta = torch.atan2(rays[..., 0], rays[..., -1]) + phi = torch.acos(rays[..., 1]) + angles = torch.stack([theta, phi], dim=-1) + return angles + + +@torch.jit.script +def dilate(image, kernel_size: int | tuple[int, int]): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + device, dtype = image.device, image.dtype + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device) + dilated_image = F.conv2d(image.float(), kernel, padding=padding, stride=1) + dilated_image = torch.where( + dilated_image > 0, + torch.tensor(1.0, device=device), + torch.tensor(0.0, device=device), + ) + return dilated_image.to(dtype) + + +@torch.jit.script +def erode(image, kernel_size: int | tuple[int, int]): + if isinstance(kernel_size, int): + kernel_size = (kernel_size, kernel_size) + device, dtype = image.device, image.dtype + padding = (kernel_size[0] // 2, kernel_size[1] // 2) + kernel = torch.ones((1, 1, *kernel_size), dtype=torch.float32, device=image.device) + eroded_image = F.conv2d(image.float(), kernel, padding=padding, stride=1) + eroded_image = torch.where( + eroded_image == (kernel_size[0] * kernel_size[1]), + torch.tensor(1.0, device=device), + torch.tensor(0.0, device=device), + ) + return eroded_image.to(dtype) + + +@torch.jit.script +def iou(mask1: torch.Tensor, mask2: torch.Tensor) -> torch.Tensor: + device = mask1.device + + # Ensure the masks are binary (0 or 1) + mask1 = mask1.to(torch.bool) + mask2 = mask2.to(torch.bool) + + # Compute intersection and union + intersection = torch.sum(mask1 & mask2).to(torch.float32) + union = torch.sum(mask1 | mask2).to(torch.float32) + + # Compute IoU + iou = intersection / union.clip(min=1.0) + + return iou + + +if __name__ == "__main__": + kernel_size = 3 + image = torch.tensor( + [ + [ + [ + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + [1, 1, 1, 1, 1], + ] + ] + ], + dtype=torch.bool, + ) + + print("testing dilate and erode, with image:\n", image, image.shape) + + # Perform dilation + dilated_image = dilate(image, kernel_size) + print("Dilated Image:\n", dilated_image) + + # Perform erosion + eroded_image = erode(image, kernel_size) + print("Eroded Image:\n", eroded_image) diff --git a/unik3d/utils/knn.py b/unik3d/utils/knn.py new file mode 100644 index 0000000000000000000000000000000000000000..e12787497268f37cc022ef14207e084d4d5eb352 --- /dev/null +++ b/unik3d/utils/knn.py @@ -0,0 +1,248 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the BSD-style license found in the +# LICENSE file in the root directory of this source tree. + +# pyre-unsafe + +from collections import namedtuple +from typing import Union + +import torch +from torch.autograd import Function +from torch.autograd.function import once_differentiable + +_KNN = namedtuple("KNN", "dists idx knn") + + +class _knn_points(Function): + """ + Torch autograd Function wrapper for KNN C++/CUDA implementations. + """ + + @staticmethod + # pyre-fixme[14]: `forward` overrides method defined in `Function` inconsistently. + def forward( + ctx, + p1, + p2, + lengths1, + lengths2, + K, + version, + norm: int = 2, + return_sorted: bool = True, + ): + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + norm: (int) indicating the norm. Only supports 1 (for L1) and 2 (for L2). + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. + + Returns: + p1_dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + p1_idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 points. + """ + if not ((norm == 1) or (norm == 2)): + raise ValueError("Support for 1 or 2 norm.") + + idx, dists = _C.knn_points_idx(p1, p2, lengths1, lengths2, norm, K, version) + + # sort KNN in ascending order if K > 1 + if K > 1 and return_sorted: + if lengths2.min() < K: + P1 = p1.shape[1] + mask = lengths2[:, None] <= torch.arange(K, device=dists.device)[None] + # mask has shape [N, K], true where dists irrelevant + mask = mask[:, None].expand(-1, P1, -1) + # mask has shape [N, P1, K], true where dists irrelevant + dists[mask] = float("inf") + dists, sort_idx = dists.sort(dim=2) + dists[mask] = 0 + else: + dists, sort_idx = dists.sort(dim=2) + idx = idx.gather(2, sort_idx) + + ctx.save_for_backward(p1, p2, lengths1, lengths2, idx) + ctx.mark_non_differentiable(idx) + ctx.norm = norm + return dists, idx + + @staticmethod + @once_differentiable + def backward(ctx, grad_dists, grad_idx): + p1, p2, lengths1, lengths2, idx = ctx.saved_tensors + norm = ctx.norm + # TODO(gkioxari) Change cast to floats once we add support for doubles. + if not (grad_dists.dtype == torch.float32): + grad_dists = grad_dists.float() + if not (p1.dtype == torch.float32): + p1 = p1.float() + if not (p2.dtype == torch.float32): + p2 = p2.float() + grad_p1, grad_p2 = _C.knn_points_backward( + p1, p2, lengths1, lengths2, idx, norm, grad_dists + ) + return grad_p1, grad_p2, None, None, None, None, None, None + + +def knn_points( + p1: torch.Tensor, + p2: torch.Tensor, + lengths1: Union[torch.Tensor, None] = None, + lengths2: Union[torch.Tensor, None] = None, + norm: int = 2, + K: int = 1, + version: int = -1, + return_nn: bool = False, + return_sorted: bool = True, +) -> _KNN: + """ + K-Nearest neighbors on point clouds. + + Args: + p1: Tensor of shape (N, P1, D) giving a batch of N point clouds, each + containing up to P1 points of dimension D. + p2: Tensor of shape (N, P2, D) giving a batch of N point clouds, each + containing up to P2 points of dimension D. + lengths1: LongTensor of shape (N,) of values in the range [0, P1], giving the + length of each pointcloud in p1. Or None to indicate that every cloud has + length P1. + lengths2: LongTensor of shape (N,) of values in the range [0, P2], giving the + length of each pointcloud in p2. Or None to indicate that every cloud has + length P2. + norm: Integer indicating the norm of the distance. Supports only 1 for L1, 2 for L2. + K: Integer giving the number of nearest neighbors to return. + version: Which KNN implementation to use in the backend. If version=-1, + the correct implementation is selected based on the shapes of the inputs. + return_nn: If set to True returns the K nearest neighbors in p2 for each point in p1. + return_sorted: (bool) whether to return the nearest neighbors sorted in + ascending order of distance. + + Returns: + dists: Tensor of shape (N, P1, K) giving the squared distances to + the nearest neighbors. This is padded with zeros both where a cloud in p2 + has fewer than K points and where a cloud in p1 has fewer than P1 points. + + idx: LongTensor of shape (N, P1, K) giving the indices of the + K nearest neighbors from points in p1 to points in p2. + Concretely, if `p1_idx[n, i, k] = j` then `p2[n, j]` is the k-th nearest + neighbors to `p1[n, i]` in `p2[n]`. This is padded with zeros both where a cloud + in p2 has fewer than K points and where a cloud in p1 has fewer than P1 + points. + + nn: Tensor of shape (N, P1, K, D) giving the K nearest neighbors in p2 for + each point in p1. Concretely, `p2_nn[n, i, k]` gives the k-th nearest neighbor + for `p1[n, i]`. Returned if `return_nn` is True. + The nearest neighbors are collected using `knn_gather` + + .. code-block:: + + p2_nn = knn_gather(p2, p1_idx, lengths2) + + which is a helper function that allows indexing any tensor of shape (N, P2, U) with + the indices `p1_idx` returned by `knn_points`. The output is a tensor + of shape (N, P1, K, U). + + """ + if p1.shape[0] != p2.shape[0]: + raise ValueError("pts1 and pts2 must have the same batch dimension.") + if p1.shape[2] != p2.shape[2]: + raise ValueError("pts1 and pts2 must have the same point dimension.") + + p1 = p1.contiguous() + p2 = p2.contiguous() + + P1 = p1.shape[1] + P2 = p2.shape[1] + + if lengths1 is None: + lengths1 = torch.full((p1.shape[0],), P1, dtype=torch.int64, device=p1.device) + if lengths2 is None: + lengths2 = torch.full((p1.shape[0],), P2, dtype=torch.int64, device=p1.device) + + p1_dists, p1_idx = _knn_points.apply( + p1, p2, lengths1, lengths2, K, version, norm, return_sorted + ) + + p2_nn = None + if return_nn: + p2_nn = knn_gather(p2, p1_idx, lengths2) + + return _KNN(dists=p1_dists, idx=p1_idx, knn=p2_nn if return_nn else None) + + +def knn_gather( + x: torch.Tensor, idx: torch.Tensor, lengths: Union[torch.Tensor, None] = None +): + """ + A helper function for knn that allows indexing a tensor x with the indices `idx` + returned by `knn_points`. + + For example, if `dists, idx = knn_points(p, x, lengths_p, lengths, K)` + where p is a tensor of shape (N, L, D) and x a tensor of shape (N, M, D), + then one can compute the K nearest neighbors of p with `p_nn = knn_gather(x, idx, lengths)`. + It can also be applied for any tensor x of shape (N, M, U) where U != D. + + Args: + x: Tensor of shape (N, M, U) containing U-dimensional features to + be gathered. + idx: LongTensor of shape (N, L, K) giving the indices returned by `knn_points`. + lengths: LongTensor of shape (N,) of values in the range [0, M], giving the + length of each example in the batch in x. Or None to indicate that every + example has length M. + Returns: + x_out: Tensor of shape (N, L, K, U) resulting from gathering the elements of x + with idx, s.t. `x_out[n, l, k] = x[n, idx[n, l, k]]`. + If `k > lengths[n]` then `x_out[n, l, k]` is filled with 0.0. + """ + N, M, U = x.shape + _N, L, K = idx.shape + + if N != _N: + raise ValueError("x and idx must have same batch dimension.") + + if lengths is None: + lengths = torch.full((x.shape[0],), M, dtype=torch.int64, device=x.device) + + idx_expanded = idx[:, :, :, None].expand(-1, -1, -1, U) + # idx_expanded has shape [N, L, K, U] + + x_out = x[:, :, None].expand(-1, -1, K, -1).gather(1, idx_expanded) + # p2_nn has shape [N, L, K, U] + + needs_mask = lengths.min() < K + if needs_mask: + # mask has shape [N, K], true where idx is irrelevant because + # there is less number of points in p2 than K + mask = lengths[:, None] <= torch.arange(K, device=x.device)[None] + + # expand mask to shape [N, L, K, U] + mask = mask[:, None].expand(-1, L, -1) + mask = mask[:, :, :, None].expand(-1, -1, -1, U) + x_out[mask] = 0.0 + + return x_out diff --git a/unik3d/utils/misc.py b/unik3d/utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..0d144b0cf119a781c2e13998db842717b0141ed9 --- /dev/null +++ b/unik3d/utils/misc.py @@ -0,0 +1,630 @@ +from functools import wraps +from time import time + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange, reduce, repeat +from scipy import interpolate + + +@torch.jit.script +def max_stack(tensors: list[torch.Tensor]) -> torch.Tensor: + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).max(dim=-1).values + + +def last_stack(tensors: list[torch.Tensor]) -> torch.Tensor: + return tensors[-1] + + +def first_stack(tensors: list[torch.Tensor]) -> torch.Tensor: + return tensors[0] + + +@torch.jit.script +def softmax_stack( + tensors: list[torch.Tensor], temperature: float = 1.0 +) -> torch.Tensor: + if len(tensors) == 1: + return tensors[0] + return F.softmax(torch.stack(tensors, dim=-1) / temperature, dim=-1).sum(dim=-1) + + +@torch.jit.script +def mean_stack(tensors: list[torch.Tensor]) -> torch.Tensor: + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).mean(dim=-1) + + +@torch.jit.script +def sum_stack(tensors: list[torch.Tensor]) -> torch.Tensor: + if len(tensors) == 1: + return tensors[0] + return torch.stack(tensors, dim=-1).sum(dim=-1) + + +def convert_module_to_f16(l): + """ + Convert primitive modules to float16. + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.half() + if l.bias is not None: + l.bias.data = l.bias.data.half() + + +def convert_module_to_f32(l): + """ + Convert primitive modules to float32, undoing convert_module_to_f16(). + """ + if isinstance(l, (nn.Conv1d, nn.Conv2d, nn.Conv3d)): + l.weight.data = l.weight.data.float() + if l.bias is not None: + l.bias.data = l.bias.data.float() + + +def format_seconds(seconds): + minutes, seconds = divmod(seconds, 60) + hours, minutes = divmod(minutes, 60) + return f"{hours:d}:{minutes:02d}:{seconds:02d}" + + +def get_params(module, lr, wd): + skip_list = {} + skip_keywords = {} + if hasattr(module, "no_weight_decay"): + skip_list = module.no_weight_decay() + if hasattr(module, "no_weight_decay_keywords"): + skip_keywords = module.no_weight_decay_keywords() + has_decay = [] + no_decay = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + continue # frozen weights + if ( + (name in skip_list) + or any((kw in name for kw in skip_keywords)) + or len(param.shape) == 1 + or name.endswith(".gamma") + or name.endswith(".beta") + or name.endswith(".bias") + ): + no_decay.append(param) + else: + has_decay.append(param) + + group1 = { + "params": has_decay, + "weight_decay": wd, + "lr": lr, + "weight_decay_init": wd, + "weight_decay_base": wd, + "lr_base": lr, + } + group2 = { + "params": no_decay, + "weight_decay": 0.0, + "lr": lr, + "weight_decay_init": 0.0, + "weight_decay_base": 0.0, + "weight_decay_final": 0.0, + "lr_base": lr, + } + return [group1, group2], [lr, lr] + + +def get_num_layer_for_swin(var_name, num_max_layer, layers_per_stage): + if var_name in ("cls_token", "mask_token", "pos_embed", "absolute_pos_embed"): + return 0 + elif var_name.startswith("patch_embed"): + return 0 + elif var_name.startswith("layers"): + if var_name.split(".")[2] == "blocks": + stage_id = int(var_name.split(".")[1]) + layer_id = int(var_name.split(".")[3]) + sum(layers_per_stage[:stage_id]) + return layer_id + 1 + elif var_name.split(".")[2] == "downsample": + stage_id = int(var_name.split(".")[1]) + layer_id = sum(layers_per_stage[: stage_id + 1]) + return layer_id + else: + return num_max_layer - 1 + + +def get_params_layerdecayswin(module, lr, wd, ld): + skip_list = {} + skip_keywords = {} + if hasattr(module, "no_weight_decay"): + skip_list = module.no_weight_decay() + if hasattr(module, "no_weight_decay_keywords"): + skip_keywords = module.no_weight_decay_keywords() + layers_per_stage = module.depths + num_layers = sum(layers_per_stage) + 1 + lrs = [] + params = [] + for name, param in module.named_parameters(): + if not param.requires_grad: + print(f"{name} frozen") + continue # frozen weights + layer_id = get_num_layer_for_swin(name, num_layers, layers_per_stage) + lr_cur = lr * ld ** (num_layers - layer_id - 1) + # if (name in skip_list) or any((kw in name for kw in skip_keywords)) or len(param.shape) == 1 or name.endswith(".bias"): + if (name in skip_list) or any((kw in name for kw in skip_keywords)): + wd_cur = 0.0 + else: + wd_cur = wd + params.append({"params": param, "weight_decay": wd_cur, "lr": lr_cur}) + lrs.append(lr_cur) + return params, lrs + + +def log(t, eps: float = 1e-5): + return torch.log(t.clamp(min=eps)) + + +def l2norm(t): + return F.normalize(t, dim=-1) + + +def exists(val): + return val is not None + + +def identity(t, *args, **kwargs): + return t + + +def divisible_by(numer, denom): + return (numer % denom) == 0 + + +def first(arr, d=None): + if len(arr) == 0: + return d + return arr[0] + + +def default(val, d): + if exists(val): + return val + return d() if callable(d) else d + + +def maybe(fn): + @wraps(fn) + def inner(x): + if not exists(x): + return x + return fn(x) + + return inner + + +def once(fn): + called = False + + @wraps(fn) + def inner(x): + nonlocal called + if called: + return + called = True + return fn(x) + + return inner + + +def _many(fn): + @wraps(fn) + def inner(tensors, pattern, **kwargs): + return (fn(tensor, pattern, **kwargs) for tensor in tensors) + + return inner + + +rearrange_many = _many(rearrange) +repeat_many = _many(repeat) +reduce_many = _many(reduce) + + +def load_pretrained(state_dict, checkpoint): + checkpoint_model = checkpoint["model"] + if any([True if "encoder." in k else False for k in checkpoint_model.keys()]): + checkpoint_model = { + k.replace("encoder.", ""): v + for k, v in checkpoint_model.items() + if k.startswith("encoder.") + } + print("Detect pre-trained model, remove [encoder.] prefix.") + else: + print("Detect non-pre-trained model, pass without doing anything.") + print(f">>>>>>>>>> Remapping pre-trained keys for SWIN ..........") + checkpoint = load_checkpoint_swin(state_dict, checkpoint_model) + + +def load_checkpoint_swin(model, checkpoint_model): + state_dict = model.state_dict() + # Geometric interpolation when pre-trained patch size mismatch with fine-tuned patch size + all_keys = list(checkpoint_model.keys()) + for key in all_keys: + if "relative_position_bias_table" in key: + relative_position_bias_table_pretrained = checkpoint_model[key] + relative_position_bias_table_current = state_dict[key] + L1, nH1 = relative_position_bias_table_pretrained.size() + L2, nH2 = relative_position_bias_table_current.size() + if nH1 != nH2: + print(f"Error in loading {key}, passing......") + else: + if L1 != L2: + print(f"{key}: Interpolate relative_position_bias_table using geo.") + src_size = int(L1**0.5) + dst_size = int(L2**0.5) + + def geometric_progression(a, r, n): + return a * (1.0 - r**n) / (1.0 - r) + + left, right = 1.01, 1.5 + while right - left > 1e-6: + q = (left + right) / 2.0 + gp = geometric_progression(1, q, src_size // 2) + if gp > dst_size // 2: + right = q + else: + left = q + + # if q > 1.090307: + # q = 1.090307 + + dis = [] + cur = 1 + for i in range(src_size // 2): + dis.append(cur) + cur += q ** (i + 1) + + r_ids = [-_ for _ in reversed(dis)] + + x = r_ids + [0] + dis + y = r_ids + [0] + dis + + t = dst_size // 2.0 + dx = np.arange(-t, t + 0.1, 1.0) + dy = np.arange(-t, t + 0.1, 1.0) + + print("Original positions = %s" % str(x)) + print("Target positions = %s" % str(dx)) + + all_rel_pos_bias = [] + + for i in range(nH1): + z = ( + relative_position_bias_table_pretrained[:, i] + .view(src_size, src_size) + .float() + .numpy() + ) + f_cubic = interpolate.interp2d(x, y, z, kind="cubic") + all_rel_pos_bias.append( + torch.Tensor(f_cubic(dx, dy)) + .contiguous() + .view(-1, 1) + .to(relative_position_bias_table_pretrained.device) + ) + + new_rel_pos_bias = torch.cat(all_rel_pos_bias, dim=-1) + checkpoint_model[key] = new_rel_pos_bias + + # delete relative_position_index since we always re-init it + relative_position_index_keys = [ + k for k in checkpoint_model.keys() if "relative_position_index" in k + ] + for k in relative_position_index_keys: + del checkpoint_model[k] + + # delete relative_coords_table since we always re-init it + relative_coords_table_keys = [ + k for k in checkpoint_model.keys() if "relative_coords_table" in k + ] + for k in relative_coords_table_keys: + del checkpoint_model[k] + + # # re-map keys due to name change + rpe_mlp_keys = [k for k in checkpoint_model.keys() if "cpb_mlp" in k] + for k in rpe_mlp_keys: + checkpoint_model[k.replace("cpb_mlp", "rpe_mlp")] = checkpoint_model.pop(k) + + # delete attn_mask since we always re-init it + attn_mask_keys = [k for k in checkpoint_model.keys() if "attn_mask" in k] + for k in attn_mask_keys: + del checkpoint_model[k] + + encoder_keys = [k for k in checkpoint_model.keys() if k.startswith("encoder.")] + for k in encoder_keys: + checkpoint_model[k.replace("encoder.", "")] = checkpoint_model.pop(k) + + return checkpoint_model + + +def add_padding_metas(out, image_metas): + device = out.device + # left, right, top, bottom + paddings = [img_meta.get("paddings", [0] * 4) for img_meta in image_metas] + paddings = torch.stack(paddings).to(device) + outs = [F.pad(o, padding, value=0.0) for padding, o in zip(paddings, out)] + return torch.stack(outs) + + +# left, right, top, bottom +def remove_padding(out, paddings): + H, W = out.shape[-2:] + outs = [ + o[..., padding[2] : H - padding[3], padding[0] : W - padding[1]] + for padding, o in zip(paddings, out) + ] + return torch.stack(outs) + + +def remove_padding_metas(out, image_metas): + B, C, H, W = out.shape + device = out.device + # left, right, top, bottom + paddings = [ + torch.tensor(img_meta.get("paddings", [0] * 4)) for img_meta in image_metas + ] + return remove_padding(out, paddings) + + +def ssi_helper(tensor1, tensor2): + stability_mat = 1e-4 * torch.eye(2, device=tensor1.device) + tensor2_one = torch.stack([tensor2, torch.ones_like(tensor2)], dim=1) + scale_shift = torch.inverse(tensor2_one.T @ tensor2_one + stability_mat) @ ( + tensor2_one.T @ tensor1.unsqueeze(1) + ) + scale, shift = scale_shift.squeeze().chunk(2, dim=0) + return scale, shift + + +def calculate_mean_values(names, values): + # Create a defaultdict to store sum and count for each name + name_values = {name: {} for name in names} + + # Iterate through the lists and accumulate values for each name + for name, value in zip(names, values): + name_values[name]["sum"] = name_values[name].get("sum", 0.0) + value + name_values[name]["count"] = name_values[name].get("count", 0.0) + 1 + + # Calculate mean values and create the output dictionary + output_dict = { + name: name_values[name]["sum"] / name_values[name]["count"] + for name in name_values + } + + return output_dict + + +def remove_leading_dim(infos): + if isinstance(infos, dict): + return {k: remove_leading_dim(v) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos.squeeze(0) + else: + return infos + + +def recursive_index(infos, index): + if isinstance(infos, dict): + return {k: recursive_index(v, index) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos[index] + else: + return infos + + +def to_cpu(infos): + if isinstance(infos, dict): + return {k: to_cpu(v) for k, v in infos.items()} + elif isinstance(infos, torch.Tensor): + return infos.detach() + else: + return infos + + +def masked_mean( + data: torch.Tensor, + mask: torch.Tensor | None = None, + dim: list[int] | None = None, + keepdim: bool = False, +) -> torch.Tensor: + dim = dim if dim is not None else list(range(data.dim())) + if mask is None: + return data.mean(dim=dim, keepdim=keepdim) + mask = mask.float() + mask_sum = torch.sum(mask, dim=dim, keepdim=True) + mask_mean = torch.sum(data * mask, dim=dim, keepdim=True) / torch.clamp( + mask_sum, min=1.0 + ) + return mask_mean.squeeze(dim) if not keepdim else mask_mean + + +class ProfileMethod: + def __init__(self, model, func_name, track_statistics=True, verbose=False): + self.model = model + self.func_name = func_name + self.verbose = verbose + self.track_statistics = track_statistics + self.timings = [] + + def __enter__(self): + # Start timing + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.start_time = time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + self.end_time = time() + + elapsed_time = self.end_time - self.start_time + + self.timings.append(elapsed_time) + if self.track_statistics and len(self.timings) > 25: + + # Compute statistics if tracking + timings_array = np.array(self.timings) + mean_time = np.mean(timings_array) + std_time = np.std(timings_array) + quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) + print( + f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" + ) + print(f"Mean Time: {mean_time:.4f} seconds") + print(f"Std Time: {std_time:.4f} seconds") + print( + f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" + ) + + else: + print( + f"{self.model.__class__.__name__}.{self.func_name} took {elapsed_time:.4f} seconds" + ) + + +def profile_method(track_statistics=True, verbose=False): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with ProfileMethod(self, func.__name__, track_statistics, verbose): + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +class ProfileFunction: + def __init__(self, func_name, track_statistics=True, verbose=False): + self.func_name = func_name + self.verbose = verbose + self.track_statistics = track_statistics + self.timings = [] + + def __enter__(self): + # Start timing + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + self.start_time = time() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if self.verbose: + if torch.cuda.is_available(): + torch.cuda.synchronize() + + self.end_time = time() + + elapsed_time = self.end_time - self.start_time + + self.timings.append(elapsed_time) + if self.track_statistics and len(self.timings) > 25: + + # Compute statistics if tracking + timings_array = np.array(self.timings) + mean_time = np.mean(timings_array) + std_time = np.std(timings_array) + quantiles = np.percentile(timings_array, [0, 25, 50, 75, 100]) + print(f"{self.func_name} took {elapsed_time:.4f} seconds") + print(f"Mean Time: {mean_time:.4f} seconds") + print(f"Std Time: {std_time:.4f} seconds") + print( + f"Quantiles: Min={quantiles[0]:.4f}, 25%={quantiles[1]:.4f}, Median={quantiles[2]:.4f}, 75%={quantiles[3]:.4f}, Max={quantiles[4]:.4f}" + ) + + else: + print(f"{self.func_name} took {elapsed_time:.4f} seconds") + + +def profile_function(track_statistics=True, verbose=False): + def decorator(func): + @wraps(func) + def wrapper(self, *args, **kwargs): + with ProfileFunction(func.__name__, track_statistics, verbose): + return func(self, *args, **kwargs) + + return wrapper + + return decorator + + +def recursive_apply(inputs, func): + if isinstance(inputs, list): + return [recursive_apply(camera, func) for camera in inputs] + else: + return func(inputs) + + +def squeeze_list(nested_list, dim, current_dim=0): + # If the current dimension is in the list of indices to squeeze + if isinstance(nested_list, list) and len(nested_list) == 1 and current_dim == dim: + return squeeze_list(nested_list[0], dim, current_dim + 1) + elif isinstance(nested_list, list): + return [squeeze_list(item, dim, current_dim + 1) for item in nested_list] + else: + return nested_list + + +def match_gt(tensor1, tensor2, padding1, padding2, mode: str = "bilinear"): + """ + Transform each item in tensor1 batch to match tensor2's dimensions and padding. + + Args: + tensor1 (torch.Tensor): The input tensor to transform, with shape (batch_size, channels, height, width). + tensor2 (torch.Tensor): The target tensor to match, with shape (batch_size, channels, height, width). + padding1 (tuple): Padding applied to tensor1 (pad_left, pad_right, pad_top, pad_bottom). + padding2 (tuple): Desired padding to be applied to match tensor2 (pad_left, pad_right, pad_top, pad_bottom). + + Returns: + torch.Tensor: The batch of transformed tensors matching tensor2's size and padding. + """ + # Get batch size + batch_size = len(tensor1) + src_dtype = tensor1[0].dtype + tgt_dtype = tensor2[0].dtype + + # List to store transformed tensors + transformed_tensors = [] + + for i in range(batch_size): + item1 = tensor1[i] + item2 = tensor2[i] + + h1, w1 = item1.shape[1], item1.shape[2] + pad1_l, pad1_r, pad1_t, pad1_b = ( + padding1[i] if padding1 is not None else (0, 0, 0, 0) + ) + pad2_l, pad2_r, pad2_t, pad2_b = ( + padding2[i] if padding2 is not None else (0, 0, 0, 0) + ) + item1_unpadded = item1[:, pad1_t : h1 - pad1_b, pad1_l : w1 - pad1_r] + + h2, w2 = ( + item2.shape[1] - pad2_t - pad2_b, + item2.shape[2] - pad2_l - pad2_r, + ) + + item1_resized = F.interpolate( + item1_unpadded.unsqueeze(0).to(tgt_dtype), size=(h2, w2), mode=mode + ) + item1_padded = F.pad(item1_resized, (pad2_l, pad2_r, pad2_t, pad2_b)) + transformed_tensors.append(item1_padded) + + transformed_batch = torch.cat(transformed_tensors) + return transformed_batch.to(src_dtype) diff --git a/unik3d/utils/pose.py b/unik3d/utils/pose.py new file mode 100644 index 0000000000000000000000000000000000000000..f878bb257bcde1f98d2ce611ed85eeac1a8fcb17 --- /dev/null +++ b/unik3d/utils/pose.py @@ -0,0 +1,225 @@ +import torch +from torch.nn import functional as F + + +def quaternion_to_R(quaternions): + """ + Convert rotations given as quaternions to rotation matrices. + + Args: + quaternions: quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + r, i, j, k = torch.unbind(quaternions, -1) + two_s = 2.0 / (quaternions * quaternions).sum(-1) + + o = torch.stack( + ( + 1 - two_s * (j * j + k * k), + two_s * (i * j - k * r), + two_s * (i * k + j * r), + two_s * (i * j + k * r), + 1 - two_s * (i * i + k * k), + two_s * (j * k - i * r), + two_s * (i * k - j * r), + two_s * (j * k + i * r), + 1 - two_s * (i * i + j * j), + ), + -1, + ) + return o.reshape(quaternions.shape[:-1] + (3, 3)) + + +def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor: + """ + Convert a unit quaternion to a standard form: one in which the real + part is non negative. + + Args: + quaternions: Quaternions with real part first, + as tensor of shape (..., 4). + + Returns: + Standardized quaternions as tensor of shape (..., 4). + """ + return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions) + + +def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor: + """ + Returns torch.sqrt(torch.max(0, x)) + but with a zero subgradient where x is 0. + """ + ret = torch.zeros_like(x) + positive_mask = x > 0 + ret[positive_mask] = torch.sqrt(x[positive_mask]) + return ret + + +def R_to_quaternion(matrix: torch.Tensor) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to quaternions. + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + + Returns: + quaternions with real part first, as tensor of shape (..., 4). + """ + if matrix.size(-1) != 3 or matrix.size(-2) != 3: + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + batch_dim = matrix.shape[:-2] + m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind( + matrix.reshape(batch_dim + (9,)), dim=-1 + ) + + q_abs = _sqrt_positive_part( + torch.stack( + [ + 1.0 + m00 + m11 + m22, + 1.0 + m00 - m11 - m22, + 1.0 - m00 + m11 - m22, + 1.0 - m00 - m11 + m22, + ], + dim=-1, + ) + ) + # we produce the desired quaternion multiplied by each of r, i, j, k + quat_by_rijk = torch.stack( + [ + torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1), + torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1), + torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1), + torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1), + ], + dim=-2, + ) + + # We floor here at 0.1 but the exact level is not important; if q_abs is small, + # the candidate won't be picked. + flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device) + quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr)) + + # if not for numerical problems, quat_candidates[i] should be same (up to a sign), + # forall i; we pick the best-conditioned one (with the largest denominator) + out = quat_candidates[ + F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, : + ].reshape(batch_dim + (4,)) + out = standardize_quaternion(out) + return out + + +def Rt_to_pose(R, t): + assert R.shape[-2:] == (3, 3), "The last two dimensions of R must be 3x3" + assert t.shape[-2:] == (3, 1), "The last dimension of t must be 3" + + # Create the pose matrix + pose = torch.cat([R, t], dim=-1) + pose = F.pad(pose, (0, 0, 0, 1), value=0) + pose[..., 3, 3] = 1 + + return pose + + +def pose_to_Rt(pose): + assert pose.shape[-2:] == (4, 4), "The last two dimensions of pose must be 4x4" + + # Extract the rotation matrix and translation vector + R = pose[..., :3, :3] + t = pose[..., :3, 3:] + + return R, t + + +def relative_pose(pose1, pose2): + # Compute world_to_cam for pose1 + pose1_inv = invert_pose(pose1) + + # Relative pose as cam_to_world_2 -> world_to_cam_1 => cam2_to_cam1 + relative_pose = pose1_inv @ pose2 + + return relative_pose + + +@torch.autocast(device_type="cuda", dtype=torch.float32) +def invert_pose(pose): + R, t = pose_to_Rt(pose) + R_inv = R.transpose(-2, -1) + t_inv = -torch.matmul(R_inv, t) + pose_inv = Rt_to_pose(R_inv, t_inv) + return pose_inv + + +def apply_pose_transformation(point_cloud, pose): + reshape = point_cloud.ndim > 3 + shapes = point_cloud.shape + # Extract rotation and translation from pose + R, t = pose_to_Rt(pose) + + # Apply the pose transformation + if reshape: + point_cloud = point_cloud.reshape(shapes[0], -1, shapes[-1]) + transformed_points = torch.matmul(point_cloud, R.transpose(-2, -1)) + t.transpose( + -2, -1 + ) + if reshape: + transformed_points = transformed_points.reshape(shapes) + return transformed_points + + +def euler2mat(roll, pitch, yaw) -> torch.Tensor: + """ + Convert Euler angles (roll, pitch, yaw) to a 3x3 rotation matrix. + + Args: + euler_angles (torch.Tensor): Tensor of shape (N, 3) representing roll, pitch, yaw in radians. + - roll: rotation around z-axis + - pitch: rotation around x-axis + - yaw: rotation around y-axis + Returns: + torch.Tensor: Tensor of shape (N, 3, 3) representing the rotation matrices. + """ + + cos_r, sin_r = torch.cos(roll), torch.sin(roll) # Roll + cos_p, sin_p = torch.cos(pitch), torch.sin(pitch) # Pitch + cos_y, sin_y = torch.cos(yaw), torch.sin(yaw) # Yaw + + # Rotation matrices + R_z = torch.zeros((roll.shape[0], 3, 3), device=roll.device) + R_y = torch.zeros_like(R_z) + R_x = torch.zeros_like(R_z) + + # Z-axis (roll) + R_z[:, 0, 0], R_z[:, 0, 1], R_z[:, 1, 0], R_z[:, 1, 1], R_z[:, 2, 2] = ( + cos_y, + -sin_y, + sin_y, + cos_y, + 1.0, + ) + + # Y-axis (yaw) + R_y[:, 0, 0], R_y[:, 0, 2], R_y[:, 2, 0], R_y[:, 2, 2], R_y[:, 1, 1] = ( + cos_p, + sin_p, + -sin_p, + cos_p, + 1.0, + ) + + # X-axis (pitch) + R_x[:, 1, 1], R_x[:, 1, 2], R_x[:, 2, 1], R_x[:, 2, 2], R_x[:, 0, 0] = ( + cos_r, + -sin_r, + sin_r, + cos_r, + 1.0, + ) + + # Combine rotations: R = R_z * R_y * R_x + rotation_matrix = torch.matmul(torch.matmul(R_z, R_y), R_x) + return rotation_matrix diff --git a/unik3d/utils/positional_embedding.py b/unik3d/utils/positional_embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..6e3ac5727f8e12c2f155de5da66352326ce9bdbd --- /dev/null +++ b/unik3d/utils/positional_embedding.py @@ -0,0 +1,269 @@ +from math import log, pi +from typing import Optional + +import torch +import torch.nn as nn +from einops import rearrange, repeat + + +class PositionEmbeddingSine(nn.Module): + + def __init__( + self, num_pos_feats=64, temperature=10000, normalize=False, scale=None + ): + super().__init__() + self.num_pos_feats = num_pos_feats + self.temperature = temperature + self.normalize = normalize + if scale is not None and normalize is False: + raise ValueError("normalize should be True if scale is passed") + if scale is None: + scale = 2 * pi + self.scale = scale + + def forward( + self, x: torch.Tensor, mask: Optional[torch.Tensor] = None + ) -> torch.Tensor: + if mask is None: + mask = torch.zeros( + (x.size(0), x.size(2), x.size(3)), device=x.device, dtype=torch.bool + ) + not_mask = ~mask + y_embed = not_mask.cumsum(1, dtype=torch.float32) + x_embed = not_mask.cumsum(2, dtype=torch.float32) + if self.normalize: + eps = 1e-6 + y_embed = y_embed / (y_embed[:, -1:, :] + eps) * self.scale + x_embed = x_embed / (x_embed[:, :, -1:] + eps) * self.scale + + dim_t = torch.arange(self.num_pos_feats, dtype=torch.float32, device=x.device) + dim_t = self.temperature ** ( + 2 * torch.div(dim_t, 2, rounding_mode="floor") / self.num_pos_feats + ) + + pos_x = x_embed[:, :, :, None] / dim_t + pos_y = y_embed[:, :, :, None] / dim_t + pos_x = torch.stack( + (pos_x[:, :, :, 0::2].sin(), pos_x[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos_y = torch.stack( + (pos_y[:, :, :, 0::2].sin(), pos_y[:, :, :, 1::2].cos()), dim=4 + ).flatten(3) + pos = torch.cat((pos_y, pos_x), dim=3).permute(0, 3, 1, 2) + return pos + + def __repr__(self, _repr_indent=4): + head = "Positional encoding " + self.__class__.__name__ + body = [ + "num_pos_feats: {}".format(self.num_pos_feats), + "temperature: {}".format(self.temperature), + "normalize: {}".format(self.normalize), + "scale: {}".format(self.scale), + ] + # _repr_indent = 4 + lines = [head] + [" " * _repr_indent + line for line in body] + return "\n".join(lines) + + +class LearnedSinusoidalPosEmb(nn.Module): + def __init__(self, dim): + super().__init__() + assert (dim % 2) == 0 + half_dim = dim // 2 + self.weights = nn.Parameter(torch.randn(half_dim)) + + def forward(self, x): + x = rearrange(x, "b -> b 1") + freqs = x * rearrange(self.weights, "d -> 1 d") * 2 * pi + fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1) + fouriered = torch.cat((x, fouriered), dim=-1) + return fouriered + + +def broadcat(tensors, dim=-1): + num_tensors = len(tensors) + shape_lens = set(list(map(lambda t: len(t.shape), tensors))) + assert len(shape_lens) == 1, "tensors must all have the same number of dimensions" + shape_len = list(shape_lens)[0] + dim = (dim + shape_len) if dim < 0 else dim + dims = list(zip(*map(lambda t: list(t.shape), tensors))) + expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim] + assert all( + [*map(lambda t: len(set(t[1])) <= 2, expandable_dims)] + ), "invalid dimensions for broadcastable concatentation" + max_dims = list(map(lambda t: (t[0], max(t[1])), expandable_dims)) + expanded_dims = list(map(lambda t: (t[0], (t[1],) * num_tensors), max_dims)) + expanded_dims.insert(dim, (dim, dims[dim])) + expandable_shapes = list(zip(*map(lambda t: t[1], expanded_dims))) + tensors = list(map(lambda t: t[0].expand(*t[1]), zip(tensors, expandable_shapes))) + return torch.cat(tensors, dim=dim) + + +def rotate_half(x): + x = rearrange(x, "... (d r) -> ... d r", r=2) + x1, x2 = x.unbind(dim=-1) + x = torch.stack((-x2, x1), dim=-1) + return rearrange(x, "... d r -> ... (d r)") + + +class VisionRotaryEmbedding(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs_h = torch.einsum("..., f -> ... f", t, freqs) + freqs_h = repeat(freqs_h, "... n -> ... (n r)", r=2) + + freqs_w = torch.einsum("..., f -> ... f", t, freqs) + freqs_w = repeat(freqs_w, "... n -> ... (n r)", r=2) + + freqs = broadcat((freqs_h[:, None, :], freqs_w[None, :, :]), dim=-1) + + self.register_buffer("freqs_cos", freqs.cos()) + self.register_buffer("freqs_sin", freqs.sin()) + + print("======== shape of rope freq", self.freqs_cos.shape, "========") + + def forward(self, t, start_index=0): + rot_dim = self.freqs_cos.shape[-1] + end_index = start_index + rot_dim + assert ( + rot_dim <= t.shape[-1] + ), f"feature dimension {t.shape[-1]} is not of sufficient size to rotate in all the positions {rot_dim}" + t_left, t, t_right = ( + t[..., :start_index], + t[..., start_index:end_index], + t[..., end_index:], + ) + t = (t * self.freqs_cos) + (rotate_half(t) * self.freqs_sin) + return torch.cat((t_left, t, t_right), dim=-1) + + +class VisionRotaryEmbeddingFast(nn.Module): + def __init__( + self, + dim, + pt_seq_len, + ft_seq_len=None, + custom_freqs=None, + freqs_for="lang", + theta=10000, + max_freq=10, + num_freqs=1, + ): + super().__init__() + if custom_freqs: + freqs = custom_freqs + elif freqs_for == "lang": + freqs = 1.0 / ( + theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim) + ) + elif freqs_for == "pixel": + freqs = torch.linspace(1.0, max_freq / 2, dim // 2) * pi + elif freqs_for == "constant": + freqs = torch.ones(num_freqs).float() + else: + raise ValueError(f"unknown modality {freqs_for}") + + if ft_seq_len is None: + ft_seq_len = pt_seq_len + t = torch.arange(ft_seq_len) / ft_seq_len * pt_seq_len + + freqs = torch.einsum("..., f -> ... f", t, freqs) + freqs = repeat(freqs, "... n -> ... (n r)", r=2) + freqs = broadcat((freqs[:, None, :], freqs[None, :, :]), dim=-1) + + freqs_cos = freqs.cos().view(-1, freqs.shape[-1]) + freqs_sin = freqs.sin().view(-1, freqs.shape[-1]) + + self.register_buffer("freqs_cos", freqs_cos) + self.register_buffer("freqs_sin", freqs_sin) + + def forward(self, t): + return t * self.freqs_cos + rotate_half(t) * self.freqs_sin + + +from math import log2 + + +def generate_fourier_features( + x: torch.Tensor, + dim: int = 512, + max_freq: int = 64, + use_cos: bool = False, + use_log: bool = False, + cat_orig: bool = False, +): + x_orig = x + device, dtype, input_dim = x.device, x.dtype, x.shape[-1] + num_bands = dim // (2 * input_dim) if use_cos else dim // input_dim + + if use_log: + scales = 2.0 ** torch.linspace( + 0.0, log2(max_freq), steps=num_bands, device=device, dtype=dtype + ) + else: + scales = torch.linspace( + 1.0, max_freq / 2, num_bands, device=device, dtype=dtype + ) + + x = x.unsqueeze(-1) + scales = scales[(*((None,) * (len(x.shape) - 1)), Ellipsis)] + + x = x * scales * pi + x = torch.cat( + ( + [x.sin(), x.cos()] + if use_cos + else [ + x.sin(), + ] + ), + dim=-1, + ) + x = x.flatten(-2) + if cat_orig: + return torch.cat((x, x_orig), dim=-1) + return x + + +# from PIL import Image +# from unik3d.utils import image_grid, colorize +# if __name__ == "__main__": +# H, W = 512, 512 +# resolution = 128 +# mesh = torch.meshgrid(torch.linspace(-1, 1, H), torch.linspace(-1, 1, W)) +# mesh = torch.stack(mesh, dim=0).unsqueeze(0) +# mesh = mesh.view(1, 2, -1).permute(0, 2, 1) + +# features = generate_fourier_features(mesh, dim=32, max_freq=resolution, use_log=True) +# channels = features.shape[-1] +# print(features.shape) + +# features = features[0].view(H, W, channels).permute(2, 0, 1).numpy() +# Image.fromarray(image_grid([colorize(1+x, 0.0, 2.0, "viridis") for x in features], rows=8, cols=4)).save(f"tmp_{resolution}.png") diff --git a/unik3d/utils/sht.py b/unik3d/utils/sht.py new file mode 100644 index 0000000000000000000000000000000000000000..ba8d0cf5708ba4b87397158ebb5181e8911c1f62 --- /dev/null +++ b/unik3d/utils/sht.py @@ -0,0 +1,1639 @@ +"""Real spherical harmonics in Cartesian form for PyTorch. + +This is an autogenerated file. See +https://github.com/cheind/torch-spherical-harmonics +for more information. +""" + +import torch + + +def rsh_cart_0(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 0. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,1) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + ], + -1, + ) + + +def rsh_cart_1(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 1. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,4) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + ], + -1, + ) + + +def rsh_cart_2(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 2. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,9) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + ], + -1, + ) + + +@torch.autocast(device_type="cuda", enabled=True, dtype=torch.float32) +def rsh_cart_3(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 3. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,16) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + ], + -1, + ) + + +def rsh_cart_4(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 4. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,25) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + ], + -1, + ) + + +def rsh_cart_5(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 5. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,36) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + ], + -1, + ) + + +def rsh_cart_6(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 6. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,49) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + ], + -1, + ) + + +def rsh_cart_7(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 7. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,64) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + z4 = z2**2 + + return torch.stack( + [ + xyz.new_tensor(0.282094791773878).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + ], + -1, + ) + + +# @torch.jit.script +def rsh_cart_8(xyz: torch.Tensor): + """Computes all real spherical harmonics up to degree 8. + + This is an autogenerated method. See + https://github.com/cheind/torch-spherical-harmonics + for more information. + + Params: + xyz: (N,...,3) tensor of points on the unit sphere + + Returns: + rsh: (N,...,81) real spherical harmonics + projections of input. Ynm is found at index + `n*(n+1) + m`, with `0 <= n <= degree` and + `-n <= m <= n`. + """ + x = xyz[..., 0] + y = xyz[..., 1] + z = xyz[..., 2] + + x2 = x**2 + y2 = y**2 + z2 = z**2 + xy = x * y + xz = x * z + yz = y * z + x4 = x2**2 + y4 = y2**2 + # z4 = z2**2 + return torch.stack( + [ + 0.282094791773878 * torch.ones(1, device=xyz.device).expand(xyz.shape[:-1]), + -0.48860251190292 * y, + 0.48860251190292 * z, + -0.48860251190292 * x, + 1.09254843059208 * xy, + -1.09254843059208 * yz, + 0.94617469575756 * z2 - 0.31539156525252, + -1.09254843059208 * xz, + 0.54627421529604 * x2 - 0.54627421529604 * y2, + -0.590043589926644 * y * (3.0 * x2 - y2), + 2.89061144264055 * xy * z, + 0.304697199642977 * y * (1.5 - 7.5 * z2), + 1.24392110863372 * z * (1.5 * z2 - 0.5) - 0.497568443453487 * z, + 0.304697199642977 * x * (1.5 - 7.5 * z2), + 1.44530572132028 * z * (x2 - y2), + -0.590043589926644 * x * (x2 - 3.0 * y2), + 2.5033429417967 * xy * (x2 - y2), + -1.77013076977993 * yz * (3.0 * x2 - y2), + 0.126156626101008 * xy * (52.5 * z2 - 7.5), + 0.267618617422916 * y * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 1.48099765681286 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 0.952069922236839 * z2 + + 0.317356640745613, + 0.267618617422916 * x * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z), + 0.063078313050504 * (x2 - y2) * (52.5 * z2 - 7.5), + -1.77013076977993 * xz * (x2 - 3.0 * y2), + -3.75501441269506 * x2 * y2 + + 0.625835735449176 * x4 + + 0.625835735449176 * y4, + -0.65638205684017 * y * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 8.30264925952416 * xy * z * (x2 - y2), + 0.00931882475114763 * y * (52.5 - 472.5 * z2) * (3.0 * x2 - y2), + 0.0913054625709205 * xy * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.241571547304372 + * y + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + -1.24747010616985 * z * (1.5 * z2 - 0.5) + + 1.6840846433293 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.498988042467941 * z, + 0.241571547304372 + * x + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ), + 0.0456527312854602 * (x2 - y2) * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z), + 0.00931882475114763 * x * (52.5 - 472.5 * z2) * (x2 - 3.0 * y2), + 2.07566231488104 * z * (-6.0 * x2 * y2 + x4 + y4), + -0.65638205684017 * x * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 4.09910463115149 * x**4 * xy + - 13.6636821038383 * xy**3 + + 4.09910463115149 * xy * y**4, + -2.36661916223175 * yz * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00427144889505798 * xy * (x2 - y2) * (5197.5 * z2 - 472.5), + 0.00584892228263444 + * y + * (3.0 * x2 - y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0701870673916132 + * xy + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.221950995245231 + * y + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + -1.48328138624466 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.86469659985043 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.953538034014426 * z2 + - 0.317846011338142, + 0.221950995245231 + * x + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ), + 0.0350935336958066 + * (x2 - y2) + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ), + 0.00584892228263444 + * x + * (x2 - 3.0 * y2) + * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z), + 0.0010678622237645 * (5197.5 * z2 - 472.5) * (-6.0 * x2 * y2 + x4 + y4), + -2.36661916223175 * xz * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 0.683184105191914 * x2**3 + + 10.2477615778787 * x2 * y4 + - 10.2477615778787 * x4 * y2 + - 0.683184105191914 * y2**3, + -0.707162732524596 + * y + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 2.6459606618019 * z * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 9.98394571852353e-5 + * y + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00239614697244565 + * xy + * (x2 - y2) + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z), + 0.00397356022507413 + * y + * (3.0 * x2 - y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.0561946276120613 + * xy + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.206472245902897 + * y + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 1.24862677781952 * z * (1.5 * z2 - 0.5) + - 1.68564615005635 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 2.02901851395672 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.499450711127808 * z, + 0.206472245902897 + * x + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ), + 0.0280973138060306 + * (x2 - y2) + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ), + 0.00397356022507413 + * x + * (x2 - 3.0 * y2) + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ), + 0.000599036743111412 + * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + * (-6.0 * x2 * y2 + x4 + y4), + 9.98394571852353e-5 + * x + * (5197.5 - 67567.5 * z2) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 2.6459606618019 * z * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -0.707162732524596 + * x + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + 5.83141328139864 * xy * (x2**3 + 7.0 * x2 * y4 - 7.0 * x4 * y2 - y2**3), + -2.91570664069932 + * yz + * (7.0 * x2**3 + 21.0 * x2 * y4 - 35.0 * x4 * y2 - y2**3), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (6.0 * x**4 * xy - 20.0 * xy**3 + 6.0 * xy * y**4), + 5.10587282657803e-5 + * y + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + 5.0 * x4 + y4), + 0.00147275890257803 + * xy + * (x2 - y2) + * ( + 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 0.0028519853513317 + * y + * (3.0 * x2 - y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.0463392770473559 + * xy + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.193851103820053 + * y + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 1.48417251362228 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.86581687426801 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 2.1808249179756 + * z + * ( + 1.14285714285714 * z * (1.5 * z2 - 0.5) + - 1.54285714285714 + * z + * ( + 1.75 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + - 1.125 * z2 + + 0.375 + ) + + 1.85714285714286 + * z + * ( + -1.45833333333333 + * z + * (1.66666666666667 * z * (1.5 * z2 - 0.5) - 0.666666666666667 * z) + + 1.83333333333333 + * z + * ( + -1.33333333333333 * z * (1.5 * z2 - 0.5) + + 1.8 + * z + * ( + 1.75 + * z + * ( + 1.66666666666667 * z * (1.5 * z2 - 0.5) + - 0.666666666666667 * z + ) + - 1.125 * z2 + + 0.375 + ) + + 0.533333333333333 * z + ) + + 0.9375 * z2 + - 0.3125 + ) + - 0.457142857142857 * z + ) + - 0.954110901614325 * z2 + + 0.318036967204775, + 0.193851103820053 + * x + * ( + 3.2 * z * (1.5 - 7.5 * z2) + - 2.51428571428571 + * z + * ( + 2.25 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + + 2.14285714285714 + * z + * ( + -2.625 * z * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 2.16666666666667 + * z + * ( + -2.8 * z * (1.5 - 7.5 * z2) + + 2.2 + * z + * ( + 2.25 + * z + * (2.33333333333333 * z * (1.5 - 7.5 * z2) + 4.0 * z) + + 9.375 * z2 + - 1.875 + ) + - 4.8 * z + ) + - 10.9375 * z2 + + 2.1875 + ) + + 5.48571428571429 * z + ), + 0.0231696385236779 + * (x2 - y2) + * ( + -4.125 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + + 2.5 + * z + * ( + -4.8 * z * (52.5 * z2 - 7.5) + + 2.6 + * z + * ( + 2.75 * z * (3.0 * z * (52.5 * z2 - 7.5) - 30.0 * z) + - 91.875 * z2 + + 13.125 + ) + + 48.0 * z + ) + + 137.8125 * z2 + - 19.6875 + ), + 0.0028519853513317 + * x + * (x2 - 3.0 * y2) + * ( + -7.33333333333333 * z * (52.5 - 472.5 * z2) + + 3.0 + * z + * ( + 3.25 * z * (3.66666666666667 * z * (52.5 - 472.5 * z2) + 280.0 * z) + + 1063.125 * z2 + - 118.125 + ) + - 560.0 * z + ), + 0.000368189725644507 + * (-6.0 * x2 * y2 + x4 + y4) + * ( + 3.75 * z * (4.33333333333333 * z * (5197.5 * z2 - 472.5) - 3150.0 * z) + - 14293.125 * z2 + + 1299.375 + ), + 5.10587282657803e-5 + * x + * (5.0 * z * (5197.5 - 67567.5 * z2) + 41580.0 * z) + * (-10.0 * x2 * y2 + x4 + 5.0 * y4), + 7.87853281621404e-6 + * (1013512.5 * z2 - 67567.5) + * (x2**3 + 15.0 * x2 * y4 - 15.0 * x4 * y2 - y2**3), + -2.91570664069932 + * xz + * (x2**3 + 35.0 * x2 * y4 - 21.0 * x4 * y2 - 7.0 * y2**3), + -20.4099464848952 * x2**3 * y2 + - 20.4099464848952 * x2 * y2**3 + + 0.72892666017483 * x4**2 + + 51.0248662122381 * x4 * y4 + + 0.72892666017483 * y4**2, + ], + -1, + ) + + +__all__ = [ + "rsh_cart_0", + "rsh_cart_1", + "rsh_cart_2", + "rsh_cart_3", + "rsh_cart_4", + "rsh_cart_5", + "rsh_cart_6", + "rsh_cart_7", + "rsh_cart_8", +] + + +from typing import Optional + +import torch + + +class SphHarm(torch.nn.Module): + def __init__(self, m, n, dtype=torch.float32) -> None: + super().__init__() + self.dtype = dtype + m = torch.tensor(list(range(-m + 1, m))) + n = torch.tensor(list(range(n))) + self.is_normalized = False + vals = torch.cartesian_prod(m, n).T + vals = vals[:, vals[0] <= vals[1]] + m, n = vals.unbind(0) + + self.register_buffer("m", tensor=m) + self.register_buffer("n", tensor=n) + self.register_buffer("l_max", tensor=torch.max(self.n)) + + f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d = self._init_legendre() + self.register_buffer("f_a", tensor=f_a) + self.register_buffer("f_b", tensor=f_b) + self.register_buffer("d0_mask_3d", tensor=d0_mask_3d) + self.register_buffer("d1_mask_3d", tensor=d1_mask_3d) + self.register_buffer("initial_value", tensor=initial_value) + + @property + def device(self): + return next(self.buffers()).device + + def forward(self, points: torch.Tensor) -> torch.Tensor: + """Computes the spherical harmonics.""" + # Y_l^m = (-1) ^ m c_l^m P_l^m(cos(theta)) exp(i m phi) + B, N, D = points.shape + dtype = points.dtype + theta, phi = points.view(-1, D).to(self.dtype).unbind(-1) + cos_colatitude = torch.cos(phi) + legendre = self._gen_associated_legendre(cos_colatitude) + vals = torch.stack([self.m.abs(), self.n], dim=0) + vals = torch.cat( + [ + vals.repeat(1, theta.shape[0]), + torch.arange(theta.shape[0], device=theta.device) + .unsqueeze(0) + .repeat_interleave(vals.shape[1], dim=1), + ], + dim=0, + ) + legendre_vals = legendre[vals[0], vals[1], vals[2]] + legendre_vals = legendre_vals.reshape(-1, theta.shape[0]) + angle = torch.outer(self.m.abs(), theta) + vandermonde = torch.complex(torch.cos(angle), torch.sin(angle)) + harmonics = torch.complex( + legendre_vals * torch.real(vandermonde), + legendre_vals * torch.imag(vandermonde), + ) + + # Negative order. + m = self.m.unsqueeze(-1) + harmonics = torch.where( + m < 0, (-1.0) ** m.abs() * torch.conj(harmonics), harmonics + ) + harmonics = harmonics.permute(1, 0).reshape(B, N, -1).to(dtype) + return harmonics + + def _gen_recurrence_mask(self) -> tuple[torch.Tensor, torch.Tensor]: + """Generates mask for recurrence relation on the remaining entries. + + The remaining entries are with respect to the diagonal and offdiagonal + entries. + + Args: + l_max: see `gen_normalized_legendre`. + Returns: + torch.Tensors representing the mask used by the recurrence relations. + """ + + # Computes all coefficients. + m_mat, l_mat = torch.meshgrid( + torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), + torch.arange(0, self.l_max + 1, device=self.device, dtype=self.dtype), + indexing="ij", + ) + if self.is_normalized: + c0 = l_mat * l_mat + c1 = m_mat * m_mat + c2 = 2.0 * l_mat + c3 = (l_mat - 1.0) * (l_mat - 1.0) + d0 = torch.sqrt((4.0 * c0 - 1.0) / (c0 - c1)) + d1 = torch.sqrt(((c2 + 1.0) * (c3 - c1)) / ((c2 - 3.0) * (c0 - c1))) + else: + d0 = (2.0 * l_mat - 1.0) / (l_mat - m_mat) + d1 = (l_mat + m_mat - 1.0) / (l_mat - m_mat) + + d0_mask_indices = torch.triu_indices(self.l_max + 1, 1) + d1_mask_indices = torch.triu_indices(self.l_max + 1, 2) + + d_zeros = torch.zeros( + (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device + ) + d_zeros[d0_mask_indices] = d0[d0_mask_indices] + d0_mask = d_zeros + + d_zeros = torch.zeros( + (self.l_max + 1, self.l_max + 1), dtype=self.dtype, device=self.device + ) + d_zeros[d1_mask_indices] = d1[d1_mask_indices] + d1_mask = d_zeros + + # Creates a 3D mask that contains 1s on the diagonal plane and 0s elsewhere. + i = torch.arange(self.l_max + 1, device=self.device)[:, None, None] + j = torch.arange(self.l_max + 1, device=self.device)[None, :, None] + k = torch.arange(self.l_max + 1, device=self.device)[None, None, :] + mask = (i + j - k == 0).to(self.dtype) + d0_mask_3d = torch.einsum("jk,ijk->ijk", d0_mask, mask) + d1_mask_3d = torch.einsum("jk,ijk->ijk", d1_mask, mask) + return (d0_mask_3d, d1_mask_3d) + + def _recursive(self, i: int, p_val: torch.Tensor, x: torch.Tensor) -> torch.Tensor: + coeff_0 = self.d0_mask_3d[i] + coeff_1 = self.d1_mask_3d[i] + h = torch.einsum( + "ij,ijk->ijk", + coeff_0, + torch.einsum("ijk,k->ijk", torch.roll(p_val, shifts=1, dims=1), x), + ) - torch.einsum("ij,ijk->ijk", coeff_1, torch.roll(p_val, shifts=2, dims=1)) + p_val = p_val + h + return p_val + + def _init_legendre(self): + a_idx = torch.arange(1, self.l_max + 1, dtype=self.dtype, device=self.device) + b_idx = torch.arange(self.l_max, dtype=self.dtype, device=self.device) + if self.is_normalized: + # The initial value p(0,0). + initial_value: torch.Tensor = torch.tensor( + 0.5 / (torch.pi**0.5), device=self.device + ) + f_a = torch.cumprod(-1 * torch.sqrt(1.0 + 0.5 / a_idx), dim=0) + f_b = torch.sqrt(2.0 * b_idx + 3.0) + else: + # The initial value p(0,0). + initial_value = torch.tensor(1.0, device=self.device) + f_a = torch.cumprod(1.0 - 2.0 * a_idx, dim=0) + f_b = 2.0 * b_idx + 1.0 + + d0_mask_3d, d1_mask_3d = self._gen_recurrence_mask() + return f_a, f_b, initial_value, d0_mask_3d, d1_mask_3d + + def _gen_associated_legendre(self, x: torch.Tensor) -> torch.Tensor: + r"""Computes associated Legendre functions (ALFs) of the first kind. + + The ALFs of the first kind are used in spherical harmonics. The spherical + harmonic of degree `l` and order `m` can be written as + `Y_l^m(θ, φ) = N_l^m * P_l^m(cos(θ)) * exp(i m φ)`, where `N_l^m` is the + normalization factor and θ and φ are the colatitude and longitude, + repectively. `N_l^m` is chosen in the way that the spherical harmonics form + a set of orthonormal basis function of L^2(S^2). For the computational + efficiency of spherical harmonics transform, the normalization factor is + used in the computation of the ALFs. In addition, normalizing `P_l^m` + avoids overflow/underflow and achieves better numerical stability. Three + recurrence relations are used in the computation. + + Args: + l_max: The maximum degree of the associated Legendre function. Both the + degrees and orders are `[0, 1, 2, ..., l_max]`. + x: A vector of type `float32`, `float64` containing the sampled points in + spherical coordinates, at which the ALFs are computed; `x` is essentially + `cos(θ)`. For the numerical integration used by the spherical harmonics + transforms, `x` contains the quadrature points in the interval of + `[-1, 1]`. There are several approaches to provide the quadrature points: + Gauss-Legendre method (`scipy.special.roots_legendre`), Gauss-Chebyshev + method (`scipy.special.roots_chebyu`), and Driscoll & Healy + method (Driscoll, James R., and Dennis M. Healy. "Computing Fourier + transforms and convolutions on the 2-sphere." Advances in applied + mathematics 15, no. 2 (1994): 202-250.). The Gauss-Legendre quadrature + points are nearly equal-spaced along θ and provide exact discrete + orthogonality, (P^m)^T W P_m = I, where `T` represents the transpose + operation, `W` is a diagonal matrix containing the quadrature weights, + and `I` is the identity matrix. The Gauss-Chebyshev points are equally + spaced, which only provide approximate discrete orthogonality. The + Driscoll & Healy qudarture points are equally spaced and provide the + exact discrete orthogonality. The number of sampling points is required to + be twice as the number of frequency points (modes) in the Driscoll & Healy + approach, which enables FFT and achieves a fast spherical harmonics + transform. + is_normalized: True if the associated Legendre functions are normalized. + With normalization, `N_l^m` is applied such that the spherical harmonics + form a set of orthonormal basis functions of L^2(S^2). + + Returns: + The 3D array of shape `(l_max + 1, l_max + 1, len(x))` containing the values + of the ALFs at `x`; the dimensions in the sequence of order, degree, and + evalution points. + """ + p = torch.zeros( + (self.l_max + 1, self.l_max + 1, x.shape[0]), dtype=x.dtype, device=x.device + ) + p[0, 0] = self.initial_value + + # Compute the diagonal entries p(l,l) with recurrence. + y = torch.cumprod( + torch.broadcast_to(torch.sqrt(1.0 - x * x), (self.l_max, x.shape[0])), dim=0 + ) + p_diag = self.initial_value * torch.einsum("i,ij->ij", self.f_a, y) + # torch.diag_indices(l_max + 1) + diag_indices = torch.stack( + [torch.arange(0, self.l_max + 1, device=x.device)] * 2, dim=0 + ) + p[(diag_indices[0][1:], diag_indices[1][1:])] = p_diag + + diag_indices = torch.stack( + [torch.arange(0, self.l_max, device=x.device)] * 2, dim=0 + ) + + # Compute the off-diagonal entries with recurrence. + p_offdiag = torch.einsum( + "ij,ij->ij", + torch.einsum("i,j->ij", self.f_b, x), + p[(diag_indices[0], diag_indices[1])], + ) # p[torch.diag_indices(l_max)]) + p[(diag_indices[0][: self.l_max], diag_indices[1][: self.l_max] + 1)] = ( + p_offdiag + ) + + # Compute the remaining entries with recurrence. + if self.l_max > 1: + for i in range(2, self.l_max + 1): + p = self._recursive(i, p, x) + return p diff --git a/unik3d/utils/validation.py b/unik3d/utils/validation.py new file mode 100644 index 0000000000000000000000000000000000000000..972ff00f58a12571197f8d133ea3b17df0adb12d --- /dev/null +++ b/unik3d/utils/validation.py @@ -0,0 +1,329 @@ +import json +import os +from collections import defaultdict +from typing import Any, Dict, List, Optional + +import numpy as np +import torch +import torch.distributed as dist +import torch.utils.data.distributed +import wandb +from PIL import Image +from torch.nn import functional as F +from torch.utils.data import DataLoader +from tqdm import tqdm + +from unik3d.utils.distributed import barrier, get_world_size, is_main_process +from unik3d.utils.misc import remove_leading_dim, remove_padding, ssi_helper +from unik3d.utils.visualization import colorize, image_grid + + +def stack_mixedshape_numpy(tensor_list, dim=0): + max_rows = max(tensor.shape[0] for tensor in tensor_list) + max_columns = max(tensor.shape[1] for tensor in tensor_list) + + padded_tensors = [] + for tensor in tensor_list: + rows, columns, *_ = tensor.shape + pad_rows = max_rows - rows + pad_columns = max_columns - columns + + padded_tensor = np.pad( + tensor, ((0, pad_rows), (0, pad_columns), (0, 0)), mode="constant" + ) + padded_tensors.append(padded_tensor) + + return np.stack(padded_tensors, axis=dim) + + +def original_image(batch): + paddings = [ + torch.tensor(pads) + for img_meta in batch["img_metas"] + for pads in img_meta.get("paddings", [[0] * 4]) + ] + paddings = torch.stack(paddings).to(batch["data"]["image"].device)[ + ..., [0, 2, 1, 3] + ] # lrtb + + T, _, H, W = batch["data"]["depth"].shape + batch["data"]["image"] = F.interpolate( + batch["data"]["image"], + (H + paddings[0][2] + paddings[0][3], W + paddings[0][1] + paddings[0][2]), + mode="bilinear", + align_corners=False, + antialias=True, + ) + batch["data"]["image"] = remove_padding( + batch["data"]["image"], paddings.repeat(T, 1) + ) + return batch + + +def original_image_inv(batch, preds=None): + paddings = [ + torch.tensor(pads) + for img_meta in batch["img_metas"] + for pads in img_meta.get("padding_size", [[0] * 4]) + ] + T, _, H, W = batch["data"]["depth"].shape + batch["data"]["image"] = remove_padding(batch["data"]["image"], paddings * T) + batch["data"]["image"] = F.interpolate( + batch["data"]["image"], + (H, W), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + if preds is not None: + for key in ["depth"]: + if key in preds: + preds[key] = remove_padding(preds[key], paddings * T) + preds[key] = F.interpolate( + preds[key], + (H, W), + mode="bilinear", + align_corners=False, + antialias=True, + ) + + return batch, preds + + +def aggregate_metrics(metrics_all, exclude_fn=lambda name: False): + aggregate_name = "".join( + [name_ds[:3] for name_ds in metrics_all.keys() if not exclude_fn(name_ds)] + ) + metrics_aggregate = defaultdict(list) + for name_ds, metrics in metrics_all.items(): + if exclude_fn(name_ds): + continue + for metrics_name, metrics_value in metrics.items(): + metrics_aggregate[metrics_name].append(metrics_value) + return { + **{aggregate_name: {k: sum(v) / len(v) for k, v in metrics_aggregate.items()}}, + **metrics_all, + } + + +GROUPS = { + "SFoV": ["KITTI", "NYUv2Depth", "DiodeIndoor", "ETH3D", "IBims"], + "SFoVDi": ["DiodeIndoor_F", "ETH3D_F", "IBims_F"], + "LFoV": ["ADT", "KITTI360", "ScanNetpp_F"], +} + + +def aggregate_metrics_camera(metrics_all): + available_groups = { + k: v for k, v in GROUPS.items() if any([name in metrics_all for name in v]) + } + for group_name, group_datasets in available_groups.items(): + metrics_aggregate = defaultdict(list) + for dataset_name in group_datasets: + if dataset_name not in metrics_all: + print( + f"Dataset {dataset_name} not used for aggregation of {group_name}" + ) + continue + for metrics_name, metrics_value in metrics_all[dataset_name].items(): + metrics_aggregate[metrics_name].append(metrics_value) + metrics_all[group_name] = { + k: sum(v) / len(v) for k, v in metrics_aggregate.items() + } + return metrics_all + + +def log_metrics(metrics_all, step): + for name_ds, metrics in metrics_all.items(): + for metrics_name, metrics_value in metrics.items(): + try: + wandb.log( + {f"Metrics/{name_ds}/{metrics_name}": metrics_value}, step=step + ) + except: + print(f"Metrics/{name_ds}/{metrics_name} {round(metrics_value, 4)}") + + +def log_artifacts(artifacts_all, step, run_id): + for ds_name, artifacts in artifacts_all.items(): + rgbs, gts = artifacts["rgbs"], artifacts["gts"] + logging_imgs = [ + *rgbs, + *gts, + *[ + x + for k, v in artifacts.items() + if ("rgbs" not in k and "gts" not in k) + for x in v + ], + ] + artifacts_grid = image_grid(logging_imgs, len(artifacts), len(rgbs)) + try: + wandb.log({f"{ds_name}_test": [wandb.Image(artifacts_grid)]}, step=step) + except: + print(f"Error while saving artifacts at step {step}") + + +def show(vals, dataset, ssi_depth=False): + output_artifacts, additionals = {}, {} + predictions, gts, errors, images = [], [], [], [] + for v in vals: + image = v["image"][0].unsqueeze(0) + gt = v["depth"][0].unsqueeze(0) + prediction = v["depth_pred"][0].unsqueeze(0) + # Downsample for memory and viz + # if any([x in dataset.__class__.__name__ for x in ["DDAD", "Argoverse", "Waymo", "DrivingStereo"]]): + # gt = F.interpolate(gt, scale_factor=0.5, mode="nearest-exact") + # # Dilate for a better visualization + # gt[gt < 1e-4] = dilate(gt)[gt < 1e-4] + H, W = gt.shape[-2:] + aspect_ratio = H / W + new_W = int((300_000 / aspect_ratio) ** 0.5) + new_H = int(aspect_ratio * new_W) + gt = F.interpolate(gt, (new_H, new_W), mode="nearest-exact") + + # Format predictions and errors for every metrics used + prediction = F.interpolate( + prediction, + gt.shape[-2:], + mode="bilinear", + align_corners=False, + antialias=True, + ) + error = torch.zeros_like(prediction) + error[gt > dataset.min_depth] = ( + 4 + * dataset.max_depth + * torch.abs(gt - prediction)[gt > dataset.min_depth] + / gt[gt > dataset.min_depth] + ) + if ssi_depth: + scale, shift = ssi_helper(gt[gt > 0], prediction[gt > 0]) + prediction = (prediction * scale + shift).clip(0.0, dataset.max_depth) + prediction = colorize( + prediction.squeeze().cpu().detach().numpy(), + vmin=dataset.min_depth, + vmax=dataset.max_depth, + cmap="magma_r", + ) + error = error.clip(0.0, dataset.max_depth).cpu().detach().numpy() + error = colorize(error.squeeze(), vmin=0.001, vmax=1.0, cmap="coolwarm") + errors.append(error) + predictions.append(prediction) + + image = F.interpolate( + image, gt.shape[-2:], mode="bilinear", align_corners=False, antialias=True + ) + image = image.cpu().detach() * dataset.normalization_stats["std"].view( + 1, -1, 1, 1 + ) + dataset.normalization_stats["mean"].view(1, -1, 1, 1) + image = ( + (255 * image) + .clip(0.0, 255.0) + .to(torch.uint8) + .permute(0, 2, 3, 1) + .numpy() + .squeeze() + ) + gt = gt.clip(0.0, dataset.max_depth).cpu().detach().numpy() + gt = colorize( + gt.squeeze(), vmin=dataset.min_depth, vmax=dataset.max_depth, cmap="magma_r" + ) + gts.append(gt) + images.append(image) + + for name, additional in v.get("infos", {}).items(): + if name not in additionals: + additionals[name] = [] + if additional[0].shape[0] == 3: + val = ( + (127.5 * (additional[0] + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + ) + else: + val = colorize( + additional[0].cpu().detach().squeeze().numpy(), + 0.0, + dataset.max_depth, + ) + additionals[name].append(val) + + output_artifacts.update( + { + f"predictions": stack_mixedshape_numpy(predictions), + f"errors": stack_mixedshape_numpy(errors), + "rgbs": stack_mixedshape_numpy(images), + "gts": stack_mixedshape_numpy(gts), + **{k: stack_mixedshape_numpy(v) for k, v in additionals.items()}, + } + ) + return output_artifacts + + +METRIC_B = "F1" +INVERT = True +SSI_VISUALIZATION = True + + +def validate( + model, + test_loaders: Dict[str, DataLoader], + step, + run_id, + context, + idxs=(1, 100, 150, 1000), +): + + metrics_all, predictions_select = {}, {} + world_size = get_world_size() + for name_ds, test_loader in test_loaders.items(): + idxs = [idx % len(test_loader.dataset) for idx in idxs] + ds_show = [] + for i, batch in enumerate(test_loader): + with context: + batch["data"] = { + k: v.to(model.device) for k, v in batch["data"].items() + } + preds = model(batch["data"], batch["img_metas"]) + + if batch["data"]["image"].ndim == 5: + batch["data"] = remove_leading_dim(batch["data"]) + if preds["depth"].ndim == 5: + preds = remove_leading_dim(preds) + batch = original_image(batch) + test_loader.dataset.accumulate_metrics( + inputs=batch["data"], + preds=preds, + keyframe_idx=batch["img_metas"][0].get("keyframe_idx"), + ) + + # for prediction images logging + if i * world_size in idxs: + ii = (len(preds["depth"]) + 1) // 2 - 1 + slice_ = slice(ii, ii + 1) + batch["data"] = {k: v[slice_] for k, v in batch["data"].items()} + preds["depth"] = preds["depth"][slice_] + ds_show.append({**batch["data"], **{"depth_pred": preds["depth"]}}) + + barrier() + + metrics_all[name_ds] = test_loader.dataset.get_evaluation() + predictions_select[name_ds] = show( + ds_show, test_loader.dataset, ssi_depth=SSI_VISUALIZATION + ) + + barrier() + if is_main_process(): + log_artifacts(artifacts_all=predictions_select, step=step, run_id=run_id) + metrics_all = aggregate_metrics( + metrics_all, exclude_fn=lambda name: "mono" in name + ) + metrics_all = aggregate_metrics_camera(metrics_all) + log_metrics(metrics_all=metrics_all, step=step) + return metrics_all diff --git a/unik3d/utils/visualization.py b/unik3d/utils/visualization.py new file mode 100644 index 0000000000000000000000000000000000000000..89b04ea0b29255d3169963a1d381edaea19a6618 --- /dev/null +++ b/unik3d/utils/visualization.py @@ -0,0 +1,251 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import torch +import torch.nn.functional as F +import wandb +from PIL import Image + +from unik3d.utils.distributed import get_rank +from unik3d.utils.misc import ssi_helper + + +def colorize( + value: np.ndarray, vmin: float = None, vmax: float = None, cmap: str = "magma_r" +): + # if already RGB, do nothing + if value.ndim > 2: + if value.shape[-1] > 1: + return value + value = value[..., 0] + invalid_mask = value < 0.0001 + # normalize + vmin = value.min() if vmin is None else vmin + vmax = value.max() if vmax is None else vmax + value = (value - vmin) / (vmax - vmin) # vmin..vmax + + # set color + cmapper = plt.get_cmap(cmap) + value = cmapper(value, bytes=True) # (nxmx4) + value[invalid_mask] = 0 + img = value[..., :3] + return img + + +def image_grid(imgs: list[np.ndarray], rows: int, cols: int) -> np.ndarray: + if not len(imgs): + return None + assert len(imgs) == rows * cols + h, w = imgs[0].shape[:2] + grid = Image.new("RGB", size=(cols * w, rows * h)) + + for i, img in enumerate(imgs): + grid.paste( + Image.fromarray(img.astype(np.uint8)).resize( + (w, h), resample=Image.BILINEAR + ), + box=(i % cols * w, i // cols * h), + ) + + return np.array(grid) + + +def get_pointcloud_from_rgbd( + image: np.array, + depth: np.array, + mask: np.ndarray, + intrinsic_matrix: np.array, + extrinsic_matrix: np.array = None, +): + depth = np.array(depth).squeeze() + mask = np.array(mask).squeeze() + # Mask the depth array + masked_depth = np.ma.masked_where(mask == False, depth) + # masked_depth = np.ma.masked_greater(masked_depth, 8000) + # Create idx array + idxs = np.indices(masked_depth.shape) + u_idxs = idxs[1] + v_idxs = idxs[0] + # Get only non-masked depth and idxs + z = masked_depth[~masked_depth.mask] + compressed_u_idxs = u_idxs[~masked_depth.mask] + compressed_v_idxs = v_idxs[~masked_depth.mask] + image = np.stack( + [image[..., i][~masked_depth.mask] for i in range(image.shape[-1])], axis=-1 + ) + + # Calculate local position of each point + # Apply vectorized math to depth using compressed arrays + cx = intrinsic_matrix[0, 2] + fx = intrinsic_matrix[0, 0] + x = (compressed_u_idxs - cx) * z / fx + cy = intrinsic_matrix[1, 2] + fy = intrinsic_matrix[1, 1] + # Flip y as we want +y pointing up not down + y = -((compressed_v_idxs - cy) * z / fy) + + # # Apply camera_matrix to pointcloud as to get the pointcloud in world coords + # if extrinsic_matrix is not None: + # # Calculate camera pose from extrinsic matrix + # camera_matrix = np.linalg.inv(extrinsic_matrix) + # # Create homogenous array of vectors by adding 4th entry of 1 + # # At the same time flip z as for eye space the camera is looking down the -z axis + # w = np.ones(z.shape) + # x_y_z_eye_hom = np.vstack((x, y, -z, w)) + # # Transform the points from eye space to world space + # x_y_z_world = np.dot(camera_matrix, x_y_z_eye_hom)[:3] + # return x_y_z_world.T + # else: + x_y_z_local = np.stack((x, y, z), axis=-1) + return np.concatenate([x_y_z_local, image], axis=-1) + + +def save_file_ply(xyz, rgb, pc_file): + if rgb.max() < 1.001: + rgb = rgb * 255.0 + rgb = rgb.astype(np.uint8) + # print(rgb) + with open(pc_file, "w") as f: + # headers + f.writelines( + [ + "ply\n" "format ascii 1.0\n", + "element vertex {}\n".format(xyz.shape[0]), + "property float x\n", + "property float y\n", + "property float z\n", + "property uchar red\n", + "property uchar green\n", + "property uchar blue\n", + "end_header\n", + ] + ) + + for i in range(xyz.shape[0]): + str_v = "{:10.6f} {:10.6f} {:10.6f} {:d} {:d} {:d}\n".format( + xyz[i, 0], xyz[i, 1], xyz[i, 2], rgb[i, 0], rgb[i, 1], rgb[i, 2] + ) + f.write(str_v) + + +# really awful fct... FIXME + + +def train_artifacts(rgbs, gts, preds, infos={}): + # interpolate to same shape, will be distorted! FIXME TODO + shape = rgbs[0].shape[-2:] + gts = F.interpolate(gts, shape, mode="nearest-exact") + + rgbs = [ + (127.5 * (rgb + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + for rgb in rgbs + ] + new_gts, new_preds = [], [] + num_additional, additionals = 0, [] + + if len(gts) > 0: + for i, gt in enumerate(gts): + # scale, shift = ssi_helper(gts[i][gts[i]>0].cpu().detach(), preds[i][gts[i]>0].cpu().detach()) + scale, shift = 1, 0 + up = torch.quantile( + torch.log(1 + gts[i][gts[i] > 0]).float().cpu().detach(), 0.98 + ).item() + down = torch.quantile( + torch.log(1 + gts[i][gts[i] > 0]).float().cpu().detach(), 0.02 + ).item() + gt = gts[i].cpu().detach().squeeze().numpy() + pred = (preds[i].cpu().detach() * scale + shift).squeeze().numpy() + new_gts.append( + colorize(np.log(1.0 + gt), vmin=down, vmax=up) + ) # , vmin=vmin, vmax=vmax)) + new_preds.append( + colorize(np.log(1.0 + pred), vmin=down, vmax=up) + ) # , vmin=vmin, vmax=vmax)) + + gts, preds = new_gts, new_preds + else: + preds = [ + colorize(pred.cpu().detach().squeeze().numpy(), 0.0, 80.0) + for i, pred in enumerate(preds) + ] + + for name, info in infos.items(): + num_additional += 1 + if info.shape[1] == 3: + additionals.extend( + [ + (127.5 * (x + 1)) + .clip(0, 255) + .to(torch.uint8) + .cpu() + .detach() + .permute(1, 2, 0) + .numpy() + for x in info + ] + ) + else: # must be depth! + additionals.extend( + [ + colorize(x.cpu().detach().squeeze().numpy()) + for i, x in enumerate(info) + ] + ) + + num_rows = 2 + int(len(gts) > 0) + num_additional + + artifacts_grid = image_grid( + [*rgbs, *gts, *preds, *additionals], num_rows, len(rgbs) + ) + return artifacts_grid + + +def log_train_artifacts(rgbs, gts, preds, step, infos={}): + artifacts_grid = train_artifacts(rgbs, gts, preds, infos) + try: + wandb.log({f"training": [wandb.Image(artifacts_grid)]}, step=step) + except: + Image.fromarray(artifacts_grid).save( + os.path.join( + os.environ.get("TMPDIR", "/tmp"), + f"{get_rank()}_art_grid{step}.png", + ) + ) + print("Logging training images failed") + + +def plot_quiver(flow, spacing, margin=0, **kwargs): + """Plots less dense quiver field. + + Args: + ax: Matplotlib axis + flow: motion vectors + spacing: space (px) between each arrow in grid + margin: width (px) of enclosing region without arrows + kwargs: quiver kwargs (default: angles="xy", scale_units="xy") + """ + h, w, *_ = flow.shape + + nx = int((w - 2 * margin) / spacing) + ny = int((h - 2 * margin) / spacing) + + x = np.linspace(margin, w - margin - 1, nx, dtype=np.int64) + y = np.linspace(margin, h - margin - 1, ny, dtype=np.int64) + + flow = flow[np.ix_(y, x)] + u = flow[:, :, 0] + v = flow[:, :, 1] + + kwargs = {**dict(angles="xy", scale_units="xy"), **kwargs} + fig, ax = plt.subplots(figsize=(10, 10)) + ax.quiver(x, y, u, v, **kwargs) + + # ax.set_ylim(sorted(ax.get_ylim(), reverse=True)) + return fig, ax