diff --git a/build_docker.sh b/build_docker.sh
index a5aea45e6ff5024b71818dea6f4e7cfb0d0ae6c0..50d857a6f9deadefb85cab7b12442920d1734290 100644
--- a/build_docker.sh
+++ b/build_docker.sh
@@ -1,3 +1,4 @@
 docker build -t image-matching-webui:latest . --no-cache
 docker tag image-matching-webui:latest vincentqin/image-matching-webui:latest
 docker push vincentqin/image-matching-webui:latest
+ 
\ No newline at end of file
diff --git a/hloc/matchers/roma.py b/hloc/matchers/roma.py
index 1f9bcb0edff59453680f9309f9ead6b364d8c8ad..d91fbb8dcc35354c75ad30d2753c8ed85fb82da5 100644
--- a/hloc/matchers/roma.py
+++ b/hloc/matchers/roma.py
@@ -6,7 +6,7 @@ from PIL import Image
 from ..utils.base_model import BaseModel
 from .. import logger
 
-roma_path = Path(__file__).parent / "../../third_party/Roma"
+roma_path = Path(__file__).parent / "../../third_party/RoMa"
 sys.path.append(str(roma_path))
 
 from roma.models.model_zoo.roma_models import roma_model
@@ -63,6 +63,8 @@ class Roma(BaseModel):
             weights=weights,
             dinov2_weights=dinov2_weights,
             device=device,
+            #temp fix issue: https://github.com/Parskatt/RoMa/issues/26
+            amp_dtype=torch.float32,
         )
         logger.info(f"Load Roma model done.")
 
diff --git a/third_party/Roma/.gitignore b/third_party/RoMa/.gitignore
similarity index 100%
rename from third_party/Roma/.gitignore
rename to third_party/RoMa/.gitignore
diff --git a/third_party/RoMa/LICENSE b/third_party/RoMa/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..ca95157052a76debc473afb395bffae0c1329e63
--- /dev/null
+++ b/third_party/RoMa/LICENSE
@@ -0,0 +1,21 @@
+MIT License
+
+Copyright (c) 2023 Johan Edstedt
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in all
+copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
+SOFTWARE.
diff --git a/third_party/RoMa/README.md b/third_party/RoMa/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a3b6484b2a6c19af426b731396c5c91331f99ada
--- /dev/null
+++ b/third_party/RoMa/README.md
@@ -0,0 +1,92 @@
+# 
+<p align="center">
+  <h1 align="center"> <ins>RoMa</ins> đŸ›ïž:<br> Robust Dense Feature Matching <br> ⭐CVPR 2024⭐</h1>
+  <p align="center">
+    <a href="https://scholar.google.com/citations?user=Ul-vMR0AAAAJ">Johan Edstedt</a>
+    ·
+    <a href="https://scholar.google.com/citations?user=HS2WuHkAAAAJ">Qiyu Sun</a>
+    ·
+    <a href="https://scholar.google.com/citations?user=FUE3Wd0AAAAJ">Georg Bökman</a>
+    ·
+    <a href="https://scholar.google.com/citations?user=6WRQpCQAAAAJ">MÄrten WadenbÀck</a>
+    ·
+    <a href="https://scholar.google.com/citations?user=lkWfR08AAAAJ">Michael Felsberg</a>
+  </p>
+  <h2 align="center"><p>
+    <a href="https://arxiv.org/abs/2305.15404" align="center">Paper</a> | 
+    <a href="https://parskatt.github.io/RoMa" align="center">Project Page</a>
+  </p></h2>
+  <div align="center"></div>
+</p>
+<br/>
+<p align="center">
+    <img src="https://github.com/Parskatt/RoMa/assets/22053118/15d8fea7-aa6d-479f-8a93-350d950d006b" alt="example" width=80%>
+    <br>
+    <em>RoMa is the robust dense feature matcher capable of estimating pixel-dense warps and reliable certainties for almost any image pair.</em>
+</p>
+
+## Setup/Install
+In your python environment (tested on Linux python 3.10), run:
+```bash
+pip install -e .
+```
+## Demo / How to Use
+We provide two demos in the [demos folder](demo).
+Here's the gist of it:
+```python
+from roma import roma_outdoor
+roma_model = roma_outdoor(device=device)
+# Match
+warp, certainty = roma_model.match(imA_path, imB_path, device=device)
+# Sample matches for estimation
+matches, certainty = roma_model.sample(warp, certainty)
+# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
+kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+# Find a fundamental matrix (or anything else of interest)
+F, mask = cv2.findFundamentalMat(
+    kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+)
+```
+
+**New**: You can also match arbitrary keypoints with RoMa. A demo for this will be added soon.
+## Settings
+
+### Resolution
+By default RoMa uses an initial resolution of (560,560) which is then upsampled to (864,864). 
+You can change this at construction (see roma_outdoor kwargs).
+You can also change this later, by changing the roma_model.w_resized, roma_model.h_resized, and roma_model.upsample_res.
+
+### Sampling
+roma_model.sample_thresh controls the thresholding used when sampling matches for estimation. In certain cases a lower or higher threshold may improve results.
+
+
+## Reproducing Results
+The experiments in the paper are provided in the [experiments folder](experiments).
+
+### Training
+1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
+2. Run the relevant experiment, e.g.,
+```bash
+torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
+```
+### Testing
+```bash
+python experiments/roma_outdoor.py --only_test --benchmark mega-1500
+```
+## License
+All our code except DINOv2 is MIT license.
+DINOv2 has an Apache 2 license [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE).
+
+## Acknowledgement
+Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
+
+## BibTeX
+If you find our models useful, please consider citing our paper!
+```
+@article{edstedt2024roma,
+title={{RoMa: Robust Dense Feature Matching}},
+author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and WadenbÀck, MÄrten and Felsberg, Michael},
+journal={IEEE Conference on Computer Vision and Pattern Recognition},
+year={2024}
+}
+```
diff --git a/third_party/Roma/assets/sacre_coeur_A.jpg b/third_party/RoMa/assets/sacre_coeur_A.jpg
similarity index 100%
rename from third_party/Roma/assets/sacre_coeur_A.jpg
rename to third_party/RoMa/assets/sacre_coeur_A.jpg
diff --git a/third_party/Roma/assets/sacre_coeur_B.jpg b/third_party/RoMa/assets/sacre_coeur_B.jpg
similarity index 100%
rename from third_party/Roma/assets/sacre_coeur_B.jpg
rename to third_party/RoMa/assets/sacre_coeur_B.jpg
diff --git a/third_party/RoMa/assets/toronto_A.jpg b/third_party/RoMa/assets/toronto_A.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..450622c06c06b5bdcb4b20150ec4b5e8e34f9787
--- /dev/null
+++ b/third_party/RoMa/assets/toronto_A.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:40270c227df93f0f31b55e0f2ff38eb24f47940c4800c83758a74a5dfd7346ec
+size 525339
diff --git a/third_party/RoMa/assets/toronto_B.jpg b/third_party/RoMa/assets/toronto_B.jpg
new file mode 100644
index 0000000000000000000000000000000000000000..6a8c7907bfc9bcd88f9d9deaa6e148e18a764d12
--- /dev/null
+++ b/third_party/RoMa/assets/toronto_B.jpg
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:a2c07550ed87e40fca8c38076eb3a81395d760a88bf0b8615167704107deff2f
+size 286466
diff --git a/third_party/Roma/data/.gitignore b/third_party/RoMa/data/.gitignore
similarity index 100%
rename from third_party/Roma/data/.gitignore
rename to third_party/RoMa/data/.gitignore
diff --git a/third_party/RoMa/demo/demo_3D_effect.py b/third_party/RoMa/demo/demo_3D_effect.py
new file mode 100644
index 0000000000000000000000000000000000000000..5afd6e5ce0fdd32788160e8c24df0b26a27f34dd
--- /dev/null
+++ b/third_party/RoMa/demo/demo_3D_effect.py
@@ -0,0 +1,46 @@
+from PIL import Image
+import torch
+import torch.nn.functional as F
+import numpy as np
+from roma.utils.utils import tensor_to_pil
+
+from roma import roma_outdoor
+
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
+
+
+if __name__ == "__main__":
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/gif/roma_warp_toronto", type=str)
+
+    args, _ = parser.parse_known_args()
+    im1_path = args.im_A_path
+    im2_path = args.im_B_path
+    save_path = args.save_path
+
+    # Create model
+    roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
+    roma_model.symmetric = False
+
+    H, W = roma_model.get_output_resolution()
+
+    im1 = Image.open(im1_path).resize((W, H))
+    im2 = Image.open(im2_path).resize((W, H))
+
+    # Match
+    warp, certainty = roma_model.match(im1_path, im2_path, device=device)
+    # Sampling not needed, but can be done with model.sample(warp, certainty)
+    x1 = (torch.tensor(np.array(im1)) / 255).to(device).permute(2, 0, 1)
+    x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
+
+    coords_A, coords_B = warp[...,:2], warp[...,2:]
+    for i, x in enumerate(np.linspace(0,2*np.pi,200)):
+        t = (1 + np.cos(x))/2
+        interp_warp = (1-t)*coords_A + t*coords_B
+        im2_transfer_rgb = F.grid_sample(
+        x2[None], interp_warp[None], mode="bilinear", align_corners=False
+        )[0]
+        tensor_to_pil(im2_transfer_rgb, unnormalize=False).save(f"{save_path}_{i:03d}.jpg")
\ No newline at end of file
diff --git a/third_party/Roma/demo/demo_fundamental.py b/third_party/RoMa/demo/demo_fundamental.py
similarity index 76%
rename from third_party/Roma/demo/demo_fundamental.py
rename to third_party/RoMa/demo/demo_fundamental.py
index a71fd5532412fb4c65eb109e8e9f83813c11fd85..31618d4b06cd56fdd4be9065fb00b826a19e10f9 100644
--- a/third_party/Roma/demo/demo_fundamental.py
+++ b/third_party/RoMa/demo/demo_fundamental.py
@@ -3,12 +3,11 @@ import torch
 import cv2
 from roma import roma_outdoor
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 if __name__ == "__main__":
     from argparse import ArgumentParser
-
     parser = ArgumentParser()
     parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
     parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
@@ -20,6 +19,7 @@ if __name__ == "__main__":
     # Create model
     roma_model = roma_outdoor(device=device)
 
+
     W_A, H_A = Image.open(im1_path).size
     W_B, H_B = Image.open(im2_path).size
 
@@ -27,12 +27,7 @@ if __name__ == "__main__":
     warp, certainty = roma_model.match(im1_path, im2_path, device=device)
     # Sample matches for estimation
     matches, certainty = roma_model.sample(warp, certainty)
-    kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
+    kpts1, kpts2 = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)    
     F, mask = cv2.findFundamentalMat(
-        kpts1.cpu().numpy(),
-        kpts2.cpu().numpy(),
-        ransacReprojThreshold=0.2,
-        method=cv2.USAC_MAGSAC,
-        confidence=0.999999,
-        maxIters=10000,
-    )
+        kpts1.cpu().numpy(), kpts2.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
+    )
\ No newline at end of file
diff --git a/third_party/Roma/demo/demo_match.py b/third_party/RoMa/demo/demo_match.py
similarity index 56%
rename from third_party/Roma/demo/demo_match.py
rename to third_party/RoMa/demo/demo_match.py
index 69eb07ffb0b480db99252bbb03a9858964e8d5f0..80dfcd252e6665246a1b21cca7c8c64a183fa0e2 100644
--- a/third_party/Roma/demo/demo_match.py
+++ b/third_party/RoMa/demo/demo_match.py
@@ -4,20 +4,17 @@ import torch.nn.functional as F
 import numpy as np
 from roma.utils.utils import tensor_to_pil
 
-from roma import roma_indoor
+from roma import roma_outdoor
 
-device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
 
 
 if __name__ == "__main__":
     from argparse import ArgumentParser
-
     parser = ArgumentParser()
-    parser.add_argument("--im_A_path", default="assets/sacre_coeur_A.jpg", type=str)
-    parser.add_argument("--im_B_path", default="assets/sacre_coeur_B.jpg", type=str)
-    parser.add_argument(
-        "--save_path", default="demo/dkmv3_warp_sacre_coeur.jpg", type=str
-    )
+    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
 
     args, _ = parser.parse_known_args()
     im1_path = args.im_A_path
@@ -25,7 +22,7 @@ if __name__ == "__main__":
     save_path = args.save_path
 
     # Create model
-    roma_model = roma_indoor(device=device)
+    roma_model = roma_outdoor(device=device, coarse_res=560, upsample_res=(864, 1152))
 
     H, W = roma_model.get_output_resolution()
 
@@ -39,12 +36,12 @@ if __name__ == "__main__":
     x2 = (torch.tensor(np.array(im2)) / 255).to(device).permute(2, 0, 1)
 
     im2_transfer_rgb = F.grid_sample(
-        x2[None], warp[:, :W, 2:][None], mode="bilinear", align_corners=False
+    x2[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
     )[0]
     im1_transfer_rgb = F.grid_sample(
-        x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+    x1[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
     )[0]
-    warp_im = torch.cat((im2_transfer_rgb, im1_transfer_rgb), dim=2)
-    white_im = torch.ones((H, 2 * W), device=device)
+    warp_im = torch.cat((im2_transfer_rgb,im1_transfer_rgb),dim=2)
+    white_im = torch.ones((H,2*W),device=device)
     vis_im = certainty * warp_im + (1 - certainty) * white_im
-    tensor_to_pil(vis_im, unnormalize=False).save(save_path)
+    tensor_to_pil(vis_im, unnormalize=False).save(save_path)
\ No newline at end of file
diff --git a/third_party/RoMa/demo/demo_match_opencv_sift.py b/third_party/RoMa/demo/demo_match_opencv_sift.py
new file mode 100644
index 0000000000000000000000000000000000000000..3196fcfaab248f6c4c6247a0afb4db745206aee8
--- /dev/null
+++ b/third_party/RoMa/demo/demo_match_opencv_sift.py
@@ -0,0 +1,43 @@
+from PIL import Image
+import numpy as np
+
+import numpy as np
+import cv2 as cv
+import matplotlib.pyplot as plt
+
+
+
+if __name__ == "__main__":
+    from argparse import ArgumentParser
+    parser = ArgumentParser()
+    parser.add_argument("--im_A_path", default="assets/toronto_A.jpg", type=str)
+    parser.add_argument("--im_B_path", default="assets/toronto_B.jpg", type=str)
+    parser.add_argument("--save_path", default="demo/roma_warp_toronto.jpg", type=str)
+
+    args, _ = parser.parse_known_args()
+    im1_path = args.im_A_path
+    im2_path = args.im_B_path
+    save_path = args.save_path
+
+    img1 = cv.imread(im1_path,cv.IMREAD_GRAYSCALE)          # queryImage
+    img2 = cv.imread(im2_path,cv.IMREAD_GRAYSCALE) # trainImage
+    # Initiate SIFT detector
+    sift = cv.SIFT_create()
+    # find the keypoints and descriptors with SIFT
+    kp1, des1 = sift.detectAndCompute(img1,None)
+    kp2, des2 = sift.detectAndCompute(img2,None)
+    # BFMatcher with default params
+    bf = cv.BFMatcher()
+    matches = bf.knnMatch(des1,des2,k=2)
+    # Apply ratio test
+    good = []
+    for m,n in matches:
+        if m.distance < 0.75*n.distance:
+            good.append([m])
+    # cv.drawMatchesKnn expects list of lists as matches.
+    draw_params = dict(matchColor = (255,0,0), # draw matches in red color
+                   singlePointColor = None,
+                   flags = 2)
+
+    img3 = cv.drawMatchesKnn(img1,kp1,img2,kp2,good,None,**draw_params)
+    Image.fromarray(img3).save("demo/sift_matches.png")
diff --git a/third_party/RoMa/demo/gif/.gitignore b/third_party/RoMa/demo/gif/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c96a04f008ee21e260b28f7701595ed59e2839e3
--- /dev/null
+++ b/third_party/RoMa/demo/gif/.gitignore
@@ -0,0 +1,2 @@
+*
+!.gitignore
\ No newline at end of file
diff --git a/third_party/Roma/pretrained/dinov2_vitl14_pretrain.pth b/third_party/RoMa/pretrained/dinov2_vitl14_pretrain.pth
similarity index 100%
rename from third_party/Roma/pretrained/dinov2_vitl14_pretrain.pth
rename to third_party/RoMa/pretrained/dinov2_vitl14_pretrain.pth
diff --git a/third_party/Roma/pretrained/roma_outdoor.pth b/third_party/RoMa/pretrained/roma_outdoor.pth
similarity index 100%
rename from third_party/Roma/pretrained/roma_outdoor.pth
rename to third_party/RoMa/pretrained/roma_outdoor.pth
diff --git a/third_party/Roma/requirements.txt b/third_party/RoMa/requirements.txt
similarity index 65%
rename from third_party/Roma/requirements.txt
rename to third_party/RoMa/requirements.txt
index 12addf0d0eb74e6cac0da6bca704eac0b28990d7..f0dbab3d4cb35a5f00e3dbc8e3f8b00a3e578428 100644
--- a/third_party/Roma/requirements.txt
+++ b/third_party/RoMa/requirements.txt
@@ -10,4 +10,4 @@ matplotlib
 h5py
 wandb
 timm
-xformers # Optional, used for memefficient attention
\ No newline at end of file
+#xformers # Optional, used for memefficient attention
\ No newline at end of file
diff --git a/third_party/Roma/roma/__init__.py b/third_party/RoMa/roma/__init__.py
similarity index 62%
rename from third_party/Roma/roma/__init__.py
rename to third_party/RoMa/roma/__init__.py
index a3c12d5247b93a83882edfb45bd127db794e791f..a7c96481e0a808b68c7b3054a3e34fa0b5c45ab9 100644
--- a/third_party/Roma/roma/__init__.py
+++ b/third_party/RoMa/roma/__init__.py
@@ -2,7 +2,7 @@ import os
 from .models import roma_outdoor, roma_indoor
 
 DEBUG_MODE = False
-RANK = int(os.environ.get("RANK", default=0))
+RANK = int(os.environ.get('RANK', default = 0))
 GLOBAL_STEP = 0
 STEP_SIZE = 1
-LOCAL_RANK = -1
+LOCAL_RANK = -1
\ No newline at end of file
diff --git a/third_party/Roma/roma/benchmarks/__init__.py b/third_party/RoMa/roma/benchmarks/__init__.py
similarity index 100%
rename from third_party/Roma/roma/benchmarks/__init__.py
rename to third_party/RoMa/roma/benchmarks/__init__.py
diff --git a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py b/third_party/RoMa/roma/benchmarks/hpatches_sequences_homog_benchmark.py
similarity index 91%
rename from third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
rename to third_party/RoMa/roma/benchmarks/hpatches_sequences_homog_benchmark.py
index 6417d4d54798360a027a0d11d50fc65cdfae015a..2154a471c73d9e883c3ba8ed1b90d708f4950a63 100644
--- a/third_party/Roma/roma/benchmarks/hpatches_sequences_homog_benchmark.py
+++ b/third_party/RoMa/roma/benchmarks/hpatches_sequences_homog_benchmark.py
@@ -53,7 +53,7 @@ class HpatchesHomogBenchmark:
         )
         return im_A_coords, im_A_to_im_B
 
-    def benchmark(self, model, model_name=None):
+    def benchmark(self, model, model_name = None):
         n_matches = []
         homog_dists = []
         for seq_idx, seq_name in tqdm(
@@ -69,7 +69,9 @@ class HpatchesHomogBenchmark:
                 H = np.loadtxt(
                     os.path.join(self.seqs_path, seq_name, "H_1_" + str(im_idx))
                 )
-                dense_matches, dense_certainty = model.match(im_A_path, im_B_path)
+                dense_matches, dense_certainty = model.match(
+                    im_A_path, im_B_path
+                )
                 good_matches, _ = model.sample(dense_matches, dense_certainty, 5000)
                 pos_a, pos_b = self.convert_coordinates(
                     good_matches[:, :2], good_matches[:, 2:], w1, h1, w2, h2
@@ -78,9 +80,9 @@ class HpatchesHomogBenchmark:
                     H_pred, inliers = cv2.findHomography(
                         pos_a,
                         pos_b,
-                        method=cv2.RANSAC,
-                        confidence=0.99999,
-                        ransacReprojThreshold=3 * min(w2, h2) / 480,
+                        method = cv2.RANSAC,
+                        confidence = 0.99999,
+                        ransacReprojThreshold = 3 * min(w2, h2) / 480,
                     )
                 except:
                     H_pred = None
diff --git a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py b/third_party/RoMa/roma/benchmarks/megadepth_dense_benchmark.py
similarity index 81%
rename from third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
rename to third_party/RoMa/roma/benchmarks/megadepth_dense_benchmark.py
index f51a77e15510572b8f594dbc7713a0f348a33fd8..0600d354b1d0dfa7f8e2b0f8882a4cc08fafeed9 100644
--- a/third_party/Roma/roma/benchmarks/megadepth_dense_benchmark.py
+++ b/third_party/RoMa/roma/benchmarks/megadepth_dense_benchmark.py
@@ -6,11 +6,8 @@ from roma.utils import warp_kpts
 from torch.utils.data import ConcatDataset
 import roma
 
-
 class MegadepthDenseBenchmark:
-    def __init__(
-        self, data_root="data/megadepth", h=384, w=512, num_samples=2000
-    ) -> None:
+    def __init__(self, data_root="data/megadepth", h = 384, w = 512, num_samples = 2000) -> None:
         mega = MegadepthBuilder(data_root=data_root)
         self.dataset = ConcatDataset(
             mega.build_scenes(split="test_loftr", ht=h, wt=w)
@@ -52,15 +49,13 @@ class MegadepthDenseBenchmark:
             pck_3_tot = 0.0
             pck_5_tot = 0.0
             sampler = torch.utils.data.WeightedRandomSampler(
-                torch.ones(len(self.dataset)),
-                replacement=False,
-                num_samples=self.num_samples,
+                torch.ones(len(self.dataset)), replacement=False, num_samples=self.num_samples
             )
             B = batch_size
             dataloader = torch.utils.data.DataLoader(
                 self.dataset, batch_size=B, num_workers=batch_size, sampler=sampler
             )
-            for idx, data in tqdm.tqdm(enumerate(dataloader), disable=roma.RANK > 0):
+            for idx, data in tqdm.tqdm(enumerate(dataloader), disable = roma.RANK > 0):
                 im_A, im_B, depth1, depth2, T_1to2, K1, K2 = (
                     data["im_A"],
                     data["im_B"],
@@ -77,36 +72,25 @@ class MegadepthDenseBenchmark:
                 if roma.DEBUG_MODE:
                     from roma.utils.utils import tensor_to_pil
                     import torch.nn.functional as F
-
                     path = "vis"
                     H, W = model.get_output_resolution()
-                    white_im = torch.ones((B, 1, H, W), device="cuda")
+                    white_im = torch.ones((B,1,H,W),device="cuda")
                     im_B_transfer_rgb = F.grid_sample(
-                        im_B.cuda(),
-                        matches[:, :, :W, 2:],
-                        mode="bilinear",
-                        align_corners=False,
+                        im_B.cuda(), matches[:,:,:W, 2:], mode="bilinear", align_corners=False
                     )
                     warp_im = im_B_transfer_rgb
-                    c_b = certainty[
-                        :, None
-                    ]  # (certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
+                    c_b = certainty[:,None]#(certainty*0.9 + 0.1*torch.ones_like(certainty))[:,None]
                     vis_im = c_b * warp_im + (1 - c_b) * white_im
                     for b in range(B):
                         import os
-
-                        os.makedirs(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}", exist_ok=True
-                        )
+                        os.makedirs(f"{path}/{model.name}/{idx}_{b}_{H}_{W}",exist_ok=True)
                         tensor_to_pil(vis_im[b], unnormalize=True).save(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg"
-                        )
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/warp.jpg")
                         tensor_to_pil(im_A[b].cuda(), unnormalize=True).save(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg"
-                        )
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_A.jpg")
                         tensor_to_pil(im_B[b].cuda(), unnormalize=True).save(
-                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg"
-                        )
+                            f"{path}/{model.name}/{idx}_{b}_{H}_{W}/im_B.jpg")
+
 
                 gd_tot, pck_1_tot, pck_3_tot, pck_5_tot = (
                     gd_tot + gd.mean(),
diff --git a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py b/third_party/RoMa/roma/benchmarks/megadepth_pose_estimation_benchmark.py
similarity index 69%
rename from third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
rename to third_party/RoMa/roma/benchmarks/megadepth_pose_estimation_benchmark.py
index 5d936a07d550763d0378a23ea83c79cec5d373fe..217aebab4cb73471cc156de9e8d3d882a1b2af95 100644
--- a/third_party/Roma/roma/benchmarks/megadepth_pose_estimation_benchmark.py
+++ b/third_party/RoMa/roma/benchmarks/megadepth_pose_estimation_benchmark.py
@@ -7,9 +7,8 @@ import torch.nn.functional as F
 import roma
 import kornia.geometry.epipolar as kepi
 
-
 class MegaDepthPoseEstimationBenchmark:
-    def __init__(self, data_root="data/megadepth", scene_names=None) -> None:
+    def __init__(self, data_root="data/megadepth", scene_names = None) -> None:
         if scene_names is None:
             self.scene_names = [
                 "0015_0.1_0.3.npz",
@@ -26,22 +25,13 @@ class MegaDepthPoseEstimationBenchmark:
         ]
         self.data_root = data_root
 
-    def benchmark(
-        self,
-        model,
-        model_name=None,
-        resolution=None,
-        scale_intrinsics=True,
-        calibrated=True,
-    ):
-        H, W = model.get_output_resolution()
+    def benchmark(self, model, model_name = None):
         with torch.no_grad():
             data_root = self.data_root
             tot_e_t, tot_e_R, tot_e_pose = [], [], []
             thresholds = [5, 10, 20]
             for scene_ind in range(len(self.scenes)):
                 import os
-
                 scene_name = os.path.splitext(self.scene_names[scene_ind])[0]
                 scene = self.scenes[scene_ind]
                 pairs = scene["pair_infos"]
@@ -58,22 +48,21 @@ class MegaDepthPoseEstimationBenchmark:
                     T2 = poses[idx2].copy()
                     R2, t2 = T2[:3, :3], T2[:3, 3]
                     R, t = compute_relative_pose(R1, t1, R2, t2)
-                    T1_to_2 = np.concatenate((R, t[:, None]), axis=-1)
+                    T1_to_2 = np.concatenate((R,t[:,None]), axis=-1)
                     im_A_path = f"{data_root}/{im_paths[idx1]}"
                     im_B_path = f"{data_root}/{im_paths[idx2]}"
                     dense_matches, dense_certainty = model.match(
                         im_A_path, im_B_path, K1.copy(), K2.copy(), T1_to_2.copy()
                     )
-                    sparse_matches, _ = model.sample(
-                        dense_matches, dense_certainty, 5000
+                    sparse_matches,_ = model.sample(
+                        dense_matches, dense_certainty, 5_000
                     )
-
+                    
                     im_A = Image.open(im_A_path)
                     w1, h1 = im_A.size
                     im_B = Image.open(im_B_path)
                     w2, h2 = im_B.size
-
-                    if scale_intrinsics:
+                    if True: # Note: we keep this true as it was used in DKM/RoMa papers. There is very little difference compared to setting to False. 
                         scale1 = 1200 / max(w1, h1)
                         scale2 = 1200 / max(w2, h2)
                         w1, h1 = scale1 * w1, scale1 * h1
@@ -82,42 +71,23 @@ class MegaDepthPoseEstimationBenchmark:
                         K1[:2] = K1[:2] * scale1
                         K2[:2] = K2[:2] * scale2
 
-                    kpts1 = sparse_matches[:, :2]
-                    kpts1 = np.stack(
-                        (
-                            w1 * (kpts1[:, 0] + 1) / 2,
-                            h1 * (kpts1[:, 1] + 1) / 2,
-                        ),
-                        axis=-1,
-                    )
-                    kpts2 = sparse_matches[:, 2:]
-                    kpts2 = np.stack(
-                        (
-                            w2 * (kpts2[:, 0] + 1) / 2,
-                            h2 * (kpts2[:, 1] + 1) / 2,
-                        ),
-                        axis=-1,
-                    )
-
+                    kpts1, kpts2 = model.to_pixel_coordinates(sparse_matches, h1, w1, h2, w2)
+                    kpts1, kpts2 = kpts1.cpu().numpy(), kpts2.cpu().numpy()
                     for _ in range(5):
                         shuffling = np.random.permutation(np.arange(len(kpts1)))
                         kpts1 = kpts1[shuffling]
                         kpts2 = kpts2[shuffling]
                         try:
-                            threshold = 0.5
-                            if calibrated:
-                                norm_threshold = threshold / (
-                                    np.mean(np.abs(K1[:2, :2]))
-                                    + np.mean(np.abs(K2[:2, :2]))
-                                )
-                                R_est, t_est, mask = estimate_pose(
-                                    kpts1,
-                                    kpts2,
-                                    K1,
-                                    K2,
-                                    norm_threshold,
-                                    conf=0.99999,
-                                )
+                            threshold = 0.5 
+                            norm_threshold = threshold / (np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
+                            R_est, t_est, mask = estimate_pose(
+                                kpts1,
+                                kpts2,
+                                K1,
+                                K2,
+                                norm_threshold,
+                                conf=0.99999,
+                            )
                             T1_to_2_est = np.concatenate((R_est, t_est), axis=-1)  #
                             e_t, e_R = compute_pose_error(T1_to_2_est, R, t)
                             e_pose = max(e_t, e_R)
diff --git a/third_party/Roma/roma/benchmarks/scannet_benchmark.py b/third_party/RoMa/roma/benchmarks/scannet_benchmark.py
similarity index 79%
rename from third_party/Roma/roma/benchmarks/scannet_benchmark.py
rename to third_party/RoMa/roma/benchmarks/scannet_benchmark.py
index 3187c2acf79f5af8f64397f55f6df40af327945b..853af0d0ebef4dfefe2632eb49e4156ea791ee76 100644
--- a/third_party/Roma/roma/benchmarks/scannet_benchmark.py
+++ b/third_party/RoMa/roma/benchmarks/scannet_benchmark.py
@@ -10,7 +10,7 @@ class ScanNetBenchmark:
     def __init__(self, data_root="data/scannet") -> None:
         self.data_root = data_root
 
-    def benchmark(self, model, model_name=None):
+    def benchmark(self, model, model_name = None):
         model.train(False)
         with torch.no_grad():
             data_root = self.data_root
@@ -24,20 +24,20 @@ class ScanNetBenchmark:
                 scene = pairs[pairind]
                 scene_name = f"scene0{scene[0]}_00"
                 im_A_path = osp.join(
-                    self.data_root,
-                    "scans_test",
-                    scene_name,
-                    "color",
-                    f"{scene[2]}.jpg",
-                )
+                        self.data_root,
+                        "scans_test",
+                        scene_name,
+                        "color",
+                        f"{scene[2]}.jpg",
+                    )
                 im_A = Image.open(im_A_path)
                 im_B_path = osp.join(
-                    self.data_root,
-                    "scans_test",
-                    scene_name,
-                    "color",
-                    f"{scene[3]}.jpg",
-                )
+                        self.data_root,
+                        "scans_test",
+                        scene_name,
+                        "color",
+                        f"{scene[3]}.jpg",
+                    )
                 im_B = Image.open(im_B_path)
                 T_gt = rel_pose[pairind].reshape(3, 4)
                 R, t = T_gt[:3, :3], T_gt[:3, 3]
@@ -76,20 +76,24 @@ class ScanNetBenchmark:
 
                 offset = 0.5
                 kpts1 = sparse_matches[:, :2]
-                kpts1 = np.stack(
-                    (
-                        w1 * (kpts1[:, 0] + 1) / 2 - offset,
-                        h1 * (kpts1[:, 1] + 1) / 2 - offset,
-                    ),
-                    axis=-1,
+                kpts1 = (
+                    np.stack(
+                        (
+                            w1 * (kpts1[:, 0] + 1) / 2 - offset,
+                            h1 * (kpts1[:, 1] + 1) / 2 - offset,
+                        ),
+                        axis=-1,
+                    )
                 )
                 kpts2 = sparse_matches[:, 2:]
-                kpts2 = np.stack(
-                    (
-                        w2 * (kpts2[:, 0] + 1) / 2 - offset,
-                        h2 * (kpts2[:, 1] + 1) / 2 - offset,
-                    ),
-                    axis=-1,
+                kpts2 = (
+                    np.stack(
+                        (
+                            w2 * (kpts2[:, 0] + 1) / 2 - offset,
+                            h2 * (kpts2[:, 1] + 1) / 2 - offset,
+                        ),
+                        axis=-1,
+                    )
                 )
                 for _ in range(5):
                     shuffling = np.random.permutation(np.arange(len(kpts1)))
@@ -97,8 +101,7 @@ class ScanNetBenchmark:
                     kpts2 = kpts2[shuffling]
                     try:
                         norm_threshold = 0.5 / (
-                            np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2]))
-                        )
+                        np.mean(np.abs(K1[:2, :2])) + np.mean(np.abs(K2[:2, :2])))
                         R_est, t_est, mask = estimate_pose(
                             kpts1,
                             kpts2,
diff --git a/third_party/Roma/roma/checkpointing/__init__.py b/third_party/RoMa/roma/checkpointing/__init__.py
similarity index 100%
rename from third_party/Roma/roma/checkpointing/__init__.py
rename to third_party/RoMa/roma/checkpointing/__init__.py
diff --git a/third_party/Roma/roma/checkpointing/checkpoint.py b/third_party/RoMa/roma/checkpointing/checkpoint.py
similarity index 96%
rename from third_party/Roma/roma/checkpointing/checkpoint.py
rename to third_party/RoMa/roma/checkpointing/checkpoint.py
index 6372d89fe86c00c7acedf015886717bfeca7bb1f..8995efeb54f4d558127ea63423fa958c64e9088f 100644
--- a/third_party/Roma/roma/checkpointing/checkpoint.py
+++ b/third_party/RoMa/roma/checkpointing/checkpoint.py
@@ -7,7 +7,6 @@ import gc
 
 import roma
 
-
 class CheckPoint:
     def __init__(self, dir=None, name="tmp"):
         self.name = name
@@ -20,7 +19,7 @@ class CheckPoint:
         optimizer,
         lr_scheduler,
         n,
-    ):
+        ):
         if roma.RANK == 0:
             assert model is not None
             if isinstance(model, (DataParallel, DistributedDataParallel)):
@@ -33,14 +32,14 @@ class CheckPoint:
             }
             torch.save(states, self.dir + self.name + f"_latest.pth")
             logger.info(f"Saved states {list(states.keys())}, at step {n}")
-
+    
     def load(
         self,
         model,
         optimizer,
         lr_scheduler,
         n,
-    ):
+        ):
         if os.path.exists(self.dir + self.name + f"_latest.pth") and roma.RANK == 0:
             states = torch.load(self.dir + self.name + f"_latest.pth")
             if "model" in states:
@@ -58,4 +57,4 @@ class CheckPoint:
             del states
             gc.collect()
             torch.cuda.empty_cache()
-        return model, optimizer, lr_scheduler, n
+        return model, optimizer, lr_scheduler, n
\ No newline at end of file
diff --git a/third_party/Roma/roma/datasets/__init__.py b/third_party/RoMa/roma/datasets/__init__.py
similarity index 52%
rename from third_party/Roma/roma/datasets/__init__.py
rename to third_party/RoMa/roma/datasets/__init__.py
index 6a11f122e222f0a9eded4afd3dd0b900826063e8..b60c709926a4a7bd019b73eac10879063a996c90 100644
--- a/third_party/Roma/roma/datasets/__init__.py
+++ b/third_party/RoMa/roma/datasets/__init__.py
@@ -1,2 +1,2 @@
 from .megadepth import MegadepthBuilder
-from .scannet import ScanNetBuilder
+from .scannet import ScanNetBuilder
\ No newline at end of file
diff --git a/third_party/Roma/roma/datasets/megadepth.py b/third_party/RoMa/roma/datasets/megadepth.py
similarity index 75%
rename from third_party/Roma/roma/datasets/megadepth.py
rename to third_party/RoMa/roma/datasets/megadepth.py
index 75cb72ded02c80d1ad6bce0d0269626ee49a9275..5deee5ac30c439a9f300c0ad2271f141931020c0 100644
--- a/third_party/Roma/roma/datasets/megadepth.py
+++ b/third_party/RoMa/roma/datasets/megadepth.py
@@ -10,7 +10,6 @@ import roma
 from roma.utils import *
 import math
 
-
 class MegadepthScene:
     def __init__(
         self,
@@ -23,20 +22,18 @@ class MegadepthScene:
         shake_t=0,
         rot_prob=0.0,
         normalize=True,
-        max_num_pairs=100_000,
-        scene_name=None,
-        use_horizontal_flip_aug=False,
-        use_single_horizontal_flip_aug=False,
-        colorjiggle_params=None,
-        random_eraser=None,
-        use_randaug=False,
-        randaug_params=None,
-        randomize_size=False,
+        max_num_pairs = 100_000,
+        scene_name = None,
+        use_horizontal_flip_aug = False,
+        use_single_horizontal_flip_aug = False,
+        colorjiggle_params = None,
+        random_eraser = None,
+        use_randaug = False,
+        randaug_params = None,
+        randomize_size = False,
     ) -> None:
         self.data_root = data_root
-        self.scene_name = (
-            os.path.splitext(scene_name)[0] + f"_{min_overlap}_{max_overlap}"
-        )
+        self.scene_name = os.path.splitext(scene_name)[0]+f"_{min_overlap}_{max_overlap}"
         self.image_paths = scene_info["image_paths"]
         self.depth_paths = scene_info["depth_paths"]
         self.intrinsics = scene_info["intrinsics"]
@@ -54,18 +51,18 @@ class MegadepthScene:
             self.overlaps = self.overlaps[pairinds]
         if randomize_size:
             area = ht * wt
-            s = int(16 * (math.sqrt(area) // 16))
-            sizes = ((ht, wt), (s, s), (wt, ht))
+            s = int(16 * (math.sqrt(area)//16))
+            sizes = ((ht,wt), (s,s), (wt,ht))
             choice = roma.RANK % 3
-            ht, wt = sizes[choice]
+            ht, wt = sizes[choice] 
         # counts, bins = np.histogram(self.overlaps,20)
         # print(counts)
         self.im_transform_ops = get_tuple_transform_ops(
-            resize=(ht, wt),
-            normalize=normalize,
-            colorjiggle_params=colorjiggle_params,
+            resize=(ht, wt), normalize=normalize, colorjiggle_params = colorjiggle_params,
         )
-        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt))
+        self.depth_transform_ops = get_depth_tuple_transform_ops(
+                resize=(ht, wt)
+            )
         self.wt, self.ht = wt, ht
         self.shake_t = shake_t
         self.random_eraser = random_eraser
@@ -78,19 +75,17 @@ class MegadepthScene:
     def load_im(self, im_path):
         im = Image.open(im_path)
         return im
-
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
+    
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
         im_A = im_A.flip(-1)
         im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
-        flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
-            K_A.device
-        )
-        K_A = flip_mat @ K_A
-        K_B = flip_mat @ K_B
-
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
+        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+        K_A = flip_mat@K_A  
+        K_B = flip_mat@K_B  
+        
         return im_A, im_B, depth_A, depth_B, K_A, K_B
-
+    
     def load_depth(self, depth_ref, crop=None):
         depth = np.array(h5py.File(depth_ref, "r")["depth"])
         return torch.from_numpy(depth)
@@ -145,31 +140,29 @@ class MegadepthScene:
         depth_A, depth_B = self.depth_transform_ops(
             (depth_A[None, None], depth_B[None, None])
         )
-
-        [im_A, im_B, depth_A, depth_B], t = self.rand_shake(
-            im_A, im_B, depth_A, depth_B
-        )
+        
+        [im_A, im_B, depth_A, depth_B], t = self.rand_shake(im_A, im_B, depth_A, depth_B)
         K1[:2, 2] += t
         K2[:2, 2] += t
-
+        
         im_A, im_B = im_A[None], im_B[None]
         if self.random_eraser is not None:
             im_A, depth_A = self.random_eraser(im_A, depth_A)
             im_B, depth_B = self.random_eraser(im_B, depth_B)
-
+                
         if self.use_horizontal_flip_aug:
             if np.random.rand() > 0.5:
-                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
-                    im_A, im_B, depth_A, depth_B, K1, K2
-                )
+                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
         if self.use_single_horizontal_flip_aug:
             if np.random.rand() > 0.5:
                 im_B, depth_B, K2 = self.single_horizontal_flip(im_B, depth_B, K2)
-
+        
         if roma.DEBUG_MODE:
-            tensor_to_pil(im_A[0], unnormalize=True).save(f"vis/im_A.jpg")
-            tensor_to_pil(im_B[0], unnormalize=True).save(f"vis/im_B.jpg")
-
+            tensor_to_pil(im_A[0], unnormalize=True).save(
+                            f"vis/im_A.jpg")
+            tensor_to_pil(im_B[0], unnormalize=True).save(
+                            f"vis/im_B.jpg")
+            
         data_dict = {
             "im_A": im_A[0],
             "im_A_identifier": self.image_paths[idx1].split("/")[-1].split(".jpg")[0],
@@ -182,53 +175,25 @@ class MegadepthScene:
             "T_1to2": T_1to2,
             "im_A_path": im_A_ref,
             "im_B_path": im_B_ref,
+            
         }
         return data_dict
 
 
 class MegadepthBuilder:
-    def __init__(
-        self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore=True
-    ) -> None:
+    def __init__(self, data_root="data/megadepth", loftr_ignore=True, imc21_ignore = True) -> None:
         self.data_root = data_root
         self.scene_info_root = os.path.join(data_root, "prep_scene_info")
         self.all_scenes = os.listdir(self.scene_info_root)
         self.test_scenes = ["0017.npy", "0004.npy", "0048.npy", "0013.npy"]
         # LoFTR did the D2-net preprocessing differently than we did and got more ignore scenes, can optionially ignore those
-        self.loftr_ignore_scenes = set(
-            [
-                "0121.npy",
-                "0133.npy",
-                "0168.npy",
-                "0178.npy",
-                "0229.npy",
-                "0349.npy",
-                "0412.npy",
-                "0430.npy",
-                "0443.npy",
-                "1001.npy",
-                "5014.npy",
-                "5015.npy",
-                "5016.npy",
-            ]
-        )
-        self.imc21_scenes = set(
-            [
-                "0008.npy",
-                "0019.npy",
-                "0021.npy",
-                "0024.npy",
-                "0025.npy",
-                "0032.npy",
-                "0063.npy",
-                "1589.npy",
-            ]
-        )
+        self.loftr_ignore_scenes = set(['0121.npy', '0133.npy', '0168.npy', '0178.npy', '0229.npy', '0349.npy', '0412.npy', '0430.npy', '0443.npy', '1001.npy', '5014.npy', '5015.npy', '5016.npy'])
+        self.imc21_scenes = set(['0008.npy', '0019.npy', '0021.npy', '0024.npy', '0025.npy', '0032.npy', '0063.npy', '1589.npy'])
         self.test_scenes_loftr = ["0015.npy", "0022.npy"]
         self.loftr_ignore = loftr_ignore
         self.imc21_ignore = imc21_ignore
 
-    def build_scenes(self, split="train", min_overlap=0.0, scene_names=None, **kwargs):
+    def build_scenes(self, split="train", min_overlap=0.0, scene_names = None, **kwargs):
         if split == "train":
             scene_names = set(self.all_scenes) - set(self.test_scenes)
         elif split == "train_loftr":
@@ -252,11 +217,7 @@ class MegadepthBuilder:
             ).item()
             scenes.append(
                 MegadepthScene(
-                    self.data_root,
-                    scene_info,
-                    min_overlap=min_overlap,
-                    scene_name=scene_name,
-                    **kwargs,
+                    self.data_root, scene_info, min_overlap=min_overlap,scene_name = scene_name, **kwargs
                 )
             )
         return scenes
diff --git a/third_party/RoMa/roma/datasets/scannet.py b/third_party/RoMa/roma/datasets/scannet.py
new file mode 100644
index 0000000000000000000000000000000000000000..704ea57259afdfbbca627ad143bee97a0a79d41c
--- /dev/null
+++ b/third_party/RoMa/roma/datasets/scannet.py
@@ -0,0 +1,160 @@
+import os
+import random
+from PIL import Image
+import cv2
+import h5py
+import numpy as np
+import torch
+from torch.utils.data import (
+    Dataset,
+    DataLoader,
+    ConcatDataset)
+
+import torchvision.transforms.functional as tvf
+import kornia.augmentation as K
+import os.path as osp
+import matplotlib.pyplot as plt
+import roma
+from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
+from roma.utils.transforms import GeometricSequential
+from tqdm import tqdm
+
+class ScanNetScene:
+    def __init__(self, data_root, scene_info, ht = 384, wt = 512, min_overlap=0., shake_t = 0, rot_prob=0.,use_horizontal_flip_aug = False,
+) -> None:
+        self.scene_root = osp.join(data_root,"scans","scans_train")
+        self.data_names = scene_info['name']
+        self.overlaps = scene_info['score']
+        # Only sample 10s
+        valid = (self.data_names[:,-2:] % 10).sum(axis=-1) == 0
+        self.overlaps = self.overlaps[valid]
+        self.data_names = self.data_names[valid]
+        if len(self.data_names) > 10000:
+            pairinds = np.random.choice(np.arange(0,len(self.data_names)),10000,replace=False)
+            self.data_names = self.data_names[pairinds]
+            self.overlaps = self.overlaps[pairinds]
+        self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
+        self.depth_transform_ops = get_depth_tuple_transform_ops(resize=(ht, wt), normalize=False)
+        self.wt, self.ht = wt, ht
+        self.shake_t = shake_t
+        self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
+        self.use_horizontal_flip_aug = use_horizontal_flip_aug
+
+    def load_im(self, im_B, crop=None):
+        im = Image.open(im_B)
+        return im
+    
+    def load_depth(self, depth_ref, crop=None):
+        depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
+        depth = depth / 1000
+        depth = torch.from_numpy(depth).float()  # (h, w)
+        return depth
+
+    def __len__(self):
+        return len(self.data_names)
+    
+    def scale_intrinsic(self, K, wi, hi):
+        sx, sy = self.wt / wi, self.ht /  hi
+        sK = torch.tensor([[sx, 0, 0],
+                        [0, sy, 0],
+                        [0, 0, 1]])
+        return sK@K
+
+    def horizontal_flip(self, im_A, im_B, depth_A, depth_B,  K_A, K_B):
+        im_A = im_A.flip(-1)
+        im_B = im_B.flip(-1)
+        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1) 
+        flip_mat = torch.tensor([[-1, 0, self.wt],[0,1,0],[0,0,1.]]).to(K_A.device)
+        K_A = flip_mat@K_A  
+        K_B = flip_mat@K_B  
+        
+        return im_A, im_B, depth_A, depth_B, K_A, K_B
+    def read_scannet_pose(self,path):
+        """ Read ScanNet's Camera2World pose and transform it to World2Camera.
+        
+        Returns:
+            pose_w2c (np.ndarray): (4, 4)
+        """
+        cam2world = np.loadtxt(path, delimiter=' ')
+        world2cam = np.linalg.inv(cam2world)
+        return world2cam
+
+
+    def read_scannet_intrinsic(self,path):
+        """ Read ScanNet's intrinsic matrix and return the 3x3 matrix.
+        """
+        intrinsic = np.loadtxt(path, delimiter=' ')
+        return torch.tensor(intrinsic[:-1, :-1], dtype = torch.float)
+
+    def __getitem__(self, pair_idx):
+        # read intrinsics of original size
+        data_name = self.data_names[pair_idx]
+        scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
+        scene_name = f'scene{scene_name:04d}_{scene_sub_name:02d}'
+        
+        # read the intrinsic of depthmap
+        K1 = K2 =  self.read_scannet_intrinsic(osp.join(self.scene_root,
+                       scene_name,
+                       'intrinsic', 'intrinsic_color.txt'))#the depth K is not the same, but doesnt really matter
+        # read and compute relative poses
+        T1 =  self.read_scannet_pose(osp.join(self.scene_root,
+                       scene_name,
+                       'pose', f'{stem_name_1}.txt'))
+        T2 =  self.read_scannet_pose(osp.join(self.scene_root,
+                       scene_name,
+                       'pose', f'{stem_name_2}.txt'))
+        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[:4, :4]  # (4, 4)
+
+        # Load positive pair data
+        im_A_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_1}.jpg')
+        im_B_ref = os.path.join(self.scene_root, scene_name, 'color', f'{stem_name_2}.jpg')
+        depth_A_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_1}.png')
+        depth_B_ref = os.path.join(self.scene_root, scene_name, 'depth', f'{stem_name_2}.png')
+
+        im_A = self.load_im(im_A_ref)
+        im_B = self.load_im(im_B_ref)
+        depth_A = self.load_depth(depth_A_ref)
+        depth_B = self.load_depth(depth_B_ref)
+
+        # Recompute camera intrinsic matrix due to the resize
+        K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
+        K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
+        # Process images
+        im_A, im_B = self.im_transform_ops((im_A, im_B))
+        depth_A, depth_B = self.depth_transform_ops((depth_A[None,None], depth_B[None,None]))
+        if self.use_horizontal_flip_aug:
+            if np.random.rand() > 0.5:
+                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(im_A, im_B, depth_A, depth_B, K1, K2)
+
+        data_dict = {'im_A': im_A,
+                    'im_B': im_B,
+                    'im_A_depth': depth_A[0,0],
+                    'im_B_depth': depth_B[0,0],
+                    'K1': K1,
+                    'K2': K2,
+                    'T_1to2':T_1to2,
+                    }
+        return data_dict
+
+
+class ScanNetBuilder:
+    def __init__(self, data_root = 'data/scannet') -> None:
+        self.data_root = data_root
+        self.scene_info_root = os.path.join(data_root,'scannet_indices')
+        self.all_scenes = os.listdir(self.scene_info_root)
+        
+    def build_scenes(self, split = 'train', min_overlap=0., **kwargs):
+        # Note: split doesn't matter here as we always use same scannet_train scenes
+        scene_names = self.all_scenes
+        scenes = []
+        for scene_name in tqdm(scene_names, disable = roma.RANK > 0):
+            scene_info = np.load(os.path.join(self.scene_info_root,scene_name), allow_pickle=True)
+            scenes.append(ScanNetScene(self.data_root, scene_info, min_overlap=min_overlap, **kwargs))
+        return scenes
+    
+    def weight_scenes(self, concat_dataset, alpha=.5):
+        ns = []
+        for d in concat_dataset.datasets:
+            ns.append(len(d))
+        ws = torch.cat([torch.ones(n)/n**alpha for n in ns])
+        return ws
diff --git a/third_party/RoMa/roma/losses/__init__.py b/third_party/RoMa/roma/losses/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e08abacfc0f83d7de0f2ddc0583766a80bf53cf
--- /dev/null
+++ b/third_party/RoMa/roma/losses/__init__.py
@@ -0,0 +1 @@
+from .robust_loss import RobustLosses
\ No newline at end of file
diff --git a/third_party/RoMa/roma/losses/robust_loss.py b/third_party/RoMa/roma/losses/robust_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..b932b2706f619c083485e1be0d86eec44ead83ef
--- /dev/null
+++ b/third_party/RoMa/roma/losses/robust_loss.py
@@ -0,0 +1,157 @@
+from einops.einops import rearrange
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from roma.utils.utils import get_gt_warp
+import wandb
+import roma
+import math
+
+class RobustLosses(nn.Module):
+    def __init__(
+        self,
+        robust=False,
+        center_coords=False,
+        scale_normalize=False,
+        ce_weight=0.01,
+        local_loss=True,
+        local_dist=4.0,
+        local_largest_scale=8,
+        smooth_mask = False,
+        depth_interpolation_mode = "bilinear",
+        mask_depth_loss = False,
+        relative_depth_error_threshold = 0.05,
+        alpha = 1.,
+        c = 1e-3,
+    ):
+        super().__init__()
+        self.robust = robust  # measured in pixels
+        self.center_coords = center_coords
+        self.scale_normalize = scale_normalize
+        self.ce_weight = ce_weight
+        self.local_loss = local_loss
+        self.local_dist = local_dist
+        self.local_largest_scale = local_largest_scale
+        self.smooth_mask = smooth_mask
+        self.depth_interpolation_mode = depth_interpolation_mode
+        self.mask_depth_loss = mask_depth_loss
+        self.relative_depth_error_threshold = relative_depth_error_threshold
+        self.avg_overlap = dict()
+        self.alpha = alpha
+        self.c = c
+
+    def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
+        with torch.no_grad():
+            B, C, H, W = scale_gm_cls.shape
+            device = x2.device
+            cls_res = round(math.sqrt(C))
+            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2)
+            GT = (G[None,:,None,None,:]-x2[:,None]).norm(dim=-1).min(dim=1).indices
+        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction  = 'none')[prob > 0.99]
+        if not torch.any(cls_loss):
+            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+
+        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:,0], prob)
+        losses = {
+            f"gm_certainty_loss_{scale}": certainty_loss.mean(),
+            f"gm_cls_loss_{scale}": cls_loss.mean(),
+        }
+        wandb.log(losses, step = roma.GLOBAL_STEP)
+        return losses
+
+    def delta_cls_loss(self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale):
+        with torch.no_grad():
+            B, C, H, W = delta_cls.shape
+            device = x2.device
+            cls_res = round(math.sqrt(C))
+            G = torch.meshgrid(*[torch.linspace(-1+1/cls_res, 1 - 1/cls_res, steps = cls_res,device = device) for _ in range(2)])
+            G = torch.stack((G[1], G[0]), dim = -1).reshape(C,2) * offset_scale
+            GT = (G[None,:,None,None,:] + flow_pre_delta[:,None] - x2[:,None]).norm(dim=-1).min(dim=1).indices
+        cls_loss = F.cross_entropy(delta_cls, GT, reduction  = 'none')[prob > 0.99]
+        if not torch.any(cls_loss):
+            cls_loss = (certainty_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:,0], prob)
+        losses = {
+            f"delta_certainty_loss_{scale}": certainty_loss.mean(),
+            f"delta_cls_loss_{scale}": cls_loss.mean(),
+        }
+        wandb.log(losses, step = roma.GLOBAL_STEP)
+        return losses
+
+    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode = "delta"):
+        epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1)
+        if scale == 1:
+            pck_05 = (epe[prob > 0.99] < 0.5 * (2/512)).float().mean()
+            wandb.log({"train_pck_05": pck_05}, step = roma.GLOBAL_STEP)
+
+        ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
+        a = self.alpha
+        cs = self.c * scale
+        x = epe[prob > 0.99]
+        reg_loss = cs**a * ((x/(cs))**2 + 1**2)**(a/2)
+        if not torch.any(reg_loss):
+            reg_loss = (ce_loss * 0.0)  # Prevent issues where prob is 0 everywhere
+        losses = {
+            f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
+            f"{mode}_regression_loss_{scale}": reg_loss.mean(),
+        }
+        wandb.log(losses, step = roma.GLOBAL_STEP)
+        return losses
+
+    def forward(self, corresps, batch):
+        scales = list(corresps.keys())
+        tot_loss = 0.0
+        # scale_weights due to differences in scale for regression gradients and classification gradients
+        scale_weights = {1:1, 2:1, 4:1, 8:1, 16:1}
+        for scale in scales:
+            scale_corresps = corresps[scale]
+            scale_certainty, flow_pre_delta, delta_cls, offset_scale, scale_gm_cls, scale_gm_certainty, flow, scale_gm_flow = (
+                scale_corresps["certainty"],
+                scale_corresps["flow_pre_delta"],
+                scale_corresps.get("delta_cls"),
+                scale_corresps.get("offset_scale"),
+                scale_corresps.get("gm_cls"),
+                scale_corresps.get("gm_certainty"),
+                scale_corresps["flow"],
+                scale_corresps.get("gm_flow"),
+
+            )
+            flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
+            b, h, w, d = flow_pre_delta.shape
+            gt_warp, gt_prob = get_gt_warp(                
+            batch["im_A_depth"],
+            batch["im_B_depth"],
+            batch["T_1to2"],
+            batch["K1"],
+            batch["K2"],
+            H=h,
+            W=w,
+        )
+            x2 = gt_warp.float()
+            prob = gt_prob
+            
+            if self.local_largest_scale >= scale:
+                prob = prob * (
+                        F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[:, 0]
+                        < (2 / 512) * (self.local_dist[scale] * scale))
+            
+            if scale_gm_cls is not None:
+                gm_cls_losses = self.gm_cls_loss(x2, prob, scale_gm_cls, scale_gm_certainty, scale)
+                gm_loss = self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"] + gm_cls_losses[f"gm_cls_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * gm_loss
+            elif scale_gm_flow is not None:
+                gm_flow_losses = self.regression_loss(x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode = "gm")
+                gm_loss = self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"] + gm_flow_losses[f"gm_regression_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * gm_loss
+            
+            if delta_cls is not None:
+                delta_cls_losses = self.delta_cls_loss(x2, prob, flow_pre_delta, delta_cls, scale_certainty, scale, offset_scale)
+                delta_cls_loss = self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"] + delta_cls_losses[f"delta_cls_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
+            else:
+                delta_regression_losses = self.regression_loss(x2, prob, flow, scale_certainty, scale)
+                reg_loss = self.ce_weight * delta_regression_losses[f"delta_certainty_loss_{scale}"] + delta_regression_losses[f"delta_regression_loss_{scale}"]
+                tot_loss = tot_loss + scale_weights[scale] * reg_loss
+            prev_epe = (flow.permute(0,2,3,1) - x2).norm(dim=-1).detach()
+        return tot_loss
diff --git a/third_party/RoMa/roma/models/__init__.py b/third_party/RoMa/roma/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f20461e2f3a1722e558cefab94c5164be8842c3
--- /dev/null
+++ b/third_party/RoMa/roma/models/__init__.py
@@ -0,0 +1 @@
+from .model_zoo import roma_outdoor, roma_indoor
\ No newline at end of file
diff --git a/third_party/Roma/roma/models/encoders.py b/third_party/RoMa/roma/models/encoders.py
similarity index 83%
rename from third_party/Roma/roma/models/encoders.py
rename to third_party/RoMa/roma/models/encoders.py
index 3b9a1a1791ec7b2f1352be1984d5232911366c0e..643360c9d61766f9f411a74bdf3a6f1114326bcb 100644
--- a/third_party/Roma/roma/models/encoders.py
+++ b/third_party/RoMa/roma/models/encoders.py
@@ -8,7 +8,8 @@ import gc
 
 
 class ResNet50(nn.Module):
-    def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
+    def __init__(self, pretrained=False, high_res = False, weights = None, 
+                 dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False, amp_dtype = torch.float16) -> None:
         super().__init__()
         if dilation is None:
             dilation = [False,False,False]
@@ -24,10 +25,7 @@ class ResNet50(nn.Module):
         self.freeze_bn = freeze_bn
         self.early_exit = early_exit
         self.amp = amp
-        if not torch.cuda.is_available():
-            self.amp_dtype = torch.float32
-        else:
-            self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
 
     def forward(self, x, **kwargs):
         with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -59,14 +57,11 @@ class ResNet50(nn.Module):
                 pass
 
 class VGG19(nn.Module):
-    def __init__(self, pretrained=False, amp = False) -> None:
+    def __init__(self, pretrained=False, amp = False, amp_dtype = torch.float16) -> None:
         super().__init__()
         self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
         self.amp = amp
-        if not torch.cuda.is_available():
-            self.amp_dtype = torch.float32
-        else:
-            self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
 
     def forward(self, x, **kwargs):
         with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
@@ -80,7 +75,7 @@ class VGG19(nn.Module):
             return feats
 
 class CNNandDinov2(nn.Module):
-    def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
+    def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None, amp_dtype = torch.float16):
         super().__init__()
         if dinov2_weights is None:
             dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
@@ -100,10 +95,7 @@ class CNNandDinov2(nn.Module):
         else:
             self.cnn = VGG19(**cnn_kwargs)
         self.amp = amp
-        if not torch.cuda.is_available():
-            self.amp_dtype = torch.float32
-        else:
-            self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
+        self.amp_dtype = amp_dtype
         if self.amp:
             dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
         self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
diff --git a/third_party/Roma/roma/models/matcher.py b/third_party/RoMa/roma/models/matcher.py
similarity index 83%
rename from third_party/Roma/roma/models/matcher.py
rename to third_party/RoMa/roma/models/matcher.py
index b68f2984e2d4515c2cf0a864213de27e714383fa..25a89c8dd99bc1eca8c591dbbc3b5ddbd987829c 100644
--- a/third_party/Roma/roma/models/matcher.py
+++ b/third_party/RoMa/roma/models/matcher.py
@@ -7,6 +7,7 @@ import torch.nn.functional as F
 from einops import rearrange
 import warnings
 from warnings import warn
+from PIL import Image
 
 import roma
 from roma.utils import get_tuple_transform_ops
@@ -37,6 +38,7 @@ class ConvRefiner(nn.Module):
         sample_mode = "bilinear",
         norm_type = nn.BatchNorm2d,
         bn_momentum = 0.1,
+        amp_dtype = torch.float16,
     ):
         super().__init__()
         self.bn_momentum = bn_momentum
@@ -71,12 +73,8 @@ class ConvRefiner(nn.Module):
         self.disable_local_corr_grad = disable_local_corr_grad
         self.is_classifier = is_classifier
         self.sample_mode = sample_mode
-        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-        if not torch.cuda.is_available():
-            self.amp_dtype = torch.float32
-        else:
-            self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
-
+        self.amp_dtype = amp_dtype
+        
     def create_block(
         self,
         in_dim,
@@ -113,8 +111,8 @@ class ConvRefiner(nn.Module):
             if self.has_displacement_emb:
                 im_A_coords = torch.meshgrid(
                 (
-                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=self.device),
-                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=self.device),
+                    torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=x.device),
+                    torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=x.device),
                 )
                 )
                 im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
@@ -278,7 +276,7 @@ class Decoder(nn.Module):
     def __init__(
         self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
         num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
-        flow_upsample_mode = "bilinear"
+        flow_upsample_mode = "bilinear", amp_dtype = torch.float16,
     ):
         super().__init__()
         self.embedding_decoder = embedding_decoder
@@ -300,11 +298,8 @@ class Decoder(nn.Module):
         self.displacement_dropout_p = displacement_dropout_p
         self.gm_warp_dropout_p = gm_warp_dropout_p
         self.flow_upsample_mode = flow_upsample_mode
-        if not torch.cuda.is_available():
-            self.amp_dtype = torch.float32
-        else:
-            self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
-
+        self.amp_dtype = amp_dtype
+        
     def get_placeholder_flow(self, b, h, w, device):
         coarse_coords = torch.meshgrid(
             (
@@ -367,7 +362,7 @@ class Decoder(nn.Module):
             corresps[ins] = {}
             f1_s, f2_s = f1[ins], f2[ins]
             if new_scale in self.proj:
-                with torch.autocast("cuda", self.amp_dtype):
+                with torch.autocast("cuda", dtype = self.amp_dtype):
                     f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
 
             if ins in coarse_scales:
@@ -429,11 +424,12 @@ class RegressionMatcher(nn.Module):
         decoder,
         h=448,
         w=448,
-        sample_mode = "threshold",
+        sample_mode = "threshold_balanced",
         upsample_preds = False,
         symmetric = False,
         name = None,
         attenuate_cert = None,
+        recrop_upsample = False,
     ):
         super().__init__()
         self.attenuate_cert = attenuate_cert
@@ -448,6 +444,7 @@ class RegressionMatcher(nn.Module):
         self.upsample_res = (14*16*6, 14*16*6)
         self.symmetric = symmetric
         self.sample_thresh = 0.05
+        self.recrop_upsample = recrop_upsample
             
     def get_output_resolution(self):
         if not self.upsample_preds:
@@ -527,12 +524,62 @@ class RegressionMatcher(nn.Module):
                                 scale_factor=scale_factor)
         return corresps
     
-    def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
-        kpts_A, kpts_B = matches[...,:2], matches[...,2:]
+    def to_pixel_coordinates(self, coords, H_A, W_A, H_B, W_B):
+        if isinstance(coords, (list, tuple)):
+            kpts_A, kpts_B = coords[0], coords[1]
+        else:
+            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
         kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
         kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
         return kpts_A, kpts_B
+    
+    def to_normalized_coordinates(self, coords, H_A, W_A, H_B, W_B):
+        if isinstance(coords, (list, tuple)):
+            kpts_A, kpts_B = coords[0], coords[1]
+        else:
+            kpts_A, kpts_B = coords[...,:2], coords[...,2:]
+        kpts_A = torch.stack((2/W_A * kpts_A[...,0] - 1, 2/H_A * kpts_A[...,1] - 1),axis=-1)
+        kpts_B = torch.stack((2/W_B * kpts_B[...,0] - 1, 2/H_B * kpts_B[...,1] - 1),axis=-1)
+        return kpts_A, kpts_B
 
+    def match_keypoints(self, x_A, x_B, warp, certainty, return_tuple = True, return_inds = False):
+        x_A_to_B = F.grid_sample(warp[...,-2:].permute(2,0,1)[None], x_A[None,None], align_corners = False, mode = "bilinear")[0,:,0].mT
+        cert_A_to_B = F.grid_sample(certainty[None,None,...], x_A[None,None], align_corners = False, mode = "bilinear")[0,0,0]
+        D = torch.cdist(x_A_to_B, x_B)
+        inds_A, inds_B = torch.nonzero((D == D.min(dim=-1, keepdim = True).values) * (D == D.min(dim=-2, keepdim = True).values) * (cert_A_to_B[:,None] > self.sample_thresh), as_tuple = True)
+        
+        if return_tuple:
+            if return_inds:
+                return inds_A, inds_B
+            else:
+                return x_A[inds_A], x_B[inds_B]
+        else:
+            if return_inds:
+                return torch.cat((inds_A, inds_B),dim=-1)
+            else:
+                return torch.cat((x_A[inds_A], x_B[inds_B]),dim=-1)
+    
+    def get_roi(self, certainty, W, H, thr = 0.025):
+        raise NotImplementedError("WIP, disable for now")
+        hs,ws = certainty.shape
+        certainty = certainty/certainty.sum(dim=(-1,-2))
+        cum_certainty_w = certainty.cumsum(dim=-1).sum(dim=-2)
+        cum_certainty_h = certainty.cumsum(dim=-2).sum(dim=-1)
+        print(cum_certainty_w)
+        print(torch.min(torch.nonzero(cum_certainty_w > thr)))
+        print(torch.min(torch.nonzero(cum_certainty_w < thr)))
+        left = int(W/ws * torch.min(torch.nonzero(cum_certainty_w > thr)))
+        right = int(W/ws * torch.max(torch.nonzero(cum_certainty_w < 1 - thr)))
+        top = int(H/hs * torch.min(torch.nonzero(cum_certainty_h > thr)))
+        bottom = int(H/hs * torch.max(torch.nonzero(cum_certainty_h < 1 - thr)))
+        print(left, right, top, bottom)
+        return left, top, right, bottom
+
+    def recrop(self, certainty, image_path):
+        roi = self.get_roi(certainty, *Image.open(image_path).size)
+        return Image.open(image_path).convert("RGB").crop(roi)
+        
+    @torch.inference_mode()
     def match(
         self,
         im_A_path,
@@ -543,9 +590,8 @@ class RegressionMatcher(nn.Module):
     ):
         if device is None:
             device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
-        from PIL import Image
         if isinstance(im_A_path, (str, os.PathLike)):
-            im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+            im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
         else:
             # Assume its not a path
             im_A, im_B = im_A_path, im_B_path
@@ -597,7 +643,14 @@ class RegressionMatcher(nn.Module):
                 test_transform = get_tuple_transform_ops(
                     resize=(hs, ws), normalize=True
                 )
-                im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
+                if self.recrop_upsample:
+                    certainty = corresps[finest_scale]["certainty"]
+                    print(certainty.shape)
+                    im_A = self.recrop(certainty[0,0], im_A_path)
+                    im_B = self.recrop(certainty[1,0], im_B_path)
+                    #TODO: need to adjust corresps when doing this
+                else:
+                    im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
                 im_A, im_B = test_transform((im_A, im_B))
                 im_A, im_B = im_A[None].to(device), im_B[None].to(device)
                 scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
@@ -653,4 +706,30 @@ class RegressionMatcher(nn.Module):
                     warp[0],
                     certainty[0, 0],
                 )
+                
+    def visualize_warp(self, warp, certainty, im_A = None, im_B = None, im_A_path = None, im_B_path = None, device = "cuda", symmetric = True, save_path = None):
+        assert symmetric == True, "Currently assuming bidirectional warp, might update this if someone complains ;)"
+        H,W2,_ = warp.shape
+        W = W2//2 if symmetric else W2
+        if im_A is None:
+            from PIL import Image
+            im_A, im_B = Image.open(im_A_path).convert("RGB"), Image.open(im_B_path).convert("RGB")
+        im_A = im_A.resize((W,H))
+        im_B = im_B.resize((W,H))
+            
+        x_A = (torch.tensor(np.array(im_A)) / 255).to(device).permute(2, 0, 1)
+        x_B = (torch.tensor(np.array(im_B)) / 255).to(device).permute(2, 0, 1)
 
+        im_A_transfer_rgb = F.grid_sample(
+        x_B[None], warp[:,:W, 2:][None], mode="bilinear", align_corners=False
+        )[0]
+        im_B_transfer_rgb = F.grid_sample(
+        x_A[None], warp[:, W:, :2][None], mode="bilinear", align_corners=False
+        )[0]
+        warp_im = torch.cat((im_A_transfer_rgb,im_B_transfer_rgb),dim=2)
+        white_im = torch.ones((H,2*W),device=device)
+        vis_im = certainty * warp_im + (1 - certainty) * white_im
+        if save_path is not None:
+            from roma.utils import tensor_to_pil
+            tensor_to_pil(vis_im, unnormalize=False).save(save_path)
+        return vis_im
diff --git a/third_party/RoMa/roma/models/model_zoo/__init__.py b/third_party/RoMa/roma/models/model_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..49ca7b8557cb8f6948bca28c631e39d899e49177
--- /dev/null
+++ b/third_party/RoMa/roma/models/model_zoo/__init__.py
@@ -0,0 +1,53 @@
+from typing import Union
+import torch
+from .roma_models import roma_model
+
+weight_urls = {
+    "roma": {
+        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
+        "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
+    },
+    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", #hopefully this doesnt change :D
+}
+
+def roma_outdoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
+    if isinstance(coarse_res, int):
+        coarse_res = (coarse_res, coarse_res)
+    if isinstance(upsample_res, int):    
+        upsample_res = (upsample_res, upsample_res)
+
+    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    
+    if weights is None:
+        weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["outdoor"],
+                                                     map_location=device)
+    if dinov2_weights is None:
+        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+                                                     map_location=device)
+    model = roma_model(resolution=coarse_res, upsample_preds=True,
+               weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
+    model.upsample_res = upsample_res
+    print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
+    return model
+
+def roma_indoor(device, weights=None, dinov2_weights=None, coarse_res: Union[int,tuple[int,int]] = 560, upsample_res: Union[int,tuple[int,int]] = 864, amp_dtype: torch.dtype = torch.float16):
+    if isinstance(coarse_res, int):
+        coarse_res = (coarse_res, coarse_res)
+    if isinstance(upsample_res, int):    
+        upsample_res = (upsample_res, upsample_res)
+
+    assert coarse_res[0] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    assert coarse_res[1] % 14 == 0, "Needs to be multiple of 14 for backbone"
+    
+    if weights is None:
+        weights = torch.hub.load_state_dict_from_url(weight_urls["roma"]["indoor"],
+                                                     map_location=device)
+    if dinov2_weights is None:
+        dinov2_weights = torch.hub.load_state_dict_from_url(weight_urls["dinov2"],
+                                                     map_location=device)
+    model = roma_model(resolution=coarse_res, upsample_preds=True,
+               weights=weights,dinov2_weights = dinov2_weights,device=device, amp_dtype=amp_dtype)
+    model.upsample_res = upsample_res
+    print(f"Using coarse resolution {coarse_res}, and upsample res {model.upsample_res}")
+    return model
diff --git a/third_party/RoMa/roma/models/model_zoo/roma_models.py b/third_party/RoMa/roma/models/model_zoo/roma_models.py
new file mode 100644
index 0000000000000000000000000000000000000000..13f8d872f3aad6ef42b090b123f77a96ff1ce68f
--- /dev/null
+++ b/third_party/RoMa/roma/models/model_zoo/roma_models.py
@@ -0,0 +1,160 @@
+import warnings
+import torch.nn as nn
+import torch
+from roma.models.matcher import *
+from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
+from roma.models.encoders import *
+
+def roma_model(resolution, upsample_preds, device = None, weights=None, dinov2_weights=None, amp_dtype: torch.dtype=torch.float16, **kwargs):
+    # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
+    #torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul TODO: these probably ruin stuff, should be careful
+    #torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
+    warnings.filterwarnings('ignore', category=UserWarning, message='TypedStorage is deprecated')
+    gp_dim = 512
+    feat_dim = 512
+    decoder_dim = gp_dim + feat_dim
+    cls_to_coord_res = 64
+    coordinate_decoder = TransformerDecoder(
+        nn.Sequential(*[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]), 
+        decoder_dim, 
+        cls_to_coord_res**2 + 1,
+        is_classifier=True,
+        amp = True,
+        pos_enc = False,)
+    dw = True
+    hidden_blocks = 8
+    kernel_size = 5
+    displacement_emb = "linear"
+    disable_local_corr_grad = True
+    
+    conv_refiner = nn.ModuleDict(
+        {
+            "16": ConvRefiner(
+                2 * 512+128+(2*7+1)**2,
+                2 * 512+128+(2*7+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=128,
+                local_corr_radius = 7,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "8": ConvRefiner(
+                2 * 512+64+(2*3+1)**2,
+                2 * 512+64+(2*3+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=64,
+                local_corr_radius = 3,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "4": ConvRefiner(
+                2 * 256+32+(2*2+1)**2,
+                2 * 256+32+(2*2+1)**2,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=32,
+                local_corr_radius = 2,
+                corr_in_other = True,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "2": ConvRefiner(
+                2 * 64+16,
+                128+16,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks=hidden_blocks,
+                displacement_emb=displacement_emb,
+                displacement_emb_dim=16,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+            "1": ConvRefiner(
+                2 * 9 + 6,
+                24,
+                2 + 1,
+                kernel_size=kernel_size,
+                dw=dw,
+                hidden_blocks = hidden_blocks,
+                displacement_emb = displacement_emb,
+                displacement_emb_dim = 6,
+                amp = True,
+                disable_local_corr_grad = disable_local_corr_grad,
+                bn_momentum = 0.01,
+            ),
+        }
+    )
+    kernel_temperature = 0.2
+    learn_temperature = False
+    no_cov = True
+    kernel = CosKernel
+    only_attention = False
+    basis = "fourier"
+    gp16 = GP(
+        kernel,
+        T=kernel_temperature,
+        learn_temperature=learn_temperature,
+        only_attention=only_attention,
+        gp_dim=gp_dim,
+        basis=basis,
+        no_cov=no_cov,
+    )
+    gps = nn.ModuleDict({"16": gp16})
+    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
+    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
+    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
+    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
+    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
+    proj = nn.ModuleDict({
+        "16": proj16,
+        "8": proj8,
+        "4": proj4,
+        "2": proj2,
+        "1": proj1,
+        })
+    displacement_dropout_p = 0.0
+    gm_warp_dropout_p = 0.0
+    decoder = Decoder(coordinate_decoder, 
+                      gps, 
+                      proj, 
+                      conv_refiner, 
+                      detach=True, 
+                      scales=["16", "8", "4", "2", "1"], 
+                      displacement_dropout_p = displacement_dropout_p,
+                      gm_warp_dropout_p = gm_warp_dropout_p)
+    
+    encoder = CNNandDinov2(
+        cnn_kwargs = dict(
+            pretrained=False,
+            amp = True),
+        amp = True,
+        use_vgg = True,
+        dinov2_weights = dinov2_weights,
+        amp_dtype=amp_dtype,
+    )
+    h,w = resolution
+    symmetric = True
+    attenuate_cert = True
+    sample_mode = "threshold_balanced"
+    matcher = RegressionMatcher(encoder, decoder, h=h, w=w, upsample_preds=upsample_preds, 
+                                symmetric = symmetric, attenuate_cert = attenuate_cert, sample_mode = sample_mode, **kwargs).to(device)
+    matcher.load_state_dict(weights)
+    return matcher
diff --git a/third_party/RoMa/roma/models/transformer/__init__.py b/third_party/RoMa/roma/models/transformer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c93008ecdaab3fa19d7166b213f8d4f664bf65d5
--- /dev/null
+++ b/third_party/RoMa/roma/models/transformer/__init__.py
@@ -0,0 +1,47 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from roma.utils.utils import get_grid
+from .layers.block import Block
+from .layers.attention import MemEffAttention
+from .dinov2 import vit_large
+
+class TransformerDecoder(nn.Module):
+    def __init__(self, blocks, hidden_dim, out_dim, is_classifier = False, *args, 
+                 amp = False, pos_enc = True, learned_embeddings = False, embedding_dim = None, amp_dtype = torch.float16, **kwargs) -> None:
+        super().__init__(*args, **kwargs)
+        self.blocks = blocks
+        self.to_out = nn.Linear(hidden_dim, out_dim)
+        self.hidden_dim = hidden_dim
+        self.out_dim = out_dim
+        self._scales = [16]
+        self.is_classifier = is_classifier
+        self.amp = amp
+        self.amp_dtype = amp_dtype
+        self.pos_enc = pos_enc
+        self.learned_embeddings = learned_embeddings
+        if self.learned_embeddings:
+            self.learned_pos_embeddings = nn.Parameter(nn.init.kaiming_normal_(torch.empty((1, hidden_dim, embedding_dim, embedding_dim))))
+
+    def scales(self):
+        return self._scales.copy()
+
+    def forward(self, gp_posterior, features, old_stuff, new_scale):
+        with torch.autocast("cuda", dtype=self.amp_dtype, enabled=self.amp):
+            B,C,H,W = gp_posterior.shape
+            x = torch.cat((gp_posterior, features), dim = 1)
+            B,C,H,W = x.shape
+            grid = get_grid(B, H, W, x.device).reshape(B,H*W,2)
+            if self.learned_embeddings:
+                pos_enc = F.interpolate(self.learned_pos_embeddings, size = (H,W), mode = 'bilinear', align_corners = False).permute(0,2,3,1).reshape(1,H*W,C)
+            else:
+                pos_enc = 0
+            tokens = x.reshape(B,C,H*W).permute(0,2,1) + pos_enc
+            z = self.blocks(tokens)
+            out = self.to_out(z)
+            out = out.permute(0,2,1).reshape(B, self.out_dim, H, W)
+            warp, certainty = out[:, :-1], out[:, -1:]
+            return warp, certainty, None
+
+
diff --git a/third_party/Roma/roma/models/transformer/dinov2.py b/third_party/RoMa/roma/models/transformer/dinov2.py
similarity index 82%
rename from third_party/Roma/roma/models/transformer/dinov2.py
rename to third_party/RoMa/roma/models/transformer/dinov2.py
index 1c27c65b5061cc0113792e40b96eaf7f4266ce18..b556c63096d17239c8603d5fe626c331963099fd 100644
--- a/third_party/Roma/roma/models/transformer/dinov2.py
+++ b/third_party/RoMa/roma/models/transformer/dinov2.py
@@ -18,29 +18,16 @@ import torch.nn as nn
 import torch.utils.checkpoint
 from torch.nn.init import trunc_normal_
 
-from .layers import (
-    Mlp,
-    PatchEmbed,
-    SwiGLUFFNFused,
-    MemEffAttention,
-    NestedTensorBlock as Block,
-)
-
-
-def named_apply(
-    fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False
-) -> nn.Module:
+from .layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
+
+
+
+def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
     if not depth_first and include_root:
         fn(module=module, name=name)
     for child_name, child_module in module.named_children():
         child_name = ".".join((name, child_name)) if name else child_name
-        named_apply(
-            fn=fn,
-            module=child_module,
-            name=child_name,
-            depth_first=depth_first,
-            include_root=True,
-        )
+        named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
     if depth_first and include_root:
         fn(module=module, name=name)
     return module
@@ -100,33 +87,22 @@ class DinoVisionTransformer(nn.Module):
         super().__init__()
         norm_layer = partial(nn.LayerNorm, eps=1e-6)
 
-        self.num_features = (
-            self.embed_dim
-        ) = embed_dim  # num_features for consistency with other models
+        self.num_features = self.embed_dim = embed_dim  # num_features for consistency with other models
         self.num_tokens = 1
         self.n_blocks = depth
         self.num_heads = num_heads
         self.patch_size = patch_size
 
-        self.patch_embed = embed_layer(
-            img_size=img_size,
-            patch_size=patch_size,
-            in_chans=in_chans,
-            embed_dim=embed_dim,
-        )
+        self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
         num_patches = self.patch_embed.num_patches
 
         self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
-        self.pos_embed = nn.Parameter(
-            torch.zeros(1, num_patches + self.num_tokens, embed_dim)
-        )
+        self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
 
         if drop_path_uniform is True:
             dpr = [drop_path_rate] * depth
         else:
-            dpr = [
-                x.item() for x in torch.linspace(0, drop_path_rate, depth)
-            ]  # stochastic depth decay rule
+            dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)]  # stochastic depth decay rule
 
         if ffn_layer == "mlp":
             ffn_layer = Mlp
@@ -163,9 +139,7 @@ class DinoVisionTransformer(nn.Module):
             chunksize = depth // block_chunks
             for i in range(0, depth, chunksize):
                 # this is to keep the block index consistent if we chunk the block list
-                chunked_blocks.append(
-                    [nn.Identity()] * i + blocks_list[i : i + chunksize]
-                )
+                chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
             self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
         else:
             self.chunked_blocks = False
@@ -179,7 +153,7 @@ class DinoVisionTransformer(nn.Module):
         self.init_weights()
         for param in self.parameters():
             param.requires_grad = False
-
+    
     @property
     def device(self):
         return self.cls_token.device
@@ -206,29 +180,20 @@ class DinoVisionTransformer(nn.Module):
         w0, h0 = w0 + 0.1, h0 + 0.1
 
         patch_pos_embed = nn.functional.interpolate(
-            patch_pos_embed.reshape(
-                1, int(math.sqrt(N)), int(math.sqrt(N)), dim
-            ).permute(0, 3, 1, 2),
+            patch_pos_embed.reshape(1, int(math.sqrt(N)), int(math.sqrt(N)), dim).permute(0, 3, 1, 2),
             scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)),
             mode="bicubic",
         )
 
-        assert (
-            int(w0) == patch_pos_embed.shape[-2]
-            and int(h0) == patch_pos_embed.shape[-1]
-        )
+        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
         patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
-        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(
-            previous_dtype
-        )
+        return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
 
     def prepare_tokens_with_masks(self, x, masks=None):
         B, nc, w, h = x.shape
         x = self.patch_embed(x)
         if masks is not None:
-            x = torch.where(
-                masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x
-            )
+            x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
 
         x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
         x = x + self.interpolate_pos_encoding(x, w, h)
@@ -236,10 +201,7 @@ class DinoVisionTransformer(nn.Module):
         return x
 
     def forward_features_list(self, x_list, masks_list):
-        x = [
-            self.prepare_tokens_with_masks(x, masks)
-            for x, masks in zip(x_list, masks_list)
-        ]
+        x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
         for blk in self.blocks:
             x = blk(x)
 
@@ -278,34 +240,26 @@ class DinoVisionTransformer(nn.Module):
         x = self.prepare_tokens_with_masks(x)
         # If n is an int, take the n last blocks. If it's a list, take them
         output, total_block_len = [], len(self.blocks)
-        blocks_to_take = (
-            range(total_block_len - n, total_block_len) if isinstance(n, int) else n
-        )
+        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
         for i, blk in enumerate(self.blocks):
             x = blk(x)
             if i in blocks_to_take:
                 output.append(x)
-        assert len(output) == len(
-            blocks_to_take
-        ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
         return output
 
     def _get_intermediate_layers_chunked(self, x, n=1):
         x = self.prepare_tokens_with_masks(x)
         output, i, total_block_len = [], 0, len(self.blocks[-1])
         # If n is an int, take the n last blocks. If it's a list, take them
-        blocks_to_take = (
-            range(total_block_len - n, total_block_len) if isinstance(n, int) else n
-        )
+        blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
         for block_chunk in self.blocks:
             for blk in block_chunk[i:]:  # Passing the nn.Identity()
                 x = blk(x)
                 if i in blocks_to_take:
                     output.append(x)
                 i += 1
-        assert len(output) == len(
-            blocks_to_take
-        ), f"only {len(output)} / {len(blocks_to_take)} blocks found"
+        assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
         return output
 
     def get_intermediate_layers(
@@ -327,9 +281,7 @@ class DinoVisionTransformer(nn.Module):
         if reshape:
             B, _, w, h = x.shape
             outputs = [
-                out.reshape(B, w // self.patch_size, h // self.patch_size, -1)
-                .permute(0, 3, 1, 2)
-                .contiguous()
+                out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
                 for out in outputs
             ]
         if return_class_token:
@@ -404,4 +356,4 @@ def vit_giant2(patch_size=16, **kwargs):
         block_fn=partial(Block, attn_class=MemEffAttention),
         **kwargs,
     )
-    return model
+    return model
\ No newline at end of file
diff --git a/third_party/Roma/roma/models/transformer/layers/__init__.py b/third_party/RoMa/roma/models/transformer/layers/__init__.py
similarity index 100%
rename from third_party/Roma/roma/models/transformer/layers/__init__.py
rename to third_party/RoMa/roma/models/transformer/layers/__init__.py
diff --git a/third_party/Roma/roma/models/transformer/layers/attention.py b/third_party/RoMa/roma/models/transformer/layers/attention.py
similarity index 93%
rename from third_party/Roma/roma/models/transformer/layers/attention.py
rename to third_party/RoMa/roma/models/transformer/layers/attention.py
index 12f388719bf5f171d59aee238d902bb7915f864b..1f9b0c94b40967dfdff4f261c127cbd21328c905 100644
--- a/third_party/Roma/roma/models/transformer/layers/attention.py
+++ b/third_party/RoMa/roma/models/transformer/layers/attention.py
@@ -48,11 +48,7 @@ class Attention(nn.Module):
 
     def forward(self, x: Tensor) -> Tensor:
         B, N, C = x.shape
-        qkv = (
-            self.qkv(x)
-            .reshape(B, N, 3, self.num_heads, C // self.num_heads)
-            .permute(2, 0, 3, 1, 4)
-        )
+        qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
 
         q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
         attn = q @ k.transpose(-2, -1)
diff --git a/third_party/Roma/roma/models/transformer/layers/block.py b/third_party/RoMa/roma/models/transformer/layers/block.py
similarity index 83%
rename from third_party/Roma/roma/models/transformer/layers/block.py
rename to third_party/RoMa/roma/models/transformer/layers/block.py
index 1b5f5158f073788d3d5fe3e09742d4485ef26441..25488f57cc0ad3c692f86b62555f6668e2a66db1 100644
--- a/third_party/Roma/roma/models/transformer/layers/block.py
+++ b/third_party/RoMa/roma/models/transformer/layers/block.py
@@ -62,9 +62,7 @@ class Block(nn.Module):
             attn_drop=attn_drop,
             proj_drop=drop,
         )
-        self.ls1 = (
-            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
-        )
+        self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 
         self.norm2 = norm_layer(dim)
@@ -76,9 +74,7 @@ class Block(nn.Module):
             drop=drop,
             bias=ffn_bias,
         )
-        self.ls2 = (
-            LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
-        )
+        self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
         self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
 
         self.sample_drop_ratio = drop_path
@@ -131,9 +127,7 @@ def drop_add_residual_stochastic_depth(
     residual_scale_factor = b / sample_subset_size
 
     # 3) add the residual
-    x_plus_residual = torch.index_add(
-        x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
-    )
+    x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
     return x_plus_residual.view_as(x)
 
 
@@ -149,16 +143,10 @@ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None
     if scaling_vector is None:
         x_flat = x.flatten(1)
         residual = residual.flatten(1)
-        x_plus_residual = torch.index_add(
-            x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor
-        )
+        x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
     else:
         x_plus_residual = scaled_index_add(
-            x,
-            brange,
-            residual.to(dtype=x.dtype),
-            scaling=scaling_vector,
-            alpha=residual_scale_factor,
+            x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
         )
     return x_plus_residual
 
@@ -170,11 +158,7 @@ def get_attn_bias_and_cat(x_list, branges=None):
     """
     this will perform the index select, cat the tensors, and provide the attn_bias from cache
     """
-    batch_sizes = (
-        [b.shape[0] for b in branges]
-        if branges is not None
-        else [x.shape[0] for x in x_list]
-    )
+    batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
     all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
     if all_shapes not in attn_bias_cache.keys():
         seqlens = []
@@ -186,9 +170,7 @@ def get_attn_bias_and_cat(x_list, branges=None):
         attn_bias_cache[all_shapes] = attn_bias
 
     if branges is not None:
-        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(
-            1, -1, x_list[0].shape[-1]
-        )
+        cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
     else:
         tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
         cat_tensors = torch.cat(tensors_bs1, dim=1)
@@ -203,9 +185,7 @@ def drop_add_residual_stochastic_depth_list(
     scaling_vector=None,
 ) -> Tensor:
     # 1) generate random set of indices for dropping samples in the batch
-    branges_scales = [
-        get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list
-    ]
+    branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
     branges = [s[0] for s in branges_scales]
     residual_scale_factors = [s[1] for s in branges_scales]
 
@@ -216,14 +196,8 @@ def drop_add_residual_stochastic_depth_list(
     residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias))  # type: ignore
 
     outputs = []
-    for x, brange, residual, residual_scale_factor in zip(
-        x_list, branges, residual_list, residual_scale_factors
-    ):
-        outputs.append(
-            add_residual(
-                x, brange, residual, residual_scale_factor, scaling_vector
-            ).view_as(x)
-        )
+    for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
+        outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
     return outputs
 
 
@@ -246,17 +220,13 @@ class NestedTensorBlock(Block):
                 x_list,
                 residual_func=attn_residual_func,
                 sample_drop_ratio=self.sample_drop_ratio,
-                scaling_vector=self.ls1.gamma
-                if isinstance(self.ls1, LayerScale)
-                else None,
+                scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
             )
             x_list = drop_add_residual_stochastic_depth_list(
                 x_list,
                 residual_func=ffn_residual_func,
                 sample_drop_ratio=self.sample_drop_ratio,
-                scaling_vector=self.ls2.gamma
-                if isinstance(self.ls1, LayerScale)
-                else None,
+                scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
             )
             return x_list
         else:
@@ -276,9 +246,7 @@ class NestedTensorBlock(Block):
         if isinstance(x_or_x_list, Tensor):
             return super().forward(x_or_x_list)
         elif isinstance(x_or_x_list, list):
-            assert (
-                XFORMERS_AVAILABLE
-            ), "Please install xFormers for nested tensors usage"
+            assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
             return self.forward_nested(x_or_x_list)
         else:
             raise AssertionError
diff --git a/third_party/Roma/roma/models/transformer/layers/dino_head.py b/third_party/RoMa/roma/models/transformer/layers/dino_head.py
similarity index 85%
rename from third_party/Roma/roma/models/transformer/layers/dino_head.py
rename to third_party/RoMa/roma/models/transformer/layers/dino_head.py
index 1147dd3a3c046aee8d427b42b1055f38a218275b..7212db92a4fd8d4c7230e284e551a0234e9d8623 100644
--- a/third_party/Roma/roma/models/transformer/layers/dino_head.py
+++ b/third_party/RoMa/roma/models/transformer/layers/dino_head.py
@@ -23,14 +23,7 @@ class DINOHead(nn.Module):
     ):
         super().__init__()
         nlayers = max(nlayers, 1)
-        self.mlp = _build_mlp(
-            nlayers,
-            in_dim,
-            bottleneck_dim,
-            hidden_dim=hidden_dim,
-            use_bn=use_bn,
-            bias=mlp_bias,
-        )
+        self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
         self.apply(self._init_weights)
         self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
         self.last_layer.weight_g.data.fill_(1)
@@ -49,9 +42,7 @@ class DINOHead(nn.Module):
         return x
 
 
-def _build_mlp(
-    nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True
-):
+def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
     if nlayers == 1:
         return nn.Linear(in_dim, bottleneck_dim, bias=bias)
     else:
diff --git a/third_party/Roma/roma/models/transformer/layers/drop_path.py b/third_party/RoMa/roma/models/transformer/layers/drop_path.py
similarity index 90%
rename from third_party/Roma/roma/models/transformer/layers/drop_path.py
rename to third_party/RoMa/roma/models/transformer/layers/drop_path.py
index a23ba7325d0fd154d5885573770956042ce2311d..af05625984dd14682cc96a63bf0c97bab1f123b1 100644
--- a/third_party/Roma/roma/models/transformer/layers/drop_path.py
+++ b/third_party/RoMa/roma/models/transformer/layers/drop_path.py
@@ -16,9 +16,7 @@ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
     if drop_prob == 0.0 or not training:
         return x
     keep_prob = 1 - drop_prob
-    shape = (x.shape[0],) + (1,) * (
-        x.ndim - 1
-    )  # work with diff dim tensors, not just 2D ConvNets
+    shape = (x.shape[0],) + (1,) * (x.ndim - 1)  # work with diff dim tensors, not just 2D ConvNets
     random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
     if keep_prob > 0.0:
         random_tensor.div_(keep_prob)
diff --git a/third_party/Roma/roma/models/transformer/layers/layer_scale.py b/third_party/RoMa/roma/models/transformer/layers/layer_scale.py
similarity index 100%
rename from third_party/Roma/roma/models/transformer/layers/layer_scale.py
rename to third_party/RoMa/roma/models/transformer/layers/layer_scale.py
diff --git a/third_party/Roma/roma/models/transformer/layers/mlp.py b/third_party/RoMa/roma/models/transformer/layers/mlp.py
similarity index 100%
rename from third_party/Roma/roma/models/transformer/layers/mlp.py
rename to third_party/RoMa/roma/models/transformer/layers/mlp.py
diff --git a/third_party/Roma/roma/models/transformer/layers/patch_embed.py b/third_party/RoMa/roma/models/transformer/layers/patch_embed.py
similarity index 81%
rename from third_party/Roma/roma/models/transformer/layers/patch_embed.py
rename to third_party/RoMa/roma/models/transformer/layers/patch_embed.py
index 837f952cf9a463444feeb146e0d5b539102ee26c..574abe41175568d700a389b8b96d1ba554914779 100644
--- a/third_party/Roma/roma/models/transformer/layers/patch_embed.py
+++ b/third_party/RoMa/roma/models/transformer/layers/patch_embed.py
@@ -63,21 +63,15 @@ class PatchEmbed(nn.Module):
 
         self.flatten_embedding = flatten_embedding
 
-        self.proj = nn.Conv2d(
-            in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW
-        )
+        self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
         self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
 
     def forward(self, x: Tensor) -> Tensor:
         _, _, H, W = x.shape
         patch_H, patch_W = self.patch_size
 
-        assert (
-            H % patch_H == 0
-        ), f"Input image height {H} is not a multiple of patch height {patch_H}"
-        assert (
-            W % patch_W == 0
-        ), f"Input image width {W} is not a multiple of patch width: {patch_W}"
+        assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
+        assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
 
         x = self.proj(x)  # B C H W
         H, W = x.size(2), x.size(3)
@@ -89,13 +83,7 @@ class PatchEmbed(nn.Module):
 
     def flops(self) -> float:
         Ho, Wo = self.patches_resolution
-        flops = (
-            Ho
-            * Wo
-            * self.embed_dim
-            * self.in_chans
-            * (self.patch_size[0] * self.patch_size[1])
-        )
+        flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
         if self.norm is not None:
             flops += Ho * Wo * self.embed_dim
         return flops
diff --git a/third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py b/third_party/RoMa/roma/models/transformer/layers/swiglu_ffn.py
similarity index 100%
rename from third_party/Roma/roma/models/transformer/layers/swiglu_ffn.py
rename to third_party/RoMa/roma/models/transformer/layers/swiglu_ffn.py
diff --git a/third_party/Roma/roma/train/__init__.py b/third_party/RoMa/roma/train/__init__.py
similarity index 100%
rename from third_party/Roma/roma/train/__init__.py
rename to third_party/RoMa/roma/train/__init__.py
diff --git a/third_party/Roma/roma/train/train.py b/third_party/RoMa/roma/train/train.py
similarity index 65%
rename from third_party/Roma/roma/train/train.py
rename to third_party/RoMa/roma/train/train.py
index eb3deaf1792a315d1cce77a2ee0fd50ae9e98ac1..5556f7ebf9b6378e1395c125dde093f5e55e7141 100644
--- a/third_party/Roma/roma/train/train.py
+++ b/third_party/RoMa/roma/train/train.py
@@ -4,62 +4,41 @@ import roma
 import torch
 import wandb
 
-
-def log_param_statistics(named_parameters, norm_type=2):
+def log_param_statistics(named_parameters, norm_type = 2):
     named_parameters = list(named_parameters)
     grads = [p.grad for n, p in named_parameters if p.grad is not None]
-    weight_norms = [
-        p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None
-    ]
-    names = [n for n, p in named_parameters if p.grad is not None]
+    weight_norms = [p.norm(p=norm_type) for n, p in named_parameters if p.grad is not None]
+    names = [n for n,p in named_parameters if p.grad is not None]
     param_norm = torch.stack(weight_norms).norm(p=norm_type)
     device = grads[0].device
-    grad_norms = torch.stack(
-        [torch.norm(g.detach(), norm_type).to(device) for g in grads]
-    )
+    grad_norms = torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads])
     nans_or_infs = torch.isinf(grad_norms) | torch.isnan(grad_norms)
     nan_inf_names = [name for name, naninf in zip(names, nans_or_infs) if naninf]
     total_grad_norm = torch.norm(grad_norms, norm_type)
     if torch.any(nans_or_infs):
         print(f"These params have nan or inf grads: {nan_inf_names}")
-    wandb.log({"grad_norm": total_grad_norm.item()}, step=roma.GLOBAL_STEP)
-    wandb.log({"param_norm": param_norm.item()}, step=roma.GLOBAL_STEP)
-
+    wandb.log({"grad_norm": total_grad_norm.item()}, step = roma.GLOBAL_STEP)
+    wandb.log({"param_norm": param_norm.item()}, step = roma.GLOBAL_STEP)
 
-def train_step(
-    train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm=1.0, **kwargs
-):
+def train_step(train_batch, model, objective, optimizer, grad_scaler, grad_clip_norm = 1.,**kwargs):
     optimizer.zero_grad()
     out = model(train_batch)
     l = objective(out, train_batch)
     grad_scaler.scale(l).backward()
     grad_scaler.unscale_(optimizer)
     log_param_statistics(model.named_parameters())
-    torch.nn.utils.clip_grad_norm_(
-        model.parameters(), grad_clip_norm
-    )  # what should max norm be?
+    torch.nn.utils.clip_grad_norm_(model.parameters(), grad_clip_norm) # what should max norm be?
     grad_scaler.step(optimizer)
     grad_scaler.update()
-    wandb.log({"grad_scale": grad_scaler._scale.item()}, step=roma.GLOBAL_STEP)
-    if grad_scaler._scale < 1.0:
-        grad_scaler._scale = torch.tensor(1.0).to(grad_scaler._scale)
-    roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE  # increment global step
+    wandb.log({"grad_scale": grad_scaler._scale.item()}, step = roma.GLOBAL_STEP)
+    if grad_scaler._scale < 1.:
+        grad_scaler._scale = torch.tensor(1.).to(grad_scaler._scale)
+    roma.GLOBAL_STEP = roma.GLOBAL_STEP + roma.STEP_SIZE # increment global step
     return {"train_out": out, "train_loss": l.item()}
 
 
 def train_k_steps(
-    n_0,
-    k,
-    dataloader,
-    model,
-    objective,
-    optimizer,
-    lr_scheduler,
-    grad_scaler,
-    progress_bar=True,
-    grad_clip_norm=1.0,
-    warmup=None,
-    ema_model=None,
+    n_0, k, dataloader, model, objective, optimizer, lr_scheduler, grad_scaler, progress_bar=True, grad_clip_norm = 1., warmup = None, ema_model = None,
 ):
     for n in tqdm(range(n_0, n_0 + k), disable=(not progress_bar) or roma.RANK > 0):
         batch = next(dataloader)
@@ -73,7 +52,7 @@ def train_k_steps(
             lr_scheduler=lr_scheduler,
             grad_scaler=grad_scaler,
             n=n,
-            grad_clip_norm=grad_clip_norm,
+            grad_clip_norm = grad_clip_norm,
         )
         if ema_model is not None:
             ema_model.update()
@@ -82,10 +61,7 @@ def train_k_steps(
                 lr_scheduler.step()
         else:
             lr_scheduler.step()
-        [
-            wandb.log({f"lr_group_{grp}": lr})
-            for grp, lr in enumerate(lr_scheduler.get_last_lr())
-        ]
+        [wandb.log({f"lr_group_{grp}": lr}) for grp, lr in enumerate(lr_scheduler.get_last_lr())]
 
 
 def train_epoch(
diff --git a/third_party/Roma/roma/utils/__init__.py b/third_party/RoMa/roma/utils/__init__.py
similarity index 100%
rename from third_party/Roma/roma/utils/__init__.py
rename to third_party/RoMa/roma/utils/__init__.py
diff --git a/third_party/RoMa/roma/utils/kde.py b/third_party/RoMa/roma/utils/kde.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ee1378282965ab091b77c2a97f0e80bd13d4637
--- /dev/null
+++ b/third_party/RoMa/roma/utils/kde.py
@@ -0,0 +1,8 @@
+import torch
+
+def kde(x, std = 0.1):
+    # use a gaussian kernel to estimate density
+    x = x.half() # Do it in half precision TODO: remove hardcoding
+    scores = (-torch.cdist(x,x)**2/(2*std**2)).exp()
+    density = scores.sum(dim=-1)
+    return density
\ No newline at end of file
diff --git a/third_party/RoMa/roma/utils/local_correlation.py b/third_party/RoMa/roma/utils/local_correlation.py
new file mode 100644
index 0000000000000000000000000000000000000000..2919595b93aef10c6f95938e5bf104705ee0cbb6
--- /dev/null
+++ b/third_party/RoMa/roma/utils/local_correlation.py
@@ -0,0 +1,44 @@
+import torch
+import torch.nn.functional as F
+
+def local_correlation(
+    feature0,
+    feature1,
+    local_radius,
+    padding_mode="zeros",
+    flow = None,
+    sample_mode = "bilinear",
+):
+    r = local_radius
+    K = (2*r+1)**2
+    B, c, h, w = feature0.size()
+    corr = torch.empty((B,K,h,w), device = feature0.device, dtype=feature0.dtype)
+    if flow is None:
+        # If flow is None, assume feature0 and feature1 are aligned
+        coords = torch.meshgrid(
+                (
+                    torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=feature0.device),
+                    torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=feature0.device),
+                ))
+        coords = torch.stack((coords[1], coords[0]), dim=-1)[
+            None
+        ].expand(B, h, w, 2)
+    else:
+        coords = flow.permute(0,2,3,1) # If using flow, sample around flow target.
+    local_window = torch.meshgrid(
+                (
+                    torch.linspace(-2*local_radius/h, 2*local_radius/h, 2*r+1, device=feature0.device),
+                    torch.linspace(-2*local_radius/w, 2*local_radius/w, 2*r+1, device=feature0.device),
+                ))
+    local_window = torch.stack((local_window[1], local_window[0]), dim=-1)[
+            None
+        ].expand(1, 2*r+1, 2*r+1, 2).reshape(1, (2*r+1)**2, 2)
+    for _ in range(B):
+        with torch.no_grad():
+            local_window_coords = (coords[_,:,:,None]+local_window[:,None,None]).reshape(1,h,w*(2*r+1)**2,2)
+            window_feature = F.grid_sample(
+                feature1[_:_+1], local_window_coords, padding_mode=padding_mode, align_corners=False, mode = sample_mode, #
+            )
+            window_feature = window_feature.reshape(c,h,w,(2*r+1)**2)
+        corr[_] = (feature0[_,...,None]/(c**.5)*window_feature).sum(dim=0).permute(2,0,1)
+    return corr
diff --git a/third_party/Roma/roma/utils/transforms.py b/third_party/RoMa/roma/utils/transforms.py
similarity index 94%
rename from third_party/Roma/roma/utils/transforms.py
rename to third_party/RoMa/roma/utils/transforms.py
index b33c3f30f422bca6a81aa201952b7bb2d3d906bf..ea6476bd816a31df36f7d1b5417853637b65474b 100644
--- a/third_party/Roma/roma/utils/transforms.py
+++ b/third_party/RoMa/roma/utils/transforms.py
@@ -16,9 +16,7 @@ class GeometricSequential:
         for t in self.transforms:
             if np.random.rand() < t.p:
                 M = M.matmul(
-                    t.compute_transformation(
-                        x, t.generate_parameters((b, c, h, w)), None
-                    )
+                    t.compute_transformation(x, t.generate_parameters((b, c, h, w)), None)
                 )
         return (
             warp_perspective(
@@ -106,14 +104,15 @@ class RandomPerspective(K.RandomPerspective):
         return dict(start_points=start_points, end_points=end_points)
 
 
+
 class RandomErasing:
-    def __init__(self, p=0.0, scale=0.0) -> None:
+    def __init__(self, p = 0., scale = 0.) -> None:
         self.p = p
         self.scale = scale
-        self.random_eraser = K.RandomErasing(scale=(0.02, scale), p=p)
-
+        self.random_eraser = K.RandomErasing(scale = (0.02, scale), p = p)
     def __call__(self, image, depth):
         if self.p > 0:
             image = self.random_eraser(image)
             depth = self.random_eraser(depth, params=self.random_eraser._params)
         return image, depth
+        
\ No newline at end of file
diff --git a/third_party/Roma/roma/utils/utils.py b/third_party/RoMa/roma/utils/utils.py
similarity index 73%
rename from third_party/Roma/roma/utils/utils.py
rename to third_party/RoMa/roma/utils/utils.py
index 969e1003419f3b7f05874830b79de73363017f01..d7717b2ee37417c4082706ad58143b7ebfc34624 100644
--- a/third_party/Roma/roma/utils/utils.py
+++ b/third_party/RoMa/roma/utils/utils.py
@@ -9,14 +9,13 @@ import torch.nn.functional as F
 from PIL import Image
 import kornia
 
-
 def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     best_num_inliers = 0
-    K0inv = np.linalg.inv(K0[:2, :2])
-    K1inv = np.linalg.inv(K1[:2, :2])
+    K0inv = np.linalg.inv(K0[:2,:2])
+    K1inv = np.linalg.inv(K1[:2,:2])
 
-    kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
-    kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
+    kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
+    kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
 
     for _E in np.split(E, len(E) / 3):
         n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
@@ -26,16 +25,17 @@ def recover_pose(E, kpts0, kpts1, K0, K1, mask):
     return ret
 
 
+
 # Code taken from https://github.com/PruneTruong/DenseMatching/blob/40c29a6b5c35e86b9509e65ab0cd12553d998e5f/validation/utils_pose_estimation.py
 # --- GEOMETRY ---
 def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
-    K0inv = np.linalg.inv(K0[:2, :2])
-    K1inv = np.linalg.inv(K1[:2, :2])
+    K0inv = np.linalg.inv(K0[:2,:2])
+    K1inv = np.linalg.inv(K1[:2,:2])
 
-    kpts0 = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
-    kpts1 = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
+    kpts0 = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
+    kpts1 = (K1inv @ (kpts1-K1[None,:2,2]).T).T
     E, mask = cv2.findEssentialMat(
         kpts0, kpts1, np.eye(3), threshold=norm_thresh, prob=conf
     )
@@ -51,40 +51,31 @@ def estimate_pose(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
                 ret = (R, t, mask.ravel() > 0)
     return ret
 
-
 def estimate_pose_uncalibrated(kpts0, kpts1, K0, K1, norm_thresh, conf=0.99999):
     if len(kpts0) < 5:
         return None
     method = cv2.USAC_ACCURATE
     F, mask = cv2.findFundamentalMat(
-        kpts0,
-        kpts1,
-        ransacReprojThreshold=norm_thresh,
-        confidence=conf,
-        method=method,
-        maxIters=10000,
+        kpts0, kpts1, ransacReprojThreshold=norm_thresh, confidence=conf, method=method, maxIters=10000
     )
-    E = K1.T @ F @ K0
+    E = K1.T@F@K0
     ret = None
     if E is not None:
         best_num_inliers = 0
-        K0inv = np.linalg.inv(K0[:2, :2])
-        K1inv = np.linalg.inv(K1[:2, :2])
-
-        kpts0_n = (K0inv @ (kpts0 - K0[None, :2, 2]).T).T
-        kpts1_n = (K1inv @ (kpts1 - K1[None, :2, 2]).T).T
+        K0inv = np.linalg.inv(K0[:2,:2])
+        K1inv = np.linalg.inv(K1[:2,:2])
 
+        kpts0_n = (K0inv @ (kpts0-K0[None,:2,2]).T).T 
+        kpts1_n = (K1inv @ (kpts1-K1[None,:2,2]).T).T
+ 
         for _E in np.split(E, len(E) / 3):
-            n, R, t, _ = cv2.recoverPose(
-                _E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask
-            )
+            n, R, t, _ = cv2.recoverPose(_E, kpts0_n, kpts1_n, np.eye(3), 1e9, mask=mask)
             if n > best_num_inliers:
                 best_num_inliers = n
                 ret = (R, t, mask.ravel() > 0)
     return ret
 
-
-def unnormalize_coords(x_n, h, w):
+def unnormalize_coords(x_n,h,w):
     x = torch.stack(
         (w * (x_n[..., 0] + 1) / 2, h * (x_n[..., 1] + 1) / 2), dim=-1
     )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
@@ -164,7 +155,6 @@ def get_depth_tuple_transform_ops_nearest_exact(resize=None):
         ops.append(TupleResizeNearestExact(resize))
     return TupleCompose(ops)
 
-
 def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
     ops = []
     if resize:
@@ -172,9 +162,7 @@ def get_depth_tuple_transform_ops(resize=None, normalize=True, unscale=False):
     return TupleCompose(ops)
 
 
-def get_tuple_transform_ops(
-    resize=None, normalize=True, unscale=False, clahe=False, colorjiggle_params=None
-):
+def get_tuple_transform_ops(resize=None, normalize=True, unscale=False, clahe = False, colorjiggle_params = None):
     ops = []
     if resize:
         ops.append(TupleResize(resize))
@@ -185,7 +173,6 @@ def get_tuple_transform_ops(
         )  # Imagenet mean/std
     return TupleCompose(ops)
 
-
 class ToTensorScaled(object):
     """Convert a RGB PIL Image to a CHW ordered Tensor, scale the range to [0, 1]"""
 
@@ -234,15 +221,11 @@ class TupleToTensorUnscaled(object):
     def __repr__(self):
         return "TupleToTensorUnscaled()"
 
-
 class TupleResizeNearestExact:
     def __init__(self, size):
         self.size = size
-
     def __call__(self, im_tuple):
-        return [
-            F.interpolate(im, size=self.size, mode="nearest-exact") for im in im_tuple
-        ]
+        return [F.interpolate(im, size = self.size, mode = 'nearest-exact') for im in im_tuple]
 
     def __repr__(self):
         return "TupleResizeNearestExact(size={})".format(self.size)
@@ -252,19 +235,17 @@ class TupleResize(object):
     def __init__(self, size, mode=InterpolationMode.BICUBIC):
         self.size = size
         self.resize = transforms.Resize(size, mode)
-
     def __call__(self, im_tuple):
         return [self.resize(im) for im in im_tuple]
 
     def __repr__(self):
         return "TupleResize(size={})".format(self.size)
-
-
+    
 class Normalize:
-    def __call__(self, im):
-        mean = im.mean(dim=(1, 2), keepdims=True)
-        std = im.std(dim=(1, 2), keepdims=True)
-        return (im - mean) / std
+    def __call__(self,im):
+        mean = im.mean(dim=(1,2), keepdims=True)
+        std = im.std(dim=(1,2), keepdims=True)
+        return (im-mean)/std
 
 
 class TupleNormalize(object):
@@ -274,7 +255,7 @@ class TupleNormalize(object):
         self.normalize = transforms.Normalize(mean=mean, std=std)
 
     def __call__(self, im_tuple):
-        c, h, w = im_tuple[0].shape
+        c,h,w = im_tuple[0].shape
         if c > 3:
             warnings.warn(f"Number of channels c={c} > 3, assuming first 3 are rgb")
         return [self.normalize(im[:3]) for im in im_tuple]
@@ -300,82 +281,50 @@ class TupleCompose(object):
         format_string += "\n)"
         return format_string
 
-
 @torch.no_grad()
-def cls_to_flow(cls, deterministic_sampling=True):
-    B, C, H, W = cls.shape
+def cls_to_flow(cls, deterministic_sampling = True):
+    B,C,H,W = cls.shape
     device = cls.device
     res = round(math.sqrt(C))
-    G = torch.meshgrid(
-        *[
-            torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device)
-            for _ in range(2)
-        ]
-    )
-    G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2)
+    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+    G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
     if deterministic_sampling:
         sampled_cls = cls.max(dim=1).indices
     else:
-        sampled_cls = torch.multinomial(
-            cls.permute(0, 2, 3, 1).reshape(B * H * W, C).softmax(dim=-1), 1
-        ).reshape(B, H, W)
+        sampled_cls = torch.multinomial(cls.permute(0,2,3,1).reshape(B*H*W,C).softmax(dim=-1), 1).reshape(B,H,W)
     flow = G[sampled_cls]
     return flow
 
-
 @torch.no_grad()
 def cls_to_flow_refine(cls):
-    B, C, H, W = cls.shape
+    B,C,H,W = cls.shape
     device = cls.device
     res = round(math.sqrt(C))
-    G = torch.meshgrid(
-        *[
-            torch.linspace(-1 + 1 / res, 1 - 1 / res, steps=res, device=device)
-            for _ in range(2)
-        ]
-    )
-    G = torch.stack([G[1], G[0]], dim=-1).reshape(C, 2)
+    G = torch.meshgrid(*[torch.linspace(-1+1/res, 1-1/res, steps = res, device = device) for _ in range(2)])
+    G = torch.stack([G[1],G[0]],dim=-1).reshape(C,2)
     cls = cls.softmax(dim=1)
     mode = cls.max(dim=1).indices
-
-    index = (
-        torch.stack((mode - 1, mode, mode + 1, mode - res, mode + res), dim=1)
-        .clamp(0, C - 1)
-        .long()
-    )
-    neighbours = torch.gather(cls, dim=1, index=index)[..., None]
-    flow = (
-        neighbours[:, 0] * G[index[:, 0]]
-        + neighbours[:, 1] * G[index[:, 1]]
-        + neighbours[:, 2] * G[index[:, 2]]
-        + neighbours[:, 3] * G[index[:, 3]]
-        + neighbours[:, 4] * G[index[:, 4]]
-    )
-    tot_prob = neighbours.sum(dim=1)
+    
+    index = torch.stack((mode-1, mode, mode+1, mode - res, mode + res), dim = 1).clamp(0,C - 1).long()
+    neighbours = torch.gather(cls, dim = 1, index = index)[...,None]
+    flow = neighbours[:,0] * G[index[:,0]] + neighbours[:,1] * G[index[:,1]] + neighbours[:,2] * G[index[:,2]] + neighbours[:,3] * G[index[:,3]] + neighbours[:,4] * G[index[:,4]]
+    tot_prob = neighbours.sum(dim=1)  
     flow = flow / tot_prob
     return flow
 
 
-def get_gt_warp(
-    depth1,
-    depth2,
-    T_1to2,
-    K1,
-    K2,
-    depth_interpolation_mode="bilinear",
-    relative_depth_error_threshold=0.05,
-    H=None,
-    W=None,
-):
-
+def get_gt_warp(depth1, depth2, T_1to2, K1, K2, depth_interpolation_mode = 'bilinear', relative_depth_error_threshold = 0.05, H = None, W = None):
+    
     if H is None:
-        B, H, W = depth1.shape
+        B,H,W = depth1.shape
     else:
         B = depth1.shape[0]
     with torch.no_grad():
         x1_n = torch.meshgrid(
             *[
-                torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=depth1.device)
+                torch.linspace(
+                    -1 + 1 / n, 1 - 1 / n, n, device=depth1.device
+                )
                 for n in (B, H, W)
             ]
         )
@@ -387,27 +336,15 @@ def get_gt_warp(
             T_1to2.double(),
             K1.double(),
             K2.double(),
-            depth_interpolation_mode=depth_interpolation_mode,
-            relative_depth_error_threshold=relative_depth_error_threshold,
+            depth_interpolation_mode = depth_interpolation_mode,
+            relative_depth_error_threshold = relative_depth_error_threshold,
         )
         prob = mask.float().reshape(B, H, W)
         x2 = x2.reshape(B, H, W, 2)
         return x2, prob
 
-
 @torch.no_grad()
-def warp_kpts(
-    kpts0,
-    depth0,
-    depth1,
-    T_0to1,
-    K0,
-    K1,
-    smooth_mask=False,
-    return_relative_depth_error=False,
-    depth_interpolation_mode="bilinear",
-    relative_depth_error_threshold=0.05,
-):
+def warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, smooth_mask = False, return_relative_depth_error = False, depth_interpolation_mode = "bilinear", relative_depth_error_threshold = 0.05):
     """Warp kpts0 from I0 to I1 with depth, K and Rt
     Also check covisibility and depth consistency.
     Depth is consistent if relative error < 0.2 (hard-coded).
@@ -432,44 +369,26 @@ def warp_kpts(
         # Inspired by approach in inloc, try to fill holes from bilinear interpolation by nearest neighbour interpolation
         if smooth_mask:
             raise NotImplementedError("Combined bilinear and NN warp not implemented")
-        valid_bilinear, warp_bilinear = warp_kpts(
-            kpts0,
-            depth0,
-            depth1,
-            T_0to1,
-            K0,
-            K1,
-            smooth_mask=smooth_mask,
-            return_relative_depth_error=return_relative_depth_error,
-            depth_interpolation_mode="bilinear",
-            relative_depth_error_threshold=relative_depth_error_threshold,
-        )
-        valid_nearest, warp_nearest = warp_kpts(
-            kpts0,
-            depth0,
-            depth1,
-            T_0to1,
-            K0,
-            K1,
-            smooth_mask=smooth_mask,
-            return_relative_depth_error=return_relative_depth_error,
-            depth_interpolation_mode="nearest-exact",
-            relative_depth_error_threshold=relative_depth_error_threshold,
-        )
-        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest)
+        valid_bilinear, warp_bilinear = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
+                  smooth_mask = smooth_mask, 
+                  return_relative_depth_error = return_relative_depth_error, 
+                  depth_interpolation_mode = "bilinear",
+                  relative_depth_error_threshold = relative_depth_error_threshold)
+        valid_nearest, warp_nearest = warp_kpts(kpts0, depth0, depth1, T_0to1, K0, K1, 
+                  smooth_mask = smooth_mask, 
+                  return_relative_depth_error = return_relative_depth_error, 
+                  depth_interpolation_mode = "nearest-exact",
+                  relative_depth_error_threshold = relative_depth_error_threshold)
+        nearest_valid_bilinear_invalid = (~valid_bilinear).logical_and(valid_nearest) 
         warp = warp_bilinear.clone()
-        warp[nearest_valid_bilinear_invalid] = warp_nearest[
-            nearest_valid_bilinear_invalid
-        ]
+        warp[nearest_valid_bilinear_invalid] = warp_nearest[nearest_valid_bilinear_invalid]
         valid = valid_bilinear | valid_nearest
         return valid, warp
-
-    kpts0_depth = F.grid_sample(
-        depth0[:, None],
-        kpts0[:, :, None],
-        mode=depth_interpolation_mode,
-        align_corners=False,
-    )[:, 0, :, 0]
+        
+        
+    kpts0_depth = F.grid_sample(depth0[:, None], kpts0[:, :, None], mode = depth_interpolation_mode, align_corners=False)[
+        :, 0, :, 0
+    ]
     kpts0 = torch.stack(
         (w * (kpts0[..., 0] + 1) / 2, h * (kpts0[..., 1] + 1) / 2), dim=-1
     )  # [-1+1/h, 1-1/h] -> [0.5, h-0.5]
@@ -508,26 +427,22 @@ def warp_kpts(
     # w_kpts0[~covisible_mask, :] = -5 # xd
 
     w_kpts0_depth = F.grid_sample(
-        depth1[:, None],
-        w_kpts0[:, :, None],
-        mode=depth_interpolation_mode,
-        align_corners=False,
+        depth1[:, None], w_kpts0[:, :, None], mode=depth_interpolation_mode, align_corners=False
     )[:, 0, :, 0]
-
+    
     relative_depth_error = (
         (w_kpts0_depth - w_kpts0_depth_computed) / w_kpts0_depth
     ).abs()
     if not smooth_mask:
         consistent_mask = relative_depth_error < relative_depth_error_threshold
     else:
-        consistent_mask = (-relative_depth_error / smooth_mask).exp()
+        consistent_mask = (-relative_depth_error/smooth_mask).exp()
     valid_mask = nonzero_mask * covisible_mask * consistent_mask
     if return_relative_depth_error:
         return relative_depth_error, w_kpts0
     else:
         return valid_mask, w_kpts0
 
-
 imagenet_mean = torch.tensor([0.485, 0.456, 0.406])
 imagenet_std = torch.tensor([0.229, 0.224, 0.225])
 
@@ -547,9 +462,7 @@ def numpy_to_pil(x: np.ndarray):
 
 def tensor_to_pil(x, unnormalize=False):
     if unnormalize:
-        x = x * (imagenet_std[:, None, None].to(x.device)) + (
-            imagenet_mean[:, None, None].to(x.device)
-        )
+        x = x * (imagenet_std[:, None, None].to(x.device)) + (imagenet_mean[:, None, None].to(x.device))
     x = x.detach().permute(1, 2, 0).cpu().numpy()
     x = np.clip(x, 0.0, 1.0)
     return numpy_to_pil(x)
@@ -579,63 +492,73 @@ def compute_relative_pose(R1, t1, R2, t2):
     trans = -rots @ t1 + t2
     return rots, trans
 
-
 @torch.no_grad()
 def reset_opt(opt):
     for group in opt.param_groups:
-        for p in group["params"]:
+        for p in group['params']:
             if p.requires_grad:
                 state = opt.state[p]
                 # State initialization
 
                 # Exponential moving average of gradient values
-                state["exp_avg"] = torch.zeros_like(p)
+                state['exp_avg'] = torch.zeros_like(p)
                 # Exponential moving average of squared gradient values
-                state["exp_avg_sq"] = torch.zeros_like(p)
+                state['exp_avg_sq'] = torch.zeros_like(p)
                 # Exponential moving average of gradient difference
-                state["exp_avg_diff"] = torch.zeros_like(p)
+                state['exp_avg_diff'] = torch.zeros_like(p)
 
 
 def flow_to_pixel_coords(flow, h1, w1):
-    flow = torch.stack(
-        (
-            w1 * (flow[..., 0] + 1) / 2,
-            h1 * (flow[..., 1] + 1) / 2,
-        ),
-        axis=-1,
+    flow = (
+        torch.stack(
+            (
+                w1 * (flow[..., 0] + 1) / 2,
+                h1 * (flow[..., 1] + 1) / 2,
+            ),
+            axis=-1,
+        )
     )
     return flow
 
+to_pixel_coords = flow_to_pixel_coords # just an alias
 
 def flow_to_normalized_coords(flow, h1, w1):
-    flow = torch.stack(
-        (
-            2 * (flow[..., 0]) / w1 - 1,
-            2 * (flow[..., 1]) / h1 - 1,
-        ),
-        axis=-1,
+    flow = (
+        torch.stack(
+            (
+                2 * (flow[..., 0]) / w1 - 1,
+                2 * (flow[..., 1]) / h1 - 1,
+            ),
+            axis=-1,
+        )
     )
     return flow
 
+to_normalized_coords = flow_to_normalized_coords # just an alias
 
 def warp_to_pixel_coords(warp, h1, w1, h2, w2):
     warp1 = warp[..., :2]
-    warp1 = torch.stack(
-        (
-            w1 * (warp1[..., 0] + 1) / 2,
-            h1 * (warp1[..., 1] + 1) / 2,
-        ),
-        axis=-1,
+    warp1 = (
+        torch.stack(
+            (
+                w1 * (warp1[..., 0] + 1) / 2,
+                h1 * (warp1[..., 1] + 1) / 2,
+            ),
+            axis=-1,
+        )
     )
     warp2 = warp[..., 2:]
-    warp2 = torch.stack(
-        (
-            w2 * (warp2[..., 0] + 1) / 2,
-            h2 * (warp2[..., 1] + 1) / 2,
-        ),
-        axis=-1,
+    warp2 = (
+        torch.stack(
+            (
+                w2 * (warp2[..., 0] + 1) / 2,
+                h2 * (warp2[..., 1] + 1) / 2,
+            ),
+            axis=-1,
+        )
     )
-    return torch.cat((warp1, warp2), dim=-1)
+    return torch.cat((warp1,warp2), dim=-1)
+
 
 
 def signed_point_line_distance(point, line, eps: float = 1e-9):
@@ -656,9 +579,7 @@ def signed_point_line_distance(point, line, eps: float = 1e-9):
     if not line.shape[-1] == 3:
         raise ValueError(f"lines must be a (*, 3) tensor. Got {line.shape}")
 
-    numerator = (
-        line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2]
-    )
+    numerator = (line[..., 0] * point[..., 0] + line[..., 1] * point[..., 1] + line[..., 2])
     denominator = line[..., :2].norm(dim=-1)
 
     return numerator / (denominator + eps)
@@ -682,7 +603,6 @@ def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
         the computed Symmetrical distance with shape :math:`(*, N)`.
     """
     import kornia
-
     if (len(Fm.shape) < 3) or not Fm.shape[-2:] == (3, 3):
         raise ValueError(f"Fm must be a (*, 3, 3) tensor. Got {Fm.shape}")
 
@@ -694,10 +614,12 @@ def signed_left_to_right_epipolar_distance(pts1, pts2, Fm):
 
     return signed_point_line_distance(pts2, line1_in_2)
 
-
 def get_grid(b, h, w, device):
     grid = torch.meshgrid(
-        *[torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device) for n in (b, h, w)]
+        *[
+            torch.linspace(-1 + 1 / n, 1 - 1 / n, n, device=device)
+            for n in (b, h, w)
+        ]
     )
     grid = torch.stack((grid[2], grid[1]), dim=-1).reshape(b, h, w, 2)
     return grid
diff --git a/third_party/Roma/setup.py b/third_party/RoMa/setup.py
similarity index 61%
rename from third_party/Roma/setup.py
rename to third_party/RoMa/setup.py
index ae777c0e5a41f0e4b03a838d19bc9a2bb04d4617..fe2e6dc4be62254f702e34422e07468b00195dd2 100644
--- a/third_party/Roma/setup.py
+++ b/third_party/RoMa/setup.py
@@ -1,8 +1,8 @@
-from setuptools import setup
+from setuptools import setup, find_packages
 
 setup(
     name="roma",
-    packages=["roma"],
+    packages=find_packages(include=("roma*",)),
     version="0.0.1",
     author="Johan Edstedt",
     install_requires=open("requirements.txt", "r").read().split("\n"),
diff --git a/third_party/Roma/LICENSE b/third_party/Roma/LICENSE
deleted file mode 100644
index a115f899f8d09ef3b1def4a16c7bae1a0bd50fbe..0000000000000000000000000000000000000000
--- a/third_party/Roma/LICENSE
+++ /dev/null
@@ -1,400 +0,0 @@
-
-Attribution-NonCommercial 4.0 International
-
-=======================================================================
-
-Creative Commons Corporation ("Creative Commons") is not a law firm and
-does not provide legal services or legal advice. Distribution of
-Creative Commons public licenses does not create a lawyer-client or
-other relationship. Creative Commons makes its licenses and related
-information available on an "as-is" basis. Creative Commons gives no
-warranties regarding its licenses, any material licensed under their
-terms and conditions, or any related information. Creative Commons
-disclaims all liability for damages resulting from their use to the
-fullest extent possible.
-
-Using Creative Commons Public Licenses
-
-Creative Commons public licenses provide a standard set of terms and
-conditions that creators and other rights holders may use to share
-original works of authorship and other material subject to copyright
-and certain other rights specified in the public license below. The
-following considerations are for informational purposes only, are not
-exhaustive, and do not form part of our licenses.
-
-     Considerations for licensors: Our public licenses are
-     intended for use by those authorized to give the public
-     permission to use material in ways otherwise restricted by
-     copyright and certain other rights. Our licenses are
-     irrevocable. Licensors should read and understand the terms
-     and conditions of the license they choose before applying it.
-     Licensors should also secure all rights necessary before
-     applying our licenses so that the public can reuse the
-     material as expected. Licensors should clearly mark any
-     material not subject to the license. This includes other CC-
-     licensed material, or material used under an exception or
-     limitation to copyright. More considerations for licensors:
-	wiki.creativecommons.org/Considerations_for_licensors
-
-     Considerations for the public: By using one of our public
-     licenses, a licensor grants the public permission to use the
-     licensed material under specified terms and conditions. If
-     the licensor's permission is not necessary for any reason--for
-     example, because of any applicable exception or limitation to
-     copyright--then that use is not regulated by the license. Our
-     licenses grant only permissions under copyright and certain
-     other rights that a licensor has authority to grant. Use of
-     the licensed material may still be restricted for other
-     reasons, including because others have copyright or other
-     rights in the material. A licensor may make special requests,
-     such as asking that all changes be marked or described.
-     Although not required by our licenses, you are encouraged to
-     respect those requests where reasonable. More_considerations
-     for the public: 
-	wiki.creativecommons.org/Considerations_for_licensees
-
-=======================================================================
-
-Creative Commons Attribution-NonCommercial 4.0 International Public
-License
-
-By exercising the Licensed Rights (defined below), You accept and agree
-to be bound by the terms and conditions of this Creative Commons
-Attribution-NonCommercial 4.0 International Public License ("Public
-License"). To the extent this Public License may be interpreted as a
-contract, You are granted the Licensed Rights in consideration of Your
-acceptance of these terms and conditions, and the Licensor grants You
-such rights in consideration of benefits the Licensor receives from
-making the Licensed Material available under these terms and
-conditions.
-
-Section 1 -- Definitions.
-
-  a. Adapted Material means material subject to Copyright and Similar
-     Rights that is derived from or based upon the Licensed Material
-     and in which the Licensed Material is translated, altered,
-     arranged, transformed, or otherwise modified in a manner requiring
-     permission under the Copyright and Similar Rights held by the
-     Licensor. For purposes of this Public License, where the Licensed
-     Material is a musical work, performance, or sound recording,
-     Adapted Material is always produced where the Licensed Material is
-     synched in timed relation with a moving image.
-
-  b. Adapter's License means the license You apply to Your Copyright
-     and Similar Rights in Your contributions to Adapted Material in
-     accordance with the terms and conditions of this Public License.
-
-  c. Copyright and Similar Rights means copyright and/or similar rights
-     closely related to copyright including, without limitation,
-     performance, broadcast, sound recording, and Sui Generis Database
-     Rights, without regard to how the rights are labeled or
-     categorized. For purposes of this Public License, the rights
-     specified in Section 2(b)(1)-(2) are not Copyright and Similar
-     Rights.
-  d. Effective Technological Measures means those measures that, in the
-     absence of proper authority, may not be circumvented under laws
-     fulfilling obligations under Article 11 of the WIPO Copyright
-     Treaty adopted on December 20, 1996, and/or similar international
-     agreements.
-
-  e. Exceptions and Limitations means fair use, fair dealing, and/or
-     any other exception or limitation to Copyright and Similar Rights
-     that applies to Your use of the Licensed Material.
-
-  f. Licensed Material means the artistic or literary work, database,
-     or other material to which the Licensor applied this Public
-     License.
-
-  g. Licensed Rights means the rights granted to You subject to the
-     terms and conditions of this Public License, which are limited to
-     all Copyright and Similar Rights that apply to Your use of the
-     Licensed Material and that the Licensor has authority to license.
-
-  h. Licensor means the individual(s) or entity(ies) granting rights
-     under this Public License.
-
-  i. NonCommercial means not primarily intended for or directed towards
-     commercial advantage or monetary compensation. For purposes of
-     this Public License, the exchange of the Licensed Material for
-     other material subject to Copyright and Similar Rights by digital
-     file-sharing or similar means is NonCommercial provided there is
-     no payment of monetary compensation in connection with the
-     exchange.
-
-  j. Share means to provide material to the public by any means or
-     process that requires permission under the Licensed Rights, such
-     as reproduction, public display, public performance, distribution,
-     dissemination, communication, or importation, and to make material
-     available to the public including in ways that members of the
-     public may access the material from a place and at a time
-     individually chosen by them.
-
-  k. Sui Generis Database Rights means rights other than copyright
-     resulting from Directive 96/9/EC of the European Parliament and of
-     the Council of 11 March 1996 on the legal protection of databases,
-     as amended and/or succeeded, as well as other essentially
-     equivalent rights anywhere in the world.
-
-  l. You means the individual or entity exercising the Licensed Rights
-     under this Public License. Your has a corresponding meaning.
-
-Section 2 -- Scope.
-
-  a. License grant.
-
-       1. Subject to the terms and conditions of this Public License,
-          the Licensor hereby grants You a worldwide, royalty-free,
-          non-sublicensable, non-exclusive, irrevocable license to
-          exercise the Licensed Rights in the Licensed Material to:
-
-            a. reproduce and Share the Licensed Material, in whole or
-               in part, for NonCommercial purposes only; and
-
-            b. produce, reproduce, and Share Adapted Material for
-               NonCommercial purposes only.
-
-       2. Exceptions and Limitations. For the avoidance of doubt, where
-          Exceptions and Limitations apply to Your use, this Public
-          License does not apply, and You do not need to comply with
-          its terms and conditions.
-
-       3. Term. The term of this Public License is specified in Section
-          6(a).
-
-       4. Media and formats; technical modifications allowed. The
-          Licensor authorizes You to exercise the Licensed Rights in
-          all media and formats whether now known or hereafter created,
-          and to make technical modifications necessary to do so. The
-          Licensor waives and/or agrees not to assert any right or
-          authority to forbid You from making technical modifications
-          necessary to exercise the Licensed Rights, including
-          technical modifications necessary to circumvent Effective
-          Technological Measures. For purposes of this Public License,
-          simply making modifications authorized by this Section 2(a)
-          (4) never produces Adapted Material.
-
-       5. Downstream recipients.
-
-            a. Offer from the Licensor -- Licensed Material. Every
-               recipient of the Licensed Material automatically
-               receives an offer from the Licensor to exercise the
-               Licensed Rights under the terms and conditions of this
-               Public License.
-
-            b. No downstream restrictions. You may not offer or impose
-               any additional or different terms or conditions on, or
-               apply any Effective Technological Measures to, the
-               Licensed Material if doing so restricts exercise of the
-               Licensed Rights by any recipient of the Licensed
-               Material.
-
-       6. No endorsement. Nothing in this Public License constitutes or
-          may be construed as permission to assert or imply that You
-          are, or that Your use of the Licensed Material is, connected
-          with, or sponsored, endorsed, or granted official status by,
-          the Licensor or others designated to receive attribution as
-          provided in Section 3(a)(1)(A)(i).
-
-  b. Other rights.
-
-       1. Moral rights, such as the right of integrity, are not
-          licensed under this Public License, nor are publicity,
-          privacy, and/or other similar personality rights; however, to
-          the extent possible, the Licensor waives and/or agrees not to
-          assert any such rights held by the Licensor to the limited
-          extent necessary to allow You to exercise the Licensed
-          Rights, but not otherwise.
-
-       2. Patent and trademark rights are not licensed under this
-          Public License.
-
-       3. To the extent possible, the Licensor waives any right to
-          collect royalties from You for the exercise of the Licensed
-          Rights, whether directly or through a collecting society
-          under any voluntary or waivable statutory or compulsory
-          licensing scheme. In all other cases the Licensor expressly
-          reserves any right to collect such royalties, including when
-          the Licensed Material is used other than for NonCommercial
-          purposes.
-
-Section 3 -- License Conditions.
-
-Your exercise of the Licensed Rights is expressly made subject to the
-following conditions.
-
-  a. Attribution.
-
-       1. If You Share the Licensed Material (including in modified
-          form), You must:
-
-            a. retain the following if it is supplied by the Licensor
-               with the Licensed Material:
-
-                 i. identification of the creator(s) of the Licensed
-                    Material and any others designated to receive
-                    attribution, in any reasonable manner requested by
-                    the Licensor (including by pseudonym if
-                    designated);
-
-                ii. a copyright notice;
-
-               iii. a notice that refers to this Public License;
-
-                iv. a notice that refers to the disclaimer of
-                    warranties;
-
-                 v. a URI or hyperlink to the Licensed Material to the
-                    extent reasonably practicable;
-
-            b. indicate if You modified the Licensed Material and
-               retain an indication of any previous modifications; and
-
-            c. indicate the Licensed Material is licensed under this
-               Public License, and include the text of, or the URI or
-               hyperlink to, this Public License.
-
-       2. You may satisfy the conditions in Section 3(a)(1) in any
-          reasonable manner based on the medium, means, and context in
-          which You Share the Licensed Material. For example, it may be
-          reasonable to satisfy the conditions by providing a URI or
-          hyperlink to a resource that includes the required
-          information.
-
-       3. If requested by the Licensor, You must remove any of the
-          information required by Section 3(a)(1)(A) to the extent
-          reasonably practicable.
-
-       4. If You Share Adapted Material You produce, the Adapter's
-          License You apply must not prevent recipients of the Adapted
-          Material from complying with this Public License.
-
-Section 4 -- Sui Generis Database Rights.
-
-Where the Licensed Rights include Sui Generis Database Rights that
-apply to Your use of the Licensed Material:
-
-  a. for the avoidance of doubt, Section 2(a)(1) grants You the right
-     to extract, reuse, reproduce, and Share all or a substantial
-     portion of the contents of the database for NonCommercial purposes
-     only;
-
-  b. if You include all or a substantial portion of the database
-     contents in a database in which You have Sui Generis Database
-     Rights, then the database in which You have Sui Generis Database
-     Rights (but not its individual contents) is Adapted Material; and
-
-  c. You must comply with the conditions in Section 3(a) if You Share
-     all or a substantial portion of the contents of the database.
-
-For the avoidance of doubt, this Section 4 supplements and does not
-replace Your obligations under this Public License where the Licensed
-Rights include other Copyright and Similar Rights.
-
-Section 5 -- Disclaimer of Warranties and Limitation of Liability.
-
-  a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
-     EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
-     AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
-     ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
-     IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
-     WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
-     PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
-     ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
-     KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
-     ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
-
-  b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
-     TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
-     NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
-     INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
-     COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
-     USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
-     ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
-     DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
-     IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
-
-  c. The disclaimer of warranties and limitation of liability provided
-     above shall be interpreted in a manner that, to the extent
-     possible, most closely approximates an absolute disclaimer and
-     waiver of all liability.
-
-Section 6 -- Term and Termination.
-
-  a. This Public License applies for the term of the Copyright and
-     Similar Rights licensed here. However, if You fail to comply with
-     this Public License, then Your rights under this Public License
-     terminate automatically.
-
-  b. Where Your right to use the Licensed Material has terminated under
-     Section 6(a), it reinstates:
-
-       1. automatically as of the date the violation is cured, provided
-          it is cured within 30 days of Your discovery of the
-          violation; or
-
-       2. upon express reinstatement by the Licensor.
-
-     For the avoidance of doubt, this Section 6(b) does not affect any
-     right the Licensor may have to seek remedies for Your violations
-     of this Public License.
-
-  c. For the avoidance of doubt, the Licensor may also offer the
-     Licensed Material under separate terms or conditions or stop
-     distributing the Licensed Material at any time; however, doing so
-     will not terminate this Public License.
-
-  d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
-     License.
-
-Section 7 -- Other Terms and Conditions.
-
-  a. The Licensor shall not be bound by any additional or different
-     terms or conditions communicated by You unless expressly agreed.
-
-  b. Any arrangements, understandings, or agreements regarding the
-     Licensed Material not stated herein are separate from and
-     independent of the terms and conditions of this Public License.
-
-Section 8 -- Interpretation.
-
-  a. For the avoidance of doubt, this Public License does not, and
-     shall not be interpreted to, reduce, limit, restrict, or impose
-     conditions on any use of the Licensed Material that could lawfully
-     be made without permission under this Public License.
-
-  b. To the extent possible, if any provision of this Public License is
-     deemed unenforceable, it shall be automatically reformed to the
-     minimum extent necessary to make it enforceable. If the provision
-     cannot be reformed, it shall be severed from this Public License
-     without affecting the enforceability of the remaining terms and
-     conditions.
-
-  c. No term or condition of this Public License will be waived and no
-     failure to comply consented to unless expressly agreed to by the
-     Licensor.
-
-  d. Nothing in this Public License constitutes or may be interpreted
-     as a limitation upon, or waiver of, any privileges and immunities
-     that apply to the Licensor or You, including from the legal
-     processes of any jurisdiction or authority.
-
-=======================================================================
-
-Creative Commons is not a party to its public
-licenses. Notwithstanding, Creative Commons may elect to apply one of
-its public licenses to material it publishes and in those instances
-will be considered the “Licensor.” The text of the Creative Commons
-public licenses is dedicated to the public domain under the CC0 Public
-Domain Dedication. Except for the limited purpose of indicating that
-material is shared under a Creative Commons public license or as
-otherwise permitted by the Creative Commons policies published at
-creativecommons.org/policies, Creative Commons does not authorize the
-use of the trademark "Creative Commons" or any other trademark or logo
-of Creative Commons without its prior written consent including,
-without limitation, in connection with any unauthorized modifications
-to any of its public licenses or any other arrangements,
-understandings, or agreements concerning use of licensed material. For
-the avoidance of doubt, this paragraph does not form part of the
-public licenses.
-
-Creative Commons may be contacted at creativecommons.org.
diff --git a/third_party/Roma/README.md b/third_party/Roma/README.md
deleted file mode 100644
index 5e984366c8f7af37615d7666f34cd82a90073fee..0000000000000000000000000000000000000000
--- a/third_party/Roma/README.md
+++ /dev/null
@@ -1,63 +0,0 @@
-# RoMa: Revisiting Robust Losses for Dense Feature Matching
-### [Project Page (TODO)](https://parskatt.github.io/RoMa) | [Paper](https://arxiv.org/abs/2305.15404)
-<br/>
-
-> RoMa: Revisiting Robust Lossses for Dense Feature Matching  
-> [Johan Edstedt](https://scholar.google.com/citations?user=Ul-vMR0AAAAJ), [Qiyu Sun](https://scholar.google.com/citations?user=HS2WuHkAAAAJ), [Georg Bökman](https://scholar.google.com/citations?user=FUE3Wd0AAAAJ), [MÄrten WadenbÀck](https://scholar.google.com/citations?user=6WRQpCQAAAAJ), [Michael Felsberg](https://scholar.google.com/citations?&user=lkWfR08AAAAJ)  
-> Arxiv 2023
-
-**NOTE!!! Very early code, there might be bugs**
-
-The codebase is in the [roma folder](roma).
-
-## Setup/Install
-In your python environment (tested on Linux python 3.10), run:
-```bash
-pip install -e .
-```
-## Demo / How to Use
-We provide two demos in the [demos folder](demo).
-Here's the gist of it:
-```python
-from roma import roma_outdoor
-roma_model = roma_outdoor(device=device)
-# Match
-warp, certainty = roma_model.match(imA_path, imB_path, device=device)
-# Sample matches for estimation
-matches, certainty = roma_model.sample(warp, certainty)
-# Convert to pixel coordinates (RoMa produces matches in [-1,1]x[-1,1])
-kptsA, kptsB = roma_model.to_pixel_coordinates(matches, H_A, W_A, H_B, W_B)
-# Find a fundamental matrix (or anything else of interest)
-F, mask = cv2.findFundamentalMat(
-    kptsA.cpu().numpy(), kptsB.cpu().numpy(), ransacReprojThreshold=0.2, method=cv2.USAC_MAGSAC, confidence=0.999999, maxIters=10000
-)
-```
-## Reproducing Results
-The experiments in the paper are provided in the [experiments folder](experiments).
-
-### Training
-1. First follow the instructions provided here: https://github.com/Parskatt/DKM for downloading and preprocessing datasets.
-2. Run the relevant experiment, e.g.,
-```bash
-torchrun --nproc_per_node=4 --nnodes=1 --rdzv_backend=c10d experiments/roma_outdoor.py
-```
-### Testing
-```bash
-python experiments/roma_outdoor.py --only_test --benchmark mega-1500
-```
-## License
-Due to our dependency on [DINOv2](https://github.com/facebookresearch/dinov2/blob/main/LICENSE), the license is sadly non-commercial only for the moment.
-
-## Acknowledgement
-Our codebase builds on the code in [DKM](https://github.com/Parskatt/DKM).
-
-## BibTeX
-If you find our models useful, please consider citing our paper!
-```
-@article{edstedt2023roma,
-title={{RoMa}: Revisiting Robust Lossses for Dense Feature Matching},
-author={Edstedt, Johan and Sun, Qiyu and Bökman, Georg and WadenbÀck, MÄrten and Felsberg, Michael},
-journal={arXiv preprint arXiv:2305.15404},
-year={2023}
-}
-```
diff --git a/third_party/Roma/roma/datasets/scannet.py b/third_party/Roma/roma/datasets/scannet.py
deleted file mode 100644
index 91bea57c9d1ae2773c11a9c8d47f31026a2c227b..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/datasets/scannet.py
+++ /dev/null
@@ -1,191 +0,0 @@
-import os
-import random
-from PIL import Image
-import cv2
-import h5py
-import numpy as np
-import torch
-from torch.utils.data import Dataset, DataLoader, ConcatDataset
-
-import torchvision.transforms.functional as tvf
-import kornia.augmentation as K
-import os.path as osp
-import matplotlib.pyplot as plt
-import roma
-from roma.utils import get_depth_tuple_transform_ops, get_tuple_transform_ops
-from roma.utils.transforms import GeometricSequential
-from tqdm import tqdm
-
-
-class ScanNetScene:
-    def __init__(
-        self,
-        data_root,
-        scene_info,
-        ht=384,
-        wt=512,
-        min_overlap=0.0,
-        shake_t=0,
-        rot_prob=0.0,
-        use_horizontal_flip_aug=False,
-    ) -> None:
-        self.scene_root = osp.join(data_root, "scans", "scans_train")
-        self.data_names = scene_info["name"]
-        self.overlaps = scene_info["score"]
-        # Only sample 10s
-        valid = (self.data_names[:, -2:] % 10).sum(axis=-1) == 0
-        self.overlaps = self.overlaps[valid]
-        self.data_names = self.data_names[valid]
-        if len(self.data_names) > 10000:
-            pairinds = np.random.choice(
-                np.arange(0, len(self.data_names)), 10000, replace=False
-            )
-            self.data_names = self.data_names[pairinds]
-            self.overlaps = self.overlaps[pairinds]
-        self.im_transform_ops = get_tuple_transform_ops(resize=(ht, wt), normalize=True)
-        self.depth_transform_ops = get_depth_tuple_transform_ops(
-            resize=(ht, wt), normalize=False
-        )
-        self.wt, self.ht = wt, ht
-        self.shake_t = shake_t
-        self.H_generator = GeometricSequential(K.RandomAffine(degrees=90, p=rot_prob))
-        self.use_horizontal_flip_aug = use_horizontal_flip_aug
-
-    def load_im(self, im_B, crop=None):
-        im = Image.open(im_B)
-        return im
-
-    def load_depth(self, depth_ref, crop=None):
-        depth = cv2.imread(str(depth_ref), cv2.IMREAD_UNCHANGED)
-        depth = depth / 1000
-        depth = torch.from_numpy(depth).float()  # (h, w)
-        return depth
-
-    def __len__(self):
-        return len(self.data_names)
-
-    def scale_intrinsic(self, K, wi, hi):
-        sx, sy = self.wt / wi, self.ht / hi
-        sK = torch.tensor([[sx, 0, 0], [0, sy, 0], [0, 0, 1]])
-        return sK @ K
-
-    def horizontal_flip(self, im_A, im_B, depth_A, depth_B, K_A, K_B):
-        im_A = im_A.flip(-1)
-        im_B = im_B.flip(-1)
-        depth_A, depth_B = depth_A.flip(-1), depth_B.flip(-1)
-        flip_mat = torch.tensor([[-1, 0, self.wt], [0, 1, 0], [0, 0, 1.0]]).to(
-            K_A.device
-        )
-        K_A = flip_mat @ K_A
-        K_B = flip_mat @ K_B
-
-        return im_A, im_B, depth_A, depth_B, K_A, K_B
-
-    def read_scannet_pose(self, path):
-        """Read ScanNet's Camera2World pose and transform it to World2Camera.
-
-        Returns:
-            pose_w2c (np.ndarray): (4, 4)
-        """
-        cam2world = np.loadtxt(path, delimiter=" ")
-        world2cam = np.linalg.inv(cam2world)
-        return world2cam
-
-    def read_scannet_intrinsic(self, path):
-        """Read ScanNet's intrinsic matrix and return the 3x3 matrix."""
-        intrinsic = np.loadtxt(path, delimiter=" ")
-        return torch.tensor(intrinsic[:-1, :-1], dtype=torch.float)
-
-    def __getitem__(self, pair_idx):
-        # read intrinsics of original size
-        data_name = self.data_names[pair_idx]
-        scene_name, scene_sub_name, stem_name_1, stem_name_2 = data_name
-        scene_name = f"scene{scene_name:04d}_{scene_sub_name:02d}"
-
-        # read the intrinsic of depthmap
-        K1 = K2 = self.read_scannet_intrinsic(
-            osp.join(self.scene_root, scene_name, "intrinsic", "intrinsic_color.txt")
-        )  # the depth K is not the same, but doesnt really matter
-        # read and compute relative poses
-        T1 = self.read_scannet_pose(
-            osp.join(self.scene_root, scene_name, "pose", f"{stem_name_1}.txt")
-        )
-        T2 = self.read_scannet_pose(
-            osp.join(self.scene_root, scene_name, "pose", f"{stem_name_2}.txt")
-        )
-        T_1to2 = torch.tensor(np.matmul(T2, np.linalg.inv(T1)), dtype=torch.float)[
-            :4, :4
-        ]  # (4, 4)
-
-        # Load positive pair data
-        im_A_ref = os.path.join(
-            self.scene_root, scene_name, "color", f"{stem_name_1}.jpg"
-        )
-        im_B_ref = os.path.join(
-            self.scene_root, scene_name, "color", f"{stem_name_2}.jpg"
-        )
-        depth_A_ref = os.path.join(
-            self.scene_root, scene_name, "depth", f"{stem_name_1}.png"
-        )
-        depth_B_ref = os.path.join(
-            self.scene_root, scene_name, "depth", f"{stem_name_2}.png"
-        )
-
-        im_A = self.load_im(im_A_ref)
-        im_B = self.load_im(im_B_ref)
-        depth_A = self.load_depth(depth_A_ref)
-        depth_B = self.load_depth(depth_B_ref)
-
-        # Recompute camera intrinsic matrix due to the resize
-        K1 = self.scale_intrinsic(K1, im_A.width, im_A.height)
-        K2 = self.scale_intrinsic(K2, im_B.width, im_B.height)
-        # Process images
-        im_A, im_B = self.im_transform_ops((im_A, im_B))
-        depth_A, depth_B = self.depth_transform_ops(
-            (depth_A[None, None], depth_B[None, None])
-        )
-        if self.use_horizontal_flip_aug:
-            if np.random.rand() > 0.5:
-                im_A, im_B, depth_A, depth_B, K1, K2 = self.horizontal_flip(
-                    im_A, im_B, depth_A, depth_B, K1, K2
-                )
-
-        data_dict = {
-            "im_A": im_A,
-            "im_B": im_B,
-            "im_A_depth": depth_A[0, 0],
-            "im_B_depth": depth_B[0, 0],
-            "K1": K1,
-            "K2": K2,
-            "T_1to2": T_1to2,
-        }
-        return data_dict
-
-
-class ScanNetBuilder:
-    def __init__(self, data_root="data/scannet") -> None:
-        self.data_root = data_root
-        self.scene_info_root = os.path.join(data_root, "scannet_indices")
-        self.all_scenes = os.listdir(self.scene_info_root)
-
-    def build_scenes(self, split="train", min_overlap=0.0, **kwargs):
-        # Note: split doesn't matter here as we always use same scannet_train scenes
-        scene_names = self.all_scenes
-        scenes = []
-        for scene_name in tqdm(scene_names, disable=roma.RANK > 0):
-            scene_info = np.load(
-                os.path.join(self.scene_info_root, scene_name), allow_pickle=True
-            )
-            scenes.append(
-                ScanNetScene(
-                    self.data_root, scene_info, min_overlap=min_overlap, **kwargs
-                )
-            )
-        return scenes
-
-    def weight_scenes(self, concat_dataset, alpha=0.5):
-        ns = []
-        for d in concat_dataset.datasets:
-            ns.append(len(d))
-        ws = torch.cat([torch.ones(n) / n**alpha for n in ns])
-        return ws
diff --git a/third_party/Roma/roma/losses/__init__.py b/third_party/Roma/roma/losses/__init__.py
deleted file mode 100644
index 12cb6d40b90ca3ccf712321f78c033401db865fb..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/losses/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .robust_loss import RobustLosses
diff --git a/third_party/Roma/roma/losses/robust_loss.py b/third_party/Roma/roma/losses/robust_loss.py
deleted file mode 100644
index cd9fd5bbc9c2d01bb6dd40823e350b588bd598b3..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/losses/robust_loss.py
+++ /dev/null
@@ -1,222 +0,0 @@
-from einops.einops import rearrange
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-from roma.utils.utils import get_gt_warp
-import wandb
-import roma
-import math
-
-
-class RobustLosses(nn.Module):
-    def __init__(
-        self,
-        robust=False,
-        center_coords=False,
-        scale_normalize=False,
-        ce_weight=0.01,
-        local_loss=True,
-        local_dist=4.0,
-        local_largest_scale=8,
-        smooth_mask=False,
-        depth_interpolation_mode="bilinear",
-        mask_depth_loss=False,
-        relative_depth_error_threshold=0.05,
-        alpha=1.0,
-        c=1e-3,
-    ):
-        super().__init__()
-        self.robust = robust  # measured in pixels
-        self.center_coords = center_coords
-        self.scale_normalize = scale_normalize
-        self.ce_weight = ce_weight
-        self.local_loss = local_loss
-        self.local_dist = local_dist
-        self.local_largest_scale = local_largest_scale
-        self.smooth_mask = smooth_mask
-        self.depth_interpolation_mode = depth_interpolation_mode
-        self.mask_depth_loss = mask_depth_loss
-        self.relative_depth_error_threshold = relative_depth_error_threshold
-        self.avg_overlap = dict()
-        self.alpha = alpha
-        self.c = c
-
-    def gm_cls_loss(self, x2, prob, scale_gm_cls, gm_certainty, scale):
-        with torch.no_grad():
-            B, C, H, W = scale_gm_cls.shape
-            device = x2.device
-            cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(
-                *[
-                    torch.linspace(
-                        -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
-                    )
-                    for _ in range(2)
-                ]
-            )
-            G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2)
-            GT = (
-                (G[None, :, None, None, :] - x2[:, None])
-                .norm(dim=-1)
-                .min(dim=1)
-                .indices
-            )
-        cls_loss = F.cross_entropy(scale_gm_cls, GT, reduction="none")[prob > 0.99]
-        if not torch.any(cls_loss):
-            cls_loss = certainty_loss * 0.0  # Prevent issues where prob is 0 everywhere
-
-        certainty_loss = F.binary_cross_entropy_with_logits(gm_certainty[:, 0], prob)
-        losses = {
-            f"gm_certainty_loss_{scale}": certainty_loss.mean(),
-            f"gm_cls_loss_{scale}": cls_loss.mean(),
-        }
-        wandb.log(losses, step=roma.GLOBAL_STEP)
-        return losses
-
-    def delta_cls_loss(
-        self, x2, prob, flow_pre_delta, delta_cls, certainty, scale, offset_scale
-    ):
-        with torch.no_grad():
-            B, C, H, W = delta_cls.shape
-            device = x2.device
-            cls_res = round(math.sqrt(C))
-            G = torch.meshgrid(
-                *[
-                    torch.linspace(
-                        -1 + 1 / cls_res, 1 - 1 / cls_res, steps=cls_res, device=device
-                    )
-                    for _ in range(2)
-                ]
-            )
-            G = torch.stack((G[1], G[0]), dim=-1).reshape(C, 2) * offset_scale
-            GT = (
-                (G[None, :, None, None, :] + flow_pre_delta[:, None] - x2[:, None])
-                .norm(dim=-1)
-                .min(dim=1)
-                .indices
-            )
-        cls_loss = F.cross_entropy(delta_cls, GT, reduction="none")[prob > 0.99]
-        if not torch.any(cls_loss):
-            cls_loss = certainty_loss * 0.0  # Prevent issues where prob is 0 everywhere
-        certainty_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
-        losses = {
-            f"delta_certainty_loss_{scale}": certainty_loss.mean(),
-            f"delta_cls_loss_{scale}": cls_loss.mean(),
-        }
-        wandb.log(losses, step=roma.GLOBAL_STEP)
-        return losses
-
-    def regression_loss(self, x2, prob, flow, certainty, scale, eps=1e-8, mode="delta"):
-        epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1)
-        if scale == 1:
-            pck_05 = (epe[prob > 0.99] < 0.5 * (2 / 512)).float().mean()
-            wandb.log({"train_pck_05": pck_05}, step=roma.GLOBAL_STEP)
-
-        ce_loss = F.binary_cross_entropy_with_logits(certainty[:, 0], prob)
-        a = self.alpha
-        cs = self.c * scale
-        x = epe[prob > 0.99]
-        reg_loss = cs**a * ((x / (cs)) ** 2 + 1**2) ** (a / 2)
-        if not torch.any(reg_loss):
-            reg_loss = ce_loss * 0.0  # Prevent issues where prob is 0 everywhere
-        losses = {
-            f"{mode}_certainty_loss_{scale}": ce_loss.mean(),
-            f"{mode}_regression_loss_{scale}": reg_loss.mean(),
-        }
-        wandb.log(losses, step=roma.GLOBAL_STEP)
-        return losses
-
-    def forward(self, corresps, batch):
-        scales = list(corresps.keys())
-        tot_loss = 0.0
-        # scale_weights due to differences in scale for regression gradients and classification gradients
-        scale_weights = {1: 1, 2: 1, 4: 1, 8: 1, 16: 1}
-        for scale in scales:
-            scale_corresps = corresps[scale]
-            (
-                scale_certainty,
-                flow_pre_delta,
-                delta_cls,
-                offset_scale,
-                scale_gm_cls,
-                scale_gm_certainty,
-                flow,
-                scale_gm_flow,
-            ) = (
-                scale_corresps["certainty"],
-                scale_corresps["flow_pre_delta"],
-                scale_corresps.get("delta_cls"),
-                scale_corresps.get("offset_scale"),
-                scale_corresps.get("gm_cls"),
-                scale_corresps.get("gm_certainty"),
-                scale_corresps["flow"],
-                scale_corresps.get("gm_flow"),
-            )
-            flow_pre_delta = rearrange(flow_pre_delta, "b d h w -> b h w d")
-            b, h, w, d = flow_pre_delta.shape
-            gt_warp, gt_prob = get_gt_warp(
-                batch["im_A_depth"],
-                batch["im_B_depth"],
-                batch["T_1to2"],
-                batch["K1"],
-                batch["K2"],
-                H=h,
-                W=w,
-            )
-            x2 = gt_warp.float()
-            prob = gt_prob
-
-            if self.local_largest_scale >= scale:
-                prob = prob * (
-                    F.interpolate(prev_epe[:, None], size=(h, w), mode="nearest-exact")[
-                        :, 0
-                    ]
-                    < (2 / 512) * (self.local_dist[scale] * scale)
-                )
-
-            if scale_gm_cls is not None:
-                gm_cls_losses = self.gm_cls_loss(
-                    x2, prob, scale_gm_cls, scale_gm_certainty, scale
-                )
-                gm_loss = (
-                    self.ce_weight * gm_cls_losses[f"gm_certainty_loss_{scale}"]
-                    + gm_cls_losses[f"gm_cls_loss_{scale}"]
-                )
-                tot_loss = tot_loss + scale_weights[scale] * gm_loss
-            elif scale_gm_flow is not None:
-                gm_flow_losses = self.regression_loss(
-                    x2, prob, scale_gm_flow, scale_gm_certainty, scale, mode="gm"
-                )
-                gm_loss = (
-                    self.ce_weight * gm_flow_losses[f"gm_certainty_loss_{scale}"]
-                    + gm_flow_losses[f"gm_regression_loss_{scale}"]
-                )
-                tot_loss = tot_loss + scale_weights[scale] * gm_loss
-
-            if delta_cls is not None:
-                delta_cls_losses = self.delta_cls_loss(
-                    x2,
-                    prob,
-                    flow_pre_delta,
-                    delta_cls,
-                    scale_certainty,
-                    scale,
-                    offset_scale,
-                )
-                delta_cls_loss = (
-                    self.ce_weight * delta_cls_losses[f"delta_certainty_loss_{scale}"]
-                    + delta_cls_losses[f"delta_cls_loss_{scale}"]
-                )
-                tot_loss = tot_loss + scale_weights[scale] * delta_cls_loss
-            else:
-                delta_regression_losses = self.regression_loss(
-                    x2, prob, flow, scale_certainty, scale
-                )
-                reg_loss = (
-                    self.ce_weight
-                    * delta_regression_losses[f"delta_certainty_loss_{scale}"]
-                    + delta_regression_losses[f"delta_regression_loss_{scale}"]
-                )
-                tot_loss = tot_loss + scale_weights[scale] * reg_loss
-            prev_epe = (flow.permute(0, 2, 3, 1) - x2).norm(dim=-1).detach()
-        return tot_loss
diff --git a/third_party/Roma/roma/models/__init__.py b/third_party/Roma/roma/models/__init__.py
deleted file mode 100644
index 3918d67063b9ab7a8ced80c22a5e74f95ff7fd4a..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/models/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .model_zoo import roma_outdoor, roma_indoor
diff --git a/third_party/Roma/roma/models/model_zoo/__init__.py b/third_party/Roma/roma/models/model_zoo/__init__.py
deleted file mode 100644
index 2ef0b6cf03473500d4198521764cd6dc9ccba784..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/models/model_zoo/__init__.py
+++ /dev/null
@@ -1,46 +0,0 @@
-import torch
-from .roma_models import roma_model
-
-weight_urls = {
-    "roma": {
-        "outdoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_outdoor.pth",
-        "indoor": "https://github.com/Parskatt/storage/releases/download/roma/roma_indoor.pth",
-    },
-    "dinov2": "https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",  # hopefully this doesnt change :D
-}
-
-
-def roma_outdoor(device, weights=None, dinov2_weights=None):
-    if weights is None:
-        weights = torch.hub.load_state_dict_from_url(
-            weight_urls["roma"]["outdoor"], map_location=device
-        )
-    if dinov2_weights is None:
-        dinov2_weights = torch.hub.load_state_dict_from_url(
-            weight_urls["dinov2"], map_location=device
-        )
-    return roma_model(
-        resolution=(14 * 8 * 6, 14 * 8 * 6),
-        upsample_preds=True,
-        weights=weights,
-        dinov2_weights=dinov2_weights,
-        device=device,
-    )
-
-
-def roma_indoor(device, weights=None, dinov2_weights=None):
-    if weights is None:
-        weights = torch.hub.load_state_dict_from_url(
-            weight_urls["roma"]["indoor"], map_location=device
-        )
-    if dinov2_weights is None:
-        dinov2_weights = torch.hub.load_state_dict_from_url(
-            weight_urls["dinov2"], map_location=device
-        )
-    return roma_model(
-        resolution=(14 * 8 * 5, 14 * 8 * 5),
-        upsample_preds=False,
-        weights=weights,
-        dinov2_weights=dinov2_weights,
-        device=device,
-    )
diff --git a/third_party/Roma/roma/models/model_zoo/roma_models.py b/third_party/Roma/roma/models/model_zoo/roma_models.py
deleted file mode 100644
index f98ee44f5e2ebd7e43a8e4b17f99b6ed0e85c93a..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/models/model_zoo/roma_models.py
+++ /dev/null
@@ -1,175 +0,0 @@
-import warnings
-import torch.nn as nn
-from roma.models.matcher import *
-from roma.models.transformer import Block, TransformerDecoder, MemEffAttention
-from roma.models.encoders import *
-
-
-def roma_model(
-    resolution, upsample_preds, device=None, weights=None, dinov2_weights=None, **kwargs
-):
-    # roma weights and dinov2 weights are loaded seperately, as dinov2 weights are not parameters
-    torch.backends.cuda.matmul.allow_tf32 = True  # allow tf32 on matmul
-    torch.backends.cudnn.allow_tf32 = True  # allow tf32 on cudnn
-    warnings.filterwarnings(
-        "ignore", category=UserWarning, message="TypedStorage is deprecated"
-    )
-    gp_dim = 512
-    feat_dim = 512
-    decoder_dim = gp_dim + feat_dim
-    cls_to_coord_res = 64
-    coordinate_decoder = TransformerDecoder(
-        nn.Sequential(
-            *[Block(decoder_dim, 8, attn_class=MemEffAttention) for _ in range(5)]
-        ),
-        decoder_dim,
-        cls_to_coord_res**2 + 1,
-        is_classifier=True,
-        amp=True,
-        pos_enc=False,
-    )
-    dw = True
-    hidden_blocks = 8
-    kernel_size = 5
-    displacement_emb = "linear"
-    disable_local_corr_grad = True
-
-    conv_refiner = nn.ModuleDict(
-        {
-            "16": ConvRefiner(
-                2 * 512 + 128 + (2 * 7 + 1) ** 2,
-                2 * 512 + 128 + (2 * 7 + 1) ** 2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=128,
-                local_corr_radius=7,
-                corr_in_other=True,
-                amp=True,
-                disable_local_corr_grad=disable_local_corr_grad,
-                bn_momentum=0.01,
-            ),
-            "8": ConvRefiner(
-                2 * 512 + 64 + (2 * 3 + 1) ** 2,
-                2 * 512 + 64 + (2 * 3 + 1) ** 2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=64,
-                local_corr_radius=3,
-                corr_in_other=True,
-                amp=True,
-                disable_local_corr_grad=disable_local_corr_grad,
-                bn_momentum=0.01,
-            ),
-            "4": ConvRefiner(
-                2 * 256 + 32 + (2 * 2 + 1) ** 2,
-                2 * 256 + 32 + (2 * 2 + 1) ** 2,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=32,
-                local_corr_radius=2,
-                corr_in_other=True,
-                amp=True,
-                disable_local_corr_grad=disable_local_corr_grad,
-                bn_momentum=0.01,
-            ),
-            "2": ConvRefiner(
-                2 * 64 + 16,
-                128 + 16,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=16,
-                amp=True,
-                disable_local_corr_grad=disable_local_corr_grad,
-                bn_momentum=0.01,
-            ),
-            "1": ConvRefiner(
-                2 * 9 + 6,
-                24,
-                2 + 1,
-                kernel_size=kernel_size,
-                dw=dw,
-                hidden_blocks=hidden_blocks,
-                displacement_emb=displacement_emb,
-                displacement_emb_dim=6,
-                amp=True,
-                disable_local_corr_grad=disable_local_corr_grad,
-                bn_momentum=0.01,
-            ),
-        }
-    )
-    kernel_temperature = 0.2
-    learn_temperature = False
-    no_cov = True
-    kernel = CosKernel
-    only_attention = False
-    basis = "fourier"
-    gp16 = GP(
-        kernel,
-        T=kernel_temperature,
-        learn_temperature=learn_temperature,
-        only_attention=only_attention,
-        gp_dim=gp_dim,
-        basis=basis,
-        no_cov=no_cov,
-    )
-    gps = nn.ModuleDict({"16": gp16})
-    proj16 = nn.Sequential(nn.Conv2d(1024, 512, 1, 1), nn.BatchNorm2d(512))
-    proj8 = nn.Sequential(nn.Conv2d(512, 512, 1, 1), nn.BatchNorm2d(512))
-    proj4 = nn.Sequential(nn.Conv2d(256, 256, 1, 1), nn.BatchNorm2d(256))
-    proj2 = nn.Sequential(nn.Conv2d(128, 64, 1, 1), nn.BatchNorm2d(64))
-    proj1 = nn.Sequential(nn.Conv2d(64, 9, 1, 1), nn.BatchNorm2d(9))
-    proj = nn.ModuleDict(
-        {
-            "16": proj16,
-            "8": proj8,
-            "4": proj4,
-            "2": proj2,
-            "1": proj1,
-        }
-    )
-    displacement_dropout_p = 0.0
-    gm_warp_dropout_p = 0.0
-    decoder = Decoder(
-        coordinate_decoder,
-        gps,
-        proj,
-        conv_refiner,
-        detach=True,
-        scales=["16", "8", "4", "2", "1"],
-        displacement_dropout_p=displacement_dropout_p,
-        gm_warp_dropout_p=gm_warp_dropout_p,
-    )
-
-    encoder = CNNandDinov2(
-        cnn_kwargs=dict(pretrained=False, amp=True),
-        amp=True,
-        use_vgg=True,
-        dinov2_weights=dinov2_weights,
-    )
-    h, w = resolution
-    symmetric = True
-    attenuate_cert = True
-    matcher = RegressionMatcher(
-        encoder,
-        decoder,
-        h=h,
-        w=w,
-        upsample_preds=upsample_preds,
-        symmetric=symmetric,
-        attenuate_cert=attenuate_cert,
-        **kwargs
-    ).to(device)
-    matcher.load_state_dict(weights)
-    return matcher
diff --git a/third_party/Roma/roma/models/transformer/__init__.py b/third_party/Roma/roma/models/transformer/__init__.py
deleted file mode 100644
index b1409045ef9c5dddef88484762137b9a2ab79cd5..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/models/transformer/__init__.py
+++ /dev/null
@@ -1,79 +0,0 @@
-import torch
-import torch.nn as nn
-import torch.nn.functional as F
-
-from roma.utils.utils import get_grid
-from .layers.block import Block
-from .layers.attention import MemEffAttention
-from .dinov2 import vit_large
-
-device = "cuda" if torch.cuda.is_available() else "cpu"
-
-
-class TransformerDecoder(nn.Module):
-    def __init__(
-        self,
-        blocks,
-        hidden_dim,
-        out_dim,
-        is_classifier=False,
-        *args,
-        amp=False,
-        pos_enc=True,
-        learned_embeddings=False,
-        embedding_dim=None,
-        **kwargs
-    ) -> None:
-        super().__init__(*args, **kwargs)
-        self.blocks = blocks
-        self.to_out = nn.Linear(hidden_dim, out_dim)
-        self.hidden_dim = hidden_dim
-        self.out_dim = out_dim
-        self._scales = [16]
-        self.is_classifier = is_classifier
-        self.amp = amp
-        if torch.cuda.is_available():
-            if torch.cuda.is_bf16_supported():
-                self.amp_dtype = torch.bfloat16
-            else:
-                self.amp_dtype = torch.float16
-        else:
-            self.amp_dtype = torch.float32
-
-        self.pos_enc = pos_enc
-        self.learned_embeddings = learned_embeddings
-        if self.learned_embeddings:
-            self.learned_pos_embeddings = nn.Parameter(
-                nn.init.kaiming_normal_(
-                    torch.empty((1, hidden_dim, embedding_dim, embedding_dim))
-                )
-            )
-
-    def scales(self):
-        return self._scales.copy()
-
-    def forward(self, gp_posterior, features, old_stuff, new_scale):
-        with torch.autocast(device, dtype=self.amp_dtype, enabled=self.amp):
-            B, C, H, W = gp_posterior.shape
-            x = torch.cat((gp_posterior, features), dim=1)
-            B, C, H, W = x.shape
-            grid = get_grid(B, H, W, x.device).reshape(B, H * W, 2)
-            if self.learned_embeddings:
-                pos_enc = (
-                    F.interpolate(
-                        self.learned_pos_embeddings,
-                        size=(H, W),
-                        mode="bilinear",
-                        align_corners=False,
-                    )
-                    .permute(0, 2, 3, 1)
-                    .reshape(1, H * W, C)
-                )
-            else:
-                pos_enc = 0
-            tokens = x.reshape(B, C, H * W).permute(0, 2, 1) + pos_enc
-            z = self.blocks(tokens)
-            out = self.to_out(z)
-            out = out.permute(0, 2, 1).reshape(B, self.out_dim, H, W)
-            warp, certainty = out[:, :-1], out[:, -1:]
-            return warp, certainty, None
diff --git a/third_party/Roma/roma/utils/kde.py b/third_party/Roma/roma/utils/kde.py
deleted file mode 100644
index eff7c72dad4a3f90f5ff79d2630427de89838fc5..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/utils/kde.py
+++ /dev/null
@@ -1,9 +0,0 @@
-import torch
-
-
-def kde(x, std=0.1):
-    # use a gaussian kernel to estimate density
-    x = x.half()  # Do it in half precision
-    scores = (-torch.cdist(x, x) ** 2 / (2 * std**2)).exp()
-    density = scores.sum(dim=-1)
-    return density
diff --git a/third_party/Roma/roma/utils/local_correlation.py b/third_party/Roma/roma/utils/local_correlation.py
deleted file mode 100644
index 603ab524333c29fbc284a73065847645f3100847..0000000000000000000000000000000000000000
--- a/third_party/Roma/roma/utils/local_correlation.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import torch
-import torch.nn.functional as F
-
-device = "cuda" if torch.cuda.is_available() else "cpu"
-
-
-def local_correlation(
-    feature0,
-    feature1,
-    local_radius,
-    padding_mode="zeros",
-    flow=None,
-    sample_mode="bilinear",
-):
-    r = local_radius
-    K = (2 * r + 1) ** 2
-    B, c, h, w = feature0.size()
-    feature0 = feature0.half()
-    feature1 = feature1.half()
-    corr = torch.empty((B, K, h, w), device=feature0.device, dtype=feature0.dtype)
-    if flow is None:
-        # If flow is None, assume feature0 and feature1 are aligned
-        coords = torch.meshgrid(
-            (
-                torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
-                torch.linspace(-1 + 1 / w, 1 - 1 / w, w, device=device),
-            )
-        )
-        coords = torch.stack((coords[1], coords[0]), dim=-1)[None].expand(B, h, w, 2)
-    else:
-        coords = flow.permute(0, 2, 3, 1)  # If using flow, sample around flow target.
-    local_window = torch.meshgrid(
-        (
-            torch.linspace(
-                -2 * local_radius / h, 2 * local_radius / h, 2 * r + 1, device=device
-            ),
-            torch.linspace(
-                -2 * local_radius / w, 2 * local_radius / w, 2 * r + 1, device=device
-            ),
-        )
-    )
-    local_window = (
-        torch.stack((local_window[1], local_window[0]), dim=-1)[None]
-        .expand(1, 2 * r + 1, 2 * r + 1, 2)
-        .reshape(1, (2 * r + 1) ** 2, 2)
-    )
-    for _ in range(B):
-        with torch.no_grad():
-            local_window_coords = (
-                (coords[_, :, :, None] + local_window[:, None, None])
-                .reshape(1, h, w * (2 * r + 1) ** 2, 2)
-                .float()
-            )
-            window_feature = F.grid_sample(
-                feature1[_ : _ + 1].float(),
-                local_window_coords,
-                padding_mode=padding_mode,
-                align_corners=False,
-                mode=sample_mode,  #
-            )
-            window_feature = window_feature.reshape(c, h, w, (2 * r + 1) ** 2)
-        corr[_] = (
-            (feature0[_, ..., None] / (c**0.5) * window_feature)
-            .sum(dim=0)
-            .permute(2, 0, 1)
-        )
-    torch.cuda.empty_cache()
-    return corr